sd-scriptsで任意の画像とキャプションで学習する

はじめに

sd-scriptsのリポジトリでは画像ファイルとキャプションファイルを指定してStable DiffusionやLoRA等を学習できますが、場合によってはより柔軟な学習をしたいこともあるかと思います。たとえば画像を動的に生成したい、augmentationを自由に行いたい、画像やキャプションをステップに応じて変化させたい、などです。
この記事ではそのような場合に、任意のDatasetを定義して学習する方法について記述します。

Datasetクラスの作成

クラスの定義

任意のPythonスクリプトファイルに定義します。
PyTorchのDatasetがそのまま使えれば良いのですが、学習用スクリプトがメタデータ等をDatasetから取得しているため、特にLoRAの学習ではそのままでは動作しません。最低限のメタデータを(一部ダミーを含めて)提供する基底クラス、train_util.MinimalDatasetクラスが用意してありますので、そちらを継承して作成してください。

以下にサンプルを置きますので、実装の詳細はこちらのコード、コメントをご覧ください(常用漢字LoRA・白黒版を作るためのDatasetクラスです)。

実際に学習する場合はGoogle Fonts等でフォントファイルを用意してください。他に必要なファイル、学習したLoRA、学習設定は末尾に書きました。

学習スクリプトの起動オプションで指定

定義したクラスをtrain_network.pyなどの学習スクリプトに、引数--dataset_classに、「package.module.ClassName」という形で指定します(.tomlファイルでの指定も可能です)。

※train_textual_inversion_XTI.pyは任意のDataset指定に対応していません。

たとえば以下のようになります。

accelerate launch --num_cpu_threads_per_process 1 train_network.py 
    --config_file ..\cfg_ja_sd.toml --output_name ja-sd-1 
    --dataset_class logs.joyo_kanji_dataset.JoyoKanjiDataset

ファイルの置き場所はどこでも良いと思われますが、私は適当にlogsフォルダ内に配置しています(git pullの妨げにならないように)。

制限

bucketing、latentのキャッシュなどには対応していませんので、もし必要な場合にはDatasetクラス内で独自に行ってください(基底クラスBaseDatasetのメソッドをうまく使えば行けるかもしれませんが未検証です)。

おわりに

この記事では、任意のDatasetを定義して学習する方法について解説しました。画像の動的生成やaugmentationの自由な実施、ステップに応じた画像やキャプションの変化など、より柔軟な学習が可能です。また、LoRAの学習では、最低限のメタデータを提供する基底クラスtrain_util.MinimalDatasetクラスが用意されているため、そちらを継承して作成することができます。本記事を通して、より効率的で自由度の高い学習を行えることをご理解いただけたかと思います。引き続き、このサイトでの更新や、他の記事にもぜひ関心を持っていただけると幸いです。

↑このまとめはnoteの「AIアシスタント(β)」により生成されました(;・∀・) 

余談ですがリポジトリ名のsd-scriptsはもう少しいい感じにしておけば良かったですね……。

サンプルのLoRAおよび学習設定

LoRAのモデルファイルはこちらに置きました。実際に使用する場合のキャプションは「letter X, in serif」や「with sans, the letter X」などになります。学習解像度が192x192ですので、この解像度で生成してください。

たまに字形がおかしくなる……

使用できるフォント名は以下です。存在しない"bold brush"などもそれなりに動作します。

"pop"
"handwritten neat"
"maru sans"
"bold sans"
"sans"
"light sans"
"bold serif"
"serif"
"light serif"
"brush pop"
"brush playful"
"brush"
"handwritten simple"

実際に学習する場合は、Datasetから参照する文字一覧letters.txtを適切な場所に置いてください。

学習設定は以下になります。

pretrained_model_name_or_path = "path/to/v1-5-pruned-emaonly.ckpt"
output_dir = "path/to/lora_output"
output_name = "ja_sd"
resolution = "192,192"
vae_batch_size = 32
cache_latents = false
save_precision = "fp16"
save_every_n_epochs = 8
xformers = true
max_train_epochs = 512
max_data_loader_n_workers = 4
persistent_data_loader_workers = true
seed = 42
gradient_checkpointing = false # true
mixed_precision = "fp16"
sample_every_n_epochs = 8
sample_prompts = "path/to/prompts_ja_sd.txt"
sample_sampler = "k_euler_a"
save_model_as = "safetensors"
optimizer_type = "adamw8bit"
learning_rate = 5e-4 # 1e-3
train_unet_only = true
network_module = "networks.lora"
network_dim = 64
network_args = [ "conv_dim=64" ] # , "rank_dropout=0.25" ]
noise_offset = 0.1

学習はかなり時間が掛かり、バッチサイズ64、学習率5e-4で512エポックでは足りず、さらに学習率2e-4で160エポック学習しました。それでもまだ足りないかもしれません。

この記事が気に入ったらサポートをしてみませんか?