Latent Consistency Modelによる蒸留を試してみた

 1~8ステップくらいで画像生成ができるようになるLatent Consistency Modelをつくります。まずはLatent Consistency Modelの説明をし、次に蒸留体験記を書いていきます。

性能は微妙ですが、個人的にはうまくいくことが分かっただけで満足です。生成なんてしないしー。

拡散モデル

 目を瞑って想像してみてください、あなたは深い森の中に迷い込んでしまいました。その森には1時間に1回の頻度で鐘を鳴らす教会があり、あなたはその音を頼りに教会へ向かうことになりました。果たしてあなたは何時間で教会に到着するでしょうか・・・。
 さて、ここで歩くスピードとか体力は無視して、シンプルに鐘の音から方向が正確に予測することができるかを問題とします。1回の鐘の音を頼りにどんどん進んでしまうと、見当違いの方向に向かってしまいたどり着かないかもしれません。しかし何回も音を聞きながら慎重に進むと日が暮れてしまいます。
 このたとえ話は拡散モデルに似ています。拡散モデルは、完全なノイズ(森の中)から意味のある画像(教会)へ向かって歩き出します。その際方向を予測するのがUNetで、予測を元にどういった戦略で歩くかを決めるのがサンプラーです。たとえばEulerはUNetが予測する方向通りに歩きますが、Heunは二回分の予測を元に歩きます。またDPM++2Mという前回聞いた音を参考にする方法もあります。この辺のたとえ話はあくまで私の勝手な妄想であり理論的に正しいかは分かりませんが、ある程度あってるんじゃないかな。

早く教会にたどり着くために

 Stable Diffusionでは1000回音を聞いてたどり着くよう学習されていますが、時間がかかりすぎて話になりません。
 教会にたどり着く時間を削減するための方法として、二種類(それ以上あるかもだけど)あります。一つはサンプラーを工夫することです。上記であげたDPM++2Mとかがそうで、戦略を工夫して削減します。こちらは学習不要ですが削減には限界があり、今のところ少なくとも10~20回も音を聞かなければいけません。
 UNet側(音を聞いた時の予測)を強化する方法が蒸留です。蒸留とはあるモデルの性能を抽出して、別のモデルに移すことです。一般的には、より軽量なモデルに移すことで計算時間の削減をするという方法として使われますが、今回はより音を聞く回数が少なく済むようモデルを蒸留します。今回はこちらの手法を試しました。

既存の蒸留手法

Progressive Distillation

 拡散モデルの蒸留について話すとき、必ず出てくるのがProgressive Distillationです。この手法では、学習済みのモデルを教師、学習対象のモデルを生徒として、生徒を強化していきます。

 あなたは生徒です。教師から、「えー、今日は鐘の音を頼りに教会へ向かう授業をやるぞ。まずは先生が見本を見せてやるからな。」といわれどんな授業だよと思いながら観察していました。あなたは先生が必死に教会へ向かう様子をみて、「鐘の音を2回聞いて歩いた進路を、1回聞いただけで再現できるようにすれば半分の時間で済むじゃんwww」と思いました。まあそんなことできたら苦労しないだろと人間だったら思うんですが、AIだとできちゃうんですよね。生徒は教師の進路を学習し、教師の半分の時間で教会に着くようになりました。こうなったらもう教師交代ですね。あなたは得意げな顔で前任の教師よりも半分の時間で教会に着いてみせます。すると今度はあなたの生徒が・・・この後の流れは分かりますよね。という手法です。
(ちなみに教師はDDIMサンプラーの戦略をとります。)

Consistency Distillation

 Progressive Distillationより優れた方法としてでてきたのが、Consistency distillationです。

 教会に向かう進路を見ると、音を聞くたびに方向を修正しながら進むと考えられます。もし方向を修正せず、一直線に進んでたどり着くことができたらどうでしょう。この場合最初に予測した方向があっていたので、1回音を聞いただけでたどり着くことになります。Consistency distillationではこの状態を目指します。

 今回、あなたは生徒として教師と一緒に森の中にいます。一回鐘の音を聞き、教会までの方向を予測してみました。しかし自信がなかったので教師の言う通りに進みました。次に音がしたとき、先ほどより教会に近づいていたため、少し自信をもって方向を予測することができました。そこで1回目の予測結果を2回目の予測結果に合わせるように学習しました。1回目の予測結果を2回目に合わせ、2回目の予測結果を3回目に合わせ・・・ということを繰り返すことで、何回目でも同じ予測結果を出せるようになり、進路が一貫性を持つようになります。

発展:
 実は2回目に聞いた音は、少し前の自分による予測結果を使います。少し前の自分とは、発展なので特に説明せずに言いますが、EMAモデルのことです。これは強化学習におけるDouble DQNからきてるみたいですね。
 DQNでは、今の時刻における報酬予測を、現状の最適行動を行った後の時刻の報酬予測に合わせるように学習します。Consistency distillationと似てますね。しかし最適行動を決める報酬関数と、次の時刻の報酬を予測する報酬関数が一緒だと、自己採点となるので過大評価してしまいます。そこで正解データとなる次の時刻の報酬を予測する関数は更新を停止し、定期的に同期するようにします。
 Consistency Modelの論文ではこれによって学習が安定するとしか言及されていないので同じ理由かはわかりません。

条件付き生成への対応

 ここまでの話は目標を教会の一つに定めていました。しかしStable Diffusionはプロンプトによる指示があるのでそれに対応する必要があります。森の中にはさまざまな教会があります。たとえば金髪ツインテ教、ふたなり教の他、おねショタ逆転に代表される危険なカルトもあります。あなたは自分の教義に基づいた教会に向かい、逆におねショタ逆転のようなカルトには近づかないようにしなければいけません(※ただの一例です)。
 といってもやることが大きく変わったりはしません。拡散モデルではClassifier Free Guidance(CFG)という手法がありそれだけ注意が必要です。割と過去の記事でも説明しているような気がするのでもうめんどくさいから書きませんが、プロンプトによる予測と空文による予測を組み合わせ、プロンプトによる予測を強調することで、よりプロンプトに忠実な画像を生成できるようになる手法です。二回分の予測が必要なので、計算時間も二倍になります。そのためこれもなんとかする必要があります。

On Distillation of Guided Diffusion Models

 この論文の手法は二段階に分けられます。まずはCFGによる二回分の予測の組み合わせを、1回で予測できるようなモデルになるよう蒸留します。これによりCFGが不要になるため、後はProgressive Distillationでさらに計算時間を削減します。
一段階目はやってみたことがあります。

Latent Consistency distillation

 ようやく今回やった方法がでてきました。これはConsistency ModelにCFGを対応したものです。やり方は簡単で、教師による進路(前の図の赤い矢印)に対して、CFGを適用するだけです。CFG_scaleは学習中にランダムに選びますが、生徒側はCFGを適用しないため、時刻埋め込みと同じような方法でCFG_scaleを入力として受け付けられるようにしています。また教師が進む距離を長くする(20ステップ分を一気に進む)ことで、学習を高速化しています。論文では、二段階の学習が必要な先ほどの手法より圧倒的に早く安定して学習できると主張しています。またConsistency distillationでは教師の戦略としてEulerやHeunなどを使っていましたが、こちらはDDIMやDPM系を使っているらしい。なんでかは知らんけどStable Diffusionの場合実装的にはDDIMの方がやりやすいですね。

蒸留してみた

 蒸留には学習済みのモデル(教師)・学習中のモデル(生徒)・生徒のEMAの3つ分が必要です。VRAM使用量が大変なことになるので、LoRAを使いました。LoRAであれば学習済みモデルとLoRAとLoRAのEMAだけで済みます。本来はCFG_scaleのためのモジュールを新しく学習しなければいけないのでその辺の実装もすべきですが、めんどうだったのでweightが0のまま固定しています。CFG_scaleも7.0で固定しました。また論文では教師側の進路に無条件生成を使っていますが、私はここにネガティブプロンプトを適用してみました。効果があったのかはわかりません。
結果はまあ、それっぽいのはでてきますね。

コードはこんな感じ

損失は以下のようにどんどんあがっていくようですね。ふしぎ。

生成結果

prompt:anime, masterpiece, best quality, 1girl, solo, blush, sitting, twintails, blonde hair, bowtie, school uniforme, nature
seed:4545

step=4

まあ見れる画像にはなってます。

コードの詳しい説明

 主に学習ループの中身について詳しく書いておきます。まずは必要な関数の説明から、

f_lcm関数について:
論文中に出てくる関数fのことです。モデルによる元の画像の予測結果を出力します。ノイズ予測モデルや、velocity予測モデルの場合は元の画像の予測に変換します。ただしt=0のときはモデルへの入力をそのまま返す関数になります。c_skipやc_outは今回の設定ではデルタ関数になっておりあまり意味はありません。diffusersではここにあります。

pred_original_sample_noise_predについて:
書いてて気づいたけど何だこの関数名・・・。モデルによる元画像の予測とノイズの予測を計算する関数です。

k = 20
timesteps = torch.randint(k, 1000, (bsz,), device=latents.device)
timesteps = timesteps.long()

noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

guidance_scale = torch.tensor(7).repeat(bsz)
w_embedding = get_w_embedding(guidance_scale, embedding_dim=256, dtype=latents.dtype).to(latents.device)

  ここでは時刻をランダムで選び、それに応じたノイズをデータセットから取り出した潜在変数に加えます。また本来はguidance_scaleをランダムに選ぶべきですが、さぼって固定しています。(time_condを学習対象としていないため)。

model_output = unet(
    noisy_latents,
    timesteps,
    encoder_hidden_states,
    timestep_cond=w_embedding,
    added_cond_kwargs=added_cond_kwargs,
).sample

pred = f_lcm(noise_scheduler, model_output, timesteps, noisy_latents)

モデルの予測結果です。

ここからは勾配を止めて推論します。

with network.set_temporary_multiplier(0.0):
    cond = unet(
             noisy_latents,
             timesteps,
             encoder_hidden_states,
             timestep_cond=w_embedding,
             added_cond_kwargs=added_cond_kwargs,
        ).sample

    uncond = unet(
            noisy_latents,
            timesteps,
            uncond_hidden_states,
            timestep_cond=w_embedding,
            added_cond_kwargs=uncond_added_cond_kwargs,
        ).sample

    cfg = uncond + 7.0 * (cond - uncond)

    pred_original_sample, noise_pred = pred_original_sample_noise_pred(noise_scheduler, cfg, timesteps, noisy_latents)
    prev_noisy_latents = noise_scheduler.add_noise(pred_original_sample, noise_pred, timesteps-k) # ddim one step

この部分は、教師モデルによるkステップ分の生成を再現しています。cfgを適用した予測結果に対して、元画像の予測とノイズ予測を計算し、時刻t-kで元画像の予測に対してノイズ予測を加えます。これは分かりづらいですが、DDIMのeta=0(デフォルト設定)です。DPM++などを教師としたい場合もうちょっと工夫する必要がありますね。

with network.set_temporary_ema():
    prev_model_output = unet(
          prev_noisy_latents,
            timesteps - k,
            encoder_hidden_states,
            timestep_cond=w_embedding,
            added_cond_kwargs=added_cond_kwargs,
        ).sample

target = f_lcm(noise_scheduler, prev_model_output, timesteps - k, prev_noisy_latents)

時刻t-kの画像に対して、EMAモデルで推論します。

loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="mean")

損失は時刻tの推測と、時刻t-kにおけるEMAモデルの推論の二乗誤差です。Consistency Modelの方では損失として二乗誤差以外も議論されているようですが、Latent Consistency Modelでは特に見当たらないので二乗誤差にしました。