見出し画像

Image GPTの使い方

「Image GPT」の使い方をまとめました。

前回

1. Image GPT

Image GPT」は、半分の画像から残り半分の画像を生成することができるフレームワークです。

GPT-2」は任意の文章に続く、もっともらしい文章を自動的生成できるモデルとして話題になりましたが、このモデルを画像に適用することで、半分の画像に続く、もっともらしい画像を自動生成できるようにしました。

画像1

「Image GPT」のモデルは、以下の3種類が提供されています。いずれも「ImageNet」で学習しています。

・iGPT-S: 7600万パラメータ
・iGPT-M: 4億5500万パラメータ
・iGPT-L: 14億パラメータ

2. Image GPTの使い方

「Google Colab」での「Image GPT」の使い方は、次のとおりです。

(1) データの永続化。

# データの永続化
from google.colab import drive 
drive.mount('/content/drive')
!mkdir -p '/content/drive/My Drive/work/'
%cd '/content/drive/My Drive/work/'

(2) TensorFlow 1.xへの切り替え。
「Image GPT」はTensorFlow 1.xで動きます。

# TensorFlow 1.xへの切り替え
%tensorflow_version 1.x

(3) 「openai/image-gpt」リポジトリのクローン。

# openai/image-gptのクローン
!git clone https://github.com/openai/image-gpt.git

(4) 「事前学習済みモデル」と「カラーパレット」のダウンロード。

# 事前学習済みモデルとカラーパレットのダウンロード
!python image-gpt/download.py --model=s --ckpt 1000000 --cluster --download_dir=download

downloadフォルダにダウンロードされます。

・model.ckpt-1000000.XXXX: 事前学習モデル
・kmeans_centers.npy: カラーパレット(512色)

(5) 作業フォルダ(work)に画像をアップロード。
推論を行う画像は、正方形のpng24(透過色なし)です。

・image.png

画像5

(6) 画像をパレット画像に変換。
推論を行う画像を、カラーパレットを適用した32x32のNumPy配列に変換します。

from PIL import Image
from imageio import imwrite
import numpy as np

# カラーパレットの読み込み
clusters = np.load("download/kmeans_centers.npy")

# RGB画像の読み込み
im = Image.open('image.png')
im = im.resize((32,32)) # リサイズ
imr = np.array(im) # NumPy配列への変換

# 2点間の距離の計算
def dist(x, y):
   return np.sqrt(np.sum((x-y)**2))

# RGB色をカラーパレットのインデックスに変換
def find_index(a):
   mind = 10000  # 最小距離
   minidx = -1   # 最小距離のインデックス
   for i in range(len(clusters)): 
       d = dist(a/127.5-1.0, clusters[i])
       if mind > d:
           mind = d
           minidx = i
   return minidx

# RGB画像をパレット画像に変換
result = []
for y in range(32):
   for x in range(32):
       result.append(find_index(imr[y, x]))

# パレット画像の保存
samples = np.array(result) # NumPy配列への変換
np.save("palette_image.npy",samples) # 保存

# 確認用にパレット画像をRGB画像に変換して表示
samples = np.reshape(np.rint(127.5*(clusters[samples]+1.0)), [32, 32, 3]).astype(np.uint8)
imwrite('check_image.png', samples)
Image.open('check_image.png')

画像6

最後に、確認用にパレット画像をRGB画像に変換して表示しています。元画像とほぼ同じ画像が表示されることを確認してください。

(7) image-gpt/src/run.pyのsample()を以下のように変更。
image-gpt/src/run.pyは、「Image GPT」の推論を行うサンプルコードです。ImageNetの画像を推論するコードなので、編集して先程作成したパレット画像を推論するように変更します。

def sample(sess, X, gen_logits, n_sub_batch, n_gpu, n_px, n_vocab, clusters, save_dir):
    samples = np.zeros([n_gpu * n_sub_batch, n_px * n_px], dtype=np.int32)

    # 8枚の半分画像の生成
    for k in range(n_sub_batch):
        samples[k] = np.load("palette_image.npy") # パレット画像の読み込み
        samples[k,n_px*16:] = 0 # 半分画像の生成

    # 残り半分の画像の推論
    for i in tqdm(range(n_px * 16, n_px * n_px), ncols=80, leave=False): 
        np_gen_logits = sess.run(gen_logits, {X: samples})
        for j in range(n_gpu):
            p = softmax(np_gen_logits[j][:, i, :], axis=-1)
            for k in range(1, n_sub_batch):
                c = np.random.choice(n_vocab, p=p[k])
                samples[j*n_sub_batch+k, i] = c
   
    # パレット画像をRGB画像に変換
    samples = [np.reshape(np.rint(127.5*(clusters[s]+1.0)), [32, 32, 3]).astype(np.uint8) for s in samples]

    # ファイルに保存
    samples = np.asarray(samples).reshape(32*n_sub_batch, 32, 3)
    imwrite('result.png', samples)

(8) 推論の実行

%%time

# 画像の生成
!python image-gpt/src/run.py --sample --n_embd=512 --n_head=8 --n_layer=24 --ckpt_path=download/model.ckpt-1000000 --color_cluster_path=download/kmeans_centers.npy --n_gpu=1 --n_sub_batch=16

(9) 結果画像の確認。

# 結果画像の表示
Image.open('result.png')

画像2

【おまけ】 download.pyのパラメータ

「download.py」の主なパラメータは、次のとおりです。

--model: ダウンロードするモデルのサイズ(s,m,l)
--ckpt: ダウンロードするモデルのチェックポイント
--clusters: カラーパレットのダウンロード
--download_dir: ダウンロードフォルダ

【おまけ】 run.pyのパラメータ

「run.py」のパラメータ「--sample」「--n_embed」「--n_head」「--n_layer」は、モデルサイズに応じて以下ように指定します。

・s: run.py --sample --n_embd 512 --n_head 8 --n_layer 24
・m: run.py --sample --n_embd 1024 --n_head 8 --n_layer 36
・l: run.py --sample --n_embd 1536 --n_head 16 --n_layer 48

その他の主なパラメータは、次のとおりです。

--sample: サンプリングの実行
--ckpt_path
: 事前学習済みモデルのパス 
--color_cluster_path: カラーパレットのパス
--n_gpu: GPUの数
--n_sub_batch: バッチサイズ



この記事が気に入ったらサポートをしてみませんか?