見出し画像

CLIP Text Deprojectorを使って画像生成してみる ~Embedding演算~

前回は、CLIP Text Deprojectorのモデルをデータ増強を用いて改良するという記事を書きましたが、今回は、主にEmbeddingの演算を工夫してモデルを改良したことについて書きます。

前回の記事

他のStable Diffusionの関連記事

Layered Diffusion Pipelineを使うためのリンク集

Embeddingの合成方法の改良

CLIP Text Deprojectorは、最終Embeddingからlast hidden state(最終隠れレイヤーの状態)を予測するモデルですが、モデルの入力として与えられるEmbeddingに複数のEmbeddingを合成したものを使うことができるというのが、特徴の1つとして挙げられます。

その際、複数のEmbeddingの合成を行う演算としてどのような演算が最も適当かということが、未解決問題として存在していました。これまではこれを、ベクトルの平均を取ることで行っていましたが、これが最適であるという確証はありませんでした。

今回、この点についていくつかの試行錯誤を経て、空文字列に対応するEmbeddingと、対象となる文字列のEmbeddingとの差を取って、この差を合計する方法が、経験的に生成画像の質を上げるようでした。下図は、この合成方法を図示したものです。

Embeddingの合成方法

生成画像

ベクトルの平均を取る方法と、空文字列との差を合計する方法の2種類の方法で生成したembeddingを元に画像生成を行って比較してみました。

上2段が従来のベクトルの平均を取る方法で、下2段が新しい空文字列との差を合計する方法を用いた画像です。

cat + maid
1girl + red hair + blue eye + black skirt
1boy + 1girl + in class room

ベクトルの平均を取る方法(上2段)では、人物が切れてしまったり、人数が増えたり、表情に不自然さがあったりしましたが、空文字列との差を合計する方法(下2段)では、それらの問題が解消しています。

傍証:コサイン類似度

さらに、空文字列との差を合計する方法がより適していることの傍証として、コサイン類似度を使った検証を行いました。コサイン類似度はベクトルの類似度を評価する手法の1つで、2つのベクトルが似ているほど1.0に近く、異なるほど-1.0に近い値を取ります。

2つの文字列"a"と"b"がある時に、2つを並べ替えて連結させた"a b"と"b a"はよく似た意味を持つと考えられます。その場合、2つの文字列に対するembeddingのコサイン類似度は1.0に近い値になることが望ましいです。

そこで、いくつかの文字列の組に対して、embeddingのコサイン類似度と、空文字列とのembeddingの差のベクトルのコサイン類似度をそれぞれ計算してみました。

  • 'cat maid' vs 'maid cat'

    • そのままのコサイン類似度:0.9056

    • 差のベクトルのコサイン類似度:0.9339

  • 'red hair blue eye' vs 'blue eye red hair'

    • そのままのコサイン類似度:0.9567

    • 差のベクトルのコサイン類似度:0.9717

  • '1girl classroom' vs 'classroom 1girl'

    • そのままのコサイン類似度:0.8832

    • 差のベクトルのコサイン類似度:0.9089

それぞれ、差のベクトルのコサイン類似度の方が高い値になり、語順によるembeddingへのノイズが低減されていると考えられます。

訓練方法の改良

さらに、モデルの訓練方法も前回のデータ増強に引き続き改良を続けています。今回の改良点は主に2つの点にあります。

  • 増強なしデータと増強ありデータで、同じテキストを共有しない

  • モデルへの入力データ形式を変更する

以下、それぞれについて解説します。

増強なしデータと増強ありデータで、同じテキストを共有しない

前回記事で増強データを作った時、LAION 400Mから取ったテキストをそのまま変化させずに訓練データに加えた後、同じテキストを加工して増強データを作って訓練データに加えていました。

こうすると、訓練データには、非常によく似たデータが2回出現してしまうことになり、訓練データに偏りが生まれてしまいます。前回の記事で指摘した、単一embeddingと合成embeddingでの学習速度の差が生まれる原因の1つがこれなのではないかと考えたのです。

そこで、新しい増強方式ではこのような重複が生まれないよう、増強データの生成に使ったテキストは、増強なしデータとして訓練データに加えることはしないように変更を加えました。

この変更の副産物として、重複がなくなったことで以前より訓練に使用するテキストの数が増え、訓練データの多様性が高まりました。

モデルへの入力データ形式を変更する

これまでの入力データの形式は次の図のように、1つ目に入力embeddingを配置し、残りはモデルの出力を1つずつずらして与えていました。これは、CLIPのテキストモデルを再利用するために考えられたモデル構造でした。

従来の入力データ形式

しかし、embeddingの合成方法の改良で見たように、embeddingそのままよりも、空文字列embeddingとの差を取る方が、embeddingの意味をはっきりさせることができると考えられました。

さらに、last hidden stateの1つ目の出力は、入力トークンのSOSに対応した出力で、どんな文字列に対しても厳密に同じベクトルとなることが分かっています。そのため、上図の入力の2つ目は、実質的に空きスロットとなっていました。

つまり、このモデルは、入力スロットとして2つ使える箇所があるところ、1つしか使用していないという状態となっていたのです。

改良された入力データ形式は次のような図で表されます。

改良された入力データ形式

1つ目のスロットには、空文字列から生成されたpooled stateが入ります。さらにこのベクトルは、2つ目以降のスロットの入力に対して差を計算するのに使われます。

2つ目のスロットには、モデルへの入力embeddingが与えられます。ただし、transformerに入力する前に、1つ目のスロットへの入力に使った空文字列embeddingとの差を計算するように変更しました。

3つ目以降のスロットには、モデルの出力を1つずつずらして入力します。ただし、2つ目のスロットと同じように、空文字列embeddingとの差が計算されるようになりました。

このように、それぞれの入力スロットで空文字列embeddingとの差を計算することは、transformerの最初の線形結合で固定のbiasを加えていることと同じなので、数学的な意味は全くありませんが、事前学習パラメータの一部として学習を補助していると考えることができると思います。

また、1つ目のスロットに与えた空文字列embeddingは入力に依らない固定値なので、やはりこのスロットも事実上の空きスロットと考えることもできます。そのため、このスロットのより有効な活用方法は依然として今後の課題として残されています。

生成画像 - 単一embedding

以上の改良を比較検証するために、次の3種類のモデルを作って同じ乱数シードから画像を生成しました。

  • v9: 前回記事のモデル

  • v15: v9から重複を生まない増強方式に変更したモデル

  • v20: v15から入力データ形式を変更したモデル

また、v15, v20では訓練データの元テキストのデータセットを、LAION 400MからLAION 2B enへと変更していますが、この変更はモデルの質的な変更をもたらすものではないと考えられます。

さらに、前回記事では訓練量の比較として、訓練データ生成の元となるテキストの数を数えていましたが、本記事では増強後の訓練データの件数を用いて訓練量の比較とするようにしました。

以下、左からv9, v15, v20と並んで、最上段がDeprojectorなしの比較用画像、2段目以降から訓練データ件数が1万件、2万件、3万件、4万件と並んでいます。

使用したプロンプトは、前回と同様、次のテキストに画質調整用のタグをいくつか追加したものを使用しています。

  • cat maid (猫耳メイド)

  • 1girl red hair blue eye black skirt(赤髪 青目 黒スカート)

  • 1boy 1girl in class room(少年 少女 教室)

猫耳メイド
赤髪 青目 黒スカート
少年 少女 教室

最右列(v20:重複なし、新入力形式)の生成画像に明らかな改善が見られることが見て取れます。

  • 人物の枠内への収まりが改善

  • 訓練を進めても、人物と背景の位置関係といった構図への影響が少ない

  • 背景の描き込みがしっかりして、無地背景になりにくい

  • プロンプトの特徴がより反映されている(猫耳、青目など)

生成画像 - 複数embeddingの合成

最後に、複数embeddingの合成を用いた画像生成の結果を見ます。前回と同様、プロンプトには、次の3種類を使用しました。

  • cat, maid (猫 メイド)

  • 1girl, red hair, blue eye, black skirt (赤髪 青目 黒スカート)

  • 1boy, 1girl, in class room (少年 少女 教室)

以下、左からv9, v15, v20と並んで、上2段が学習件数2万件のモデルを用いた画像、下2段が学習件数4万件のモデルを用いた画像となります。

猫 メイド
赤髪 青目 黒スカート
少年 少女 教室

複数embeddingの合成には、まだ解決しなければならない課題が多く残されているように思われます。

まとめ

今回行ったのは以下のことです。

  • embeddingの合成方法として、空文字列embeddingとの差を取る方法を考案しました。

  • データ増強でテキストの重複を含めない方式に変更しました。

  • モデルの入力データの形式を変更し、空文字列embeddingとの差を用いるようにしました。

今後の課題としては、次のようなことが挙げられます。

  • 訓練データ件数を増やした時の生成画像の変化

  • 有効な正則化の考案とその効果の検証

  • 増強データの最適な比率とより効果的なデータ増強方法の考案

  • 複数embeddingの合成からの生成画像の改善方法の探索

この記事が気に入ったらサポートをしてみませんか?