Stable Diffusionの学習コードを作る:2.学習編
前回の続きとして、学習のために必要なコードを紹介していきます。今回はLoRAではなく、フルファインチューニングができるようにします。
データセット(BaseDataset)
データセットのフォルダ構造は以下のような感じです。フォルダ名はデフォルトの名前であって、設定で自由に変えられるようにします。
Dataset/
buckets.json # bucketingのメタデータ
original_size.json # 任意でSDXLのみ必要
images/ # 学習画像 拡張子はpng
latents/ # 学習画像をVAEエンコードしたもの 拡張子はnpy
captions/ # キャプションのテキストファイル 拡張子はcaption
text_embs/ # キャプションの埋め込み 拡張子はnpz
imagesとlatentsはどちらかがあればいいという感じです。captionsとtext_embsもどっちかがあればいいです。画像とキャプションはファイル名で対応付けします。
画像の前処理は古いコードをずっと使いまわしており、あんまりいいコードじゃないけど変える意欲がわきません。
Aspect ratio Bucketing
https://github.com/laksjdjf/sd-trainer/tree/main/preprocess
NovelAIが提唱した複数のアスペクト比のデータで学習する方法です。UNetは任意の解像度に対応できますが、Pytorchは異なるサイズのテンソルで並列計算することはできません。そのためいくつかの解像度候補(bucket)を選び、画像をそれぞれ一番近い解像度にリサイズ・切り抜きを行います。そして同じbucketからミニバッチを取り出します。詳細は学習設定編で書いてあります。今回は実装面を見てみましょう。
bucketの作成については、基本的には幅を64単位で広げながら一番大きくできる高さを計算してbucketに追加します。アスペクト比が極端な画像は入れないようにしてます。
def make_buckets():
# モデルの構造からして64の倍数が推奨される。(VAEで8分の1⇒UNetで8分の1)
increment = 64
# 最大ピクセル数
max_pixels = args.resolution*args.resolution
# 正方形は手動で追加
buckets = set()
buckets.add((args.resolution, args.resolution))
# 最小値から~
width = args.min_length
# ~最大値まで
while width <= args.max_length:
# 最大ピクセル数と最大長を越えない最大の高さ
height = min(args.max_length, (max_pixels // width) - (max_pixels // width) % increment)
ratio = width/height
# アスペクト比が極端じゃなかったら追加、高さと幅入れ替えたものも追加。
if 1/args.max_ratio <= ratio <= args.max_ratio:
buckets.add((width, height))
buckets.add((height, width))
width += increment # 幅を大きくして次のループへ
画像は一番アスペクト比が近いbucketに割り当て、リサイズおよび中央切り抜きをします。
def resize_image(file):
image = Image.open(file)
image = image.convert("RGB")
ratio = image.width / image.height
ar_errors = ratios - ratio
indice = np.argmin(np.abs(ar_errors)) # 一番近いアスペクト比のインデックス
bucket_width, bucket_height = buckets[indice]
ar_error = ar_errors[indice]
if ar_error <= 0: # 幅<高さなら高さを合わせる
temp_width = int(image.width*bucket_height/image.height)
image = image.resize((temp_width, bucket_height)) # アスペクト比を変えずに高さだけbucketに合わせる
left = (temp_width - bucket_width) / 2 # 切り取り境界左側
right = bucket_width + left # 切り取り境界右側
image = image.crop((left, 0, right, bucket_height)) # 左右切り取り
else: # 幅高さを逆にしたもの
temp_height = int(image.height*bucket_width/image.width)
image = image.resize((bucket_width, temp_height))
upper = (temp_height - bucket_height) / 2
lower = bucket_height + upper
image = image.crop((0, upper, bucket_width, lower))
image.save(os.path.join(args.output_dir, os.path.basename(file)))
return [os.path.splitext(os.path.basename(file))[0], str((bucket_width, bucket_height))]
この手法は学習開始時に行う方法もありますが、学習スクリプトが複雑になるのが嫌なので学習前にやってストレージに保存します。同じデータセットを複数の設定で学習したいときに何度もやらずに済みますしね。
メタデータで、各bucketに入っているファイル名を辞書形式で保存しておきます。これによって学習時にデータを精査せずともミニバッチを取りだせるようにしておきます。
オリジナルサイズ
SDXLの学習では、学習データの元々の解像度情報が必要です。他にも切り取り位置に関する情報も必要ですが、今回のコードは中央切り取りしかしないんで、元の解像度から逆算可能です。original_size.jsonというファイルでkeyがファイル名でvalueが元の解像度の辞書となる辞書形式で保存しておきます。辞書in辞書だね。
潜在変数のキャッシュ
https://github.com/laksjdjf/sd-trainer/blob/main/preprocess/latent.py
学習データはVAEによって潜在変数にエンコードする必要がありますが、VAEは学習対象ではないので、1枚の画像を何度もエンコードするのは非効率です。そのため潜在変数をキャッシュします。上と同じく学習前に行いストレージに保存します(これはキャッシュといっていいのか?)。
コードはVAEでエンコードしてnp.saveするだけなので特に紹介しないよ。
テキスト埋め込みのキャッシュ
https://github.com/laksjdjf/sd-trainer/blob/main/preprocess/text_embedding.py
潜在変数のキャッシュと同じですが、こちらの方が計算量が圧倒的に低いのでVRAMがぎりぎり足りないとかじゃなければやらなくてもいいです。というわけで今回は省略します。
BaseDatasetクラス
https://github.com/laksjdjf/sd-trainer/blob/main/modules/dataset.py
データセットはtorch.utils.data.Datasetを継承します。通常Datasetは要素を1個ずつ取り出し、ミニバッチとしてまとめるのはDataloaderなんですが、Bucketingの関係上Dataset段階でミニバッチを作成します。
キャプションなどのフォルダ名は任意に選べるようにします。
設定について、prefixはキャプションの文頭に共通のワードを追加するための設定です。トリガーワードみたいなものに使えます。shuffleはエポックごとにデータをシャッフルするためのものなんですが、うまくいってるのかわかりません。ucgはキャプションを一定確率で空文にするというものです。CFGを使うためのものですが、大規模な学習以外では必要ないと思います。ucgを適用しつつテキスト埋め込みのキャッシュをする場合、空文のテキスト埋め込みが追加で必要になるため、コンストラクタで作成して保存しておきますが、このためだけにtext_modelを引数に取るのはよくないきがする。
class BaseDataset(Dataset):
def __init__(
self,
text_model: TextModel,
batch_size: int,
path: str,
metadata: str="buckets.json",
original_size: Optional[str] = None,
latent: Optional[str] = "latents",
caption: Optional[str] = "captions",
image: Optional[str] = None,
text_emb: Optional[str] = None,
prompt: Optional[str] = None,
prefix: str = "",
shuffle: bool = False,
ucg: float = 0.0
):
with open(os.path.join(path, metadata), "r") as f:
self.bucket2file = json.load(f)
if original_size is not None:
with open(os.path.join(path, original_size), "r") as f:
self.original_size = json.load(f)
else:
self.original_size = {}
self.path = path
self.batch_size = batch_size
self.text_model = text_model
self.latent = latent
self.caption = caption
self.image = image
self.text_emb = text_emb
self.prompt = prompt # 全ての画像のcaptionをpromptにする
self.prefix = prefix # captionのprefix
self.shuffle = shuffle # バッチの取り出し方をシャッフルするかどうか(データローダー側でシャッフルした方が良い^^)
self.ucg = ucg # captionをランダムにする空文にする確率
# 空文の埋め込みを事前に計算しておく
if self.ucg > 0.0 and self.text_emb:
text_device = self.text_model.device
self.text_model.to("cuda")
with torch.no_grad():
self.uncond_hidden_state, self.uncond_pooled_output = self.text_model([""])
self.uncond_hidden_state.detach().float().cpu()
self.uncond_pooled_output.detach().float().cpu()
self.text_model.to(text_device)
logger.info(f"空文の埋め込みを計算したよ!")
self.init_batch_samples()
logger.info(f"データセットを作ったよ!")
init_batch_samples
Bucketからミニバッチの取り出し方を決定します。keyがBucket, valueがファイル名のリストとなっているメタデータを作ってあるので、各Bucketのファイル名をシャッフルしてミニバッチにわけ、全部のミニバッチを再度シャッフルします。
# バッチの取り出し方を初期化するメソッド
def init_batch_samples(self):
self.batch_samples = []
for key in self.bucket2file:
random.shuffle(self.bucket2file[key])
self.batch_samples.extend([self.bucket2file[key][i:i+self.batch_size]
for i in range(0, len(self.bucket2file[key]), self.batch_size)])
random.shuffle(self.batch_samples)
get_item
ミニバッチを辞書形式でつくります。画像と潜在変数、キャプションとテキスト埋め込みはどっちかだけ読み込むようにします。size_conditionはSDXLに必要な画像の解像度情報で、次の節で説明します。
def __getitem__(self, i):
if i == 0 and self.shuffle:
self.init_batch_samples()
batch = {}
samples = self.batch_samples[i]
if self.image:
batch["images"] = self.get_images(samples, self.image if isinstance(self.image, str) else "images")
target_height, target_width = batch["images"].shape[2:]
else:
batch["latents"] = self.get_latents(samples, self.latent)
target_height, target_width = batch["latents"].shape[2]*8, batch["latents"].shape[3]*8
batch["size_condition"] = self.get_size_condition(samples, target_height, target_width)
if self.text_emb:
batch["encoder_hidden_states"], batch["pooled_outputs"] = self.get_text_embeddings(samples, self.text_emb if isinstance(self.text_emb, str) else "text_emb")
else:
batch["captions"] = self.get_captions(samples, self.caption)
return batch
get_size_condition
学習時に画像をリサイズしたり、切り抜きをしますが、当然その際に情報が失われます。size_conditionはモデルにそのことを伝えるためにあります。具体的には学習データの元々の解像度縦横、切り抜き位置縦横、前処理後の解像度縦横の6要素からなります。生成時は基本的に二つの解像度は生成解像度に設定して、切り抜き位置は0にします。このコードではbucketingに基づいてリサイズ・中央切り抜きするため元々の解像度があれば切り抜き位置は計算できます。この切り抜き情報の計算が厄介で、実装があっているかもわかりませんが、✟kohya_tech✟さんが解説してくれています。
def get_size_condition(self, samples, target_height, target_width):
size_condition = []
for sample in samples:
if sample in self.original_size:
original_width = self.original_size[sample]["original_width"]
original_height = self.original_size[sample]["original_height"]
original_ratio = original_width / original_height
target_ratio = target_width / target_height
if original_ratio > target_ratio: # 横長の場合
resize_ratio = target_width / original_width # 横幅を合わせる
resized_height = original_height * resize_ratio # 縦をリサイズ
crop_top = (target_height - resized_height) // 2 # 上部の足りない分がcrop_top
crop_left = 0
else:
resize_ratio = target_height / original_height
resize_width = original_width * resize_ratio
crop_top = 0
crop_left = (target_width - resize_width) // 2
else:
original_width, original_height = target_width, target_height
crop_top = 0
crop_left = 0
size_list = [original_height, original_width, crop_top, crop_left, target_height, target_width]
size_condition.append(torch.tensor(size_list))
return torch.stack(size_condition)
get_その他
各データをフォルダから取り出すメソッドです。いっぱいあってきりがないのでとばしますが、基本はTensorにしてミニバッチでまとめます。キャプションの場合は文字列のリストにします。その際ucgやprefixの処理を行います。
Dataloader
データセット側がミニバッチまで作成するため、データローダーはあんまりやることないです。データローダー側のバッチサイズは1にして、ミニバッチの作成を定義するcollate_fnは以下のようにすることでDatasetの__get_item__()の出力がそのままでてきます。
def collate_fn(x):
return x[0]
Config
学習設定についてやっていきます。configはmain, trainer, dataset, dataloader, network(次回用)に分かれます。以下がConfigの構造ですが、ChatGPTに作ってもらったのでよく分からないです。デフォルト値を設定しておくことで、設定項目が増えても設定ファイルを更新しなくていいようにしていきます。
設定ファイルでdatasetやtrainerなどのモジュール名を直接指定できるようにすることで、拡張性を高くしてます。代わりに開発者にしか使い方が理解できないようになりました。
https://github.com/laksjdjf/sd-trainer/blob/main/modules/config.py
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any
from omegaconf import MISSING
@dataclass
class MainConfig:
model_path: str = MISSING
output_path: str = MISSING
seed: Optional[int] = 4545
sdxl: bool = MISSING
clip_skip: Optional[bool] = None
steps: Optional[int] = None
epochs: Optional[int] = None
save_steps: Optional[int] = None
save_epochs: Optional[int] = 1
sample_steps: Optional[int] = None
sample_epochs: Optional[int] = 1
log_level: str = "loggging.WARNING"
wandb: Optional[str] = None
@dataclass
class OptimizerConfig:
module: str = "torch.optim.AdamW"
args: Optional[Any] = None
@dataclass
class TrainerConfig:
module: str = "modules.trainer.BaseTrainer"
train_unet: bool = MISSING
train_text_encoder: bool = MISSING
te_device: Optional[str] = None
vae_device: Optional[str] = None
train_dtype: str = MISSING
weight_dtype: str = MISSING
autocast_dtype: Optional[str] = None
vae_dtype: Optional[str] = None
lr: str = MISSING
lr_scheduler: str = "constant"
gradient_checkpointing: bool = False
optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
validation_num_samples: int = 4
validation_seed: int = 4545
validation_args: Dict[str, Any] = field(default_factory=dict)
@dataclass
class DatasetArgs:
batch_size: int = MISSING
path: str = MISSING
metadata: str = "buckets.json"
original_size: Optional[str] = None
latent: Optional[str] = "latents"
caption: Optional[str] = "captions"
image: Optional[str] = None
text_emb: Optional[str] = None
prompt: Optional[str] = None
prefix: str = ""
shuffle: bool = False
ucg: float = 0.0
@dataclass
class DatasetConfig:
module: str = MISSING
args: DatasetArgs = field(default_factory=DatasetArgs)
@dataclass
class DataLoaderArgs:
num_workers: int = 0
shuffle: bool = True
@dataclass
class DataLoaderConfig:
module: str = MISSING
args: DataLoaderArgs = field(default_factory=DataLoaderArgs)
@dataclass
class NetworkArgs:
module: str = MISSING
unet_key_filters: Optional[List[str]] = None
module_args: Optional[Dict[str, Any]] = None
conv_module_args: Optional[Dict[str, Any]] = None
text_module_args: Optional[Dict[str, Any]] = None
@dataclass
class NetworkConfig:
train: bool = MISSING
args: NetworkArgs = field(default_factory=NetworkArgs)
@dataclass
class Config:
main: MainConfig = field(default_factory=MainConfig)
trainer: TrainerConfig = field(default_factory=TrainerConfig)
dataset: DatasetConfig = field(default_factory=DatasetConfig)
dataloader: DataLoaderConfig = field(default_factory=DataLoaderConfig)
network: Optional[NetworkConfig] = None
yaml
実際の設定ファイルは以下のような感じです。上のConfigでMISSINGになっていない項目は省略可能です。
main:
model_path: "model_path"
output_path: "output_path"
seed: 4545
sdxl: false
clip_skip: null
steps: null
epochs: 20
save_steps: null
save_epochs: 2
sample_steps: null
sample_epochs: 2
log_level: "logging.INFO"
wandb: sd-trainer
trainer:
module: modules.trainer.BaseTrainer
train_unet: true
train_text_encoder: false
te_device: null
vae_device: null
train_dtype: torch.float32
weight_dtype: torch.bfloat16
autocast_dtype: null
vae_dtype: torch.bfloat16
lr: "1e-5"
lr_scheduler: "cosine"
gradient_checkpointing: false
optimizer:
module: torch.optim.AdamW
args: null
validation_num_samples: 4
validation_seed: 4545
validation_args:
prompt: "waifu, anime, exceptional, best aesthetic, new, newest, best quality, masterpiece, extremely detailed, astolfo, fate, solo, white cloak, armor, gauntlets, garter strap, black thighhighs, skirt, nature"
negative_prompt: "realistic, real life, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name"
width: 640
height: 896
dataset:
module: modules.dataset.BaseDataset
args:
batch_size: 4
path: "dataset_path"
metadata: "buckets.json"
dataloader:
module: torch.utils.data.DataLoader
args:
num_workers: 4
shuffle: true
get_attr_from_config
設定ファイルの文字列からモジュールをインポートする関数も用意しています。
def get_attr_from_config(config_text: str):
if config_text is None:
return None
module = ".".join(config_text.split(".")[:-1])
attr = config_text.split(".")[-1]
return getattr(importlib.import_module(module), attr)
main.py
学習スクリプトの中核コードです。コマンドライン引数で設定ファイルのパスを受け取り、Omegaconfでロードします。そしてtrainerやdatasetを作成して、学習ループを回します。ループ中に決められた頻度でモデルのセーブやサンプル画像の生成をしていきます。
orを多用していますが、実はorはTrueやFalseを返すわけではありません。A if A else Bと同じ意味です(多分)。Aの値を優先するけど、AがNoneのときはBにするみたいなときに便利ですね。
WandBによるログ管理ができるようになっています。損失やサンプル画像をアップロードします。VRAM使用量なんかも実装せずとも勝手に記録されていて便利だよね。
from omegaconf import OmegaConf
import sys
import math
from accelerate.utils import set_seed
from modules.utils import get_attr_from_config, collate_fn
from modules.config import Config
from tqdm import tqdm
import logging
import wandb
logger = logging.getLogger("メインちゃん")
def main(config):
set_seed(config.main.seed)
logger.info(f"シードは{config.main.seed}だよ!")
logger.info(f"モデルを{config.main.model_path}からロードしちゃうよ!")
trainer_cls = get_attr_from_config(config.trainer.module)
trainer = trainer_cls.from_pretrained(config.main.model_path, config.main.sdxl, config.main.clip_skip, config.trainer)
dataset_cls = get_attr_from_config(config.dataset.module)
dataset = dataset_cls(trainer.text_model, **config.dataset.args)
dataloder_cls = get_attr_from_config(config.dataloader.module)
dataloader = dataloder_cls(dataset, collate_fn=collate_fn, **config.dataloader.args)
trainer.prepare_modules_for_training()
trainer.prepare_network(config.network)
trainer.prepare_optimizer()
if config.main.wandb is not None:
wandb_run = wandb.init(project=config.main.wandb, name=config.main.output_path, dir="wandb")
else:
wandb_run = None
steps_per_epoch = len(dataloader)
total_steps = config.main.steps or steps_per_epoch * config.main.epochs
total_epochs = config.main.epochs or math.floor(total_steps / steps_per_epoch)
logger.info(f"トータルのステップ数は{total_steps}だよ!")
trainer.prepare_lr_scheduler(total_steps)
save_interval = config.main.save_steps or config.main.save_epochs * steps_per_epoch
sample_interval = config.main.sample_steps or config.main.sample_epochs * steps_per_epoch
logger.info(f"モデルを{save_interval}ステップごとにセーブするよ!")
logger.info(f"サンプルは{sample_interval}ステップごとに生成するよ!")
progress_bar = tqdm(total=total_steps, desc="Training")
current_step = 0
for epoch in range(total_epochs):
for batch in dataloader:
logs = trainer.step(batch)
progress_bar.update(1)
progress_bar.set_postfix(logs)
if wandb_run is not None:
wandb_run.log(logs, step=current_step)
if current_step % save_interval == 0 or current_step == total_steps - 1:
trainer.save_model(config.main.output_path)
if current_step % sample_interval == 0 or current_step == total_steps - 1:
images = trainer.sample_validation(current_step)
if wandb_run is not None:
images = [wandb.Image(image, caption=config.trainer.validation_args.prompt) for image in images]
wandb_run.log({'images': images}, step=current_step)
else:
[image.save(f"image_logs/{current_step}_{i}.png") for i, image in enumerate(images)]
current_step += 1
if current_step == total_steps:
logger.info(f"トレーニングが終わったよ!")
if wandb_run is not None:
wandb_run.finish()
return
logger.info(f"エポック{epoch+1}が終わったよ!")
if __name__ == "__main__":
config = OmegaConf.load(sys.argv[1])
config = OmegaConf.merge(OmegaConf.structured(Config), config)
logging.basicConfig(level=logging.INFO)
print(OmegaConf.to_yaml(config))
main(config)
Trainer
mainで使われている学習用メソッドについて紹介していきます。
prepare_modules_for_training
モデルをデバイスに移動して型を指定してgrad_scalerを定義してtrainやrequired_grad_を設定してgradient_checkpointingを設定するメソッドです。
指定する型はなんと4つあります。train_dtypeは学習対象モデルの型です。基本的にfloat32にした方がいいと思います。bfloat16にするとfull_bf16になってVRAM使用量が下がります。weight_dtypeは学習対象以外のモデルの型です。AMPを利用する場合bfloat16やGPUが対応していない場合float16がいいでしょう。float8_e4m3fnを指定するとさらにVRAM使用量を削減できますが、現状fp8ではほとんどの計算ができないので、autocast_dtypeをbfloat16とかに設定する必要があります。vae_dtypeはfloat16で(NaN;)になるSDXLのVAEを使う場合に別の型を指定できるようになっています。
テキストエンコーダはEmbedding層をfp8にするとエラーが起きるので、そこだけautocast_dtypeにしておきます(KohakuBlueLeafさんがsd-scriptsにfp8を適用したコードを参考にしています。)
GradScalerはAMP利用時に勾配のアンダーフロー(値が小さすぎて0になってしまう)を防ぐためのものです。勾配は通常float32(train_dtypeと同じ)で計算されますが、途中の計算がfloat16だとアンダーフローが起きてしまう可能性が起きます。それを防ぐためにlossを大きな値でスケーリングして、勾配を計算した後再スケーリングして元の値に戻します。この処理はダイナミックレンジがfloat32より小さいfloat16でのみ必要であって、bfloat16では必要ないので(多分)、enabledでfloat16のときだけ適用するようにしています(そうしないとfull_bf16でエラーが起きる)。
train()やrequired_grad_()に関しては学習対象だったらTrueに、そうでなかったらFalseにするだけですね。ただしgradient_checkpointingはDiffusersのコードだとtrain(True)のときにしか適用されないので、そうします。
def prepare_modules_for_training(self, device="cuda"):
config = self.config
self.device = device
self.te_device = config.te_device or device
self.vae_device = config.vae_device or device
self.train_dtype = get_attr_from_config(config.train_dtype)
self.weight_dtype = get_attr_from_config(config.weight_dtype)
self.autocast_dtype = get_attr_from_config(config.autocast_dtype) or self.weight_dtype
self.vae_dtype = get_attr_from_config(config.vae_dtype) or self.weight_dtype
self.te_dtype = self.train_dtype if config.train_text_encoder else self.weight_dtype
logger.info(f"学習対象モデルの型は{self.train_dtype}だって。")
logger.info(f"学習対象以外の型は{self.weight_dtype}だよ!")
logger.info(f"オートキャストの型は{self.autocast_dtype}にしちゃった。")
self.grad_scaler = torch.cuda.amp.GradScaler(enabled=self.autocast_dtype == torch.float16)
self.diffusion.unet.to(device, dtype=self.train_dtype if config.train_unet else self.weight_dtype)
self.text_model.to(self.te_device, dtype=self.te_dtype)
if hasattr(torch, 'float8_e4m3fn') and self.te_dtype== torch.float8_e4m3fn:
self.text_model.set_embedding_dtype(self.autocast_dtype) # fp8時のエラー回避
self.vae.to(self.vae_device, dtype=self.vae_dtype)
self.diffusion.unet.train(config.train_unet)
self.diffusion.unet.requires_grad_(config.train_unet)
self.text_model.train(config.train_text_encoder)
self.text_model.requires_grad_(config.train_text_encoder)
self.vae.eval()
if config.gradient_checkpointing:
self.diffusion.enable_gradient_checkpointing()
self.text_model.enable_gradient_checkpointing()
self.diffusion.unet.train() # trainでないと適用されない。
self.text_model.train()
logger.info("勾配チェックポイントを有効にしてみたよ!")
prepare_optimizer
最適化関数を用意するメソッドです。学習率はカンマ区切りの文字列で設定することでunetとテキストエンコーダの学習率を変えられます。[0]と[-1]を参照することで、"1e-5"みたいな感じで一つの数字のみ指定しても対応できるようになってます。
self.networkに関する行は次回やります。
最適化関数は文字列でモジュールを指定します。"torch.optim.AdamW"とか"bitsandbytes.optim.AdamW8bit"とかね。引数も指定できます。
def prepare_optimizer(self):
lrs = [float(lr) for lr in self.config.lr.split(",")]
unet_lr, text_lr = lrs[0], lrs[-1]
logger.info(f"UNetの学習率は{unet_lr}、text_encoderの学習率は{text_lr}にしてみた!")
params = []
if self.config.train_unet:
params += [{"params":self.diffusion.unet.parameters(), "lr":unet_lr}]
if self.config.train_text_encoder:
params += [{"params":self.text_model.parameters(), "lr":text_lr}]
if self.network:
params += self.network.prepare_optimizer_params(text_lr, unet_lr)
optimizer_cls = get_attr_from_config(self.config.optimizer.module)
self.optimizer = optimizer_cls(params, **self.config.optimizer.args or {})
logger.info(f"オプティマイザーは{self.optimizer}にしてみた!")
total_params = sum(p.numel() for group in self.optimizer.param_groups for p in group['params'] if p.requires_grad)
logger.info(f"学習対象のパラメーター数は{total_params:,}だよ!")
return params
prepare_lr_scheduler
これはdiffusersのget_schedulerの力をお借りします。num_warmup_stepsとか適当なので改善が必要なきがする。
def prepare_lr_scheduler(self, total_steps):
self.lr_scheduler = get_scheduler(
self.config.lr_scheduler,
optimizer=self.optimizer,
num_warmup_steps=int(0.05 * total_steps),
num_training_steps=total_steps
)
logger.info(f"学習率スケジューラーは{self.lr_scheduler}にした!")
step
学習ループ1ステップ分の処理をまとめています。ここはPytorchの定型的なコードですね。ログで損失値の指数移動平均や学習速度等が出力できるようにしておきます。
def step(self, batch):
b_start = time.perf_counter()
self.optimizer.zero_grad()
loss = self.loss(batch)
self.grad_scaler.scale(loss).backward()
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
self.lr_scheduler.step()
if hasattr(self, "loss_ema"):
self.loss_ema = self.loss_ema * 0.99 + loss.item() * 0.01
else:
self.loss_ema = loss.item()
b_end = time.perf_counter()
samples_per_second = self.batch_size / (b_end - b_start)
logs = {"loss_ema":self.loss_ema, "samples_per_second":samples_per_second, "lr":self.lr_scheduler.get_last_lr()[0]}
return logs
loss
ミニバッチを受け取って損失を計算します。Diffusion modelの損失は潜在変数にノイズを加えてUNetに入力し、ノイズ予測を出して加えたノイズとの平均二乗誤差になります。
潜在変数はvae.scaling_factorでスケーリングする必要があります。ここはWARNINGが出てunet.config.scaling_factorを使えと言われるので変えるかもしれません。
SD2はv_predictionモデルなので損失はvelocityとの誤差になります。それを計算するのがschedulerのget_targetメソッドです。
def loss(self, batch):
if "latents" in batch:
latents = batch["latents"].to(self.device) * self.vae.scaling_factor
else:
with torch.autocast("cuda", dtype=self.vae_dtype), torch.no_grad():
latents = self.vae.encode(batch['images'].to(self.device)).latent_dist.sample() * self.vae.scaling_factor
self.batch_size = latents.shape[0] # stepメソッドでも使う
if "encoder_hidden_states" in batch:
encoder_hidden_states = batch["encoder_hidden_states"].to(self.device)
pooled_output = batch["pooled_outputs"].to(self.device)
else:
with torch.autocast("cuda", dtype=self.autocast_dtype):
encoder_hidden_states, pooled_output = self.text_model(batch["captions"])
if "size_condition" in batch:
size_condition = batch["size_condition"].to(self.device)
else:
size_condition = None
timesteps = torch.randint(0, 1000, (self.batch_size,), device=latents.device)
noise = torch.randn_like(latents)
noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)
with torch.autocast("cuda", dtype=self.autocast_dtype):
model_output = self.diffusion(noisy_latents, timesteps, encoder_hidden_states, pooled_output, size_condition)
target = self.scheduler.get_target(latents, noise, timesteps) # v_predictionの場合はvelocityになる
loss = nn.functional.mse_loss(model_output.float(), target.float(), reduction="mean")
return loss
save_model, sample_validation
モデルは学習対象のみをセーブするようにしています。検証画像の生成オプションは設定ファイルから決められます。コードは大したことないので省略します。
さあ学習だ!
学習データを用意して設定ファイルを定義して、
python main.py config/config.yaml
でできます。コマンドライン引数をいっぱいならべるのはきらいです。
よくある質問(妄想)
Q. collate_fnってlambda式つかえばよくねwww
A. と思ってやったらnum_workersを1以上にしたときにmultiprocessで使う
pickleにlambda式が対応してないとかでエラーが起こるらしい。
Q. キャプションのシャッフルはしないんですか
A. ほんとうはするべきだとおもいますが、めんどくさいのでしません。
Q. 各ステップでモデルを上書きせず個別に保存しないんですか
A. 前のコードはそういう機能もありましたが、私はほとんど使いませんでした。サンプル画像の生成しているんだからよくね?