SD3でRegional Prompt

 SD3はSDXL以前と異なりMMDiTというモデルが使われています。そこでSD3でも前にやったやつと同じことをできるか試してみました。
 SD3の設計については以下の記事がくわしーので省略します。


SDXLまでとの違い

 SDXLはSelf AttentionとCross Attentionに分かれていましたが、SD3ではその二つが合体しているようなイメージです。そのためSDXLまででやっていたCross Attentionをプロンプトごとに計算する方法は使えません。合体したAttentionを複数回計算することもできますが、計算量が大きいので非効率です。そこでAttentionマスクを使うことにしました。Attentionマスクはqk^T行列に足されます。Softmaxが適用されることを考えると、マイナス無限大を足してあげることで、一部分を無効化することができます。

マスクの設計

 簡単のため二つのプロンプトを用意して、左右に分けることを考えます。このとき、画像の左側に対して右側用プロンプトとの関係性を断ち切ってあがればおっけーです(左右逆でも同様)。
 画像サイズは1024×1024、左右のプロンプトはclipとT5合わせて154トークンとします。するとAttentionのトークン数は画像(64×64=4096)と左右のプロンプト(154×2=308)を結合して、4096+308=4404になります。
※画像はVAEで8分の1、パッチ分割で2分の1になります。
 以下がAttentionマスクです。灰色がマイナス無限大にすべき領域です。

あたまのたいそうになります

 赤い部分は画像のパッチ間の関係性なので変えません。紫や青の部分が画像とプロンプトの関係性を計算する部分で、ここをうまいこと左右分かれるようにマスクします。緑の部分はプロンプトのトークン間の関係性を計算する部分で、左右が完全に分離するようにマスクします。

実装

 今回はDiffusersという誰も使ってないライブラリにあるAttentionProcessorというものを利用します。まずはチュートリアル通りロードします。

import torch
from diffusers import StableDiffusion3Pipeline
pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
pipe = pipe.to("cuda")

 Pytorch2.0以上の場合、たぶんJointAttnProcessor2_0というよくわからんやつがAttention processorとしてセットされています。これを書き換えてあげます。

pipe.transformer.attn_processors
>{'transformer_blocks.0.attn.processor': <diffusers.models.attention_processor.JointAttnProcessor2_0 at 0x7efbc55cfad0>,...

 置き換えるprocessorは以下のような実装になります。__init__でマスクを設定しています。__call__の部分はF.scaled_dot_product_attentionの引数にマスクを指定しただけで、ほぼコピペです。

import torch
import torch.nn.functional as F
class AttnCoupleProcessor:
    """Attention processor used typically in processing the SD3-like self-attention projections."""

    def __init__(self):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
        
        image_size = 4096
        num_tokens = 154
        self.left_mask = torch.zeros((64, 64))
        self.left_mask[:, 32:] = -float('inf') # 左側プロンプト用、右側をマスク
        self.left_mask = self.left_mask.flatten() # 1次元に並べなおす
        self.right_mask = torch.zeros((64, 64)) 
        self.right_mask[:, :32] = -float('inf') # 右側プロンプト用、左側をマスク
        self.right_mask = self.right_mask.flatten() # 1次元に並べなおす
        
        self.attn_mask = torch.zeros((image_size+num_tokens * 2, image_size+num_tokens * 2))
        
        # 紫部分
        self.attn_mask[image_size:image_size+num_tokens, :image_size] = self.left_mask.unsqueeze(0)
        self.attn_mask[image_size+num_tokens:, :image_size] = self.right_mask.unsqueeze(0)
        
        # 青部分
        self.attn_mask[:image_size, image_size:image_size+num_tokens] = self.left_mask.unsqueeze(1)
        self.attn_mask[:image_size, image_size+num_tokens:] = self.right_mask.unsqueeze(1)
        
        # 緑部分
        self.attn_mask[image_size:image_size+num_tokens, image_size+num_tokens:] = -float("inf") 
        self.attn_mask[image_size+num_tokens:, image_size:image_size+num_tokens] = -float("inf")

    def __call__(
        self,
        attn,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor = None,
        attention_mask = None,
        *args,
        **kwargs,
    ) -> torch.FloatTensor:
        residual = hidden_states
        self.attn_mask = self.attn_mask.to(hidden_states)

        input_ndim = hidden_states.ndim
        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
        context_input_ndim = encoder_hidden_states.ndim
        if context_input_ndim == 4:
            batch_size, channel, height, width = encoder_hidden_states.shape
            encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size = encoder_hidden_states.shape[0]

        # `sample` projections.
        query = attn.to_q(hidden_states)
        key = attn.to_k(hidden_states)
        value = attn.to_v(hidden_states)

        # `context` projections.
        encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

        # attention
        query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
        key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
        value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=self.attn_mask)
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # Split the attention outputs.
        hidden_states, encoder_hidden_states = (
            hidden_states[:, : residual.shape[1]],
            hidden_states[:, residual.shape[1] :],
        )

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)
        if not attn.context_pre_only:
            encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
        if context_input_ndim == 4:
            encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        return hidden_states, encoder_hidden_states

以下のようなコードでprocessorを上書きできます。

pipe.transformer.set_attn_processor(AttnCoupleProcessor())

 一応matplotでマスクを可視化してみると、以下のような感じでした。黒がマスクする(マイナス無限大)領域です。

生成

 まずpipelineには二つのプロンプトに対応する機能なんかないので、プロンプトのエンコード部分を自前で行います。SD3のプロンプトはどうすればいいかわからないので適当です。

left = "2girl are walking side by side on park, yuri, black hair, red eyes, maid, maid apron, maid headdress, white thighhigh"
right = "2girl are walking side by side on park, yuri, white hair, blue eyes, white school uniform, red bow, blue sailor collar, blue skirt, black thighhigh"
left_emb = pipe.encode_prompt(left,left,left, max_sequence_length=77)
right_emb = pipe.encode_prompt(right,right,right, max_sequence_length=77)

prompt_embeds = torch.cat([left_emb[0], right_emb[0]], axis=1)
negative_prompt_embeds = torch.cat([left_emb[1], right_emb[1]], axis=1)
pooled_prompt_embeds = (left_emb[2] + right_emb[2]) / 2
negative_pooled_prompt_embeds = left_emb[3]

 pooled_embedsは適当に平均とってみました。ここも頑張れば分割できそうな気もしますが、めんどくさそうです。
 以下のように生成できます。

generator = torch.Generator()
generator.manual_seed(4545)
image = pipe(
    num_inference_steps=28,
    guidance_scale=4.0,
    prompt_embeds = prompt_embeds,
    negative_prompt_embeds = negative_prompt_embeds,
    pooled_prompt_embeds = pooled_prompt_embeds,
    negative_pooled_prompt_embeds = negative_pooled_prompt_embeds,
    generator = generator
).images[0]

結果

 てきとうにつくったLoRAを使っています。SD3の使い方は全く分からないので質はびみょうですが、プロンプトの分離自体はできていそうですね。

ノースリーブメイドさんとか珍しいですね。

おまけ

 紫、青、緑のいずれかの領域をマスクしないで生成してみます。コメントアウトしただけです。

紫領域をマスクしない

青領域をマスクしない

緑領域をマスクしない

 紫領域はマスクしなくてもうまく分離されていますね。この領域はプロンプト側が画像を参考に変形する領域で、画像側には直接影響しないからなのかな。反対に青領域は画像側がプロンプトを参考にする領域なので、マスクしないと全く分離されなくなります。緑部分はメイドさんのニーハイが黒になってしまっています。この領域はプロンプトが自分自身を参考にするもので、左右のプロンプトの意味が混じってしまう可能性があります。もしかしたらthighhighという共通ワードがあったため、そこだけ混同してしまったとか?書いてて思ってたけど複数形にすべきなのでは?

まとめ

ComfyUIさんmmditにもattn_patchみたいなの実装してください・・・