Stable Diffusionの学習コードを作る:4.LCM-LoRA編

 今回は、LCM-LoRAの学習についてやっていくよ。じつはコードの作り直しをした最大の理由はこのLCM-LoRAを簡潔に実装するためだったんだよ。

https://github.com/laksjdjf/sd-trainer/blob/main/modules/lcm/lcm_trainer.py

スケジューラー

 実装はめっちゃ簡単です。学習時にddimを使うので、分岐しています。こうしてみるとddimとの違いが分かりやすいですね。

class LCMScheduler(BaseScheduler):
     # x_t -> x_prev_t
    def step(self, sample, model_output, t, prev_t, use_ddim=False):
        pred_original_sample = self.pred_original_sample(sample, model_output, t)
        if use_ddim: # for training step
            noise = self.pred_noise(sample, model_output, t)
        else:
            noise = torch.randn_like(sample)
        return self.add_noise(pred_original_sample, noise, prev_t)

 lcmのc_skip, c_outとかはt=0に近いときだけしか変わらないので見なかったことにします。またtimestepについても本家は学習した時刻だけでスケジューリングするようになっていると思いますが、面倒なので無視します(学習も任意の時刻でやります)。

Config

 LCMでは追加で学習時のDDIMステップ数、guidance_scale、ネガティブプロンプトが必要です。Trainerの設定の中にadditional_confという自由に使える設定を追加したので、そこにLCMに関する設定を置きます。

trainer:
  module: modules.lcm.lcm_trainer.LCMTrainer
  additional_conf:
    lcm:
      guidance_scale: 7.0
      num_inference_steps: 50
      negative_prompt: "low quality"

各設定の意味は実装見ればなんとなくわかると思います。

LCMTrainer

 BaseTrainerを継承して実装します。コンストラクタではスケジューラーを先ほどのものに置き換えます。またprepare_modules_for_trainingでは学習中に使うネガティブプロンプトの計算を行っておきます。

class LCMTrainer(BaseTrainer):
    def __init__(self, config, diffusion, text_model, vae, scheduler, network):
        super().__init__(config, diffusion, text_model, vae, scheduler, network)
        self.scheduler = LCMScheduler(self.scheduler.v_prediction) # overwrite
        
    def prepare_modules_for_training(self, device="cuda"):
        super().prepare_modules_for_training(device)

        self.text_model.to(device)
        self.negative_encoder_hidden_states, self.negative_pooled_output = self.text_model([self.config.additional_conf.lcm.negative_prompt])
        self.text_model.to(self.te_device)

loss

 損失の計算にはUNetの計算が4回分必要です。一つ目は通常の学習と同じで学習データにノイズを加えたものを入力します。二つ目はLoRAなしで同じものを入力します。三つ目はLoRAなしでネガティブプロンプト(元論文では空文)を入力します。そして二つ目と三つ目の結果を用いてCFGを適用したDDIMの1ステップ分でノイズを除去します。四つ目はLoRAありで先ほどノイズを除去した潜在変数を入力します。勾配が有効なのは一つ目だけです。そして一つ目と四つ目の誤差(x0予測)を損失にします。

    def loss(self, batch):
       # batch取り出しコードは省略

        num_inference_steps = self.config.additional_conf.lcm.num_inference_steps
        interval = 1000 // num_inference_steps
        timesteps = torch.randint(interval, 1000, (self.batch_size,), device=latents.device)

        prev_timesteps = timesteps - interval

        noise = torch.randn_like(latents)
        noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)
        
        with torch.autocast("cuda", dtype=self.autocast_dtype):
            model_output = self.diffusion(noisy_latents, timesteps, encoder_hidden_states, pooled_output, size_condition)
            pred_original_sample = self.scheduler.pred_original_sample(noisy_latents, model_output, timesteps)

            with torch.no_grad():
                # one step ddim
                negative_encoder_hidden_states = self.negative_encoder_hidden_states.repeat(self.batch_size, 1, 1)
                negative_pooled_output = self.negative_pooled_output.repeat(self.batch_size, 1)
                
                with self.network.set_temporary_multiplier(0.0):
                    uncond = self.diffusion(noisy_latents, timesteps, negative_encoder_hidden_states, negative_pooled_output, size_condition)
                    cond = self.diffusion(noisy_latents, timesteps, encoder_hidden_states, pooled_output, size_condition)
                    cfg_model_output = uncond + self.config.additional_conf.lcm.guidance_scale * (cond - uncond)

                prev_noisy_latents = self.scheduler.step(noisy_latents, cfg_model_output, timesteps, prev_timesteps, use_ddim=True)

                # target
                target_model_output = self.diffusion(prev_noisy_latents, prev_timesteps, encoder_hidden_states, pooled_output, size_condition)
                target_original_sample = self.scheduler.pred_original_sample(prev_noisy_latents, target_model_output, prev_timesteps)

        loss = nn.functional.mse_loss(pred_original_sample.float(), target_original_sample.float(), reduction="mean")

        return loss

なんとこれでおわりです。
 ちなみにLCMの元論文ではEMAを使いますが、学習が遅くなるだけのように思います。少なくとも既存のLCM-LoRAをファインチューニングする分にはいらないと思います。

LCM学習について

 SD1やSDXL等のLCM-LoRAを、ファインチューニングモデル様に適合するよう学習するのは難しくなさそうです。どうやらLCM-LoRAはrank=1でも十分らしいので、rank=1にリサイズしたLoRAから学習を始めるとよさそう。