8x7B=47B?

Mixtral「7×8は56じゃないぞ!オレたちは7×8で47だ!6.7倍だぞ6.7倍!」

 MoEモデルのパラメータ数について気になったので調べてみました。


Mistral-7Bのパラメータ数

 まずはただの7Bモデルについてみていきます。何番煎じだ?
実装はここみたい
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py

 何も考えずtransformersでモデルをロードして標準出力を見てみたら、めっちゃ単純でした。ここから計算すればいいですね。

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

config.jsonの中身と比べてみると、
vocab_size(トークン数):32000
hidden_size(だいたいのチャンネル数):4096
intermediate_size(FeedForwardの真ん中のチャンネル数):14336
num_hidden_layers(層の数):32
num_attention_heads(queryのヘッド数):32
num_kv_heads(key, valueのヘッド数):8
となっています。

計算に使う定数をまとめておく

vocab_size = 32000
hidden_size = 4096
num_attention_heads = 32
num_kv_heads = 8
head_dim = hidden_size // num_attention_heads
intermediate_size = 14336
num_hidden_layers = 32

embed_tokens、lm_head

 embed_tokensはトークンをhidden_sizeベクトルにするのでvocab_size ×hidden_size、lm_headはhidden_sizeのベクトルから各トークンの確率(正確にはソフトマックスをとったら確率になるlogitsとかいうやつ)にするのでhidden_size×vocab_size です。掛け算の順番には意味がない派だけどこういうときは意味ある気がするね。

embed_tokens = vocab_size * hidden_size
lm_head = hidden_size * vocab_size
print(embed_tokens, lm_head) # 131,072,000 131,072,000

MistralSdpaAttention

 あてんちょんなので、q,k,v,oがあります。Grouped query attentionとかいうやつで、kvのヘッド数がqより少ないです。実装的にはkvをqのヘッド数に合わせるように複製してるだけみたいですね(多分)。

q = hidden_size * head_dim * num_attention_heads # = hidden_size * hidden_size
k = hidden_size * head_dim * num_kv_heads
v = hidden_size * head_dim * num_kv_heads
o = hidden_size * head_dim * num_attention_heads
attn = q + k + v + o
print(attn) # 41,943,040

 位置エンコーディング(MistralRotaryEmbedding)にはパラメータはありません。

MistralMLP

 あてんちょんのあとあるやつ。ふぃーどふぉわーどと呼ばれることの方が多い気がします。hidden_size→intermediate_sizeにする二つの全結合層(gate,up)があって、片方だけSiLUを適用した後掛けあわせて、intermediate_size→hidden_sizeにする全結合層(down)をかけます。

gate = hidden_size * intermediate_size
up = hidden_size * intermediate_size
down = intermediate_size * hidden_size
mlp = gate + up + down
print(mlp) # 176,160,768

MistralRMSNorm

 RMSを使ったNorm層らしい。なんだかよく分かってないがパラメータ数はhidden_sizeと同じ。

norm = hidden_size

MistralDecoderLayer

 一つの層にattnとmlp、normが入力に一つ、attnとmlpの間に一つあります。

layer = norm + attn + norm + mlp
# 218,112,000

 これが32個あります。

ぜんぶ

emb→layer×32→norm→lm_headという流れです。

mistral = embed_tokens + layer * num_hidden_layers + norm + lm_head
print(mistral) # 7,241,732,096

 うおー、7Bというか正確には7.24Bくらいなんですねえ。
こんな頑張らなくてもpytorchでは以下のように1行でパラメータ数が分かります。

sum(p.numel() for p in model.parameters()) == mistral # Trueだ!やったあ^q^

Mixtral-8x7Bのパラメータ数

  ほんばんです。ふつ―のモデルはふつーのPCでは読み込めないので以下の4bit版をロードしました。

これでもVRAM27GBも使ってたけど・・・

 構造は以下のような感じでした。

MixtralForCausalLM(
  (model): MixtralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MixtralDecoderLayer(
        (self_attn): MixtralSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MixtralRotaryEmbedding()
        )
        (block_sparse_moe): MixtralSparseMoeBlock(
          (gate): Linear4bit(in_features=4096, out_features=8, bias=False)
          (experts): ModuleList(
            (0-7): 8 x MixtralBlockSparseTop2MLP(
              (w1): Linear4bit(in_features=4096, out_features=14336, bias=False)
              (w2): Linear4bit(in_features=14336, out_features=4096, bias=False)
              (w3): Linear4bit(in_features=4096, out_features=14336, bias=False)
              (act_fn): SiLU()
            )
          )
        )
        (input_layernorm): MixtralRMSNorm()
        (post_attention_layernorm): MixtralRMSNorm()
      )
    )
    (norm): MixtralRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

 クラス名が代わってたりするけどだいたい形は同じです。異なるのはMLP部分だけです。MLPレイヤーが8個あるMoEブロックに変わってます。それぞれのMLPをエキスパートと呼びます。MoEでは8個のエキスパート全部を使うのではなく、gate層で各トークンに対してどれを使うか決めます。
config.jsonで
num_local_experts: 8
num_experts_per_tok: 2
となってます。8個のエキスパートがあって、各トークンで2個のエキスパートを選んで使うという意味ですね。計算量の増加を抑えながらパラメータ数を大きくできる手法です。VRAMが絶望的に足らない我々庶民にとってはあまりいい方法とは思えないのですが・・・。

 じゃあパラメータ数を数えていきます。といってもmlpだけ変わります。
各トークンに対してどのエキスパートを使うべきか決める全結合層(moe_gate)と、エキスパートとなるmlpが8個があります。

moe_gate = hidden_size * num_local_experts
mlp_moe = mlp * num_local_experts + moe_gate
print(mlp_moe) # 1,409,318,912

 あとは同じです。

layer_moe = norm + attn + mlp_moe + norm
mixtral = embed_tokens + layer_moe * num_hidden_layers + norm + lm_head
print(mixtral) # 46,702,792,704

 46.7Bということで前情報と一致しました。それでは確認・・・ん?

sum([p.numel() for p in model.parameters()]) == mixtral # False・・なんで>q<

 torchのnumelを使った計算方法では、半分くらいの値になってしまいました。これはbitsandbytesの4bitでロードしてたからみたいです。
 ためしに一番最初のattn.q_projを見てみると、

print(model.model.layers[0].self_attn.q_proj.weight.shape) # torch.Size([8388608, 1]) 
print(model.model.layers[0].self_attn.q_proj.weight.dtype) # torch.uint8
print(4096*4096) # 16777216

 こんな感じでした。サイズが半分になっています。そして型はuint8ですね。torchには4bitの型がないため、2つの4bitパラメータを1つの8bitパラメータにまとめているみたいですね。あたまいい~。ということで8bitの型はパラメータ数を2倍にして計算してみます。

import torch
param_size = 0
for p in model.parameters():
    if p.dtype == torch.uint8:
        param_size += p.numel() * 2
    else:
        param_size += p.numel()
print(param_size == mixtral) # Trueだ!やったあ^q^

結論

1 = 2です。
46702792704 = 7241732096 * 8
両辺46702792704を引くと
0 = 11231064064
両辺11231064064で割ると
0 = 1
両辺1足すと
1 = 2 (証明終了)

おまけ

 さいきんでたMixtral 8x22Bも計算してみる。どうせロードできないので確認しようがないですが、最初の定数を変えるだけでいいはず・・・。

vocab_size = 32000
hidden_size = 6144
num_attention_heads = 48
num_kv_heads = 8
head_dim = hidden_size // num_attention_heads
intermediate_size = 16384
num_hidden_layers = 56

結果は140,620,634,112でした。140Bですね。
ちなみにMoEなしだと22,237,845,504となって、本当に22Bでした。