見出し画像

CycleGANでリアルタイムに映像を変換してみた!

映像を浮世絵風にリアルタイムでAIが変換してくれるアプリを作成していきます。

左がWebカメラに入力されたオリジナルです。
右がCycleGANによって浮世絵風に変換した映像です。

本棚あたりは、特に浮世絵を感じさせる描画となっております。

また、全体図は以下になります。


環境

Google Colabで学習させ、ローカルマシンでアプリを実行していきます。

#Local環境
Python 3.8.10
PyTorch 1.12.1
OpenCV 4.2.0
Ubuntu 20.04.5 LTS


使用するデータセット

こちらからダウンロード可能です。
http://efrosgans.eecs.berkeley.edu/cyclegan/datasets/

馬vsシマウマ、リンゴvsオレンジ、冬vs夏などいくつかCycleGAN用のデータとして公開されていますが、今回はその中から実写vs浮世絵であるukiyoe2photoを学習に使用します。google colab上では、例えば以下のようにして取得します。

!wget http://efrosgans.eecs.berkeley.edu/cyclegan/datasets/ukiyoe2photo.zip
!unzip ukiyoe2photo.zip

ダウンロードした画像ファイルたちは以下のように、定義しておきます。

ROOT_PATH = './ukiyoe2photo/'
TrainX_Path = ROOT_PATH + "trainA/"
TrainY_Path = ROOT_PATH + "trainB/"
TestX_Path = ROOT_PATH + "testA/"
TestY_Path = ROOT_PATH + "testB/"

TrainX_ImgPath_list = list(glob.glob(TrainX_Path + "*.jpg"))
TrainY_ImgPath_list = list(glob.glob(TrainY_Path + "*.jpg"))
TestX_ImgPath_list = list(glob.glob(TestX_Path + "*.jpg"))
TestY_ImgPath_list = list(glob.glob(TestY_Path + "*.jpg"))

サンプルを以下に示します。
1行目が浮世絵で2行目が実写画像になります。画像サイズは (256×256) です。

ここで注意です。
ダウンロードした画像のほとんどはRGBの3chですが、2枚だけ1chが混ざっていたので取り除きます。

# clean dataset
baddata_path_list = []
for path_x, path_y in zip(TrainX_ImgPath_list, TrainY_ImgPath_list):
 x = Image.open(path_x)
 y = Image.open(path_y)
 try:
   imgx, imgy = transform(x, y)
 except:
   print(imgx.shape)
   print(x.mode)
   plt.imshow(x)
   plt.show()
   baddata_path_list.append(path_x)
   TrainX_ImgPath_list.remove(path_x)


CycleGANとは

2017年にUCバークレー校のBAIR研究所によって考案[1]された、2種類のデータセットのドメインを相互変換させるGANベースのモデルです。その論文では以下の画像(Fig.3)が投稿されています。

Fig.3(cite: [1])

CycleGANの仕組みとしては、まずGeneratorとDiscriminatorを各々2つずつ用意します。データセットをXとYとしたときに、空間Xから空間Yへ写像するGeneratorをGとして、その逆写像を行うGeneratorをFとします。入力画像が空間Xに存在するかの識別を行うDiscriminatorをDxとし、空間Yに存在するか識別するDiscriminatorをDyとします。以下の画像(Fig.4)の(a)に該当します。

Fig. 4(cite: [1])

これらが実際にどうやって学習されていくかを知るために、目的関数を見てみます。比較のために最初に、2014年Goodfellow氏によって考案[2]されたオリジナルのGAN(以下単にGANと表記)を簡単に示します。


※ ↑↑間違えた。DiscriminatorもGeneratorもG(z)です(笑)


気を取り直して、
これをひとつの式にまとめると、以下になります。


一方で、CycleGANではGANにはない正則化項を導入しており、その目的関数を以下に示します。


損失関数LGANは、GANと同等の損失関数ですが、CycleGANではGenerator側に正則化項としてcycle consistency lossが加わっています。これはFig.4中の(b), (c)に該当します。画像xを入力とするGの出力である画像yをFに入れたら、元の画像xが復元出来るはずだ、という考えの元で計算されます。逆も然りで、Fについても同様に考えます。ちなみに論文ではλ=10としています。


実装するにあたって「学習の流れ」を改めて確認します。
便宜上、空間Xを実写画像空間、空間Yを浮世絵空間とします。


1. まずはDiscriminatorの学習です。実写画像をGに入力して新たな画像を生成します。新たに生成した画像について、Dyが浮世絵であるかを識別します。正解ラベルは、Dyへの入力が実写画像なら真(1)とし、生成画像なら偽(0)とします。また、出力サイズを1×1とするGANのDiscriminatorと異なり、CycleGANでは、n×n(nは実装者次第)のようにある程度のサイズをもっており、ピクセル単位で真偽判定します。逆も同様にして、浮世絵をFに入力して新たな画像を生成します。その生成画像について、Dxが実写画像かを識別します。正解ラベルはDxへの入力が浮世絵なら真(1)で、生成画像なら偽(0)とします。

2. 上記1に関して、Dx, Dyの損失をバックプロパゲーションして重みを更新します。損失はピクセル単位で計算し、その和ないし平均を損失とします。これでDiscriminatorの学習が完了です。

3. 次はGeneratorの学習です。実装者の狙いとして、生成画像をDiscriminatorが真(1)と誤判断してほしいところです。ですので、Discriminatorへの入力が生成画像でありますが、正解ラベルを真(1)として損失を計算します。

4. 上記3に加えて、Generatorでは正則化項の計算を行います。cycle consistency(reconstraction) lossでした。上記1で生成した画像から各々元の空間に戻した際の損失を計算します。つまり、画像xをGに入力してその出力を更にFに入力すれば、画像yが生成されるという意図です。

5. 実際の実装では、正則化項をさらに拡張して、identity lossを追加します(この正則化項はオプションなので、実装しなくても理論上問題ない)。これは空間Xから空間Yへの写像を担うGに、画像xではなく画像yを入力して画像yを出力させた際の損失を計算します。名前から分かる通り、恒等関数の役割を担っています。逆も然りで、Fに関しても画像xを入力して画像xを生成させた際の損失を計算します。

6. 上記3,4,5をまとめた損失がGeneratorの最終損失で、それをバックプロパゲーションして重みを更新させたら、Generatorの学習終了となります。


実装

それでは以下にCycleGANを実装していきます。

まずはimport文です。

import torch
import torch.nn as nn
import torch.optim as opt
import torchvision.transforms.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image

from PIL import Image, ImageOps
import matplotlib.pyplot as plt
import glob
import numpy as np
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


Transformの定義

class RandomMirror():
 def __call__(self, imgX, imgY):
   if np.random.randint(2):
     imgX = ImageOps.mirror(imgX)
     imgY = ImageOps.mirror(imgY)
   return imgX, imgY

class Normalize_Tensor():
 def __init__(self, mean, std):
   self.mean = mean
   self.std = std

 def __call__(self, imgX, imgY):
   imgX = F.to_tensor(imgX)
   imgX = F.normalize(imgX, self.mean, self.std)
   imgY = F.to_tensor(imgY)
   imgY = F.normalize(imgY, self.mean, self.std)
   return imgX, imgY


class Scale():
 def __init__(self, scale_range=[0.5, 1.5]):
   self.s_range = scale_range

 def __call__(self, imgX, imgY):
   #assert (same size)
   scale_ratio = np.random.uniform(self.s_range[0], self.s_range[1])
   w, h = imgX.size[0], imgX.size[1]
   w_scaled = int(scale_ratio * w)
   h_scaled = int(scale_ratio * h)
   imgX_scaled = imgX.resize((w_scaled, h_scaled), Image.BICUBIC)
   imgY_scaled = imgY.resize((w_scaled, h_scaled), Image.BICUBIC)
   if scale_ratio >1.:
     left = int(np.random.uniform(0, w_scaled - w))
     top = int(np.random.uniform(0, h_scaled - h))
     imgX = imgX_scaled.crop((left, top, left+w, top+h))
     imgY = imgY_scaled.crop((left, top, left+w, top+h))
   else:
     left = int(np.random.uniform(0, w - w_scaled))
     top = int(np.random.uniform(0, h - h_scaled))
     imgX = Image.new(imgX.mode, (w, h), (0,0,0))
     imgY = Image.new(imgY.mode, (w, h), (0,0,0))
     imgX.paste(imgX_scaled, (left, top))
     imgY.paste(imgY_scaled, (left, top))
   return imgX, imgY


class Transform():
 def __init__(self, mean, std):
   self.transform = [
       Scale([1.0, 1.5]),
       RandomMirror(),
       Normalize_Tensor(mean, std),
   ]

 def __call__(self, imgX, imgY):
   for t in self.transform:
     imgX, imgY = t(imgX, imgY)
   return imgX, imgY
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
transform = Transform(mean, std)

データオーグメンテーションは、scale(拡大縮小)とmirror(反転)を使用します。epoch=30と設定した場合、同一データでの学習がそれだけ発生するので、データオーグメンテーションを使用しました。scaleは1〜1.5倍として、厳密には縮小をしておりません。この理由としては、縮小すると以下のようにブラックエリアが発生して、データの特徴が変わってしまうからです。

scaleとmirror以外にも、データオーグメンテーションの手法を加えても良いですが、今回はこの2つで行いました。ちなみにデータオーグメンテーションの有無で比較しましたが、有った方がより良い画像が生成されました。

DataLoderの定義

class My_Dataset(Dataset):
 def __init__(self, path_list:list, transform):
   super().__init__()
   self.imgX_path_list = path_list[0]
   self.imgY_path_list = path_list[1]
   self.transform = transform

 def __len__(self):
   return len(self.imgX_path_list)

 def __getitem__(self, idx):
   imgX = Image.open(self.imgX_path_list[idx])
   imgY = Image.open(self.imgY_path_list[idx])
   imgX, imgY = self.transform(imgX, imgY)
   return imgX, imgY
train_dataset = My_Dataset([TrainX_ImgPath_list, TrainY_ImgPath_list], transform)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2)

DataLoaderクラスを使用するために、My_Datasetクラスを定義します。CycleGANの論文にならって、batch_size=1にします。

Generatorの定義

class ConvNormReLU(nn.Module):
 def __init__(self, in_channels, out_channels, kernel_size,
              stride=1, padding=0, dilation=1, bias=True, padding_mode="reflect"):
   super().__init__()
   self.layers = nn.Sequential(
       nn.Conv2d(in_channels=in_channels, 
                 out_channels=out_channels,
                 kernel_size=kernel_size,
                 stride=stride,
                 padding=padding,
                 dilation=dilation,
                 bias=bias, padding_mode=padding_mode),
       nn.InstanceNorm2d(out_channels),
       nn.ReLU(inplace=True)
   )

 def forward(self, x):
   return self.layers(x)

class ResidualBlock(nn.Module):
 def __init__(self, in_channels):
   super().__init__()
   self.layers = nn.Sequential(
       ConvNormReLU(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
   )
  def forward(self, x):
   return self.layers(x) + x


class Generator(nn.Module):
 def __init__(self, in_ch, n_ch=64, n_resblock=9):
   super().__init__()
   l = []
   l.extend([
       ConvNormReLU(in_ch, n_ch, kernel_size=6, stride=1, padding=2),
       ConvNormReLU(n_ch, n_ch*2, kernel_size=3, stride=2, padding=1),
       ConvNormReLU(n_ch*2, n_ch*4, kernel_size=3, stride=2, padding=1),
   ])
   l.extend(list(ResidualBlock(n_ch*4) for _ in range(n_resblock)))
   l.extend([
       nn.ConvTranspose2d(n_ch*4, n_ch*2*4, kernel_size=3, stride=1, padding=1),
       nn.PixelShuffle(2),
       nn.InstanceNorm2d(n_ch*2),
       nn.ReLU(True),
       nn.ConvTranspose2d(n_ch*2, n_ch*4, kernel_size=3, stride=1, padding=1),
       nn.PixelShuffle(2),
       nn.InstanceNorm2d(n_ch),
       nn.ReLU(True),
       nn.ConvTranspose2d(n_ch, in_ch, kernel_size=5, stride=1, padding=2),
   ])
   self.layers = nn.Sequential(*l)

 def forward(self, x):
     return torch.tanh(self.layers(x))

Generatorの構造は、Encoder部とDecoder部に分かれています。

Encoder部には、ConvNormReLU層と称して畳み込みを3回行います。そしてその次にResidualBlockと称して、ConvNormReLUにスキップ結合を組み合わせた層を9回行います。Encoderでは、画像の特徴を捉える役割があります。

次にDecoder部へと処理が移り、捉えた特徴を復元(変換)していきます。Encoderによって画像サイズが小さくなっているので、Decoderではそのサイズを元の256×256サイズに戻していきます。逆畳み込み層(ConvTranspose)のstrideを2以上に設定することで、画像サイズを大きく戻せますが、ここでは超解像度の技術で用いられるPixelShuffle層を使用します。実際に比較してみて、逆畳み込み層stride=2でやったときよりもより良い画像が生成されました。これはPixelShuffleを使用するには、前層のConvTranspose層で、通常(PixelShuffleを使用しなかった場合、つまりstride=2)のときよりも出力chを4倍にする必要があり、パラメータ数が増えるからです。PixelShuffleの詳しい説明はWebサイト[3]をご参考下さい。

forwardメソッドの中で最後の最後に、tanh関数によって画像値を[-1, 1]の範囲内に修正しています。これは上記のTransformクラス内で、入力画像は全て[-1, 1]の範囲に正規化しているためです。

Discriminatorの定義

class Discriminator(nn.Module):
   def __init__(self, in_ch, features=[]):
       super().__init__()
       block1 = nn.Sequential(
               nn.Conv2d(in_ch, features[0], kernel_size=4, stride=1, padding=1),
               nn.LeakyReLU(negative_slope=0.2, inplace=True)
       )   
       l = []
       for i,_ in enumerate(features[:-1]):
           l.extend([
               nn.Conv2d(features[i], features[i+1], kernel_size=4, stride=2, padding=1),
               nn.InstanceNorm2d(features[i+1]),
               nn.LeakyReLU(negative_slope=0.2, inplace=True),
           ])
       block2 = nn.Sequential(*l)
       block3 = nn.Conv2d(features[-1], 1, kernel_size=4, stride=1, padding=1)
       self.laysers = nn.ModuleList([block1, block2, block3])
      
   def forward(self, x):
       for l in self.laysers:
           x = l(x)
       return torch.sigmoid(x)

Discriminatorの構造は、block1、block2、block3に分けられていますが、とてもシンプルです。block1では、最初の入り口として畳み込みとLeakyReLUのみを計算します。block2で畳み込み、正規化、LeakyReLUをひたすら積んでいきます。出口である最後のblock3で、畳み込みだけして終了です。厳密には、最終結果を[0,1]の範囲に収めるためにforwardメソッドのところで、最後の最後にSigmoidを計算しています。ちなみに活性化関数には、ReLUではなくLeakyReLUを用いています。GeneratorではReLUだったのに、なんで?と思われるかもしれませんが、DiscriminatorにはReLUを用いない方が良い理由があります。Generatorの重みを更新するには、Discriminatorを通してその誤差が伝搬してきます。つまりReLUを使用してしまうと、ReLUへの入力が負であった場合、その微分は0になってしまうので、そこで誤差伝搬が止まってしまいます。従ってLeakyReLUに変えることで、止まることなくGeneratorまで誤差を伝えることが出来ます。

Optimazierと損失関数の定義

netGxy = Generator(in_ch=3, n_ch=64, n_resblock=9).to(DEVICE)
netGyx = Generator(in_ch=3, n_ch=64,  n_resblock=9).to(DEVICE)
netDx = Discriminator(in_ch=3, features=[64, 128, 256, 512]).to(DEVICE)
netDy = Discriminator(in_ch=3, features=[64, 128, 256, 512]).to(DEVICE)

optG = opt.Adam(list(netGyx.parameters())+list(netGxy.parameters()),
                lr=2e-4, betas=(0.5, 0.999))
optD = opt.Adam(list(netDx.parameters())+list(netDy.parameters()),
                lr=2e-4, betas=(0.5, 0.999))

L1 = nn.L1Loss()
bce = nn.BCELoss()

CycleGAN論文でも使用されていたので、optimazierにはAdamを使用しています。損失関数には、L1LossとBCELossを使用します。L1LossはMAE(Mean Absolute Error)のことで、正則化項の損失計算に使用されます。BCELossのところは、MSE(Mean Squared Error)を用いる実装パターンもありますが、ここではBCELossを使用した方が良い画像が生成されました。

学習の実行

n_epoch = 30
netGxy.train()
netGyx.train()
netDx.train()
netDy.train()

for e in range(n_epoch):
 for imgX_real, imgY_real in train_dataloader:
   imgX_real = imgX_real.to(DEVICE)
   imgY_real = imgY_real.to(DEVICE)
   #Discriminator
   imgX_fake = netGyx(imgY_real)
   predX_fake = netDx(imgX_fake.detach())
   predX_real = netDx(imgX_real)
   lossX_netD = bce(predX_fake, torch.zeros_like(predX_fake)) + bce(predX_real, torch.ones_like(predX_real))
  
   imgY_fake = netGxy(imgX_real)
   predY_fake = netDy(imgY_fake.detach())
   predY_real = netDy(imgY_real)
   lossY_netD = bce(predY_fake, torch.zeros_like(predY_fake)) + bce(predY_real, torch.ones_like(predY_real))

   loss_netD = (lossX_netD + lossY_netD) / 2.0
   optD.zero_grad()
   loss_netD.backward()
   optD.step()

   #Generator
   reconstX = netGyx(imgY_fake)
   reconstY = netGxy(imgX_fake)
   loss_cycle = L1(reconstX, imgX_real) + L1(reconstY, imgY_real)

   identX = netGyx(imgX_real)
   identY = netGxy(imgY_real)
   loss_ident = L1(identX, imgX_real) + L1(identY, imgY_real)

   predX_fake = netDx(imgX_fake)
   predY_fake = netDy(imgY_fake)
   loss_gan = bce(predX_fake, torch.ones_like(predX_fake)) + bce(predY_fake, torch.ones_like(predY_fake))

   loss_netG = (
       loss_gan
     + loss_cycle * 10.0
     + loss_ident * 10.0       
   )

   optG.zero_grad()
   loss_netG.backward()
   optG.step()

epoch数は30としました。loss_netDとloss_netGがDiscriminatorとGeneratorの各々最終的な損失になります。ここでの流れは、冒頭「CycleGANとは」の中で触れました「学習の流れ」の欄でご確認下さい。

実行結果

最終結果は以下のようになりました。

浮世絵(1行目)を入力したときの出力(2行目)
実写画像(3行目)を入力したときの出力(4行目)

実写画像から浮世絵への変換は上手くいっていますが、その逆に関しては、モデルが努力してくれている感じはしますが、イマイチ完璧には変換しきれていません。実行結果を見ても、浮世絵から実写画像への変換は、浮世絵のまま出力した方が、損失が少なく済んだために学習しきれなかった可能性があります。一方で、以下に示した損失関数を見ていただくと分かりますが、まだ学習の余地は残っています。ハイパーパラメータの変更をいくつか試したり、層を深くしたりしましたが、それよりもモデルの大きな改善の方が、もっと効果があるように感じました。今回のモデルは画像の大局情報を捉える実装は入っておりませんので、その辺も含めると上手く行くように思いました。なおepochはこれ以上増やしても改善はしていきません。

損失の推移

学習過程の様子


リアルタイム処理

ここからはOpenCVを用いて、実際にWebカメラからのデータを学習済みモデルに入力し、リアルタイムで出力させるアプリケーションを構築します。

import cv2
import numpy as np
import torch
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt

from model import Generator

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


cap = cv2.VideoCapture(0, cv2.CAP_V4L)


netGyx = Generator(3, 64, 9).to(DEVICE)
netGyx.load_state_dict(torch.load("netGyx.pth"))
netGyx.eval()

mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)

while True:
   ret, frame_original = cap.read()
   frame = cv2.cvtColor(frame_original, cv2.COLOR_BGR2RGB)
   frame = cv2.resize(frame, (256, 256))
   frame_real = F.to_tensor(frame).to(DEVICE)
   frame_real = F.normalize(frame_real, mean, std).unsqueeze(dim=0)
   frame_fake = netGyx(frame_real).squeeze().to(DEVICE)
   frame_fake = frame_fake.detach().cpu().permute(1,2,0).numpy()
   frame_fake = (frame_fake *0.5 + 0.5) * 255
   frame_fake = frame_fake.astype(np.uint8)
   frame_fake_ = cv2.cvtColor(frame_fake, cv2.COLOR_RGB2BGR)
   cv2.imshow("original", frame_original)
   cv2.imshow("ukiyoe", frame_fake_)
   if cv2.waitKey(1) == ord('q'):
       break


cap.release()
cv2.destroyAllWindows()

from model import Generatorは、同一ディレクトリにmodel.pyを作成し、その中にCycleGANの実装で示したGeneratorクラスを定義します。

while Trueの部分がリアルタイム処理に該当します。cap.read()によって画像を読み込みます。画像はnp.ndarrayで取得されるので、学習済みモデルに入力出来るように、tensor化と正規化します。tensorにするだけでなく、[-1,1]の範囲に画像を正規化しなければいけないのは、モデルを学習させる際にその制約の元で学習を行ったからです。そしてモデルの出力画像を、次はモニタに正しく表示出来るように、画像本来の画素値[0, 255]に戻します。最後にcv2.imshow()で実際に表示してwhileの1回のループは終了です。これをひたすら繰り返します。

以上になります。
もし面白いと思ってくれたらジュース一本分の投げ銭していただけると嬉しいにゃ☆

参考文献

[1]. CycleGAN https://arxiv.org/abs/1703.10593
[2]. GAN https://arxiv.org/abs/1406.2661
[3]. PixelShuffle https://paperswithcode.com/method/pixelshuffle

この記事が気に入ったらサポートをしてみませんか?