Stable DiffusionのLoRAを試してみた(新しい追加学習手法)。

 新しいファインチューニング方式が出てきたので試してみます。

LoRAについて

Google翻訳を通したざっくりとした理解ですが、この方式のポイントは以下の3つです。

  1. モデルの誤差を学習する

  2. n次元からm次元への全結合層をn次元→r次元→m次元への全結合層とすることで、パラメータを減らす(rがランク)

  3. Attention layerの一部の全結合層のみを学習対象とする

 この3つにより一部のパラメータの誤差のみを保存すればいいので、学習結果がr×1MB程度で収まります。さらに誤差のみを保存しているので、他のモデルにも使いまわせたり、複数の学習結果をマージできるといったこともできます。今回は他のモデルに使いまわすやつも試してみます。

Dreamboothがモデル全体を学習、Hypernetworkがモデルそのものは固定して、Attention layerにオリジナルの全結合層を差し込む、textual inversionがモデルを固定して単語に意味を付与する(ようなイメージ?)に対して、LoRAはモデルの一部を学習し、残差だけ持っておくという感じですね。

実装

 diffusersにパッチを充てる形になっているため、diffusersベースのトレーニングコードを少し変えるだけでできるようです。githubにある説明と同じようなもんですが、UNetのみの場合下のような感じでできます。ただしUNetをGPUに移動するタイミング(.to("cuda"))に気を付けないとエラーが起こるようです。

from lora_diffusion import inject_trainable_lora, extract_lora_ups_down, save_lora_weight
unet.requires_grad_(False)
unet_lora_params, _ = inject_trainable_lora(unet,r=<ランク>)

#optimizerにunet_lora_paramsを渡す
optim(itertools.chain(*unet_lora_params),...)

#保存(わざわざpipeline通す必要あるのか分からないけど、diffusersベースのコードだったらsaveするときpipelineつくるっしょ^^)
save_lora_weight(pipe.unet, "<output path>")

学習について

 データセットは12000枚のごちうさ画像です。学習率は1e-4推奨らしいです。5e-6だと全然進みませんでした。ランクは今回64でやってみます。デフォルトは4なのですが、結果が微妙で上げ続けたら64になりました。rankを増やしても学習時間はほとんど変わらなかったです。学習時間はローカルの安定しない環境でやったのでざっくりとしか分かりませんが、純粋なファインチューニングに比べて4割くらい減ったと思います。
追記:どうやらバグで勾配計算が途中で途切れてしまってたようです。そのため計算時間が減ったのだと思われます。

 学習元のモデル。そんなにできのいいモデルじゃないので結果はあまり期待しないでね。

結果(r=64)

例のごとく学習で満足しちゃってプロンプトはてきとう。体がちゃんと書けていなかったりするので目を細めてみてください。

 キャラの特徴は学習できてますね。r=4だとチノちゃんができてるくらいでした。ココアさんは髪の色とか髪型を特徴づけるタグがなくてちょっと難しいですね。

 ちなみに上半身だけにすればこれくらいのクオリティになります。

 作例として上半身だけの画像をあげているものが多いですが、そんなの簡単にできるので何の意味もないです。全身の絵を出さないのは逃げだと思います。

結果(別のモデルで使いまわし)

学習結果を別のモデルで使いまわします。対象のモデルはSD2.1系に13万枚の画像を学習したものです。ごちうさの画像はほとんど入ってないです。

 うーんまあ別のモデルでもある程度学習結果を反映させることはできるようですね。実用性は・・・なさそう。

結果(up blocksファインチューン)

 実は少し前にUNetのup blocksだけを学習したらどうなるんだろうと思い立ちやってみましたので比較してみます。LoRAと同じようにモデル全体を学習対象にするのは広げすぎじゃないか?という発想です。up blocksだけにすると学習時間が2割ほど減りました。まあ本当はモデル全体のファインチューニングと比較するべきなんですがね。

LoRAに比べて再現できてるというほどでもないですね。LoRAの方が学習効率はよさそうです。

おわりに

必要ステップ数など全く調査できてないので、なんともいえないです。ああ有馬記念始まっちゃう誤字とかチェックしてないけどいいや!