Stable Diffusionの学習コードを作る:4.LCM-LoRA編
今回は、LCM-LoRAの学習についてやっていくよ。じつはコードの作り直しをした最大の理由はこのLCM-LoRAを簡潔に実装するためだったんだよ。
https://github.com/laksjdjf/sd-trainer/blob/main/modules/lcm/lcm_trainer.py
スケジューラー
実装はめっちゃ簡単です。学習時にddimを使うので、分岐しています。こうしてみるとddimとの違いが分かりやすいですね。
class LCMScheduler(BaseScheduler):
# x_t -> x_prev_t
def step(self, sample, model_output, t, prev_t, use_ddim=False):
pred_original_sample = self.pred_original_sample(sample, model_output, t)
if use_ddim: # for training step
noise = self.pred_noise(sample, model_output, t)
else:
noise = torch.randn_like(sample)
return self.add_noise(pred_original_sample, noise, prev_t)
lcmのc_skip, c_outとかはt=0に近いときだけしか変わらないので見なかったことにします。またtimestepについても本家は学習した時刻だけでスケジューリングするようになっていると思いますが、面倒なので無視します(学習も任意の時刻でやります)。
Config
LCMでは追加で学習時のDDIMステップ数、guidance_scale、ネガティブプロンプトが必要です。Trainerの設定の中にadditional_confという自由に使える設定を追加したので、そこにLCMに関する設定を置きます。
trainer:
module: modules.lcm.lcm_trainer.LCMTrainer
additional_conf:
lcm:
guidance_scale: 7.0
num_inference_steps: 50
negative_prompt: "low quality"
各設定の意味は実装見ればなんとなくわかると思います。
LCMTrainer
BaseTrainerを継承して実装します。コンストラクタではスケジューラーを先ほどのものに置き換えます。またprepare_modules_for_trainingでは学習中に使うネガティブプロンプトの計算を行っておきます。
class LCMTrainer(BaseTrainer):
def __init__(self, config, diffusion, text_model, vae, scheduler, network):
super().__init__(config, diffusion, text_model, vae, scheduler, network)
self.scheduler = LCMScheduler(self.scheduler.v_prediction) # overwrite
def prepare_modules_for_training(self, device="cuda"):
super().prepare_modules_for_training(device)
self.text_model.to(device)
self.negative_encoder_hidden_states, self.negative_pooled_output = self.text_model([self.config.additional_conf.lcm.negative_prompt])
self.text_model.to(self.te_device)
loss
損失の計算にはUNetの計算が4回分必要です。一つ目は通常の学習と同じで学習データにノイズを加えたものを入力します。二つ目はLoRAなしで同じものを入力します。三つ目はLoRAなしでネガティブプロンプト(元論文では空文)を入力します。そして二つ目と三つ目の結果を用いてCFGを適用したDDIMの1ステップ分でノイズを除去します。四つ目はLoRAありで先ほどノイズを除去した潜在変数を入力します。勾配が有効なのは一つ目だけです。そして一つ目と四つ目の誤差(x0予測)を損失にします。
def loss(self, batch):
# batch取り出しコードは省略
num_inference_steps = self.config.additional_conf.lcm.num_inference_steps
interval = 1000 // num_inference_steps
timesteps = torch.randint(interval, 1000, (self.batch_size,), device=latents.device)
prev_timesteps = timesteps - interval
noise = torch.randn_like(latents)
noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)
with torch.autocast("cuda", dtype=self.autocast_dtype):
model_output = self.diffusion(noisy_latents, timesteps, encoder_hidden_states, pooled_output, size_condition)
pred_original_sample = self.scheduler.pred_original_sample(noisy_latents, model_output, timesteps)
with torch.no_grad():
# one step ddim
negative_encoder_hidden_states = self.negative_encoder_hidden_states.repeat(self.batch_size, 1, 1)
negative_pooled_output = self.negative_pooled_output.repeat(self.batch_size, 1)
with self.network.set_temporary_multiplier(0.0):
uncond = self.diffusion(noisy_latents, timesteps, negative_encoder_hidden_states, negative_pooled_output, size_condition)
cond = self.diffusion(noisy_latents, timesteps, encoder_hidden_states, pooled_output, size_condition)
cfg_model_output = uncond + self.config.additional_conf.lcm.guidance_scale * (cond - uncond)
prev_noisy_latents = self.scheduler.step(noisy_latents, cfg_model_output, timesteps, prev_timesteps, use_ddim=True)
# target
target_model_output = self.diffusion(prev_noisy_latents, prev_timesteps, encoder_hidden_states, pooled_output, size_condition)
target_original_sample = self.scheduler.pred_original_sample(prev_noisy_latents, target_model_output, prev_timesteps)
loss = nn.functional.mse_loss(pred_original_sample.float(), target_original_sample.float(), reduction="mean")
return loss
なんとこれでおわりです。
ちなみにLCMの元論文ではEMAを使いますが、学習が遅くなるだけのように思います。少なくとも既存のLCM-LoRAをファインチューニングする分にはいらないと思います。
LCM学習について
SD1やSDXL等のLCM-LoRAを、ファインチューニングモデル様に適合するよう学習するのは難しくなさそうです。どうやらLCM-LoRAはrank=1でも十分らしいので、rank=1にリサイズしたLoRAから学習を始めるとよさそう。