見出し画像

CLIP Text Deprojectorを使って画像生成してみる ~モデルを書き直してみる~

前回、モデルアーキテクチャの改良で生成画像の質が向上することが示されましたが、そこから得た知見をより生かすため、拡張性のあるモデルに書き直すことにしました。

前回の記事

他のStable Diffusionの関連記事

Layered Diffusion Pipelineを使うためのリンク集


初期モデルからの書き直し

初期モデル

現行のモデルアーキテクチャは、何度かの改良を経たもので構造が複雑になっています。そこで一度初期モデルに戻って拡張性のある形に書き直し、そこへ機能追加することにしました。

初期モデルとしてこの記事で導入した次のモデルアーキテクチャを採用しました。

初期モデルアーキテクチャ

また、モデルの性能を比較するための参照用として、5万件の訓練データを使用して学習した単体モデルを使用することにしました。このモデルは、次の記事の時に作成されたものを再利用します。

書き直し

書き直すにあたって、初期モデルでは使用していない1つ目のトークンに対応するスロットを削除することとしました。また、transformerレイヤーをライブラリに頼らず、再実装することとしました。ただし、attention部分については、既存の実装をそのまま利用することとしました。

その結果、新しいモデルアーキテクチャは次のようになりました。

書き直したモデルアーキテクチャ

生成画像

モデルの性能を比較するため、5万件の訓練データを使用して、8エポック学習させたモデル同士で画像生成を行って比較しました。

生成画像の比較は、これまでの記事と同じプロンプトを使用しています。

単一embedding

  • cat maid (猫耳メイド)

  • 1girl red hair blue eye black skirt(赤髪 青目 黒スカート)

  • 1boy 1girl in class room(少年 少女 教室)

複数embeddingの合成

  • cat, maid (猫 メイド)

  • 1girl, red hair, blue eye, black skirt (赤髪 青目 黒スカート)

  • 1boy, 1girl, in class room (少年 少女 教室)

画像は、最上段から、

  1. Deprojectorを使用しない(参照用)

  2. 初期モデル(学習1回目)

  3. 初期モデル(学習2回目)

  4. 書き直した新モデル

となっています。

単一embedding
複数embeddingの合成

書き換え後のモデルの性能は、元の初期モデルとほぼ同様であることが分かりました。

Vicinity(近傍)モデル

近傍に注目

書き換え後のモデルを拡張するにあたって、追加部分のモデルアーキテクチャをどのように実装するかを検討する上で、まず、Attentionの部分を別のアーキテクチャで置き換えたモデルを作って、性能を検証してみました。

前回の記事のモデルを詳細に分析したところ、注目しているトークン位置の直前の隠れ状態が重要で、遠く離れたトークン位置の隠れ状態はあまり重要でないと分かったため、CNNのような近傍のみを参照するモデルでも十分に機能するのではないかと仮説を立てました。

柔軟にネットワークの構造を変えられるように、一般的なCNNを使うのではなく、独自のアーキテクチャを実装しました。これを、Vicinity(近傍)モデルと呼ぶことにします。これに対し、上のAttentionを用いたモデルをTransformerモデルと呼んで区別します。

Vicinityモデルは、2種類のバリエーションを用意しました。1つは一括適用版と呼び、もう1つは個別適用版と呼びます。以下、それぞれのアーキテクチャを説明します。

一括適用

Vicinityモデルの特徴は、TransformerモデルのAttention部分を、入力embeddingとトークン位置の周辺の隠れ状態(vicinity)を入力にしたネットワークに置き換えた点にあります。

一括適用版の場合、このネットワークは、MLP(多層パーセプトロン)となっています。

このMLPは、vicinityの値だけトークン位置の周辺の隠れ状態を入力として受付ます。もし、トークン位置が左端でvicinityの値の分の隠れ状態がなければ、ゼロテンソル(全ての値がゼロのテンソル)を入力として与えます。

Vicinityモデル(一括適用、vicinity=2)

個別適用

個別適用版では、Attentionを置き換えるネットワークが一般的なMLPではなく、入力embeddingと各トークン位置の隠れ状態毎に、個別に線形変換を適用し、その後全て結合してから2段目の線形変換を適用するという形式となります。

この方式は、Attentionが各隠れ状態毎に個別に線形変換を適用してからAttentionの中心操作を行う手法を参考にしたもので、Attention部分の置き換えとして似たような性質を示すことを期待したものです。

Vicinityモデル(個別適用、vicinity=2)

生成画像

Vicinityモデルに対しても同様に、5万件の訓練データを使用して、8エポック学習させたモデル同士で画像生成を行って比較しました。使用したプロンプトは上と同じです。

実験で使用したモデルは、注目しているトークン位置を含む3つのトークン位置の隠れ状態を入力として使用しました。(vicinity = 3)

画像は、最上段から、

  1. Deprojectorを使用しない(参照用)

  2. 書き直した新モデル(Transformerモデル)

  3. Vicinityモデル(一括適用)

  4. Vicinityモデル(個別適用)

となっています。

単一embedding
複数embeddingの合成

これより、Transformerモデルと2種類のVicinityモデルは、どれも同程度の性能を持っていることが推定されます。

まとめ

初期モデルと同等の処理を行うモデル(Transformerモデル)を新たに書き直し、性能が元のモデルと同等であることを確認しました。

また、モデル拡張の一歩として、Vicinity(近傍)モデルを考案し、2種類の実装を検討しました。そのどちらも、Transformerモデルと同程度の性能があることが分かりました。

今後の記事では、TransformerモデルとVicinityモデルを組み合わせて、モデルを拡張して性能を向上させることを試みます。

この記事が気に入ったらサポートをしてみませんか?