見出し画像

CLIP Text Deprojectorを使って画像生成してみる ~正則化を試す~

前回は、CLIP Text Deprojectorのモデルの訓練データ中の増強データの最適な割合を探しましたが、今回は、訓練中の損失関数に正則化項を追加してみました。

前回の記事

他のStable Diffusionの関連記事

Layered Diffusion Pipelineを使うためのリンク集

正則化項

正則化は、モデルの過学習を制御するために、経験的に最適なモデルが存在する可能性の高いパラメーター空間に誘導する手法のことで、すでに使用しているearly stoppingも正則化の一種ということができます。

広く行われている正則化の手法としては他に、L2正則化やdropoutなどが存在します。

今回試したのは、モデルの出力のlast hidden stateのトークン位置ごとのベクトルの分散が大きく(あるいは、小さく)なるような正則化項を損失関数に付け加えるという方法でした。これは、各トークンに対応するベクトルがお互いに異なる情報を持っている方がよい(あるいは、情報を共有している方がよい)という経験則に対応します。

分散の計算は、平均値を計算して、平均値からのMSE(平均二乗誤差)を計算すればよいため、pytorchのライブラリを使って容易に計算できます。正則化項としては、分散が大きいほど損失関数が小さくなるようにするため、分散の符号を反転させて損失関数に足し合わせました。

なお、分散を計算するトークン位置は、EOSの位置までとしています。EOSよりも後のstateは新しい入力がないため追加の情報が含まれないと考えられるためです。

訓練中のモデル出力の分散の変化

正則化項の重みを変化させて学習を行ったときに、モデル出力の分散(=正則化項の値)がどのように変化するかをグラフにしました。

モデル出力の分散の変化

正則化項の重みが正の値であれば分散が大きくなり、負の値であれば分散が小さくなっていますが、どちらも一定の値で安定していて、発散していく様子はありませんでした。

生成画像 - 単一embedding

最左列から正則化項の重みが -0.2、0.0、0.2 の順に並んでいます。最上段は比較用で、2段目以降が訓練データ2万件、3万件、4万件、5万件と並んでいます。

使用したプロンプトはこれまでと同じ3種類です。

  • cat maid (猫耳メイド)

  • 1girl red hair blue eye black skirt(赤髪 青目 黒スカート)

  • 1boy 1girl in class room(少年 少女 教室)

cat maid (猫耳メイド)
1girl red hair blue eye black skirt(赤髪 青目 黒スカート)
1boy 1girl in class room(少年 少女 教室)

当初の期待に反して、分散が大きいよりは小さい方が生成画像の質が安定するようです。しかし、正則化項がないモデル(中列、重み=0.0)の画像が最もプロンプトに適した画像を出力しています。

まとめ

過学習を抑制するための正則化項として、モデル出力のlast hidden stateの各位置に対応するベクトルの分散を導入してみましたが、正則化項を使わないモデルの生成画像が最もよいという結果になりました。

この記事が気に入ったらサポートをしてみませんか?