Waifu diffusion式の追加学習のやり方。(VRAM14GB~)(12.2GB~?)

 過去2回追加学習のやり方を記事にしましたが、元のリポジトリが変わったり、いろいろと改善したのでメモしておきます。一部を除いてcolabハイメモリでの操作を想定しています。

追記:Waifu-diffusionのリポジトリが更新され、旧来のコードは全て削除されたので、この記事の内容はもう使えません。前のバージョンをくろーんすればできますけど、そこまでして昔の方法をやる意味もないです。

Git clone

 ドライブにマウントしてgit cloneする。

%cd /content/drive/MyDrive
!git clone https://github.com/harubaru/waifu-diffusion.git
%cd waifu-diffusion

環境構築

  requirements.txtのeinopsのバージョンを削除、torch.metrics==0.10.0に変更、bitsandbytesを追加してインストール。

###requirements.txt###
einops #==0.3.0を削除
torch.metrics==0.10.0 #変更
bitsandbytes #追加

#追加後
!pip install -r requirements.txt

学習済みモデルのダウンロード

 logs/originalディレクトリを入れてそこに学習済みモデルを入れます。例えばWD v1-3の場合は以下の通り。

!mkdir -p logs/original
!wget -P logs/original https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/wd-v1-3-float16.ckpt

学習済みモデルの変換

 そのままではエラーが起こるのでちょっといじります。モデルによっては別の操作をしないといけないと思います(fullだと3,4行目いらなかったりするかも)。NovelAIのモデルの場合は重みだけをWDのモデルにコピーしたりしています。リークされたものだから詳しく書かないけど。

import torch
ckpt = torch.load("logs/original/wd-v1-3-float16.ckpt")
del ckpt["state_dict"]["model_ema.decay"]
del ckpt["state_dict"]["model_ema.num_updates"]
ckpt["optimizer_states"] = []
torch.save(ckpt,"logs/original/wd-v1-3-float16-del.ckpt")
del ckpt

学習データのスクレイピング

 私のようなものがいっぱい現れたせいかスクレイピング対策が強化されて、Colabやローカルではアクセスが拒否されるようになりました。なぜかGCPではできたけどいずれ対策されそうですね。(追記:Colabだとなんかできることもあります。リージョンの問題とかかな?)

#Colabじゃできなかったよ
python danbooru_data/scrape.py -user <user> -key <api-key> -t "検索タグ"
python danbooru_data/download.py

 コードそのままだと画像が中央でトリミングされてしまい、モデルが見切れた画像ばっかり出力するようになりごみでした。download.pyを書き換えた方がいいと思います。あとキャラクタータグもキャラクター名以外を勝手に削除してしまうので、別衣装とかを再現しづらくなります。直した方がいいかもしれません。
 データはlinks.tarファイルとして保存されます。waifu-diffusionの親ディレクトリにdatasetディレクトリを作って展開します。

!tar xvf links.tar -C ../dataset

8bit Adamの適用

 VRAM節約のために最適化関数を変更します。

#ldm/models/diffusion/ddpm.py

import bitsandbytes as bnb #上の方に追加

#420行目くらいと1370行目くらい
torch.optim.AdamW(...) → bnb.optim.Adam8bit(...)
#追記:bnb.optim.AdamW8bitの方がいいかも?

xformersによるVRAM節約(追記)

 xformersによるAttentionのメモリ効率化により12.2GBで動作することを確認しました。(学習が進むことを確認したただけなので非推奨→試しましたが、変な感じにはなってなかったです。)

xformersをインストールする。多分環境依存度がめちゃくちゃ高い。

%pip install -q https://github.com/metrolobo/xformers_wheels/releases/download/1d31a3ac_various_6/xformers-0.0.14.dev0-cp37-cp37m-linux_x86_64.whl

ldm/modules/attention.pyのCrossAttentionクラスのforwardを以下に変更する。(ほぼwebUIのコードのぱくり。)

import xformers.ops #上の方

class CrossAttention(nn.Module):
    
    ...

    def forward(self, x, context=None, mask=None):
        h = self.heads
        q = self.to_q(x)
        context = default(context, x)
    
        k = self.to_k(context)
        v = self.to_v(context)
    
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
        
        q = q.contiguous()
        k = k.contiguous()
        v = v.contiguous()

        out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
    
        out = rearrange(out, '(b h) n d-> b n (h d)', h=h)
        return self.to_out(out)

Configファイルの変更

 configs/stable-diffusion/v1-4-finetune-test.yamlを変更します。testなどとファイル名に書いてあるくらいなのですぐに変更されちゃいそうですね。

#2行目(追加や削除しながらだと行数が変わるので目安)
base_learning_rate: n #batch sizeを増やす場合小さくしたほうがいいかもしんない。

#18行目
use_ema: false #上の行と空白を合わせて追加

#51行目
ckpt_path: ... #削除する。もしくはVAEファイルのパスを指定するといいかも?

#78行目、79行目
batch_size: n #T4なら1、A100なら少なくとも4までは増やせる。増やせば学習時間節約できる。
num_workers: n  #batch_sizeと同じ数

#87行目、95行目
ucg: 0.1 #目的によりますが削除した方がいいかも。(後述)

#87行目、95行目
flip_p: 0.0 #左右非対称のキャラを覚えさせたい場合は追加する。

#100行目、#105行目
every_n_train_steps: n #モデルの保存頻度。every_n_epochsにすればエポックごとになる。
batch_frequency: n #検証用画像作成頻度。見ないならでっかい値でもいいかも。

#101行目
save_weights_only: True #追加すればoptimizer_statesが保存されなくなる。
#追記:続きから学習できなくなるので追加しない方がいいかもです!!!!

#122行目
max_epochs(またはmax_steps): n #いつ終わるか設定できる。

 ucgはようするに学習途中でたまに(設定した数値の確率で)captionを空文に置き換えることで、短いプロンプトでもいい画像を生成できるようにする手法のようです。ただしキャラクターの追加学習をしたい場合は、特定のプロンプトにキャラの特徴を結び付けたいので、いらない操作かなと思います(あくまで推測)。

学習開始

 学習します。少し経つとwandbというやつを使うかどうか聞かれます。学習の様子をグラフ化してくれるwebアプリです。興味ない人は3を押せばすすみます。

!python main.py -n "fine"  --resume_from_checkpoint logs/original/wd-v1-3-float16-del.ckpt --base ./configs/stable-diffusion/v1-4-finetune-test.yaml -t --no-test --seed 25 --scale_lr False --gpus 0, 

学習結果

 前回記事参照。

おわりに

 こんなことやってる人他にいないのかな?みんなdreamboothとか使ってるよね。
 コードに#をつかったらハッシュタグが勝手に設定されて大変なことになってたんだけどなにこれえ。