ComfyUIにおけるUNet改造ノードの作り方
誰が得するんだろうこんな記事
ComfyUIのカスタムノードについて
新しいノードを作るには、以下の記事が詳しいです。
UNetを改造するとき、たとえばMODELを受け取って好き勝手いじって、MODELを出力するという話なら簡単です。しかし入力時にコピーされるわけではないので、適当に改変してしまうと入力側も変わってしまいます。ComfyUIはノードベースの生成UIであり、複数のノードにMODELを分岐させることがあります。たとえば新しく作ったカスタムノードを通ったMODELと通らなかったMODELの二つでそれぞれ生成して、効果を比較するみたいなことを自然にやりたくなるので、ちゃんと入力側は影響を受けないように改変することになります。ComfyUIにはそのための仕掛けが色々されています。
MODELについて
ComfyUIのMODELの正体は、ModelPatcherです。このクラスにはUNetの様々な部分にパッチをあてる機能があります。LoraLoaderノード等では、このパッチ部分だけをコピーすることで、巨大なUNetをコピーせずに様々なLoRA設定で並列に生成できるというわけです。ModelPatcherにBaseModelがあり、その中にUnetがあります。入力されたMODELからUNetにアクセスするためには、以下のような方法になります。
# 例として、SDXLかどうか判定するコード(他にもいい方法ありそうだけど・・・)
is_sdxl = hasattr(new_model.model.diffusion_model, "label_emb")
いっぱい改変するぞ
例として、DeepShrinkの実装ノードを見てみましょう。
m = model.clone()
if downscale_after_skip:
m.set_model_input_block_patch_after_skip(input_block_patch)
else:
m.set_model_input_block_patch(input_block_patch)
m.set_model_output_block_patch(output_block_patch)
model.clone()するといい感じにパッチだけコピーしたmodelが返ってくるっぽいです。これをせずにmodelに対して直接改変してしまうと、入力側の
modelも改変されてしまいます。
.set_model_*_patchみたいな関数がいっぱい用意されてるので、目的に応じてパッチを当てるという感じです。それでは各パッチについて説明していきます。
ちなみに渡すパッチは関数っぽく使えればいいので、__call__メソッドがあるクラスのインスタンス変数でもいいです。
set_model_attn1_patch, set_model_attn2_patch
attn1はself attentionの直前で適用されるパッチです。self attntionのq,k,vそれぞれへの入力とextra_options(後述)を受け取って、改変したq,k,vへの入力を返します。self attentionですから、基本的に三つの入力は同じものになっているはずですね。
attn2もほとんど同じです。k,vへの入力は基本的にテキストエンコーダの出力になっています。
全てのattentionに適用されるため、限定するには後述のextra_optionsの情報を使います。
利用例:hypertile
set_model_attn(1or2)_output_patch
これは出力に対するパッチです。さっきと同じ感じです。
set_model_attn(1or2)_replace
こちらはパッチではなくAttentionを置き換えます。例えば意味ないけどそのままAttentionを実行するには以下のような関数を渡せばいいです。
from comfy.ldm.modules.attention import optimized_attention
def replace_function(q, k, v, extra_options):
return optimized_attention(q, k, v, extra_options["n_heads"])
optimized_attentionはxformersなど、設定に応じて適切なattentionが呼び出されます。入力は既にto_q,to_k,to_vを通ったもので、出力はto_outの前です。先ほどまでのパッチと違い、複数適用はできません。そのためできればこちらは使いたくないですね(使いまくってるけど)。
また適用には一工夫必要です。replaceはpatchと違い、各モジュールごとに適用します。block_nameとnumberが必要です。block_nameは"input", "middle", "output"のどれかで、numberは何番目のブロックかを表します。これはSDXLに対応していません・・・。SDXLには各ブロックごとにAttentionが複数個ありますからね。SDXLに対応するためには以下のような関数が必要です。
# attn2の場合
def set_model_patch_replace(model, patch, key):
to = model.model_options["transformer_options"]
if "patches_replace" not in to:
to["patches_replace"] = {}
if "attn2" not in to["patches_replace"]:
to["patches_replace"]["attn2"] = {}
to["patches_replace"]["attn2"][key] = patch
ここでkeyはSD1系の場合、(block_name, number)の二要素タプル、SDXLの場合そこからAttention層の何個目かを表す整数を加え、(block_name, number, , transformer_index)の形にする必要があります。
ここでkeyについての情報を書いておきます。
SD1系(SD2系も同様)のkeyは
("input", [1,2,4,5,7,8]), ("middle", 0), ("output", [3,4,5,6,7,8,9,10,11])の三種類計16個あります。
SDXL系のkeyは
("input", [4,5], [0,1]), ("input", [7,8], [0,…,9])
("middle", [0,…,9])
("output", [0,1,2], [0,…,9]), ("output", [3,4,5], [0,1])
となります。
Attentionのモジュールを直接参照したいときは、以下のようなコードになります。
# ("input", block_id, transformer_index)
new_model.model.diffusion_model.input_blocks[block_id][1].transformer_blocks[transformer_index].attn2
例としてattention_coupleではCross Attention層の全モジュールを置き換える必要があるので、以下のような実装を行っています。
self.sdxl = hasattr(new_model.model.diffusion_model, "label_emb")
if not self.sdxl:
for id in [1,2,4,5,7,8]: # id of input_blocks that have cross attention
set_model_patch_replace(new_model, self.make_patch(new_model.model.diffusion_model.input_blocks[id][1].transformer_blocks[0].attn2), ("input", id))
set_model_patch_replace(new_model, self.make_patch(new_model.model.diffusion_model.middle_block[1].transformer_blocks[0].attn2), ("middle", 0))
for id in [3,4,5,6,7,8,9,10,11]: # id of output_blocks that have cross attention
set_model_patch_replace(new_model, self.make_patch(new_model.model.diffusion_model.output_blocks[id][1].transformer_blocks[0].attn2), ("output", id))
else:
for id in [4,5,7,8]: # id of input_blocks that have cross attention
block_indices = range(2) if id in [4, 5] else range(10) # transformer_depth
for index in block_indices:
set_model_patch_replace(new_model, self.make_patch(new_model.model.diffusion_model.input_blocks[id][1].transformer_blocks[index].attn2), ("input", id, index))
for index in range(10):
set_model_patch_replace(new_model, self.make_patch(new_model.model.diffusion_model.middle_block[1].transformer_blocks[index].attn2), ("middle", id, index))
for id in range(6): # id of output_blocks that have cross attention
block_indices = range(2) if id in [3, 4, 5] else range(10) # transformer_depth
for index in block_indices:
set_model_patch_replace(new_model, self.make_patch(new_model.model.diffusion_model.output_blocks[id][1].transformer_blocks[index].attn2), ("output", id, index))
いやいくらなんでももっといい方法はないのかよ。
set_model_input_block_patch
input_blockの各出力に対してパッチをあてます。KohyaさんのDeep Shrink実装のために用意されたものです。skip connectionにも適用されます。
使用例:Deep Shrink HiresFix
set_model_input_block_patch_after_skip
こちらはskip connecition側に適用されません。
set_model_output_block_patch
output_blockの各出力に対してパッチをあてます。前層からの入力とskip connection側の入力を受け取って、何らかの改変をして出力する感じです。skip connectionはcontrolnetがすでに適用されています。
extra_options, transformer_options
説明中に、リンクをいっぱい貼り付けましたが、引数としてこれらが使われていると思います。これはパッチに渡す情報を示す辞書です。内容を見てみましょう。二つはだいたい内容が同じですが、extra_optionsについてはAttention層特有の情報が追加されています(headの次元とか)
"block":(block_name, block_id)のタプルです。UNetの位置を確認するために使えます。
"block_index": Attentionの何番目かで、SDXLでは使うかもしれません。
"original_shape": UNetへの入力サイズです。基本的には(バッチサイズ, 4, latent_height, latent_width)になります。(バッチサイズはcfgの場合設定の2倍になります。)
"cond_or_uncond": バッチの要素がcond(=0)なのか、uncond(=1)なのかを示す列です。通常の生成では(1, 0)になっているはずです。ConditionalCombine等の特殊なワークフローの場合要素が3つになったりします。
"sigmas": ノイズの強さです。バッチサイズで拡張されたテンソルであるため、実数値を見たい場合はtransformer_options["sigmas"][0].item()のようにします。model.model.model_sampling.percent_to_sigmaという[0,1]の時刻からsigmaへ置き換える関数と組み合わせると、時刻を制限したパッチがつくれます。
他にもなんかありそうですが、使ったことないです。
set_model_unet_function_wrapper
最終手段として、UNetの関数をらっぷしちゃいます。適用位置はここ。apply_modelとUNetへの入力情報を受け取って、ノイズが除去された潜在変数を返します。ノイズ予測ではないことに注意。
apply_modelとUNet.forward()を自前で作ってしまえば、好き放題できることになります。ただComfyUIのアップデートに対応しづらい実装なので、なるべくやりたくないですけどね。
適用例:DeepCache
またこのラッパーを使ってモデルを改変⇒元に戻すことで安全に何らかの処理が適用できたりもします。cd-tunerではそれをやっています。