見出し画像

終わった?AI技術のコード公開第1弾

上記の記事で紹介されているGANは廃れた技術かと思いますので、サンプルコードを配布します。
画像生成は現在はプロンプトを打ち込むことで、狙った画像を生成できるように進歩したようで、VAEやGANは終わり?を迎えました。



画像生成はGANが強いと言われていました。昔は某研究室でDCGANを用いて癌のHE画像を生成していました。そのコードをまんまおいて置きます。
inputファイルには学習させたい画像をフルパスで指定しておくとよいでしょう。どんなジャンルの画像でも一定の出力が得られると思います。
表紙画はoxford102データセットを学習させた結果です。

コードは以下の本を参考に組んでいます。本内では64x64の入出力ですが、512x512出力へとデチューンしたコードになります。


実コード

import torch
from torch import nn, optim
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import (Dataset, DataLoader, TensorDataset)
from torchvision import models
import torchvision
import tqdm
from matplotlib import pyplot as plt
import pandas as pd

g_losses = []
d_losses = []





img_data = ImageFolder(r"imput_Data_PATH",
    transform = transforms.Compose([
        transforms.Resize(512),
        transforms.CenterCrop(512),
        transforms.ToTensor()
]))

batch_size= 64
img_loader = DataLoader(img_data, batch_size=batch_size, shuffle=True)

nz =10000
ngf = 32

#潜在ベクトルを10000次元にする。
#512x512の画像を作るモデルを作る
class GNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
        nn.ConvTranspose2d(nz, ngf*64, 4, 1, 0, bias=False),
        nn.BatchNorm2d(ngf*64),
        nn.ReLU(inplace=True),
        nn.ConvTranspose2d(ngf*64, ngf*32, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf*32),
        nn.ReLU(inplace=True),
        nn.ConvTranspose2d(ngf*32, ngf*16, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf*16),
        nn.ReLU(inplace=True),
        nn.ConvTranspose2d(ngf*16, ngf*8, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf*8),
        nn.ReLU(inplace=True),
        nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf*4),
        nn.ReLU(inplace=True),
        nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf*2),
        nn.ReLU(inplace=True),
        nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ReLU(inplace=True),
        nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),
        nn.Tanh()

    )


    def forward(self, x):
        out = self.main(x)
        return out


ndf =32

class DNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf *2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf*2, ndf *4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf*8, ndf *16, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*16),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf*16, ndf*32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf*32, ndf*64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf*64, 1, 4, 1, 0, bias=False),
        )

    def forward(self, x):
        out = self.main(x)
        return out.squeeze()

d = DNet().to("cuda:0")
g = GNet().to("cuda:0")

#Adamのパラメータは改良の余地あり
opt_d = optim.Adam(d.parameters(),
    lr=0.0002, betas=(0.5, 0.999))

opt_g = optim.Adam(g.parameters(),
    lr=0.0002, betas=(0.5, 0.999))


#クロスエントロピーを計算するための補助変数
ones = torch.ones(batch_size).to('cuda:0')
zeros = torch.zeros(batch_size).to('cuda:0')
loss_f  = nn.BCEWithLogitsLoss()

#モニタリング用のz
fixed_z = torch.randn(batch_size, nz , 1, 1).to('cuda:0')




from statistics import mean

def train_dcgan(g, d, opt_g, opt_d, loader):
    #生成モデル、識別モデルの目的関数の追跡用排列
    log_loss_g = []
    log_loss_d = []
    for real_img, _ in tqdm.tqdm(loader):
        batch_len=len(real_img)

        #実際の画像をGPUにコピー
        real_img = real_img.to('cuda:0')

        #偽画像を乱数と生成モデルから作る
        z = torch.randn(batch_len, nz, 1, 1).to('cuda:0')
        fake_img = g(z)

        #のちで使用するので偽画像の値を取り出しておく
        fake_img_tensor = fake_img.detach()

        #偽画像にたいする生成モデルの評価関数を計算する
        out = d(fake_img)
        loss_g = loss_f(out, ones[: batch_len])
        log_loss_g.append(loss_g.item())
        g_losses.append(log_loss_g[-1])

        #計算グラフが生成モデルと識別モデルの両方に
        #勾配をクリアしてから微分の計算とパラメータ行進を行う

        d.zero_grad(), g.zero_grad()
        loss_g.backward()
        opt_g.step()

        #実際の画像に対する識別モデルの評価関数を計算
        real_out = d(real_img)
        loss_d_real = loss_f(real_out, ones[: batch_len])

        #PyTorchでは同じTensorを含んだ計算グラフに対して
        #2階backwardをおこなうことができない
        fake_img = fake_img_tensor

        #偽画像に対する識別モデルの評価関数の計算
        fake_out = d(fake_img_tensor)
        loss_d_fake = loss_f(fake_out, zeros[: batch_len])

        #実偽の評価関数の合計
        loss_d = loss_d_real + loss_d_fake
        log_loss_d.append(loss_d.item())
        d_losses.append(log_loss_d[-1])

        #識別モデルの微分計算とパラメータ更新
        d.zero_grad(), d.zero_grad()
        loss_d.backward()
        opt_d.step()

    print(epoch, flush=True)

    return mean(log_loss_g), mean(log_loss_d)

for epoch in range(500):
    train_dcgan(g, d, opt_g, opt_d, img_loader)
    #10回の繰り返しごとに学習結果を保存する
    if epoch % 10 ==0:
        #パラメータの保存
        torch.save(g.state_dict(),r"save folder_path/g_{:06d}.prm".format(epoch),pickle_protocol=4)
        torch.save(d.state_dict(),r"save folder_path/d_{:06d}.prm".format(epoch),pickle_protocol=4)

        g_loss_df = pd.DataFrame(g_losses)
        d_loss_df = pd.DataFrame(d_losses)

        g_loss_df.to_csv(r'save folder_path\g_loss_{:06d}.csv'.format(epoch), index=True, header=False)
        d_loss_df.to_csv(r'save folder_path\d_loss_{:06d}.csv'.format(epoch), index=True, header=False)
    #モニタリング用のzから生成した画像を保存
    generated_img = g(fixed_z)
    torchvision.utils.save_image(generated_img,r"save folder_path/{:06d}.jpg".format(epoch))

while True:
    NUMBER = input("グラフを見ますか?y=1 n=0\n")
    if NUMBER not in ("0","1"):
        continue

    if NUMBER == "1":
        plt.plot(g_losses, label='Generator_loss')
        plt.plot(d_losses, label='Discriminator_loss_total')
        plt.legend()
        plt.show()

        break

    if NUMBER == "0":
        break


g_loss_df = pd.DataFrame(g_losses)
d_loss_df = pd.DataFrame(d_losses)

g_loss_df.to_csv(r'save folder_path\g_loss_total.csv', index=True, header=False)
d_loss_df.to_csv(r'save folder_path\d_loss_total.csv', index=True, header=False)


  • imputの画像のサイズは512x512より大きいものを

  • 学習にはGPUがデフォルトで設定されているので、GPUがない場合はcuda0の消去を

  • GPUメモリーが足りない場合はnz=10000→100にする。orバッチサイズを64から下げてください。

  • pytorchを使用しています。当時の実行環境は以下

  • もっと注意事項あるけど、忘れました。

実行環境


前処理のプログラムについては後日


5000x10000のバーチャルスライド画像から病変をマークして引っ張りだすコードがあるのですがそちらの方が需要あると思うので、
この記事のスキ数次第で検討します。

感想

どこに需要あるん?
問題があればこの記事は消します


この記事が参加している募集

最近の学び

この記事が気に入ったらサポートをしてみませんか?