猫の写真をGANで学習させてみる
実家でノルウェージャン・フォレストキャットを飼っていました。かなりおっとりした性格の子で、めちゃくちゃかわいかったんですが、ある嵐の夜に外に出てしましそれ以来帰ってくることがありませんでした。そんな愛しのにゃんこをどうにか再現できないかと思い、GANに頼ることにしました。そこで今回の記事はこちらの記事を参考にGANの実装をします。先に結果を言ってしまいますが、結果は失敗でした。理由は圧倒的なデータ不足かなと思います。
GANとは
GANの詳細はコードの参照先でもあるこの記事をご覧ください。
用意したデータ
実家でノルウェージャン・フォレストキャットを飼っていました。かなりおっとりした性格の子で、めちゃくちゃかわいかったんですが、ある嵐の夜に外に出てしましそれ以来帰ってくることがありませんでした。そんな愛しのにゃんこをどうにか再現できないかと思い、GANに頼ることにしました。そこで今回の記事はこちらの記事を参考にGANの実装をします。先に結果を言ってしまいますが、結果は失敗でした。理由は圧倒的なデータ不足かなと思います。
GANの実装
参考にさせていただいた記事のコードに以下の2つの関数を追加しました。
- get_image_from_directory()
- change_img_size()
コードの全体像はこちらです。
### -*-coding:utf-8-*- from __future__ import print_function, division from keras.datasets import mnist from keras.layers import Input, Dense, Reshape, Flatten, Dropout from keras.layers import BatchNormalization, Activation, ZeroPadding2D from keras.layers.advanced_activations import LeakyReLU from keras.layers.convolutional import UpSampling2D, Conv2D from keras.models import Sequential, Model from keras.optimizers import Adam import matplotlib.pyplot as plt import sys import numpy as np class GAN(): def __init__(self): #mnistデータ用の入力データサイズ self.img_rows = 128 self.img_cols = 128 self.channels = 3 self.img_shape = (self.img_rows, self.img_cols, self.channels) self.img_path = "hogehoe" # 潜在変数の次元数 self.z_dim = 100 optimizer = Adam(0.0002, 0.5) # discriminatorモデル self.discriminator = self.build_discriminator() self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy']) # Generatorモデル self.generator = self.build_generator() # generatorは単体で学習しないのでコンパイルは必要ない #self.generator.compile(loss='binary_crossentropy', optimizer=optimizer) self.combined = self.build_combined1() #self.combined = self.build_combined2() self.combined.compile(loss='binary_crossentropy', optimizer=optimizer) def build_generator(self): noise_shape = (self.z_dim,) model = Sequential() model.add(Dense(256, input_shape=noise_shape)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(1024)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(np.prod(self.img_shape), activation='tanh')) model.add(Reshape(self.img_shape)) model.summary() return model def build_discriminator(self): img_shape = (self.img_rows, self.img_cols, self.channels) model = Sequential() model.add(Flatten(input_shape=img_shape)) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(256)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(1, activation='sigmoid')) model.summary() return model def build_combined1(self): self.discriminator.trainable = False model = Sequential([self.generator, self.discriminator]) return model def build_combined2(self): z = Input(shape=(self.z_dim,)) img = self.generator(z) self.discriminator.trainable = False valid = self.discriminator(img) model = Model(z, valid) model.summary() return model def train(self, epochs, batch_size=128, save_interval=50): # mnistデータの読み込み #(X_train, _), (_, _) = mnist.load_data() X_train = self.change_img_size(picture_size=(128, 128)) # 値を-1 to 1に規格化 X_train = (X_train.astype(np.float32) - 127.5) / 127.5 # X_train = np.expand_dims(X_train, axis=3) half_batch = int(batch_size / 2) num_batches = int(X_train.shape[0] / half_batch) print('Number of batches:', num_batches) for epoch in range(epochs): for iteration in range(num_batches): # --------------------- # Discriminatorの学習 # --------------------- # バッチサイズの半数をGeneratorから生成 noise = np.random.normal(0, 1, (half_batch, self.z_dim)) gen_imgs = self.generator.predict(noise) # バッチサイズの半数を教師データからピックアップ idx = np.random.randint(0, X_train.shape[0], half_batch) imgs = X_train[idx] # discriminatorを学習 # 本物データと偽物データは別々に学習させる d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1))) d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1))) # それぞれの損失関数を平均 d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # --------------------- # Generatorの学習 # --------------------- noise = np.random.normal(0, 1, (batch_size, self.z_dim)) # 生成データの正解ラベルは本物(1) valid_y = np.array([1] * batch_size) # Train the generator g_loss = self.combined.train_on_batch(noise, valid_y) # 進捗の表示 print ("epoch:%d, iter:%d, [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, iteration, d_loss[0], 100*d_loss[1], g_loss)) # 指定した間隔で生成画像を保存 if epoch % save_interval == 0: self.save_imgs(epoch) def save_imgs(self, epoch): # 生成画像を敷き詰めるときの行数、列数 r, c = 5, 5 noise = np.random.normal(0, 1, (r * c, self.z_dim)) gen_imgs = self.generator.predict(noise) # 生成画像を0-1に再スケール gen_imgs = 0.5 * gen_imgs + 0.5 fig, axs = plt.subplots(r, c) cnt = 0 for i in range(r): for j in range(c): axs[i,j].imshow(gen_imgs[cnt, :,:,0]) axs[i,j].axis('off') cnt += 1 fig.savefig("foo/_%d.png" % epoch) plt.close() def change_img_size(self, picture_size=(128, 128)): from PIL import Image image_list = self.get_image_from_directory() resized_img_list = [] for image in image_list: img = Image.open(image) img_resize = img.resize(picture_size) img_resize = np.array(img_resize, np.float) resized_img_list.append(img_resize) return np.array(resized_img_list) def get_image_from_directory(self): """ path:画像ファイルの親ディレクトリまでを指定 デフォルトでjpgのファイルを取得する """ import glob path = self.img_path image_list = glob.glob(path + "/*.jpg") return image_list if __name__ == '__main__': gan = GAN() gan.train(epochs=1000, batch_size=16, save_interval=1)
生成画像
うーん。全然だめですね笑
ちなみに、1500epochあたりから学習が全く進まなくなってノイズみたいな画像しか出力しなくなってしまいました。一定間隔で保存した写真の中で、それなりに猫っぽい姿が見えたものが上の写真です。
次の課題
なんといっても次はデータ不足の問題をどう克服すればいいのかを考えることだと思います。猫ちゃんはもういないので新たに写真を撮ることができません。アイディアとしては他の猫の写真を大量に用意して、そこから猫の特徴を抽出した後に、我が猫ちゃんの特徴を張っつけるみたいな感じなのかなと考えています。多分論文探せば同じようなことやっている人が見つかるでしょう。