見出し画像

CLIP Text Deprojectorを使って画像生成してみる ~モデルの大幅な簡略化~

前回まで、LSTMベースのモデルのアーキテクチャを検討し、最後に訓練データの変更を試しました。今回は、そこで得た知見から、さらにモデルアーキテクチャを簡略化してみます。

前回の記事

他のStable Diffusionの関連記事

Layered Diffusion Pipelineを使うためのリンク集


背景

これまでに、Vicinity-Transformerモデルの実験により、

  • 入力embeddingの重要性

  • 近傍の最終隠れ層状態の重要性

が判明しています。また、LSTMモデルでの実験により、

  • 再帰的ネットワークの有効性

  • 最終正規化層の固定の効果

  • EOS以降の最終隠れ層状態の情報量

が分かりました。

特に、最後の「EOS以降の最終隠れ層状態の情報量」の判明によって、学習対象となる最終隠れ層状態のシーケンスの複雑性が大きく減ったことで、モデルアーキテクチャの大幅な簡略化の可能性が生まれました。

つまり、EOS以前とEOS以後では、最終隠れ層状態の性質が変わるため、これまでのモデルはEOSの位置を推測する必要があったのですが、EOS以後のみの生成で十分であれば、EOSの位置を推測する必要がなくなり、その分のネットワークを削減できるということです。

新たなアーキテクチャ

新たなアーキテクチャは、次の性質を持ちます。

  • EOS位置以降の最終隠れ層状態を生成する

  • 再帰的ネットワーク

  • 各ステップの入力に次のものを受け取る

    • EOS位置の最終隠れ層状態 ← New!

    • 近傍の最終隠れ層状態(=直前N個のモデルの出力)

  • ゲート機構は使わない

  • 内部再帰接続は使わない

    • 代わりに、各ステップでN個前までの近傍の状態を入力に使用

  • 最終正規化層は固定

  • 2つのネットワークの組み合わせ

    • 1つ目はEOSを含む最初の最終隠れ層状態を生成する ← New!

    • 2つ目は再帰的に残りの最終隠れ層状態を生成する

アーキテクチャを2つのネットワークに分けた点が、これまでのアーキテクチャとの大きな違いになります。この理由は、

  • EOS位置の状態の生成には、近傍の最終隠れ層状態が利用できないので、アーキテクチャが異なる

  • 複数embeddingの合成で得られた入力embeddingを修正する専用ネットワークが欲しい

という点にあります。EOS位置以降の最終隠れ層状態を生成するようになり、EOS位置が常にシーケンスの最初にあることが決まっているために、ネットワークを分離することができるようになりました。

下に、上記の点を考慮に入れた新たなアーキテクチャの模式図を示します。

新たなアーキテクチャ

この模式図の中で、モデル#1と#2のアーキテクチャと、vicinity #1と#2の値、さらにEOS位置の出力(出力#2)をモデル#2の入力に追加で使うかどうかは、変更可能なパラメータとなっています。

実験詳細

今回は、このアーキテクチャの有効性を確認するため、最低限のシンプルな構成で実験を行います。

まず、モデル#1とモデル#2のネットワークですが、シンプルな線形変換とします。レイヤー正規化も活性化関数も適用しません。

vicinityの値と出力#2の使用については、まず、以下の最小限の組み合わせを試します。

  1. vicinity #1 = 1、vicinity #2 = 1、出力#2は追加で使用しない

  2. vicinity #1 = 1、vicinity #2 = 1、出力#2を追加で使用する

  3. vicinity #1 = 1、vicinity #2 = 2、出力#2は追加で使用しない

3つ目の組み合わせは、モデル#2への入力次元数を2つ目の組み合わせと同じにして比較するために用意しました。

生成画像

画像は上から次の順に並んでいます。

  1. Deprojectorなし

  2. vicinity #1 = 1、vicinity #2 = 1、出力#2は追加で使用しない

  3. vicinity #1 = 1、vicinity #2 = 1、出力#2を追加で使用する

  4. vicinity #1 = 1、vicinity #2 = 2、出力#2は追加で使用しない

使用したプロンプトはこれまでと同様、次の通りです。

単一embedding

  • cat maid (猫耳メイド)

  • 1girl red hair blue eye black skirt(赤髪 青目 黒スカート)

  • 1boy 1girl in class room(少年 少女 教室)

複数embeddingの合成

  • cat, maid (猫 メイド)

  • 1girl, red hair, blue eye, black skirt (赤髪 青目 黒スカート)

  • 1boy, 1girl, in class room (少年 少女 教室)

単一embedding
複数embeddingの合成

以上の結果から、出力#2を追加入力するモデルが最もよい画像を生成し、かつそのクオリティがこれまでのモデルと大きく変わらないことが確認できました。

さらに、以下の記事で選んだLSTMベースの最良モデルのモデルファイルのサイズは47.3MBでしたが、今回上で試したモデルのモデルファイルサイズは9.02MBで、モデルファイルサイズの大幅な削減(約5分の1)に成功しました。

追加実験

上の実験は成功のように見えますが、詳細を見ると問題点があります。それは、モデル#2に比べ、モデル#1の学習がうまく進んでいないという点です。これは、出力#2の正解とモデル出力で計算したコサイン類似度を見ると分かります。

  1. [2段目] vicinity=(1,1), 出力#2を使用しない : 0.821

  2. [3段目] vicinity=(1,1), 出力#2を使用する : 0.514

  3. [4段目] vicinity=(1,2), 出力#2を使用しない : 0.817

このように、生成画像3段目の出力#2の追加入力を使うアーキテクチャの時だけ、コサイン類似度が大きく下回っていることが分かります。

これは、出力#2の状態が、モデル全体の出力として学習データの正解と比較されるだけでなく、それ以降の出力の生成にも入力として使われるため、そちらからも勾配が流入して学習が妨げられていると考えられます。

そこで、モデルの出力を再帰的に次のステップの入力として渡す時、ネットワークから切り離して勾配が流入しないようにして学習を行いました。

生成画像

実験には、先の実験の3段目の「vicinity #1 = 1、vicinity #2 = 1、出力#2を追加で使用する」モデルを使用します。

生成画像は次の順に並んでいます。

  1. Deprojectorなし

  2. 全ての再帰的入力をネットワークから切り離さない

  3. 出力#2の追加入力だけネットワークから切り離す

  4. 全ての再帰的入力をネットワークから切り離す

単一embedding
複数embeddingの合成

以上の結果から、生成画像の質は3つのバリエーションで大きな違いはないように思われます。

出力#2のコサイン類似度の計算結果は次のようになりました。

  1. [2段目] 切り離さない : 0.514

  2. [3段目] 出力#2のみ切り離さない : 0.877

  3. [4段目] すべて切り離す : 0.902

ここから、4段目の「全ての再帰的入力をネットワークから切り離す」学習方法が最も良いコサイン類似度となりました。

テストデータにおける損失関数の値も比較してみます。損失関数にはMLEを使用しています。

  1. [2段目] 切り離さない : 0.2158

  2. [3段目] 出力#2のみ切り離さない : 0.2365

  3. [4段目] すべて切り離す : 0.2147

損失関数においても、4段目の「全ての再帰的入力をネットワークから切り離す」学習方法が最も良い結果となりました。

まとめ

今回は、新たな再帰的ネットワークのアーキテクチャを考案して実験しました。

新たなアーキテクチャは、これまでのVicinity TransformerモデルやLSTMモデルと比較しても、同程度の性能を示し、かつ、パラメータサイズの大幅な削減ができました。

学習時の再帰的入力では、ネットワークから切り離して勾配の流入を止めることが、良い学習結果につながることがわかりました。

この記事が気に入ったらサポートをしてみませんか?