(コード読み) AUTOMATIC1111 での、モデル読み込み/保存における副作用

TL;DR

日々お世話になっている AUTOMATIC1111氏の Web UI だが、モデルを読み込むときに少し 中身をいじる ので、何をしているのかをまとめておく。
また、保存方法に少し癖があるので、まとめておく。

Note:本記事の確認は 2022/12/24。それ以降の更新については考慮していない。

記載の方針

できるだけ、githubのコードの当該行へのリンクをつけるようにする

モデルの読み込み

モデルは、sd_models.py 内で読み込まれている。

ttps://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/sd_models.py


エラーを黙らせる

まずは、"五月蝿いエラーを黙らせた"。ちょっとおもしろかったので。
#pragma warning(disable:4996)

try:
    # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.

    from transformers import logging, CLIPModel

    logging.set_verbosity_error()
except Exception:
    pass

モデル内のキーを書き換えている

chckpoint_dict_replacements = {
  'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
  'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
  'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
}

これ、地味に大事なポイントだと思う。
ここで行っている処理は、上の3行の dict について、「左側のキー(の一部)をみつけたら、右側に書き換える」ということをしている。
`cond_stage_model.` で始まる行は Text Encoder に属しており、例えば SD15 で探してみると、以下の3行が見つかる。(AUTO1111でロードして探したので、キー名変換が必要であれば、変換済み)

cond_stage_model.transformer.text_model.embeddings.position_ids		
cond_stage_model.transformer.text_model.embeddings.token_embedding.weight
cond_stage_model.transformer.text_model.embeddings.position_embedding.weight

git blame からこの変更処理が追加された commit を探すと、メッセージに以下のように記載されている。

コミットメッセージ:
more careful loading of model weights (eliminates some issues with checkpoints that have weird cond_stage_model layer names)

"モデルの重みの読み込み処理を、より「注意深く」行うようにした。"
"(「cond_stage_model レイヤーの奇妙な名前(複数)」を持つモデルデータに関する、いくつかの問題を除去した)"

※"weird" を「奇妙な」と訳したが、恐らくの意図として「なにやってんの、これ」位の、(良くある)軽い罵倒を感じる。

https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/10aca1ca3e81e69e08f556a500c3dc603451429b

この対策が行われたということは、なにかのモデルで、これらのキー名が異なる場合があったということか?(経緯は見つけられなかった)

この対策の重要なポイントは、以下の2点。

「AUTO1111 でモデルをロードすると、(使用している間のみ)モデルの Text Encoder の一部の機能が復活する可能性がある」

「AUTO1111 でモデルをロードしてマージすると、当該モデルの Text Encoder の一部の機能が復活 (または喪失) する可能性がある」


`state_dict` キーの位置を変えている

普段 AUTO1111 を使っていてエラーで見ることもある "state_dict" だが、AUTO1111 ではモデルロード時に、ちょっとデータの持ち方を変えている。

def get_state_dict_from_checkpoint(pl_sd):
    pl_sd = pl_sd.pop("state_dict", pl_sd)
    pl_sd.pop("state_dict", None)

https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/sd_models.py#L150

コミットメッセージ:
make it possible to save nai model using safetensors

"NAI モデルも safetensor 形式で保存できるようにした"

https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/0376da180c81a11880a2587903d69d85541051e7

やっていることは、
・pl_sd が辞書型なので、キー "state_dict" のデータを抜き出して、pl_sd 自体に代入する。取得できなかったら、pl_sd はそのまま。
・pl_sd から "state_dict" キーのデータを抜き出す(つまり除去)。無いなら無いで良い。

意図が不明だが、NAI 由来のデータには "state_dict" が二重で入ってたりするんだろうか?どう見ても、下記のような構造を想定しているとしか読めない。

pl_sd = {
    "state_dict": {
        "state_dict": {
        },
        ... まともなデータ ...
    }
}

モデルの保存

AUTO1111 のマージにおけるデータ保存では、上述した `state_dict` を使用していない。

torch.save(theta_0, output_modelname)

https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/5927d3fa95e9ae43252d598f7791ca26cfcad5e3/modules/extras.py#L339

diffusers や、外部で見つかるマージスクリプトでは、`state_dict` を使用して保存している。(下記の例は、diffuser の github のモデル変換スクリプト)

state_dict = {"state_dict": state_dict}
torch.save(state_dict, args.checkpoint_path)

https://github.com/huggingface/diffusers/blob/9be94d9c6659f7a0a804874f445291e3a84d61d4/scripts/convert_diffusers_to_original_stable_diffusion.py#L309

この記事が気に入ったらサポートをしてみませんか?