見出し画像

FLUX1.0の中身を覗いてみる

用意するもの

import torch
from diffusers import  FluxPipeline

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", 
    torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()

print(pipe.components)

出力結果

{'vae': AutoencoderKL(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_blocks): ModuleList(
      (0): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0-1): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (downsamplers): ModuleList(
          (0): Downsample2D(
            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
          )
        )
      )
      (1): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0): ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
            (conv_shortcut): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
          )
          (1): ResnetBlock2D(
            (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
            (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (downsamplers): ModuleList(
          (0): Downsample2D(
            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))
          )
        )
      )
      (2): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0): ResnetBlock2D(
            (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
            (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
            (conv_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))
          )
          (1): ResnetBlock2D(
            (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
            (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (downsamplers): ModuleList(
          (0): Downsample2D(
            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2))
          )
        )
      )
      (3): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0-1): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
            (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
      )
    )
    (mid_block): UNetMidBlock2D(
      (attentions): ModuleList(
        (0): Attention(
          (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)
          (to_q): Linear(in_features=512, out_features=512, bias=True)
          (to_k): Linear(in_features=512, out_features=512, bias=True)
          (to_v): Linear(in_features=512, out_features=512, bias=True)
          (to_out): ModuleList(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
      )
      (resnets): ModuleList(
        (0-1): 2 x ResnetBlock2D(
          (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
          (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
      )
    )
    (conv_norm_out): GroupNorm(32, 512, eps=1e-06, affine=True)
    (conv_act): SiLU()
    (conv_out): Conv2d(512, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (decoder): Decoder(
    (conv_in): Conv2d(16, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (up_blocks): ModuleList(
      (0-1): 2 x UpDecoderBlock2D(
        (resnets): ModuleList(
          (0-2): 3 x ResnetBlock2D(
            (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
            (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (upsamplers): ModuleList(
          (0): Upsample2D(
            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
        )
      )
      (2): UpDecoderBlock2D(
        (resnets): ModuleList(
          (0): ResnetBlock2D(
            (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
            (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
            (conv_shortcut): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
          )
          (1-2): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
            (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (upsamplers): ModuleList(
          (0): Upsample2D(
            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
        )
      )
      (3): UpDecoderBlock2D(
        (resnets): ModuleList(
          (0): ResnetBlock2D(
            (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
            (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
            (conv_shortcut): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
          )
          (1-2): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
      )
    )
    (mid_block): UNetMidBlock2D(
      (attentions): ModuleList(
        (0): Attention(
          (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)
          (to_q): Linear(in_features=512, out_features=512, bias=True)
          (to_k): Linear(in_features=512, out_features=512, bias=True)
          (to_v): Linear(in_features=512, out_features=512, bias=True)
          (to_out): ModuleList(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
      )
      (resnets): ModuleList(
        (0-1): 2 x ResnetBlock2D(
          (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
          (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
      )
    )
    (conv_norm_out): GroupNorm(32, 128, eps=1e-06, affine=True)
    (conv_act): SiLU()
    (conv_out): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
), 'text_encoder': CLIPTextModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
), 'text_encoder_2': T5EncoderModel(
  (shared): Embedding(32128, 4096)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 4096)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=4096, out_features=4096, bias=False)
              (k): Linear(in_features=4096, out_features=4096, bias=False)
              (v): Linear(in_features=4096, out_features=4096, bias=False)
              (o): Linear(in_features=4096, out_features=4096, bias=False)
              (relative_attention_bias): Embedding(32, 64)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=4096, out_features=10240, bias=False)
              (wi_1): Linear(in_features=4096, out_features=10240, bias=False)
              (wo): Linear(in_features=10240, out_features=4096, bias=False)
              (dropout): Dropout(p=0.1, inplace=False)
              (act): NewGELUActivation()
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (1-23): 23 x T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=4096, out_features=4096, bias=False)
              (k): Linear(in_features=4096, out_features=4096, bias=False)
              (v): Linear(in_features=4096, out_features=4096, bias=False)
              (o): Linear(in_features=4096, out_features=4096, bias=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=4096, out_features=10240, bias=False)
              (wi_1): Linear(in_features=4096, out_features=10240, bias=False)
              (wo): Linear(in_features=10240, out_features=4096, bias=False)
              (dropout): Dropout(p=0.1, inplace=False)
              (act): NewGELUActivation()
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (final_layer_norm): T5LayerNorm()
    (dropout): Dropout(p=0.1, inplace=False)
  )
), 'tokenizer': CLIPTokenizer(name_or_path='/root/.cache/huggingface/hub/models--black-forest-labs--FLUX.1-dev/snapshots/fcb137eff9cee3c5fda4e908a1d3da105323a5bc/tokenizer', vocab_size=49408, model_max_length=77, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|startoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	49406: AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	49407: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}, 'tokenizer_2': T5TokenizerFast(name_or_path='/root/.cache/huggingface/hub/models--black-forest-labs--FLUX.1-dev/snapshots/fcb137eff9cee3c5fda4e908a1d3da105323a5bc/tokenizer_2', vocab_size=32100, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<extra_id_0>', '<extra_id_1>', '<extra_id_2>', '<extra_id_3>', '<extra_id_4>', '<extra_id_5>', '<extra_id_6>', '<extra_id_7>', '<extra_id_8>', '<extra_id_9>', '<extra_id_10>', '<extra_id_11>', '<extra_id_12>', '<extra_id_13>', '<extra_id_14>', '<extra_id_15>', '<extra_id_16>', '<extra_id_17>', '<extra_id_18>', '<extra_id_19>', '<extra_id_20>', '<extra_id_21>', '<extra_id_22>', '<extra_id_23>', '<extra_id_24>', '<extra_id_25>', '<extra_id_26>', '<extra_id_27>', '<extra_id_28>', '<extra_id_29>', '<extra_id_30>', '<extra_id_31>', '<extra_id_32>', '<extra_id_33>', '<extra_id_34>', '<extra_id_35>', '<extra_id_36>', '<extra_id_37>', '<extra_id_38>', '<extra_id_39>', '<extra_id_40>', '<extra_id_41>', '<extra_id_42>', '<extra_id_43>', '<extra_id_44>', '<extra_id_45>', '<extra_id_46>', '<extra_id_47>', '<extra_id_48>', '<extra_id_49>', '<extra_id_50>', '<extra_id_51>', '<extra_id_52>', '<extra_id_53>', '<extra_id_54>', '<extra_id_55>', '<extra_id_56>', '<extra_id_57>', '<extra_id_58>', '<extra_id_59>', '<extra_id_60>', '<extra_id_61>', '<extra_id_62>', '<extra_id_63>', '<extra_id_64>', '<extra_id_65>', '<extra_id_66>', '<extra_id_67>', '<extra_id_68>', '<extra_id_69>', '<extra_id_70>', '<extra_id_71>', '<extra_id_72>', '<extra_id_73>', '<extra_id_74>', '<extra_id_75>', '<extra_id_76>', '<extra_id_77>', '<extra_id_78>', '<extra_id_79>', '<extra_id_80>', '<extra_id_81>', '<extra_id_82>', '<extra_id_83>', '<extra_id_84>', '<extra_id_85>', '<extra_id_86>', '<extra_id_87>', '<extra_id_88>', '<extra_id_89>', '<extra_id_90>', '<extra_id_91>', '<extra_id_92>', '<extra_id_93>', '<extra_id_94>', '<extra_id_95>', '<extra_id_96>', '<extra_id_97>', '<extra_id_98>', '<extra_id_99>']}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32000: AddedToken("<extra_id_99>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32001: AddedToken("<extra_id_98>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32002: AddedToken("<extra_id_97>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32003: AddedToken("<extra_id_96>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32004: AddedToken("<extra_id_95>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32005: AddedToken("<extra_id_94>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32006: AddedToken("<extra_id_93>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32007: AddedToken("<extra_id_92>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32008: AddedToken("<extra_id_91>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32009: AddedToken("<extra_id_90>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32010: AddedToken("<extra_id_89>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32011: AddedToken("<extra_id_88>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32012: AddedToken("<extra_id_87>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32013: AddedToken("<extra_id_86>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32014: AddedToken("<extra_id_85>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32015: AddedToken("<extra_id_84>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32016: AddedToken("<extra_id_83>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32017: AddedToken("<extra_id_82>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32018: AddedToken("<extra_id_81>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32019: AddedToken("<extra_id_80>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32020: AddedToken("<extra_id_79>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32021: AddedToken("<extra_id_78>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32022: AddedToken("<extra_id_77>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32023: AddedToken("<extra_id_76>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32024: AddedToken("<extra_id_75>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32025: AddedToken("<extra_id_74>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32026: AddedToken("<extra_id_73>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32027: AddedToken("<extra_id_72>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32028: AddedToken("<extra_id_71>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32029: AddedToken("<extra_id_70>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32030: AddedToken("<extra_id_69>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32031: AddedToken("<extra_id_68>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32032: AddedToken("<extra_id_67>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32033: AddedToken("<extra_id_66>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32034: AddedToken("<extra_id_65>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32035: AddedToken("<extra_id_64>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32036: AddedToken("<extra_id_63>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32037: AddedToken("<extra_id_62>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32038: AddedToken("<extra_id_61>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32039: AddedToken("<extra_id_60>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32040: AddedToken("<extra_id_59>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32041: AddedToken("<extra_id_58>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32042: AddedToken("<extra_id_57>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32043: AddedToken("<extra_id_56>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32044: AddedToken("<extra_id_55>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32045: AddedToken("<extra_id_54>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32046: AddedToken("<extra_id_53>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32047: AddedToken("<extra_id_52>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32048: AddedToken("<extra_id_51>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32049: AddedToken("<extra_id_50>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32050: AddedToken("<extra_id_49>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32051: AddedToken("<extra_id_48>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32052: AddedToken("<extra_id_47>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32053: AddedToken("<extra_id_46>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32054: AddedToken("<extra_id_45>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32055: AddedToken("<extra_id_44>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32056: AddedToken("<extra_id_43>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32057: AddedToken("<extra_id_42>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32058: AddedToken("<extra_id_41>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32059: AddedToken("<extra_id_40>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32060: AddedToken("<extra_id_39>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32061: AddedToken("<extra_id_38>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32062: AddedToken("<extra_id_37>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32063: AddedToken("<extra_id_36>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32064: AddedToken("<extra_id_35>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32065: AddedToken("<extra_id_34>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32066: AddedToken("<extra_id_33>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32067: AddedToken("<extra_id_32>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32068: AddedToken("<extra_id_31>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32069: AddedToken("<extra_id_30>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32070: AddedToken("<extra_id_29>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32071: AddedToken("<extra_id_28>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32072: AddedToken("<extra_id_27>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32073: AddedToken("<extra_id_26>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32074: AddedToken("<extra_id_25>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32075: AddedToken("<extra_id_24>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32076: AddedToken("<extra_id_23>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32077: AddedToken("<extra_id_22>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32078: AddedToken("<extra_id_21>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32079: AddedToken("<extra_id_20>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32080: AddedToken("<extra_id_19>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32081: AddedToken("<extra_id_18>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32082: AddedToken("<extra_id_17>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32083: AddedToken("<extra_id_16>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32084: AddedToken("<extra_id_15>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32085: AddedToken("<extra_id_14>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32086: AddedToken("<extra_id_13>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32087: AddedToken("<extra_id_12>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32088: AddedToken("<extra_id_11>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32089: AddedToken("<extra_id_10>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32090: AddedToken("<extra_id_9>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32091: AddedToken("<extra_id_8>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32092: AddedToken("<extra_id_7>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32093: AddedToken("<extra_id_6>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32094: AddedToken("<extra_id_5>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32095: AddedToken("<extra_id_4>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32096: AddedToken("<extra_id_3>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32097: AddedToken("<extra_id_2>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32098: AddedToken("<extra_id_1>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32099: AddedToken("<extra_id_0>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}, 'transformer': FluxTransformer2DModel(
  (pos_embed): EmbedND()
  (time_text_embed): CombinedTimestepGuidanceTextProjEmbeddings(
    (time_proj): Timesteps()
    (timestep_embedder): TimestepEmbedding(
      (linear_1): Linear(in_features=256, out_features=3072, bias=True)
      (act): SiLU()
      (linear_2): Linear(in_features=3072, out_features=3072, bias=True)
    )
    (guidance_embedder): TimestepEmbedding(
      (linear_1): Linear(in_features=256, out_features=3072, bias=True)
      (act): SiLU()
      (linear_2): Linear(in_features=3072, out_features=3072, bias=True)
    )
    (text_embedder): PixArtAlphaTextProjection(
      (linear_1): Linear(in_features=768, out_features=3072, bias=True)
      (act_1): SiLU()
      (linear_2): Linear(in_features=3072, out_features=3072, bias=True)
    )
  )
  (context_embedder): Linear(in_features=4096, out_features=3072, bias=True)
  (x_embedder): Linear(in_features=64, out_features=3072, bias=True)
  (transformer_blocks): ModuleList(
    (0-18): 19 x FluxTransformerBlock(
      (norm1): AdaLayerNormZero(
        (silu): SiLU()
        (linear): Linear(in_features=3072, out_features=18432, bias=True)
        (norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
      )
      (norm1_context): AdaLayerNormZero(
        (silu): SiLU()
        (linear): Linear(in_features=3072, out_features=18432, bias=True)
        (norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
      )
      (attn): Attention(
        (norm_q): RMSNorm()
        (norm_k): RMSNorm()
        (to_q): Linear(in_features=3072, out_features=3072, bias=True)
        (to_k): Linear(in_features=3072, out_features=3072, bias=True)
        (to_v): Linear(in_features=3072, out_features=3072, bias=True)
        (add_k_proj): Linear(in_features=3072, out_features=3072, bias=True)
        (add_v_proj): Linear(in_features=3072, out_features=3072, bias=True)
        (add_q_proj): Linear(in_features=3072, out_features=3072, bias=True)
        (to_out): ModuleList(
          (0): Linear(in_features=3072, out_features=3072, bias=True)
          (1): Dropout(p=0.0, inplace=False)
        )
        (to_add_out): Linear(in_features=3072, out_features=3072, bias=True)
        (norm_added_q): RMSNorm()
        (norm_added_k): RMSNorm()
      )
      (norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
      (ff): FeedForward(
        (net): ModuleList(
          (0): GELU(
            (proj): Linear(in_features=3072, out_features=12288, bias=True)
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=12288, out_features=3072, bias=True)
        )
      )
      (norm2_context): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
      (ff_context): FeedForward(
        (net): ModuleList(
          (0): GELU(
            (proj): Linear(in_features=3072, out_features=12288, bias=True)
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=12288, out_features=3072, bias=True)
        )
      )
    )
  )
  (single_transformer_blocks): ModuleList(
    (0-37): 38 x FluxSingleTransformerBlock(
      (norm): AdaLayerNormZeroSingle(
        (silu): SiLU()
        (linear): Linear(in_features=3072, out_features=9216, bias=True)
        (norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
      )
      (proj_mlp): Linear(in_features=3072, out_features=12288, bias=True)
      (act_mlp): GELU(approximate='tanh')
      (proj_out): Linear(in_features=15360, out_features=3072, bias=True)
      (attn): Attention(
        (norm_q): RMSNorm()
        (norm_k): RMSNorm()
        (to_q): Linear(in_features=3072, out_features=3072, bias=True)
        (to_k): Linear(in_features=3072, out_features=3072, bias=True)
        (to_v): Linear(in_features=3072, out_features=3072, bias=True)
      )
    )
  )
  (norm_out): AdaLayerNormContinuous(
    (silu): SiLU()
    (linear): Linear(in_features=3072, out_features=6144, bias=True)
    (norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
  )
  (proj_out): Linear(in_features=3072, out_features=64, bias=True)
), 'scheduler': FlowMatchEulerDiscreteScheduler {
  "_class_name": "FlowMatchEulerDiscreteScheduler",
  "_diffusers_version": "0.30.0.dev0",
  "base_image_seq_len": 256,
  "base_shift": 0.5,
  "max_image_seq_len": 4096,
  "max_shift": 1.15,
  "num_train_timesteps": 1000,
  "shift": 3.0,
  "use_dynamic_shifting": true
}
}

解釈してみる

このモデルの出力から、FluxPipelineがどのようなコンポーネントで構成されているかが見えてきました。ここでは、主要なコンポーネントとその機能について簡単に説明し、それを理解するために数学的にどのように解析できるかを説明します。

1. VAE (Variational Autoencoder)

VAEはエンコーダーとデコーダーで構成されるニューラルネットワークで、通常は画像やその他のデータを低次元の潜在空間にエンコードし、それをデコーダーによって元の空間に復元します。

  • Encoder: 画像データを入力として受け取り、それを潜在ベクトルにエンコードします。ダウンサンプリングとResNetブロックを使用して特徴を抽出しています。

  • Decoder: 潜在ベクトルを受け取り、元の画像データに復元します。アップサンプリングとResNetブロックを使用して、入力画像の復元を行います。

数学的な解析

  • エンコーダーとデコーダーは、畳み込み層やダウンサンプリング(ストライド付き畳み込み)を使用して、入力データを徐々に圧縮(ダウンサンプリング)し、潜在空間を形成します。この操作は基本的に畳み込み演算と非線形関数(SiLU)によって行われます。

  • VAEの潜在空間は通常、正規分布に従うことを前提とし、そのためにKLダイバージェンスを使用して分布の整合性を保つことが求められます。

2. Text Encoder (CLIP and T5)

この部分はテキストをエンコードする部分で、CLIPとT5モデルを使用しています。

  • CLIPTextModel: これは、テキストをエンコードし、視覚情報との関連性を学習するために使用されるモデルです。注意機構(self-attention)を使用して、テキストの各トークンが他のトークンにどのように依存しているかを学習します。

  • T5EncoderModel: 大規模なトランスフォーマーベースのエンコーダーで、複雑な言語パターンをキャプチャします。

数学的な解析

  • Attention機構を使用することで、テキストの自己相関関係を学習します。この操作は、行列の積とソフトマックス関数によるスケーリングによって実現されます。

  • トランスフォーマーモデルの各層では、行列の積や非線形関数(GELU)を使って、入力を次の層に渡します。

3. Transformer (FluxTransformer2DModel)

この部分は、主に画像データや他の時系列データの処理に使用されるトランスフォーマーモデルです。複数のトランスフォーマーブロックから構成されており、入力データの時系列依存性や特徴量を学習します。

数学的な解析

  • 各トランスフォーマーブロックは、自己注意層とフィードフォワードネットワークから成り立っています。自己注意層では、入力の各部分が他の部分に対してどのように関連しているかを計算します。これも行列の積を利用しています。

  • AdaLayerNormやRMSNormといった正規化技術は、層ごとの出力を安定させ、学習を容易にします。

4. Scheduler (FlowMatchEulerDiscreteScheduler)

スケジューラーは、モデルが生成タスクを行う際に、ステップごとの更新手法を管理する部分です。数値解法として、オイラー法に基づく手法が使用されています。

数学的な解析

  • オイラー法は、微分方程式の数値解法の一つで、各ステップで微小な変化を積み重ねて最終的な結果を得ます。これは、シフトや動的シフトといったパラメータを用いて調整されます。

まとめ

このモデル全体を数学的に解析するためには、畳み込み、行列積、ソフトマックス、非線形活性化関数など、ニューラルネットワークで使用される基本的な数学的操作を理解し、それがどのように連鎖して大規模なモデルの出力を形成するかを理解する必要があります。

これらの操作を理解した上で、Pythonで数式を実装し、モデルが実際にどのように動作しているかをシミュレートすることができます。これにより、モデルの挙動をより深く理解することが可能になります。

この記事が気に入ったらサポートをしてみませんか?