見出し画像

[調査] Smile Test: Elysium_Anime_V3 問題を調べる #3

前々回、前回、と「1 token目が無視される」問題について調べてきたが、恐らく破損箇所と修正方法がわかった。
さらに、修正の効果の検証と、修正の副作用についてもまとめた。


クイックまとめ

・「1token目が無視される」問題について、モデル内の破損箇所を特定した
・複数の修正方法を確認した(記事内参照)
・修正した場合の影響をまとめた
・モデル内のチェック用 Extension を追記した (06/02)


破損箇所を発見?

Arenaさん(sd-tagging-helper の作者様)が、最近はモデル内部構造について調べておられるのだが、モデル内部データの破損・汚損の例として上げられていた内容が気になりお話を伺ってみた。

cond_stage_model.transformer.text_model.embeddings.position_ids

このキーの内容が torch.float32 に変わってしまっているが、本来は torch.int64。Add Difference されると倍率合計として 1 倍ではないため、本来整数のこの値が整数からずれている。さらに、

※bbcが聞き取った内容を記憶でまとめたもの
正確な引用ではありません

例:Elysium_Anime_V3 の値

Elysium_Anime_V3.safetensors
tensor([[ 0.0000,  1.0000,  2.0000,  2.9995,  3.9999,  4.9985,  5.9990,  6.9994,
          7.9999,  9.0004,  9.9971, 11.0013, 11.9979, 13.0021, 13.9988, 15.0031,
         15.9998, 16.9965, 18.0007, 19.0049, 19.9942, 20.9984, 22.0025, 23.0068,
         23.9958, 25.0000, 26.0042, 26.9935, 27.9976, 29.0019, 30.0061, 30.9954,
         31.9996, 33.0037, 33.9930, 35.0121, 36.0015, 36.9907, 38.0098, 38.9991,
         39.9884, 41.0074, 41.9967, 42.9861, 44.0051, 44.9944, 46.0137, 47.0028,
         47.9917, 49.0107, 50.0000, 50.9897, 52.0084, 52.9983, 53.9869, 55.0066,
         55.9953, 56.9852, 58.0039, 58.9935, 60.0122, 61.0019, 61.9908, 63.0102,
         63.9991, 64.9880, 66.0075, 67.0262, 67.9860, 69.0047, 70.0242, 70.9832,
         72.0030, 73.0214, 73.9815, 75.0000, 76.0197]])

このずれて float になった値だが、モデルがロードされる時に int64 へのキャストされるために小数点が切り捨てられ違う値になる。fp16 化の時は四捨五入が走るため、直っている(場合がある?)と考えられる。

※bbcが聞き取った内容を記憶でまとめたもので、正確な引用ではありません

この現象の直し方、前回・前々回と調べてきたものに酷似している。
これなら「fp16 にすると直る」というメカニズムにも説明がつく。


確認方法 (06/02 追記)

Extension 化して公開していたけど案内していなかったので、追記しておく。修正機能はありませんので、下記の記事内を参照ください。


修正方法

この発見に基づいて、8528d-fix 及び Elysium_Anime_V3 を修正してみた。
具体的には、上記のキー内の tensor を強引に初期化する。
結果、正しく修正されているように見える。

8525d-fix(上段)と、上記のキーを整数値で修正した keychange 版 (下段) の出力
与えたプロンプトは「smile sleepy girl」
上段では1token目の「smile」が無視されていたが、keychange 版(下段)では笑顔が見られる
同じく、Elysiumと修正版

修正方法1: Arenaさんの Extension

※ Arena さんが公開している Extension (※2) に CLIP-fix オプションが追加されました。Web UI から処理したい場合は、そちらも利用できるそうです。(自分は試していないので、確認の上でご使用ください)
(※2) arenatemp/stable-diffusion-webui-model-toolkit

修正方法2: 修正スクリプト

一応の別の方法として、修正スクリプトを置いておきます。
Extensionではなく、CLI です。

・初版 (4.19 kB) novelai 特有Key対応が抜けていたため、削除
・修正版 @ 01/15. (4.16 kB)

修正スクリプトの使用方法

CLI用です。コマンドラインから使用してください。
argparse で作ってるので、`python fix_postion_ids.py -h` でまずはコマンド確認ください。

  • --model <読み込むモデルのフルパス(必須)> 
    model 指定のみだと、判定だけ行います

  • --verbose
    出力がもうちょい増えます

  • --out <修正後モデル出力先フルパス、モデル名含む>
    出力先ファイル名を入力すると、モデルの修正と出力を行います。モデル名末尾が .safetensors になっている場合、safetensors で保存します

  • コマンド例

python fix_position_ids.py --model E:\tool\sd\model_sd\need_fix\Elysium_Anime_V3.safetensors --verbose

修正方法3: Merge Block Weighted の "Skip/Reset CLIP key" を使う

"Merge Block Weighted" Extension
https://github.com/bbc-mc/sdweb-merge-block-weighted-gui

設定項目
None: 何もしない(これまでと同じ計算方法)
Skip: 対象のキーの計算をスキップする
Force Reset: 修正スクリプトと同じく、整数へ書き換える

調査1:モデル内の対象キーの状態

説明が前後するが、まずは各モデルの状態の確認を行う。

各モデルについて、対象キー(torch.float32) の状態と、それを .to(torch.int64) でキャストしたらどうなるかを調査し比較した。
また、異常判定のための出力を追加し、どこがずれているか・おかしいかを確認した。

出力結果の読み方

出力結果は、テンソルを4つ表示したものになる。
上から
1.現状のテンソル:モデルから読み込んだままのデータ
2.torch.int64に変換したテンソル:torch.int64 に型変換したらどうなるか
3.修正内容候補(一定)
4.判定結果:現在のデータと、修正の内容候補に差があるか(True で差なし、False で差あり)

さらに、異常点が見つかった場合は、以下の2つの情報を表示する。
corrupt:何番目のtoken が"ずれて" いるかのリスト
missing:データに記載されている番号のうち、ずれた結果として無くなっている番号のリスト

stable diffusion 1.5

loading ... sd-v1-5-pruned.ckpt
# current data is:
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
         36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
         54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
         72, 73, 74, 75, 76]])
<class 'torch.Tensor'>
torch.int64

== if changed to torch.int64 ==
<class 'torch.Tensor'>
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
         36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
         54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
         72, 73, 74, 75, 76]])
torch.int64

# change to:
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
         36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
         54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
         72, 73, 74, 75, 76]])
<class 'torch.Tensor'>
torch.int64
#
tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True]])

tensor内のデータの型は torch.int64 になっている。
値もきれいに順番に並んでおり、欠けたり飛ばしたりしていない。

8528d-fix

loading ... 8528d-fix.ckpt
# current data is:
tensor([[ 0.0000,  1.0000,  1.9999,  2.9996,  3.9998,  5.0000,  5.9992,  6.5492,
          7.9996,  9.0008, 10.0000, 10.9992, 11.9984, 12.9977, 13.5477, 15.0000,
         15.9992, 16.9984, 18.0016, 18.5516, 20.0000, 20.9953, 21.9984, 23.0016,
         23.9969, 25.0000, 25.9953, 26.9984, 27.5484, 28.9969, 30.0000, 30.9953,
         31.9984, 32.5484, 33.9969, 35.0000, 36.0031, 36.9906, 37.5406, 38.9969,
         40.0000, 41.0031, 41.9906, 42.5406, 43.9969, 45.0000, 46.0031, 46.9906,
         47.9938, 48.9969, 50.0000, 51.0031, 51.9906, 52.9938, 53.9969, 55.0000,
         55.5500, 56.9906, 57.9938, 58.9969, 60.0000, 60.5500, 61.9906, 62.9938,
         63.9969, 64.5469, 65.5500, 66.9906, 67.9938, 68.9969, 70.0000, 70.5500,
         72.0062, 72.9938, 73.9812, 75.0000, 75.5500]])
<class 'torch.Tensor'>
torch.float32

# == if changed to torch.int64 ==
<class 'torch.Tensor'>
tensor([[ 0,  0,  1,  2,  3,  5,  5,  6,  7,  9, 10, 10, 11, 12, 13, 15, 15, 16,
         18, 18, 20, 20, 21, 23, 23, 25, 25, 26, 27, 28, 30, 30, 31, 32, 33, 35,
         36, 36, 37, 38, 40, 41, 41, 42, 43, 45, 46, 46, 47, 48, 50, 51, 51, 52,
         53, 55, 55, 56, 57, 58, 60, 60, 61, 62, 63, 64, 65, 66, 67, 68, 70, 70,
         72, 72, 73, 75, 75]])
torch.int64

# change to:
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
         36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
         54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
         72, 73, 74, 75, 76]])
<class 'torch.Tensor'>
torch.int64

#
tensor([[ True, False, False, False, False,  True, False, False, False,  True,
          True, False, False, False, False,  True, False, False,  True, False,
          True, False, False,  True, False,  True, False, False, False, False,
          True, False, False, False, False,  True,  True, False, False, False,
          True,  True, False, False, False,  True,  True, False, False, False,
          True,  True, False, False, False,  True, False, False, False, False,
          True, False, False, False, False, False, False, False, False, False,
          True, False,  True, False, False,  True, False]])
corrupt token indexes : [1, 2, 3, 4, 6, 7, 8, 11, 12, 13, 14, 16, 17, 19, 21, 22, 24, 26, 27, 28, 29, 31, 32, 33, 34, 37, 38, 39, 42, 43, 44, 47, 48, 49, 52, 53, 54, 56, 57, 58, 59, 61, 62, 63, 64, 65, 66, 67, 68, 69, 71, 
73, 74, 76]
missing token numbers : [4, 8, 14, 17, 19, 22, 24, 29, 34, 39, 44, 49, 54, 59, 69, 71, 74, 76]

まず1.現状のテンソルから。
tensor の型が torch.float32 に変化している。
また、値がところどころで若干マイナスに振れており、一部は -0.5付近(index=76, value=75.55) まで影響を受けている。

次に 2.torch.int64テンソル
前から見るとすぐに異常点が見つかるが、tensor([[0, 0, 1, 2, 3 … となっている。SD15 のデータと見比べればわかるが、ここは 0,1,2,3 … となっていないとおかしい。
さらに 4.ずれの判定 や、その下に並んでいるずれている番号リスト無くなっている値リストから、多くの範囲でこのズレが発生している事がわかる。

Elysium_Anime_V3

# current data is:
tensor([[ 0.0000,  1.0000,  2.0000,  2.9995,  3.9999,  4.9985,  5.9990,  6.9994,
          7.9999,  9.0004,  9.9971, 11.0013, 11.9979, 13.0021, 13.9988, 15.0031,
         15.9998, 16.9965, 18.0007, 19.0049, 19.9942, 20.9984, 22.0025, 23.0068,
         23.9958, 25.0000, 26.0042, 26.9935, 27.9976, 29.0019, 30.0061, 30.9954,
         31.9996, 33.0037, 33.9930, 35.0121, 36.0015, 36.9907, 38.0098, 38.9991,
         39.9884, 41.0074, 41.9967, 42.9861, 44.0051, 44.9944, 46.0137, 47.0028,
         47.9917, 49.0107, 50.0000, 50.9897, 52.0084, 52.9983, 53.9869, 55.0066,
         55.9953, 56.9852, 58.0039, 58.9935, 60.0122, 61.0019, 61.9908, 63.0102,
         63.9991, 64.9880, 66.0075, 67.0262, 67.9860, 69.0047, 70.0242, 70.9832,
         72.0030, 73.0214, 73.9815, 75.0000, 76.0197]])
<class 'torch.Tensor'>
torch.float32

# == if changed to torch.int64 ==
<class 'torch.Tensor'>
tensor([[ 0,  0,  1,  2,  3,  4,  5,  6,  7,  9,  9, 11, 11, 13, 13, 15, 15, 16,
         18, 19, 19, 20, 22, 23, 23, 25, 26, 26, 27, 29, 30, 30, 31, 33, 33, 35,
         36, 36, 38, 38, 39, 41, 41, 42, 44, 44, 46, 47, 47, 49, 50, 50, 52, 52,
         53, 55, 55, 56, 58, 58, 60, 61, 61, 63, 63, 64, 66, 67, 67, 69, 70, 70,
         72, 73, 73, 74, 76]])
torch.int64

# change to:
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
         36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
         54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
         72, 73, 74, 75, 76]])
<class 'torch.Tensor'>
torch.int64

#
tensor([[ True, False, False, False, False, False, False, False, False,  True,
         False,  True, False,  True, False,  True, False, False,  True,  True,
         False, False,  True,  True, False,  True,  True, False, False,  True,
          True, False, False,  True, False,  True,  True, False,  True, False,
         False,  True, False, False,  True, False,  True,  True, False,  True,
          True, False,  True, False, False,  True, False, False,  True, False,
          True,  True, False,  True, False, False,  True,  True, False,  True,
          True, False,  True,  True, False, False,  True]])
corrupt token indexes : [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 17, 20, 21, 24, 27, 28, 31, 32, 34, 37, 39, 40, 42, 43, 45, 48, 51, 53, 54, 56, 57, 59, 62, 64, 65, 68, 71, 74, 75]
missing token numbers : [8, 10, 12, 14, 17, 21, 24, 28, 32, 34, 37, 40, 43, 45, 48, 51, 54, 57, 59, 62, 65, 68, 71, 75]
  • 1.現状のテンソル tensor の型が torch.float32 になっている。
    値はマイナスに変化している所と、プラスに変化している所がある。

  • 2. int64 0,0,1,2,3,4,5,6,7,9 … となっており、値がずれている。

  • Missingリスト 8, 10, 12 となっているが、8528d-fix とは異なる番号が無くなっており、特に法則性は見受けられない。

8528d-final (fp16)

ついでに調査。
上段から、現状、現状を int64へキャストした場合の結果、各値がずれているかどうか。

loading ... 8528d-final.ckpt
no state_dict. direct model.
# current data is:
tensor([[ 0.0000,  1.0000,  2.0000,  3.0000,  4.0000,  5.0000,  6.0000,  6.5508,
          8.0000,  9.0000, 10.0000, 11.0000, 12.0000, 13.0000, 13.5469, 15.0000,
         16.0000, 17.0000, 18.0000, 18.5469, 20.0000, 21.0000, 22.0000, 23.0000,
         24.0000, 25.0000, 26.0000, 27.0000, 27.5469, 29.0000, 30.0000, 31.0000,
         32.0000, 32.5625, 34.0000, 35.0000, 36.0000, 37.0000, 37.5312, 39.0000,
         40.0000, 41.0000, 42.0000, 42.5312, 44.0000, 45.0000, 46.0000, 47.0000,
         48.0000, 49.0000, 50.0000, 51.0000, 52.0000, 53.0000, 54.0000, 55.0000,
         55.5625, 57.0000, 58.0000, 59.0000, 60.0000, 60.5625, 62.0000, 63.0000,
         64.0000, 64.5625, 65.5625, 67.0000, 68.0000, 69.0000, 70.0000, 70.5625,
         72.0000, 73.0000, 74.0000, 75.0000, 75.5625]], dtype=torch.float16)
<class 'torch.Tensor'>
torch.float16

# == if changed to torch.int64 ==
<class 'torch.Tensor'>
tensor([[ 0,  1,  2,  3,  4,  5,  6,  6,  8,  9, 10, 11, 12, 13, 13, 15, 16, 17,
         18, 18, 20, 21, 22, 23, 24, 25, 26, 27, 27, 29, 30, 31, 32, 32, 34, 35,
         36, 37, 37, 39, 40, 41, 42, 42, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
         54, 55, 55, 57, 58, 59, 60, 60, 62, 63, 64, 64, 65, 67, 68, 69, 70, 70,
         72, 73, 74, 75, 75]])
torch.int64

# change to:
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
         36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
         54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
         72, 73, 74, 75, 76]])
<class 'torch.Tensor'>
torch.int64

#
tensor([[ True,  True,  True,  True,  True,  True,  True, False,  True,  True,
          True,  True,  True,  True, False,  True,  True,  True,  True, False,
          True,  True,  True,  True,  True,  True,  True,  True, False,  True,
          True,  True,  True, False,  True,  True,  True,  True, False,  True,
          True,  True,  True, False,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True, False,  True,  True,  True,
          True, False,  True,  True,  True, False, False,  True,  True,  True,
          True, False,  True,  True,  True,  True, False]])
corrupt token indexes : [7, 14, 19, 28, 33, 38, 43, 56, 61, 65, 66, 71, 76]
missing token numbers : [7, 14, 19, 28, 33, 38, 43, 56, 61, 66, 71, 76]

1.現在のテンソル 内部の値が torch.float16 に変化しているが、これは fp16 化の処理が行われているので、妥当な変化。
2. int64 前半は 0,1,2,3,4,5,6,6,8, と、前半は順調だがややずれがある。
Corrupt/Missingリスト 8528d-fix と比較するとだいぶリストが短くなったが、やはりズレがある。また、無くなっている値がある。

調査1の結果

上記の調査結果から、対象キーの値がずれている、ことはわかった。
また、値のずれ方がプラスの場合・マイナスの場合が同じモデル内でも混在している事がわかった。


調査2:対象キー内の値がずれた事の影響

今回の調査で分かった tensor 内の値のずれだが、これを直すことによって、「1 token目を無視する」問題を解消できる事は、冒頭にさらっと述べた。

この事から、この値がプロンプト~モデル間の関係性についての設定値であることはわかっている。しかし、具体的に、どのようにずれたら、どのように影響が出るのかは判明していない。

そこで、上記修正スクリプト内の修正用 tensor を意図的に壊すことで、どの位置の値がどこにどのような影響を与えているか調べた。

画像の出力条件

プロンプトは以下の通り。Seed は固定、CLIP skip=1, eta=0

smile sleepy girl standing bear
Steps: 40, Sampler: Euler a, CFG scale: 7.5, Seed: 651296271, Face restoration: CodeFormer, Size: 512x512

使用した生成条件

調査2-1. 各場所を 0 にしてみる

各ファイル名の末尾の数字の意味は、以下の通り。

1-0:tensor([[ 0,0,2,3,4,5 …
2-0:tensor([[ 0,1,0,3,4,5 …
3-0:tensor([[ 0,1,2,0,4,5 …
3-1:tensor([[ 0,1,2,1,4,5 …

※かわいい

1.一番上が、通常の 8528d-fix。smile が無視されている。
2.keychange_8528dfix は、修正スクリプトを適用したもの。smile が戻っている。
3.break_1-0 再び smile が失われた
4.break_2-0 sleepy が失われ、はっきりと目が覚めている。
5.break_3-0 girl が失われ、クマになった。
6.break_3-1 girl 部の値を 1 にしたもの。girl は失われていないし、1のsmile も失われていない。

break_3-1 の結果をどのように読み取るかが難しい。

break_3-1仮説#1."3token目として受け取るのは、3.girl に代わって、1.smile" という動き?

元  : smile sleepy girl  standing bear
新?: smile sleepy smile standing bear

=> この仮説は否定できそう。
3-0 で girl を失うとクマになるのはわかっているから、3-1 でも girl を失うのであれば、おなじくクマにならないとおかしい。

break_3_1仮説#2."3token目は、1 番目のレーンに行きなさい"


元  : smile sleepy girl   standing bear
新?: smile sleepy (null) standing bear
      girl

これはありえるかも。
この場合、girl が "まるで 1token目かのように" 振る舞う可能性がある。
そうだとすると、girl の影響力が高まっているはず。
また (null) と記載したが、何も渡されなくなった 3token目がどのように振る舞うのかが不明。

調査2-2. 影響力のありそうな値を index=1 へ移動させてみる

そこで、以下のような設定で実験を行った。

元 : 0, 1, 2, 3, 4, 5
4-1: 0, 1, 2, 3, 1, 4
5-1: 0 ,1, 2, 3, 4, 1

プロンプトは「smile sleepy girl standing bear」であるから、
4-1 では、standing が影響を受けるはず
5-1 では、bear が影響を受けるはず

4-1 の結果:
standing が 1token目に合流すると、より…立つようになる?
差が分からない。ちょっと実験の設定がまずい気がする。

5-1 の結果:
bear が 1 token目に合流すれば、よりクマ度が上がるはず?
1列目はクマ耳がリボンになった。
2列目はややズームしたので、どちらかといえば smile が強調されたか?
3列目はクマ耳大型化。
4列目は、やや笑顔抑えめ・やや顔アップ。

3-1 ~ 5-1 では、プロンプトの該当番号箇所の意味が強化された。
(と言い切れるのか?!!!!!)
よく分からないので追試。


1-5 を作成する。仮説が正しければ、1.smile が (null) になり、5.bear に重ねるのだから、笑顔は弱まるはず。

(比較のため再掲) keychange_8528d_fix
keychange_8528d_fix_break_1-5

(笑顔が、やや控えめになったようにも見えるが。。。)

この事から、このキーのデータを用いて「n 番目の token は、x 番目として扱う」という処理がされていると考えられる。
この仮説を念頭に 1-5 の結果を見ると、5:bear が 1:smile の位置へ移動しているが、1:smile の効果が無くなってはいない事・笑顔が弱くなっている事が分かる。

以上から、今回のキー内容の数値がずれる問題による影響は、
「本来のプロンプトであれば前から順番に影響力が高い順となるはずが、
数値がずれた位置については影響力が移動先の場所で算定されるため、
プロンプトの場所ごとの影響力勾配がなだらかではなくなっている」

と言える。
ラフに言えば、
『これまでの感覚でプロンプトを並べると、時々効果に波や荒れが出る』ことになる。


詳細な破損の流れの再現

これまでの知見から、より具体的に破損の流れを再現する。
再現手順とレシピは以下の通り。

  1. Add Difference マージ

  2. Add Difference マージ

重要な点

上記の再現手順を見て、最初からわかっとるやないかーい!と言いたくなるが、『Add Difference したら必ずモデルが壊れる』わけでは無い点が疑問点だった。
本手順ではその部分も含めて再現する。

ステップ1. Add Difference する

O1 = model_A + model_B + model_C, 1.0

ここはどのような組み合わせでも良い。
倍率もいくつでもかまわない。
モデル A, B, C は全て正常なキーを持つとする。

この配合で対象キーに問題は生じない。
O1 は(倍率にもよるが)正常な出力を得られる事が多いはずだ。

ステップ2. Add Difference する

O2 = model_D + model_E + O1, 1.0
※モデルD, E は全て正常なキーを持つとする。

ここが重要。
Add Difference でいう所の (C) に Add Difference 済みモデル O1 を持ってきた。この時に内部のキーでは以下のような計算が行われる。
簡単のために、キー内の値が 1 の部分について考える。
(つまり「1 token 目」に該当する部分。ここが 0 になって問題は生じていた。)

$$ O2 = D + 1.0 * (E - O1) $$

ここで、O1 でどのような計算が行われたか見てみると、

$$
O1 = A + 1.0 * (B - C) \\
= 1 + 1.0 * (1 - 1)
$$

「つまり O1 = 0 でしょ????」となるはずだが、ここが異なる。
1.0 * (1-1) は少数点の演算のため、0 の表現が厳密に 0 ではなくなっている。

この事を確認するため、以前の記事の手法で「差分抽出」したモデルについて、同キーを調べた結果が以下の通り。

(調査結果)

これにより、O1 に関する計算は、実態としては以下の通りとなる。
$ 1.0 * (1-1) $ で得られる 0 に近いナニカを $ α $ と置くと、

$$
O1 = A + 1.0 * ( B - C ) \\
= 1 + 1.0 * ( 1 - 1 ) \\
= 1 + α
$$

では、この計算で得られた O1 を、O2 の式へ適用する。
O2 の計算でも再び 1-1 の計算が行われるが、その際に残留する値を仮に $ β $ と表現すると、

$$
O2 = D + 1.0 * ( E - O1) \\
= 1 + 1.0 * ( 1 - 1 - α ) \\
= 1 + 1.0 * ( β - α ) \\
= 1 + ( β - α ) 
$$

となる。
α および β の定義は、非常に小さい値、正負は未定、相互の大小は不定、であるから、1 + (β - α) は、1よりはわずかに小さい『事がある』
※α と βはどちらも非常に小さい値で、どちらかが安定的に大きくなるかは自分に知見がない。

この、1 より小さいということが、本稿の最初で示した現象につながる。
「モデルを読み込むときに torch.int64 に変換するが、小数点以下は切り捨てで計算される」
これにより、キー内の値がどのように当初の数値より小さくなるかが示された。


まとめ

破損箇所

今回の「1 token目が無視される」問題について、モデルの構造内の破損箇所までほぼ特定できたと考える。

修正方法の問題点

また、暫定的な修正方法は既にあったが、それぞれに問題があった。

  1. CLIP を入れ替えるというのはやや破壊な方法であり、DreamBoothなど CLIP にも学習を施している場合、学習成果が失われる可能性があった。

  2. fp16 化では修正が不十分な可能性がある。さらに、上記のキー内の数値がずれたままになっている事から、この対処をしたモデルは次世代にこの破損を引き継いでしまう、という問題があった。

今回作成した修正方法のメリット

今回提示したキーを修正する方法であれば、
1.CLIP 内の学習成果である重み等は保持できる
2.プロンプトとの対応付は正常化される

それでも残る問題点

修正を施した場合、プロンプトの効き方が変わるため、これまでとは絵柄が変わったと感じる可能性がある。
また、特に training/fine-tune している場合、この狂った CLIP 前提で学習が進んでいた場合、モデルの描画性能の変化を引き起こす可能性がある。

いずれにしても、絵作りにおいては「絵柄が変わる」というのは一大事なので、修正の適用はよく考えて行う必要がある。
(それぞれ、ちゃんと対応が出来るということは示せた)

以上。

この記事が気に入ったらサポートをしてみませんか?