Stable Diffusionの学習コードを作る:2.学習編

 前回の続きとして、学習のために必要なコードを紹介していきます。今回はLoRAではなく、フルファインチューニングができるようにします。


データセット(BaseDataset)

データセットのフォルダ構造は以下のような感じです。フォルダ名はデフォルトの名前であって、設定で自由に変えられるようにします。

Dataset/
 buckets.json # bucketingのメタデータ
 original_size.json # 任意でSDXLのみ必要
 images/ # 学習画像 拡張子はpng 
 latents/ # 学習画像をVAEエンコードしたもの 拡張子はnpy
 captions/ # キャプションのテキストファイル 拡張子はcaption
 text_embs/ # キャプションの埋め込み 拡張子はnpz

 imagesとlatentsはどちらかがあればいいという感じです。captionsとtext_embsもどっちかがあればいいです。画像とキャプションはファイル名で対応付けします。

 画像の前処理は古いコードをずっと使いまわしており、あんまりいいコードじゃないけど変える意欲がわきません。

Aspect ratio Bucketing

https://github.com/laksjdjf/sd-trainer/tree/main/preprocess

 NovelAIが提唱した複数のアスペクト比のデータで学習する方法です。UNetは任意の解像度に対応できますが、Pytorchは異なるサイズのテンソルで並列計算することはできません。そのためいくつかの解像度候補(bucket)を選び、画像をそれぞれ一番近い解像度にリサイズ・切り抜きを行います。そして同じbucketからミニバッチを取り出します。詳細は学習設定編で書いてあります。今回は実装面を見てみましょう。
 bucketの作成については、基本的には幅を64単位で広げながら一番大きくできる高さを計算してbucketに追加します。アスペクト比が極端な画像は入れないようにしてます。

def make_buckets():
    # モデルの構造からして64の倍数が推奨される。(VAEで8分の1⇒UNetで8分の1)
    increment = 64
    # 最大ピクセル数
    max_pixels = args.resolution*args.resolution

    # 正方形は手動で追加
    buckets = set()
    buckets.add((args.resolution, args.resolution))

    # 最小値から~
    width = args.min_length
    # ~最大値まで
    while width <= args.max_length:
        # 最大ピクセル数と最大長を越えない最大の高さ
        height = min(args.max_length, (max_pixels // width) - (max_pixels // width) % increment)
        ratio = width/height

        # アスペクト比が極端じゃなかったら追加、高さと幅入れ替えたものも追加。
        if 1/args.max_ratio <= ratio <= args.max_ratio:
            buckets.add((width, height))
            buckets.add((height, width))
        width += increment  # 幅を大きくして次のループへ

 画像は一番アスペクト比が近いbucketに割り当て、リサイズおよび中央切り抜きをします。

def resize_image(file):
    image = Image.open(file)
    image = image.convert("RGB")
    ratio = image.width / image.height
    ar_errors = ratios - ratio
    indice = np.argmin(np.abs(ar_errors))  # 一番近いアスペクト比のインデックス
    bucket_width, bucket_height = buckets[indice]
    ar_error = ar_errors[indice]
    if ar_error <= 0:  # 幅<高さなら高さを合わせる
        temp_width = int(image.width*bucket_height/image.height)
        image = image.resize((temp_width, bucket_height))  # アスペクト比を変えずに高さだけbucketに合わせる
        left = (temp_width - bucket_width) / 2  # 切り取り境界左側
        right = bucket_width + left  # 切り取り境界右側
        image = image.crop((left, 0, right, bucket_height))  # 左右切り取り
    else:  # 幅高さを逆にしたもの
        temp_height = int(image.height*bucket_width/image.width)
        image = image.resize((bucket_width, temp_height))
        upper = (temp_height - bucket_height) / 2
        lower = bucket_height + upper
        image = image.crop((0, upper, bucket_width, lower))
    image.save(os.path.join(args.output_dir, os.path.basename(file)))
    return [os.path.splitext(os.path.basename(file))[0], str((bucket_width, bucket_height))]

 この手法は学習開始時に行う方法もありますが、学習スクリプトが複雑になるのが嫌なので学習前にやってストレージに保存します。同じデータセットを複数の設定で学習したいときに何度もやらずに済みますしね。
 メタデータで、各bucketに入っているファイル名を辞書形式で保存しておきます。これによって学習時にデータを精査せずともミニバッチを取りだせるようにしておきます。

オリジナルサイズ

 SDXLの学習では、学習データの元々の解像度情報が必要です。他にも切り取り位置に関する情報も必要ですが、今回のコードは中央切り取りしかしないんで、元の解像度から逆算可能です。original_size.jsonというファイルでkeyがファイル名でvalueが元の解像度の辞書となる辞書形式で保存しておきます。辞書in辞書だね。

潜在変数のキャッシュ

https://github.com/laksjdjf/sd-trainer/blob/main/preprocess/latent.py

 学習データはVAEによって潜在変数にエンコードする必要がありますが、VAEは学習対象ではないので、1枚の画像を何度もエンコードするのは非効率です。そのため潜在変数をキャッシュします。上と同じく学習前に行いストレージに保存します(これはキャッシュといっていいのか?)。
 コードはVAEでエンコードしてnp.saveするだけなので特に紹介しないよ。

テキスト埋め込みのキャッシュ

https://github.com/laksjdjf/sd-trainer/blob/main/preprocess/text_embedding.py

 潜在変数のキャッシュと同じですが、こちらの方が計算量が圧倒的に低いのでVRAMがぎりぎり足りないとかじゃなければやらなくてもいいです。というわけで今回は省略します。

BaseDatasetクラス

https://github.com/laksjdjf/sd-trainer/blob/main/modules/dataset.py

 データセットはtorch.utils.data.Datasetを継承します。通常Datasetは要素を1個ずつ取り出し、ミニバッチとしてまとめるのはDataloaderなんですが、Bucketingの関係上Dataset段階でミニバッチを作成します。
 キャプションなどのフォルダ名は任意に選べるようにします。
 設定について、prefixはキャプションの文頭に共通のワードを追加するための設定です。トリガーワードみたいなものに使えます。shuffleはエポックごとにデータをシャッフルするためのものなんですが、うまくいってるのかわかりません。ucgはキャプションを一定確率で空文にするというものです。CFGを使うためのものですが、大規模な学習以外では必要ないと思います。ucgを適用しつつテキスト埋め込みのキャッシュをする場合、空文のテキスト埋め込みが追加で必要になるため、コンストラクタで作成して保存しておきますが、このためだけにtext_modelを引数に取るのはよくないきがする。

class BaseDataset(Dataset):
    def __init__(
        self,
        text_model: TextModel,
        batch_size: int,
        path: str,
        metadata: str="buckets.json",
        original_size: Optional[str] = None,
        latent: Optional[str] = "latents",
        caption: Optional[str] = "captions",
        image: Optional[str] = None,
        text_emb: Optional[str] = None,
        prompt: Optional[str] = None,
        prefix: str = "",
        shuffle: bool = False,
        ucg: float = 0.0
    ):

        with open(os.path.join(path, metadata), "r") as f:
            self.bucket2file = json.load(f)

        if original_size is not None:
            with open(os.path.join(path, original_size), "r") as f:
                self.original_size = json.load(f)
        else:
            self.original_size = {}

        self.path = path
        self.batch_size = batch_size
        self.text_model = text_model
        self.latent = latent
        self.caption = caption
        self.image = image
        self.text_emb = text_emb
        self.prompt = prompt  # 全ての画像のcaptionをpromptにする
        self.prefix = prefix  # captionのprefix
        self.shuffle = shuffle  # バッチの取り出し方をシャッフルするかどうか(データローダー側でシャッフルした方が良い^^)
        self.ucg = ucg  # captionをランダムにする空文にする確率

        # 空文の埋め込みを事前に計算しておく
        if self.ucg > 0.0 and self.text_emb:
            text_device = self.text_model.device
            self.text_model.to("cuda")
            with torch.no_grad():
                self.uncond_hidden_state, self.uncond_pooled_output = self.text_model([""])
            self.uncond_hidden_state.detach().float().cpu()
            self.uncond_pooled_output.detach().float().cpu()
            self.text_model.to(text_device)
            logger.info(f"空文の埋め込みを計算したよ!")

        self.init_batch_samples()
        logger.info(f"データセットを作ったよ!")

init_batch_samples

 Bucketからミニバッチの取り出し方を決定します。keyがBucket, valueがファイル名のリストとなっているメタデータを作ってあるので、各Bucketのファイル名をシャッフルしてミニバッチにわけ、全部のミニバッチを再度シャッフルします。

# バッチの取り出し方を初期化するメソッド
def init_batch_samples(self):
    self.batch_samples = []
    for key in self.bucket2file:
        random.shuffle(self.bucket2file[key])
        self.batch_samples.extend([self.bucket2file[key][i:i+self.batch_size]
                                  for i in range(0, len(self.bucket2file[key]), self.batch_size)])
    random.shuffle(self.batch_samples)

get_item

 ミニバッチを辞書形式でつくります。画像と潜在変数、キャプションとテキスト埋め込みはどっちかだけ読み込むようにします。size_conditionはSDXLに必要な画像の解像度情報で、次の節で説明します。

def __getitem__(self, i):
    if i == 0 and self.shuffle:
        self.init_batch_samples()

    batch = {}
    samples = self.batch_samples[i]

    if self.image:
        batch["images"] = self.get_images(samples, self.image if isinstance(self.image, str) else "images")
        target_height, target_width = batch["images"].shape[2:]
    else:
        batch["latents"] = self.get_latents(samples, self.latent)
        target_height, target_width = batch["latents"].shape[2]*8, batch["latents"].shape[3]*8

    batch["size_condition"] = self.get_size_condition(samples, target_height, target_width)

    if self.text_emb:
        batch["encoder_hidden_states"], batch["pooled_outputs"] = self.get_text_embeddings(samples, self.text_emb if isinstance(self.text_emb, str) else "text_emb")
    else:
        batch["captions"] = self.get_captions(samples, self.caption)

    return batch

get_size_condition

 学習時に画像をリサイズしたり、切り抜きをしますが、当然その際に情報が失われます。size_conditionはモデルにそのことを伝えるためにあります。具体的には学習データの元々の解像度縦横、切り抜き位置縦横、前処理後の解像度縦横の6要素からなります。生成時は基本的に二つの解像度は生成解像度に設定して、切り抜き位置は0にします。このコードではbucketingに基づいてリサイズ・中央切り抜きするため元々の解像度があれば切り抜き位置は計算できます。この切り抜き情報の計算が厄介で、実装があっているかもわかりませんが、✟kohya_tech✟さんが解説してくれています。

    def get_size_condition(self, samples, target_height, target_width):
        size_condition = []
        for sample in samples:
            if sample in self.original_size:
                original_width = self.original_size[sample]["original_width"]
                original_height = self.original_size[sample]["original_height"]
            
                original_ratio = original_width / original_height
                target_ratio = target_width / target_height

                if original_ratio > target_ratio: # 横長の場合
                    resize_ratio = target_width / original_width # 横幅を合わせる
                    resized_height = original_height * resize_ratio # 縦をリサイズ
                    crop_top = (target_height - resized_height) // 2 # 上部の足りない分がcrop_top
                    crop_left = 0
                else:
                    resize_ratio = target_height / original_height
                    resize_width = original_width * resize_ratio 
                    crop_top = 0
                    crop_left = (target_width - resize_width) // 2
            else:
                original_width, original_height = target_width, target_height
                crop_top = 0
                crop_left = 0
            size_list = [original_height, original_width, crop_top, crop_left, target_height, target_width]    
            size_condition.append(torch.tensor(size_list))
        return torch.stack(size_condition)

get_その他

 各データをフォルダから取り出すメソッドです。いっぱいあってきりがないのでとばしますが、基本はTensorにしてミニバッチでまとめます。キャプションの場合は文字列のリストにします。その際ucgやprefixの処理を行います。

Dataloader

 データセット側がミニバッチまで作成するため、データローダーはあんまりやることないです。データローダー側のバッチサイズは1にして、ミニバッチの作成を定義するcollate_fnは以下のようにすることでDatasetの__get_item__()の出力がそのままでてきます。

def collate_fn(x):
    return x[0]

Config

 学習設定についてやっていきます。configはmain, trainer, dataset, dataloader, network(次回用)に分かれます。以下がConfigの構造ですが、ChatGPTに作ってもらったのでよく分からないです。デフォルト値を設定しておくことで、設定項目が増えても設定ファイルを更新しなくていいようにしていきます。
 設定ファイルでdatasetやtrainerなどのモジュール名を直接指定できるようにすることで、拡張性を高くしてます。代わりに開発者にしか使い方が理解できないようになりました。

https://github.com/laksjdjf/sd-trainer/blob/main/modules/config.py

from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any
from omegaconf import MISSING

@dataclass
class MainConfig:
    model_path: str = MISSING
    output_path: str = MISSING
    seed: Optional[int] = 4545
    sdxl: bool = MISSING
    clip_skip: Optional[bool] = None
    steps: Optional[int] = None
    epochs: Optional[int] = None
    save_steps: Optional[int] = None
    save_epochs: Optional[int] = 1
    sample_steps: Optional[int] = None
    sample_epochs: Optional[int] = 1
    log_level: str = "loggging.WARNING"
    wandb: Optional[str] = None

@dataclass
class OptimizerConfig:
    module: str = "torch.optim.AdamW"
    args: Optional[Any] = None

@dataclass
class TrainerConfig:
    module: str = "modules.trainer.BaseTrainer"
    train_unet: bool = MISSING
    train_text_encoder: bool = MISSING
    te_device: Optional[str] = None
    vae_device: Optional[str] = None
    train_dtype: str = MISSING
    weight_dtype: str = MISSING
    autocast_dtype: Optional[str] = None
    vae_dtype: Optional[str] = None
    lr: str = MISSING
    lr_scheduler: str = "constant"
    gradient_checkpointing: bool = False
    optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
    validation_num_samples: int = 4
    validation_seed: int = 4545
    validation_args: Dict[str, Any] = field(default_factory=dict)

@dataclass
class DatasetArgs:
    batch_size: int = MISSING
    path: str = MISSING
    metadata: str = "buckets.json"
    original_size: Optional[str] = None
    latent: Optional[str] = "latents"
    caption: Optional[str] = "captions"
    image: Optional[str] = None
    text_emb: Optional[str] = None
    prompt: Optional[str] = None
    prefix: str = ""
    shuffle: bool = False
    ucg: float = 0.0

@dataclass
class DatasetConfig:
    module: str = MISSING
    args: DatasetArgs = field(default_factory=DatasetArgs)

@dataclass
class DataLoaderArgs:
    num_workers: int = 0
    shuffle: bool = True

@dataclass
class DataLoaderConfig:
    module: str = MISSING
    args: DataLoaderArgs = field(default_factory=DataLoaderArgs)

@dataclass
class NetworkArgs:
    module: str = MISSING
    unet_key_filters: Optional[List[str]] = None
    module_args: Optional[Dict[str, Any]] = None
    conv_module_args: Optional[Dict[str, Any]] = None
    text_module_args: Optional[Dict[str, Any]] = None

@dataclass
class NetworkConfig:
    train: bool = MISSING
    args: NetworkArgs = field(default_factory=NetworkArgs)

@dataclass
class Config:
    main: MainConfig = field(default_factory=MainConfig)
    trainer: TrainerConfig = field(default_factory=TrainerConfig)
    dataset: DatasetConfig = field(default_factory=DatasetConfig)
    dataloader: DataLoaderConfig = field(default_factory=DataLoaderConfig)
    network: Optional[NetworkConfig] = None

yaml

実際の設定ファイルは以下のような感じです。上のConfigでMISSINGになっていない項目は省略可能です。

main:
  model_path: "model_path"
  output_path: "output_path"
  seed: 4545
  sdxl: false
  clip_skip: null
  steps: null
  epochs: 20
  save_steps: null
  save_epochs: 2
  sample_steps: null
  sample_epochs: 2
  log_level: "logging.INFO"
  wandb: sd-trainer

trainer:
  module: modules.trainer.BaseTrainer
  train_unet: true
  train_text_encoder: false
  te_device: null
  vae_device: null
  train_dtype: torch.float32
  weight_dtype: torch.bfloat16
  autocast_dtype: null
  vae_dtype: torch.bfloat16
  lr: "1e-5"
  lr_scheduler: "cosine"
  gradient_checkpointing: false
  optimizer:
    module: torch.optim.AdamW
    args: null
  validation_num_samples: 4
  validation_seed: 4545
  validation_args:
    prompt: "waifu, anime, exceptional, best aesthetic, new, newest, best quality, masterpiece, extremely detailed, astolfo, fate, solo, white cloak, armor, gauntlets, garter strap, black thighhighs, skirt, nature"
    negative_prompt: "realistic, real life, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name"
    width: 640
    height: 896

dataset:
  module: modules.dataset.BaseDataset
  args:
    batch_size: 4
    path: "dataset_path"
    metadata: "buckets.json"

dataloader:
  module: torch.utils.data.DataLoader
  args:
      num_workers: 4
      shuffle: true


get_attr_from_config

 設定ファイルの文字列からモジュールをインポートする関数も用意しています。

def get_attr_from_config(config_text: str):
    if config_text is None:
        return None
    module = ".".join(config_text.split(".")[:-1])
    attr = config_text.split(".")[-1]
    return getattr(importlib.import_module(module), attr)

main.py

 学習スクリプトの中核コードです。コマンドライン引数で設定ファイルのパスを受け取り、Omegaconfでロードします。そしてtrainerやdatasetを作成して、学習ループを回します。ループ中に決められた頻度でモデルのセーブやサンプル画像の生成をしていきます。
 orを多用していますが、実はorはTrueやFalseを返すわけではありません。A if A else Bと同じ意味です(多分)。Aの値を優先するけど、AがNoneのときはBにするみたいなときに便利ですね。
 WandBによるログ管理ができるようになっています。損失やサンプル画像をアップロードします。VRAM使用量なんかも実装せずとも勝手に記録されていて便利だよね。

from omegaconf import OmegaConf
import sys
import math
from accelerate.utils import set_seed
from modules.utils import get_attr_from_config, collate_fn
from modules.config import Config
from tqdm import tqdm
import logging
import wandb

logger = logging.getLogger("メインちゃん")

def main(config):

    set_seed(config.main.seed)
    logger.info(f"シードは{config.main.seed}だよ!")
    
    logger.info(f"モデルを{config.main.model_path}からロードしちゃうよ!")
    trainer_cls = get_attr_from_config(config.trainer.module)
    trainer = trainer_cls.from_pretrained(config.main.model_path, config.main.sdxl, config.main.clip_skip, config.trainer)

    dataset_cls = get_attr_from_config(config.dataset.module)
    dataset = dataset_cls(trainer.text_model, **config.dataset.args)

    dataloder_cls = get_attr_from_config(config.dataloader.module)
    dataloader = dataloder_cls(dataset, collate_fn=collate_fn, **config.dataloader.args)

    trainer.prepare_modules_for_training()

    trainer.prepare_network(config.network)

    trainer.prepare_optimizer()

    if config.main.wandb is not None:
        wandb_run = wandb.init(project=config.main.wandb, name=config.main.output_path, dir="wandb")
    else:
        wandb_run = None

    steps_per_epoch = len(dataloader)
    total_steps = config.main.steps or steps_per_epoch * config.main.epochs
    total_epochs = config.main.epochs or math.floor(total_steps / steps_per_epoch)
    logger.info(f"トータルのステップ数は{total_steps}だよ!")

    trainer.prepare_lr_scheduler(total_steps)

    save_interval = config.main.save_steps or config.main.save_epochs * steps_per_epoch
    sample_interval = config.main.sample_steps or config.main.sample_epochs * steps_per_epoch
    logger.info(f"モデルを{save_interval}ステップごとにセーブするよ!")
    logger.info(f"サンプルは{sample_interval}ステップごとに生成するよ!")

    progress_bar = tqdm(total=total_steps, desc="Training")
    current_step = 0

    for epoch in range(total_epochs):
        for batch in dataloader:
            logs = trainer.step(batch)

            progress_bar.update(1)
            progress_bar.set_postfix(logs)
            if wandb_run is not None:
                wandb_run.log(logs, step=current_step)

            if current_step % save_interval == 0 or current_step == total_steps - 1:
                trainer.save_model(config.main.output_path)
            if current_step % sample_interval == 0 or current_step == total_steps - 1:
                images = trainer.sample_validation(current_step)
                if wandb_run is not None:
                    images = [wandb.Image(image, caption=config.trainer.validation_args.prompt) for image in images]
                    wandb_run.log({'images': images}, step=current_step)
                else:
                    [image.save(f"image_logs/{current_step}_{i}.png") for i, image in enumerate(images)]

            current_step += 1

            if current_step == total_steps:
                logger.info(f"トレーニングが終わったよ!")
                if wandb_run is not None:
                    wandb_run.finish()
                return

        logger.info(f"エポック{epoch+1}が終わったよ!")

if __name__ == "__main__":
    config = OmegaConf.load(sys.argv[1])
    config = OmegaConf.merge(OmegaConf.structured(Config), config)
    logging.basicConfig(level=logging.INFO)
    print(OmegaConf.to_yaml(config))
    main(config)

Trainer

 mainで使われている学習用メソッドについて紹介していきます。

prepare_modules_for_training

 モデルをデバイスに移動して型を指定してgrad_scalerを定義してtrainやrequired_grad_を設定してgradient_checkpointingを設定するメソッドです。
 指定する型はなんと4つあります。train_dtypeは学習対象モデルの型です。基本的にfloat32にした方がいいと思います。bfloat16にするとfull_bf16になってVRAM使用量が下がります。weight_dtypeは学習対象以外のモデルの型です。AMPを利用する場合bfloat16やGPUが対応していない場合float16がいいでしょう。float8_e4m3fnを指定するとさらにVRAM使用量を削減できますが、現状fp8ではほとんどの計算ができないので、autocast_dtypeをbfloat16とかに設定する必要があります。vae_dtypeはfloat16で(NaN;)になるSDXLのVAEを使う場合に別の型を指定できるようになっています。
 テキストエンコーダはEmbedding層をfp8にするとエラーが起きるので、そこだけautocast_dtypeにしておきます(KohakuBlueLeafさんがsd-scriptsにfp8を適用したコードを参考にしています。)
 GradScalerはAMP利用時に勾配のアンダーフロー(値が小さすぎて0になってしまう)を防ぐためのものです。勾配は通常float32(train_dtypeと同じ)で計算されますが、途中の計算がfloat16だとアンダーフローが起きてしまう可能性が起きます。それを防ぐためにlossを大きな値でスケーリングして、勾配を計算した後再スケーリングして元の値に戻します。この処理はダイナミックレンジがfloat32より小さいfloat16でのみ必要であって、bfloat16では必要ないので(多分)、enabledでfloat16のときだけ適用するようにしています(そうしないとfull_bf16でエラーが起きる)。
 train()やrequired_grad_()に関しては学習対象だったらTrueに、そうでなかったらFalseにするだけですね。ただしgradient_checkpointingはDiffusersのコードだとtrain(True)のときにしか適用されないので、そうします。

def prepare_modules_for_training(self, device="cuda"):
    config = self.config
    self.device = device
    self.te_device = config.te_device or device
    self.vae_device = config.vae_device or device

    self.train_dtype = get_attr_from_config(config.train_dtype)
    self.weight_dtype = get_attr_from_config(config.weight_dtype)
    self.autocast_dtype = get_attr_from_config(config.autocast_dtype) or self.weight_dtype
    self.vae_dtype = get_attr_from_config(config.vae_dtype) or self.weight_dtype
    self.te_dtype = self.train_dtype if config.train_text_encoder else self.weight_dtype

    logger.info(f"学習対象モデルの型は{self.train_dtype}だって。")
    logger.info(f"学習対象以外の型は{self.weight_dtype}だよ!")
    logger.info(f"オートキャストの型は{self.autocast_dtype}にしちゃった。")

    self.grad_scaler = torch.cuda.amp.GradScaler(enabled=self.autocast_dtype == torch.float16)

    self.diffusion.unet.to(device, dtype=self.train_dtype if config.train_unet else self.weight_dtype)
    
    self.text_model.to(self.te_device, dtype=self.te_dtype)
    if  hasattr(torch, 'float8_e4m3fn') and self.te_dtype== torch.float8_e4m3fn:
        self.text_model.set_embedding_dtype(self.autocast_dtype) # fp8時のエラー回避

    self.vae.to(self.vae_device, dtype=self.vae_dtype)

    self.diffusion.unet.train(config.train_unet)
    self.diffusion.unet.requires_grad_(config.train_unet)
    self.text_model.train(config.train_text_encoder)
    self.text_model.requires_grad_(config.train_text_encoder)
    self.vae.eval()

    if config.gradient_checkpointing:
        self.diffusion.enable_gradient_checkpointing()
        self.text_model.enable_gradient_checkpointing()
        self.diffusion.unet.train() # trainでないと適用されない。
        self.text_model.train()
        logger.info("勾配チェックポイントを有効にしてみたよ!")

prepare_optimizer

 最適化関数を用意するメソッドです。学習率はカンマ区切りの文字列で設定することでunetとテキストエンコーダの学習率を変えられます。[0]と[-1]を参照することで、"1e-5"みたいな感じで一つの数字のみ指定しても対応できるようになってます。
 self.networkに関する行は次回やります。
 最適化関数は文字列でモジュールを指定します。"torch.optim.AdamW"とか"bitsandbytes.optim.AdamW8bit"とかね。引数も指定できます。

def prepare_optimizer(self):
    lrs = [float(lr) for lr in self.config.lr.split(",")]
    unet_lr, text_lr = lrs[0], lrs[-1]
    logger.info(f"UNetの学習率は{unet_lr}、text_encoderの学習率は{text_lr}にしてみた!")

    params = []

    if self.config.train_unet:
        params += [{"params":self.diffusion.unet.parameters(), "lr":unet_lr}]
    if self.config.train_text_encoder:
        params += [{"params":self.text_model.parameters(), "lr":text_lr}]
    if self.network:
        params += self.network.prepare_optimizer_params(text_lr, unet_lr)

    optimizer_cls = get_attr_from_config(self.config.optimizer.module)
    self.optimizer = optimizer_cls(params, **self.config.optimizer.args or {})

    logger.info(f"オプティマイザーは{self.optimizer}にしてみた!")
    total_params = sum(p.numel() for group in self.optimizer.param_groups for p in group['params'] if p.requires_grad)
    logger.info(f"学習対象のパラメーター数は{total_params:,}だよ!")

    return params

prepare_lr_scheduler

 これはdiffusersのget_schedulerの力をお借りします。num_warmup_stepsとか適当なので改善が必要なきがする。

def prepare_lr_scheduler(self, total_steps):
    self.lr_scheduler = get_scheduler(
        self.config.lr_scheduler,
        optimizer=self.optimizer,
        num_warmup_steps=int(0.05 * total_steps),
        num_training_steps=total_steps
    )
    logger.info(f"学習率スケジューラーは{self.lr_scheduler}にした!")

step

 学習ループ1ステップ分の処理をまとめています。ここはPytorchの定型的なコードですね。ログで損失値の指数移動平均や学習速度等が出力できるようにしておきます。

def step(self, batch):
    b_start = time.perf_counter()

    self.optimizer.zero_grad()
    loss = self.loss(batch)
    self.grad_scaler.scale(loss).backward()
    self.grad_scaler.step(self.optimizer)
    self.grad_scaler.update()
    self.lr_scheduler.step()

    if hasattr(self, "loss_ema"):
        self.loss_ema = self.loss_ema * 0.99 + loss.item() * 0.01
    else:
        self.loss_ema = loss.item()

    b_end = time.perf_counter()
    samples_per_second = self.batch_size / (b_end - b_start)

    logs = {"loss_ema":self.loss_ema, "samples_per_second":samples_per_second, "lr":self.lr_scheduler.get_last_lr()[0]}

    return logs

loss

 ミニバッチを受け取って損失を計算します。Diffusion modelの損失は潜在変数にノイズを加えてUNetに入力し、ノイズ予測を出して加えたノイズとの平均二乗誤差になります。
 潜在変数はvae.scaling_factorでスケーリングする必要があります。ここはWARNINGが出てunet.config.scaling_factorを使えと言われるので変えるかもしれません。
 SD2はv_predictionモデルなので損失はvelocityとの誤差になります。それを計算するのがschedulerのget_targetメソッドです。

def loss(self, batch):
    if "latents" in batch:
        latents = batch["latents"].to(self.device) * self.vae.scaling_factor
    else:
        with torch.autocast("cuda", dtype=self.vae_dtype), torch.no_grad():
            latents = self.vae.encode(batch['images'].to(self.device)).latent_dist.sample() * self.vae.scaling_factor
    
    self.batch_size = latents.shape[0] # stepメソッドでも使う

    if "encoder_hidden_states" in batch:
        encoder_hidden_states = batch["encoder_hidden_states"].to(self.device)
        pooled_output = batch["pooled_outputs"].to(self.device)
    else:
        with torch.autocast("cuda", dtype=self.autocast_dtype):
            encoder_hidden_states, pooled_output = self.text_model(batch["captions"])

    if "size_condition" in batch:
        size_condition = batch["size_condition"].to(self.device)
    else:
        size_condition = None

    timesteps = torch.randint(0, 1000, (self.batch_size,), device=latents.device)
    noise = torch.randn_like(latents)
    noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)
    
    with torch.autocast("cuda", dtype=self.autocast_dtype):
        model_output = self.diffusion(noisy_latents, timesteps, encoder_hidden_states, pooled_output, size_condition)

    target = self.scheduler.get_target(latents, noise, timesteps) # v_predictionの場合はvelocityになる

    loss = nn.functional.mse_loss(model_output.float(), target.float(), reduction="mean")

    return loss

save_model,  sample_validation

 モデルは学習対象のみをセーブするようにしています。検証画像の生成オプションは設定ファイルから決められます。コードは大したことないので省略します。

さあ学習だ!

 学習データを用意して設定ファイルを定義して、

python main.py config/config.yaml

でできます。コマンドライン引数をいっぱいならべるのはきらいです。

よくある質問(妄想)

Q. collate_fnってlambda式つかえばよくねwww
A. と思ってやったらnum_workersを1以上にしたときにmultiprocessで使う
pickleにlambda式が対応してないとかでエラーが起こるらしい。

Q. キャプションのシャッフルはしないんですか
A. ほんとうはするべきだとおもいますが、めんどくさいのでしません。

Q. 各ステップでモデルを上書きせず個別に保存しないんですか
A. 前のコードはそういう機能もありましたが、私はほとんど使いませんでした。サンプル画像の生成しているんだからよくね?