猫の写真をGANで学習させてみる

実家でノルウェージャン・フォレストキャットを飼っていました。かなりおっとりした性格の子で、めちゃくちゃかわいかったんですが、ある嵐の夜に外に出てしましそれ以来帰ってくることがありませんでした。そんな愛しのにゃんこをどうにか再現できないかと思い、GANに頼ることにしました。そこで今回の記事はこちらの記事を参考にGANの実装をします。先に結果を言ってしまいますが、結果は失敗でした。理由は圧倒的なデータ不足かなと思います。

GANとは

GANの詳細はコードの参照先でもあるこの記事をご覧ください。

用意したデータ

f:id:memo_dl:20191020135503p:plain
猫ちゃんの写真
実家でノルウェージャン・フォレストキャットを飼っていました。かなりおっとりした性格の子で、めちゃくちゃかわいかったんですが、ある嵐の夜に外に出てしましそれ以来帰ってくることがありませんでした。そんな愛しのにゃんこをどうにか再現できないかと思い、GANに頼ることにしました。そこで今回の記事はこちらの記事を参考にGANの実装をします。先に結果を言ってしまいますが、結果は失敗でした。理由は圧倒的なデータ不足かなと思います。

GANの実装

参考にさせていただいた記事のコードに以下の2つの関数を追加しました。

  • get_image_from_directory()
  • change_img_size()

コードの全体像はこちらです。

### -*-coding:utf-8-*-
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys

import numpy as np

class GAN():
    def __init__(self):
        #mnistデータ用の入力データサイズ
        self.img_rows = 128 
        self.img_cols = 128
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        
        self.img_path = "hogehoe"

        # 潜在変数の次元数 
        self.z_dim = 100

        optimizer = Adam(0.0002, 0.5)

        # discriminatorモデル
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy', 
            optimizer=optimizer,
            metrics=['accuracy'])

        # Generatorモデル
        self.generator = self.build_generator()
        # generatorは単体で学習しないのでコンパイルは必要ない
        #self.generator.compile(loss='binary_crossentropy', optimizer=optimizer)

        self.combined = self.build_combined1()
        #self.combined = self.build_combined2()
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def build_generator(self):

        noise_shape = (self.z_dim,)
        model = Sequential()

        model.add(Dense(256, input_shape=noise_shape))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()

        return model

    def build_discriminator(self):

        img_shape = (self.img_rows, self.img_cols, self.channels)
        
        model = Sequential()

        model.add(Flatten(input_shape=img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        return model
    
    def build_combined1(self):
        self.discriminator.trainable = False
        model = Sequential([self.generator, self.discriminator])
        return model

    def build_combined2(self):
        z = Input(shape=(self.z_dim,))
        img = self.generator(z)
        self.discriminator.trainable = False
        valid = self.discriminator(img)
        model = Model(z, valid)
        model.summary()
        return model


    def train(self, epochs, batch_size=128, save_interval=50):

        # mnistデータの読み込み
        #(X_train, _), (_, _) = mnist.load_data()
        X_train = self.change_img_size(picture_size=(128, 128))

        # 値を-1 to 1に規格化
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        # X_train = np.expand_dims(X_train, axis=3)

        half_batch = int(batch_size / 2)

        num_batches = int(X_train.shape[0] / half_batch)
        print('Number of batches:', num_batches)
        
        for epoch in range(epochs):

            for iteration in range(num_batches):

                # ---------------------
                #  Discriminatorの学習
                # ---------------------

                # バッチサイズの半数をGeneratorから生成
                noise = np.random.normal(0, 1, (half_batch, self.z_dim))
                gen_imgs = self.generator.predict(noise)


                # バッチサイズの半数を教師データからピックアップ
                idx = np.random.randint(0, X_train.shape[0], half_batch)
                imgs = X_train[idx]

                # discriminatorを学習
                # 本物データと偽物データは別々に学習させる
                d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
                d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
                # それぞれの損失関数を平均
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)


                # ---------------------
                #  Generatorの学習
                # ---------------------

                noise = np.random.normal(0, 1, (batch_size, self.z_dim))

                # 生成データの正解ラベルは本物(1) 
                valid_y = np.array([1] * batch_size)

                # Train the generator
                g_loss = self.combined.train_on_batch(noise, valid_y)

                # 進捗の表示
                print ("epoch:%d, iter:%d,  [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, iteration, d_loss[0], 100*d_loss[1], g_loss))

                # 指定した間隔で生成画像を保存
                if epoch % save_interval == 0:
                    self.save_imgs(epoch)

    def save_imgs(self, epoch):
        # 生成画像を敷き詰めるときの行数、列数
        r, c = 5, 5

        noise = np.random.normal(0, 1, (r * c, self.z_dim))
        gen_imgs = self.generator.predict(noise)

        # 生成画像を0-1に再スケール
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("foo/_%d.png" % epoch)
        plt.close()


    def change_img_size(self, picture_size=(128, 128)):

        from PIL import Image
        
        image_list = self.get_image_from_directory()
        resized_img_list = []

        for image in image_list:
            img = Image.open(image)
            img_resize = img.resize(picture_size)
            img_resize = np.array(img_resize, np.float)
            resized_img_list.append(img_resize)

        return np.array(resized_img_list)


    def get_image_from_directory(self):
        """
        path:画像ファイルの親ディレクトリまでを指定
        デフォルトでjpgのファイルを取得する
        """
        import glob

        path = self.img_path
        image_list = glob.glob(path + "/*.jpg")
        return image_list


if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=1000, batch_size=16, save_interval=1)

生成画像

f:id:memo_dl:20191020140655p:plain
うーん。全然だめですね笑
ちなみに、1500epochあたりから学習が全く進まなくなってノイズみたいな画像しか出力しなくなってしまいました。一定間隔で保存した写真の中で、それなりに猫っぽい姿が見えたものが上の写真です。

次の課題

なんといっても次はデータ不足の問題をどう克服すればいいのかを考えることだと思います。猫ちゃんはもういないので新たに写真を撮ることができません。アイディアとしては他の猫の写真を大量に用意して、そこから猫の特徴を抽出した後に、我が猫ちゃんの特徴を張っつけるみたいな感じなのかなと考えています。多分論文探せば同じようなことやっている人が見つかるでしょう。