見出し画像

VAEを利用した生成: Fashion-MNIST

やりたいこと

PyTorchで実装したVAEにFashion-MNISTのデータを読み込ませ,存在しない服飾品の画像を生成する.また,潜在変数の効果を確かめる.

前提: Encoder-Decoderモデルの考え方

Encoder-Decoderモデルは,その名の通り,EncoderとDecoderの二つの部分によって構成されている.

まず,Encoderは,決められた入力(画像,文章など)を暗号化(encode)する.このようにして生成される暗号(code)は,入力を何らかの形で表現するものだが,人間が「暗号」だけを見ても,普通,その意味を理解することはできない.

一方,Decoderは,暗号化された入力を受け取り,復号(decode)を行う.日常的に使われる「復号」という言葉は,暗号を元の平文(=Encoderに対する入力)に戻すことを指すが,ここでは復号結果と元の入力が,必ずしも同じであるとは限らない.元の入力に何らかの変化が施された上で出力されても,やはり「復号」という.ただし,後述するように,Auto-Encoder乃至VAEの枠組みでは,復号結果と元の入力の一致が要求される.

Encoder-Decoderモデルが活用された技術の中で,我々に最も馴染み深いのは機械翻訳だろう.機械が英文を和訳するとき,Encoderは英文を受け取って,その内容を「暗号」にする.Decoderがその暗号を受け取って,日本語に復号することで,英文が和訳されたものが出力されるのである.

Auto-Encoderの仕組み

Auto-Encoder(自己符号化器)は,まさしくこのEncoder-Decoderモデルの考え方に立脚しており,やはりEncoderとDecoderによって構成されている.

Auto-Encoderの目的は,入力(画像など)の特徴を,比較的少ない次元のベクトルで表現する方法を学習することにある.例えば,手書き数字のデータセット(MNIST)は28×28=784ピクセル分の情報を持っているが,これを2次元や3次元などの,低次元な空間上の点として表現したい,ということだ.

このような目的から,Auto-EncoderのEncoderは,高次元の入力(画像など)$${\bm{x}}$$を受け取り,低次元のベクトル$${\bm{z}}$$を出力するものとして与えられる.前節の言い回しを用いれば,Encoderの出力は,元の画像の内容を簡潔に表す「暗号」であり,潜在変数と呼ばれる.

一方,Decoderは,Encoderから与えられた「暗号」を読み込み,元の画像の再構築を試みる.両者の学習にあたっては,元の画像$${\bm{x}}$$と再構築後の画像$${\bm{\hat{x}}}$$の差異(二乗誤差,再構築誤差)を小さくするように方向づければよい.当然,「暗号」の持つ情報は,元の画像の情報と比べると幾分乏しいので,再構築によって元の画像が完璧に再現されることはない.しかし,再構築後の画像と元の画像が充分近くなるようにできたならば,高次元の入力を,比較的低次元な空間上の点として表現することができた,といえるだろう.

Auto-Encoderの生成における問題点

では,学習が完了したAuto-EncoderのDecoderに,自分で適当に考えた(=ランダムに設定した)「暗号」を与えて,新しい画像を作成することはできるのだろうか? Decoderが「暗号」を画像に構築する手法を学習している以上,理屈の上では可能であるはずだ.

しかし,このアプローチには落とし穴がある.確かにAuto-Encoderは暗号化と復号の方法を学習してくれるが,肝心の暗号が,特徴量空間中のどのあたりに,どのように分布しているかということを保証するものではないのだ.したがって,「自分で適当に考えた暗号」が,Decoderにとって,元の画像がどのようなものか,まるで見当のつかないものであるかもしれない.そのような場合,Decoderの出力は全くトンチンカンなものとなってしまうだろう.

この問題を解決するには,Encoderの出力に工夫を施す必要がある.

VAEとは何か

Auto-Encoderの難点は,再構築後の画像と元の画像の差異(二乗誤差)のみを学習の基準としていたために,$${\bm{z}}$$の分布が制御不能になってしまうことであった.

そこで,$${\bm{z}}$$を確率分布とみなして,その形状を標準正規分布に近づける方法を考える.換言すれば,Encoderの出力となる確率分布と標準正規分布の「距離」を,罰則項として付け加えたい

※この記事では,VAEを(参考文献と同様に)Auto-Encoderの拡張として導入したので,VAEの損失関数が「復元誤差+罰則項」という形で説明されている.しかし,Kingma & Welling (2013)の元論文では,対数尤度の下限を目的関数としているので,復元誤差の定義如何によっては,Kingma & Wellingの枠組みと(厳密には)異なる実装になるかもしれない.

まずは,「Encoderが確率分布を出力する」方法について考える.出力される確率分布を正規分布に限定すると,その分布の形状を確定するためには,平均$${\mu}$$と分散$${\sigma^2}$$の,二つのパラメータを決定すればよい.したがって,「暗号」の次元を$${d}$$とすれば,二つの$${d}$$次元のベクトル$${\bm{\mu}}$$と$${\bm{\sigma}}$$をEncoderの出力とすることになる.なお,実装上は,値域が非負に限定される$${\sigma^2}$$の代わりに,$${\log{\sigma}}$$を出力することが多いようだ.

一方,Encoderが確率分布を出力したとしても,Decoderに対する入力は,特定のベクトルでなくてはならない.そのため,標準正規分布に従う乱数$${\bm{\epsilon}}$$を用意して,

$$
\bm{z} = \bm{\mu}+\bm{\sigma}\odot{\bm{\epsilon}}
$$

のようにすれば,Decoderへの入力を適切に決定できる.

次に,「Encoderの出力と標準正規分布の距離」について考える.VAEの枠組みでは,この距離を,KL-divergenceによって扱う(KL-divergenceは対称律を満たさないので,厳密に言えば「距離」ではないのだが,ここでは気にしないでおく).$${N(\bm{\mu}, \bm{\sigma})}$$と標準正規分布のKL-divergenceは,

$$
D_{KL}[N(\bm{\mu}, \bm{\sigma})||N(\bm{0}, \bm{1})] = -\frac{1}{2}\sum_{i=1}^{d}({1+\log{{\sigma_i}^2-{\mu_i}^2}-{\sigma_i}^2})
$$

で表すことができる.

したがって,VAEの損失関数は,二乗誤差とKL-divergenceを適当な重みづけによって足し合わせ,以下のように表現される.

$$
L = \frac{1}{2}\sum(X_{\mathrm{pred}}-X_{\mathrm{original}})^2+\lambda D_{KL}[N(\bm{\mu}, \bm{\sigma})||N(\bm{0}, \bm{1})]
$$

ここで,$${X_{\mathrm{pred}}}$$,$${X_{\mathrm{original}}}$$はそれぞれ,Decoderの出力(再構築後)とEncoderへの入力(元画像)を表している.

適切な$${\lambda}$$の下で学習が完了したVAEのDecoderは,標準正規分布に従う$${\bm{z}}$$を入力されれば,新たに自然な画像を生成するだろう.

VAEの実践――Fashion-MNIST

Fashion-MNISTは,Tシャツ,ズボン,靴,鞄など,十種類の服飾品の白黒(グレースケール)画像からなるデータセットである.これに対してVAEを適用し,新たな画像の生成を試みた.簡単のため,$${\bm{z}}$$の次元$${d=2}$$,重み$${\lambda=1}$$とした.また,Encoderは二層,Decoderは三層の全結合層で構成し,活性化関数にはReLU(Decoderの出力以外)とsigmoid(Decoderの出力)を用い,25エポックの学習を行った.

まずは,画像を再構成する能力を確認する.以下にテストデータから選んだ100件の画像と,その再構成結果を示す.靴は靴,TシャツはTシャツとして再現されているようには見えるが,所々間違いも見つかるうえに,再現の品質は必ずしも高くない.二つの成分のみで多様な画像を表現するのには,やはり限界があるようだ.

画像2

画像3

次に,自前で用意した$${\bm{z}}$$に対するDecoderの出力を確認する.上手くいけば,それらしい服飾品の画像が生成されるはずだ.

生成にあたっては,$${\bm{z}}$$を標準正規分布からランダムに選んでもよいのだが,$${\bm{z}}$$が生成画像の様子に与える影響を観察するために,その各成分を少しずつ変化させた.以下にその結果を示す.縦方向に第一成分,横方向に第二成分を変化させている.生成画像の多様性は乏しいものの,潜在変数の変化によって画像が少しずつ変化する様子を観察できるだろう.

画像1

実装

Pytorchを利用した実装例.

必要なモジュールのimport.

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

データセットの定義など.Fashion-MNISTの各画素値は符号無し8bit整数で与えられるので,前処理の段階で255で割っておいた.


class MyDataset(torch.utils.data.Dataset):
   # 教師無し学習用データセット.targetは無視(Dataと同じ)
   def __init__(self, X):
       self.data = X
       self.target = X
   
   def __len__(self):
       return len(self.target)
   
   def __getitem__(self, idx):
       return self.data[idx], self.target[idx]

def makeScaledDataset(dataset, num):
   # 与えられたデータセットの要素を255で割る
   X = []
   for i in range(num):
       X.append(dataset.data[i]/255)
   return MyDataset(X)

# 訓練データ,テストデータの定義
trainData = torchvision.datasets.FashionMNIST("data/Fashion_MNIST", download=True, transform=transforms.ToTensor())
testData = torchvision.datasets.FashionMNIST("data/Fashion_MNIST", train=False, download=True, transform=transforms.ToTensor())
trainData = makeScaledDataset(trainData, len(trainData))
testData = makeScaledDataset(testData, len(testData))

# データセットの可視化
fig, axes = plt.subplots(10, 10, figsize=(10, 10),\
   subplot_kw={"xticks":[], "yticks":[]}, gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i, ax in enumerate(axes.flat):
   img = (trainData.data[i].detach().numpy()*255).astype(np.uint8)
   ax.imshow(img, cmap="binary")

VAEの損失関数の定義.説明中の$${\lambda}$$はkと表されている.

class VAEloss(nn.Module):
   # VAEの損失関数
   def __init__(self):
       super().__init__()

   def forward(self, orig, res, mu, sigma, k=1e-2):
       eps = 1e-7
       kl_div = -0.5*torch.sum(1+torch.log(sigma**2+eps)-mu**2-sigma**2)
       kl_div_batch = kl_div.mean()
       e = 0.5*torch.sum((orig-res)**2, dim=[1,2])
       e_batch = e.mean()
       return k*kl_div_batch+e_batch

Fashion-MNISTに適用するVAEのクラスを定義.拡張性を考えて,nn.Linearで事足りる部分もnn.Sequentialとして与えている.

class VAE_MNIST(nn.Module):
   def __init__(self, code_dim):
       super().__init__()
       self.code_dim = code_dim
       encoder_fc_dim = 128

       self.encoder_fc = torch.nn.Sequential(
           # Encoderの共通部分
           nn.Flatten(),
           nn.Linear(28*28, encoder_fc_dim),
           nn.ReLU(),
       )
       self.encoder_mu = torch.nn.Sequential(
           # muの出力
           nn.Linear(encoder_fc_dim, code_dim)
       )
       self.encoder_sigma = torch.nn.Sequential(
           # sigmaの出力
           nn.Linear(encoder_fc_dim, code_dim)
       )
       self.decoder = torch.nn.Sequential(
           # Decoder,出力は一次元
           nn.Linear(code_dim, 128),
           nn.ReLU(),
           nn.Linear(128, 256),
           nn.ReLU(),
           nn.Linear(256, 28*28),
           nn.Sigmoid()
       )

   def reproduce(self, x, batch_size):
       # 画像の再構成
       encoder_fc_output = self.encoder_fc(x)
       mu = self.encoder_mu(encoder_fc_output)
       sigma = torch.exp(self.encoder_sigma(encoder_fc_output)/2)
       eps = torch.randn(batch_size, self.code_dim)
       # zを計算
       z = mu+sigma*eps
       # サイズを28x28にして出力
       output = self.decoder(z).view(batch_size, 28, 28)
       return output, mu, sigma

   def generate(self, z, num):
       # 与えられたzから画像群を生成,numは個数(1バッチ)
       output = self.decoder(z).view(num, 28, 28)
       return output

   def train(self, dataset, batch_size=100, eta=1e-3, k=1e-2):
       dataloader = DataLoader(dataset, batch_size, shuffle=True)
       loss_fn = VAEloss()
       optimizer = torch.optim.Adam(self.parameters(), lr=eta)
       loss_all = 0
       
       for batch, (X, y) in enumerate(dataloader, 0):
           res, mu, sigma = self.reproduce(X, batch_size)
           loss = loss_fn(res, X, mu, sigma)
           # 誤差逆伝播
           optimizer.zero_grad()
           loss.backward()
           optimizer.step()
           loss_all += loss.item()

       return loss_all/len(dataloader.dataset)

   def train_epochs(self, dataset, epochs=10, batch_size=100, eta=1e-3, k=1e-2):
       loss = np.zeros(epochs)
       for t in range(epochs):
           print("Epoch"+str(t+1), end=" ")
           loss[t] = self.train(dataset, batch_size, eta, k)
           print("loss:", loss[t])
       plt.figure()
       plt.plot(list(range(1, epochs+1)), loss)
       plt.xlabel("Epoch")
       plt.ylabel("Loss")
       plt.savefig("result/VAE_grayscale/k="+str(k)+".jpg")


   def reproduce_example(self, dataset):
       # テストデータから100件の画像を選び,再構成結果を示す
       dataloader = DataLoader(dataset, 100, shuffle=True)
       with torch.no_grad():
           for batch, (X, y) in enumerate(dataloader,0):
               pred = self.reproduce(X, 100)[0]
               res=pred
               if batch==0:
                   break
       # 再構成結果
       fig, axes = plt.subplots(10, 10, figsize=(10, 10),\
           subplot_kw={"xticks":[], "yticks":[]}, gridspec_kw=dict(hspace=0.1, wspace=0.1))
       for i, ax in enumerate(axes.flat):
           img = (res[i].detach().numpy()*255).astype(np.uint8)
           ax.imshow(img, cmap="binary")
       fig.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95)
       # 元画像
       fig, axes = plt.subplots(10, 10, figsize=(10, 10),\
           subplot_kw={"xticks":[], "yticks":[]}, gridspec_kw=dict(hspace=0.1, wspace=0.1))
       for i, ax in enumerate(axes.flat):
           img = (X[i].detach().numpy()*255).astype(np.uint8)
           ax.imshow(img, cmap="binary")
       fig.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95)


   def generate_example(self, filename):
       # 標準正規分布からサンプリングしたzを用い,100件の画像を生成
       z = torch.randn(100, self.code_dim)
       res = self.generate(z, 100)
       fig, axes = plt.subplots(10, 10, figsize=(10, 10),\
           subplot_kw={"xticks":[], "yticks":[]}, gridspec_kw=dict(hspace=0.1, wspace=0.1))
       for i, ax in enumerate(axes.flat):
           img = (res[i].detach().numpy()*255).astype(np.uint8)
           ax.imshow(img, cmap="binary")
       fig.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95)
       fig.savefig("result/VAE_grayscale/"+filename+".jpg")

   def change_var(self, filename):
       # 潜在変数を少しずつ変化させ,影響を観察する
       # 潜在変数としては,varの要素が選ばれる.標準正規分布として「常識的」な値にする
       # zの次元を3以上にしても動くようにしたが,2にすることが推奨される
       const = torch.randn(self.code_dim-2)
       res = []
       var =  torch.tensor([-2, -1.5, -1, -0.5, -0.2, -0.1, 0, 0.1, 0.2, 0.5, 1, 1.5, 2])
       for z1 in var:
           for z2 in var:
               z = torch.cat((torch.tensor([z1, z2]), const), dim=0).view(1, self.code_dim)
               res.append(self.generate(z,1)[0])
       fig, axes = plt.subplots(len(var), len(var), figsize=(10, 10),\
           subplot_kw={"xticks":[], "yticks":[]}, gridspec_kw=dict(hspace=0.1, wspace=0.1))
       for i, ax in enumerate(axes.flat):
           img = (res[i].detach().numpy()*255).astype(np.uint8)
           ax.imshow(img, cmap="binary")
       fig.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95)
       fig.savefig("result/VAE_grayscale/"+filename+".jpg")

実行部分.損失関数の重みを変えながら実験することができる.必要があれば,reproduce_exampleメソッドを実行し,画像が再構成できているか確認してもよい.

# 幾つかの重みについて実験
for k in [1, 1e-1, 1e-2, 1e-3, 1e-4]:
   net = VAE_MNIST(2)
   net.train_epochs(trainData, epochs=25, eta=1e-3, k=k)
   net.generate_example("k="+str(k)+"example")
   net.change_var("k="+str(k)+"morph")

参考文献

VAEの元論文(難しい):
Kingma, D. P., & Welling, M. (2013). Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114. 

Kerasを利用した実装.MNISTやCelebAを対象としている:
David Foster著,松田晃一,小沼千絵訳(2020)『生成Deep Learning――絵を描き、物語や音楽を作り、ゲームをプレイする』,オライリー・ジャパン

実装例は無いが,深層学習について幅広く説明されている:
岡谷貴之(2022)『機械学習プロフェッショナルシリーズ 深層学習 改訂第二版』,講談社






この記事が気に入ったらサポートをしてみませんか?