Pythonでパチもん画像を300枚作る
前回は見よう見まねで画像の合否を判定するコードを実装したものの精度が悪いので、元データを増やすべくちっさい部品を1枚1枚おいて写真を撮るのも面倒なので、パチもん画像を300枚作って一気に増やすことにした。参考にしたのは、以下のサイト。
https://keras.io/examples/generative/conditional_gan/
まずは、元なるデータセットを作るため、
keras.preprocessing.image_dataset_from_directory(ゴニョゴニョ
でDIRNAME内から64×64 にリサイズされた画像を引っ張り出す。
dataset.mapは、中身の要素を一括処理するときに使うらしい。この場合、すべての要素を255で割る(あまり理解せずにみんなやってたからやったけど、多分サイズを落としてる)。
ちなみに、lambda式は地味に便利な小さな巨人だと思ってます。
例えば、
1. def で関数定義するほどでもない名前もいらない(lambda式に名前はつけられない)単一の式だけの処理文だけど、defみたいに引数を持たせて、記述したいとき
2. 変数xに「引数を持たせたlambda式」を代入することで、その変数に関数の機能を持たせたいとき
ぜひ使ってくださいlambda式
import sys
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
# ディレクトリの変更
os.chdir('D:/DIRNAME')
# ファイルパスの設定
dataset = keras.preprocessing.image_dataset_from_directory(
"./DIRNAME", label_mode=None, image_size=(64, 64), batch_size=32
)
dataset = dataset.map(lambda x: x / 255.0)
datasetの中身をチェック、チェックするときは255倍しないと、すごく小さな画像が表示される。
for x in dataset:
plt.axis("off")
plt.imshow((x.numpy() * 255).astype("int32")[0])
break
詳しいことはわからんけど、このDiscriminatorは,生成される画像と元の正解画像の入力から,どちらの画像が生成された画像か識別する2値分類(Y/N)を行う畳み込みニューラル ネットワーク (CNN)らしい。何を畳み込んでいるのかは知らん。読んでもさっぱりわからん。パラメータも見様見真似でやったら上手くいったから、問題が発生するまではこれでよしとした。
後で分かったことだが、今回使用したSequentialは、1層前の出力をそのまま次に入力するという単純にただ層を積み上げるだけのモデルで、最初の
Input (shape=
で入力の次元を指定していた模様。
discriminator = keras.Sequential(
[
keras.Input(shape=(64, 64, 3)),
layers.Conv2D(64, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Flatten(),
layers.Dropout(0.2),
layers.Dense(1, activation="sigmoid"),
],
name="discriminator",
)
discriminator.summary()
詳しいことはわからんけど、このGeneratorは,1次元のランダムノイズの入力を、学習画像のドメインに近い画像へ変換してくれるネットワークらしい。ある人は、GeneratorをDiscriminatorの鏡だと言い、ある人は,逆畳み込み層(Deconvolution)と言っていた。なんせノイズから画像を生成してくれることは間違いない。
latent_dim = 128
generator = keras.Sequential(
[
keras.Input(shape=(latent_dim,)),
layers.Dense(8 * 8 * 128),
layers.Reshape((8, 8, 128)),
layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2D(3, kernel_size=5, padding="same", activation="sigmoid"),
],
name="generator",
)
generator.summary()
このあとはGANのクラス設計をしているけれども、何もわからん。要点は、generatorを学習させていくことでそれっぽい画像を生成して、discriminatorで画像が本物か、偽物かを識別する。このときどのくらい上手に識別できたかをd_lossに保存。このときにどのくらい騙せたかをg_lossに保存。
class GAN(keras.Model):
def __init__(self, discriminator, generator, latent_dim):
super().__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
self.d_loss_tracker = keras.metrics.Mean(name="d_loss")
self.g_loss_tracker = kera.metrics.Mean(name="g_loss")
~中略~
# Train the discriminator
with tf.GradientTape() as tape:
predictions = self.discriminator(combined_images)
d_loss = self.loss_fn(labels, predictions)
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(
zip(grads, self.discriminator.trainable_weights)
)
# Sample random points in the latent space
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
# Assemble labels that say "all real images"
misleading_labels = tf.zeros((batch_size, 1))
# Train the generator (note that we should *not* update the weights
# of the discriminator)!
with tf.GradientTape() as tape:
predictions = self.discriminator(self.generator(random_latent_vectors))
g_loss = self.loss_fn(misleading_labels, predictions)
grads = tape.gradient(g_loss, self.generator.trainable_weights)
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
# Update metrics and return their value.
self.d_loss_tracker.update_state(d_loss)
self.g_loss_tracker.update_state(g_loss)
return {
"d_loss": self.d_loss_tracker.result(),
"g_loss": self.g_loss_tracker.result(),
}
今回便利だったのが、生成された画像を定期的にセーブするコールバックですね。この設定だと、3000回ごと(epoch % 3000 == 0)に、生成した画像を全部(i % 1 == 0)保存している。
class GANMonitor(keras.callbacks.Callback):
def __init__(self, num_img=3, latent_dim=128):
self.num_img = num_img
self.latent_dim = latent_dim
def on_epoch_end(self, epoch, logs=None):
random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))
generated_images = self.model.generator(random_latent_vectors)
generated_images *= 255
generated_images.numpy()
result_dir = 'gen_img'
if not os.path.exists(result_dir):
os.mkdir(result_dir)
if epoch % 3000 == 0:
for i in range(self.num_img):
if i % 1 == 0:
img = keras.preprocessing.image.array_to_img(generated_images[i])
img.save("./" + result_dir + "/generated_img_process_%05d_%d.png" % (epoch, i))
最後は、上手くい来そうなところが見つかるまでは、epochsも少なめ100以下、num_imgも少なめ10枚以下でいい感じになるところを探して、上手くいくと信じて、「よろしくお願いします!」と叫んで実行。
epochs = 30001 # In practice, use ~100 epochs
gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(
d_optimizer=keras.optimizers.Adam(learning_rate=0.0001),
g_optimizer=keras.optimizers.Adam(learning_rate=0.0001),
loss_fn=keras.losses.BinaryCrossentropy(),
)
gan_history = gan.fit(
dataset, epochs=epochs, callbacks=[GANMonitor(num_img=300, latent_dim=latent_dim)]
)
相変わらずテキストだと何もわからんので、可視化。
ついでに出来上がったモデルも保存。
# 重みを保存
result_dir = 'results'
if not os.path.exists(result_dir):
os.mkdir(result_dir)
# モデルを保存
generator.save(os.path.join(result_dir, 'dcgan_generator.h5'))
discriminator.save(os.path.join(result_dir, 'dcgan_discriminator.h5'))
plt.style.use('ggplot')
def history_plot(history_model):
fig,ax = plt.subplots(1,1,figsize = (15,5))
ax.set_title('Loss Function')
ax.plot(history_model.history['d_loss'],label = "d_Loss")
ax.plot(history_model.history['g_loss'],label = "g_Loss")
ax.legend()
plt.savefig('loss_function.png')
gplot = history_plot(gan_history)
出来上がる途中経過がなかなかおもしろい。
相変わらず可視化すると、見事に過学習してますね。実際の画像を見ても、30000回のうちの12000回でほぼ仕上がってる。15000回以降はどんどん解像度良くなってるけど意味は薄そう。学習完了後の画像はパチもんとはいえ、ほぼモノホン部品の画像なので、見せられないが。。。
この記事が気に入ったらサポートをしてみませんか?