見出し画像

WSL2でEntropixを試してみる

「エントロピーを使用してコンテキスト認識サンプリングを行う。これにより、o1のCoTやAnthropicsに類似するものをシミュレートして、推論時間計算を使用してはるかに優れた結果を得ることができるようになった」らしいEntropixを試してみます。

shi3zさんの以下の記事。「ほんまかいな!?」ということで、気になったら試すしかありません。

GitHubはこちら。


1. 環境構築

READMEを見るとpoetryを使用していますが、いつもどおりvenvにします。

python3 -m venv entropix
cd $_

リポジトリをクローンします。

git clone https://github.com/xjdr-alt/entropix/
cd entropix

パッケージのインストールしましょう。

pip install tyro torch ml_dtypes jax[cuda] transformers chex tiktoken blobfile

2. 試してみる

(1) モデルのダウンロードと変換

download_weights.pyを実行します。

python download_weights.py \
    --model-id meta-llama/Llama-3.2-1B-Instruct \
    --out-dir weights/1B-Instruct

実行時のログがこちら。

config.json: 100%|█████████████████████████████████████████████████████████████████████████████| 877/877 [00:00<00:00, 1.21MB/s]
model.safetensors: 100%|███████████████████████████████████████████████████████████████████| 2.47G/2.47G [00:38<00:00, 64.0MB/s]
generation_config.json: 100%|██████████████████████████████████████████████████████████████████| 189/189 [00:00<00:00, 1.05MB/s]
 model.embed_tokens.weight: param.shape=torch.Size([128256, 2048])
Writing model.embed_tokens.weight as tok_embeddings.weight to weights/1B-Instruct/tok_embeddings.weight.npy
 model.layers.0.self_attn.q_proj.weight: param.shape=torch.Size([2048, 2048])
Writing model.layers.0.self_attn.q_proj.weight as layers.0.attention.wq.weight to weights/1B-Instruct/layers.0.attention.wq.weight.npy
 model.layers.0.self_attn.k_proj.weight: param.shape=torch.Size([512, 2048])
Writing model.layers.0.self_attn.k_proj.weight as layers.0.attention.wk.weight to weights/1B-Instruct/layers.0.attention.wk.weight.npy
 model.layers.0.self_attn.v_proj.weight: param.shape=torch.Size([512, 2048])
Writing model.layers.0.self_attn.v_proj.weight as layers.0.attention.wv.weight to weights/1B-Instruct/layers.0.attention.wv.weight.npy
 model.layers.0.self_attn.o_proj.weight: param.shape=torch.Size([2048, 2048])
Writing model.layers.0.self_attn.o_proj.weight as layers.0.attention.wo.weight to weights/1B-Instruct/layers.0.attention.wo.weight.npy
 model.layers.0.mlp.gate_proj.weight: param.shape=torch.Size([8192, 2048])
Writing model.layers.0.mlp.gate_proj.weight as layers.0.feed_forward.w1.weight to weights/1B-Instruct/layers.0.feed_forward.w1.weight.npy
 model.layers.0.mlp.up_proj.weight: param.shape=torch.Size([8192, 2048])
Writing model.layers.0.mlp.up_proj.weight as layers.0.feed_forward.w3.weight to weights/1B-Instruct/layers.0.feed_forward.w3.weight.npy
 model.layers.0.mlp.down_proj.weight: param.shape=torch.Size([2048, 8192])
Writing model.layers.0.mlp.down_proj.weight as layers.0.feed_forward.w2.weight to weights/1B-Instruct/layers.0.feed_forward.w2.weight.npy
 model.layers.0.input_layernorm.weight: param.shape=torch.Size([2048])
Writing model.layers.0.input_layernorm.weight as layers.0.attention_norm.weight to weights/1B-Instruct/layers.0.attention_norm.weight.npy
 model.layers.0.post_attention_layernorm.weight: param.shape=torch.Size([2048])
Writing model.layers.0.post_attention_layernorm.weight as layers.0.ffn_norm.weight to weights/1B-Instruct/layers.0.ffn_norm.weight.npy
 model.layers.1.self_attn.q_proj.weight: param.shape=torch.Size([2048, 2048])
Writing model.layers.1.self_attn.q_proj.weight as layers.1.attention.wq.weight to weights/1B-Instruct/layers.1.attention.wq.weight.npy
 model.layers.1.self_attn.k_proj.weight: param.shape=torch.Size([512, 2048])
Writing model.layers.1.self_attn.k_proj.weight as layers.1.attention.wk.weight to weights/1B-Instruct/layers.1.attention.wk.weight.npy
 model.layers.1.self_attn.v_proj.weight: param.shape=torch.Size([512, 2048])
Writing model.layers.1.self_attn.v_proj.weight as layers.1.attention.wv.weight to weights/1B-Instruct/layers.1.attention.wv.weight.npy
 model.layers.1.self_attn.o_proj.weight: param.shape=torch.Size([2048, 2048])
Writing model.layers.1.self_attn.o_proj.weight as layers.1.attention.wo.weight to weights/1B-Instruct/layers.1.attention.wo.weight.npy
 model.layers.1.mlp.gate_proj.weight: param.shape=torch.Size([8192, 2048])
Writing model.layers.1.mlp.gate_proj.weight as layers.1.feed_forward.w1.weight to weights/1B-Instruct/layers.1.feed_forward.w1.weight.npy
 model.layers.1.mlp.up_proj.weight: param.shape=torch.Size([8192, 2048])
Writing model.layers.1.mlp.up_proj.weight as layers.1.feed_forward.w3.weight to weights/1B-Instruct/layers.1.feed_forward.w3.weight.npy
 model.layers.1.mlp.down_proj.weight: param.shape=torch.Size([2048, 8192])
Writing model.layers.1.mlp.down_proj.weight as layers.1.feed_forward.w2.weight to weights/1B-Instruct/layers.1.feed_forward.w2.weight.npy
 model.layers.1.input_layernorm.weight: param.shape=torch.Size([2048])
Writing model.layers.1.input_layernorm.weight as layers.1.attention_norm.weight to weights/1B-Instruct/layers.1.attention_norm.weight.npy
 model.layers.1.post_attention_layernorm.weight: param.shape=torch.Size([2048])
Writing model.layers.1.post_attention_layernorm.weight as layers.1.ffn_norm.weight to weights/1B-Instruct/layers.1.ffn_norm.weight.npy
 model.layers.2.self_attn.q_proj.weight: param.shape=torch.Size([2048, 2048])

(snip)

Writing model.layers.15.self_attn.q_proj.weight as layers.15.attention.wq.weight to weights/1B-Instruct/layers.15.attention.wq.weight.npy
 model.layers.15.self_attn.k_proj.weight: param.shape=torch.Size([512, 2048])
Writing model.layers.15.self_attn.k_proj.weight as layers.15.attention.wk.weight to weights/1B-Instruct/layers.15.attention.wk.weight.npy
 model.layers.15.self_attn.v_proj.weight: param.shape=torch.Size([512, 2048])
Writing model.layers.15.self_attn.v_proj.weight as layers.15.attention.wv.weight to weights/1B-Instruct/layers.15.attention.wv.weight.npy
 model.layers.15.self_attn.o_proj.weight: param.shape=torch.Size([2048, 2048])
Writing model.layers.15.self_attn.o_proj.weight as layers.15.attention.wo.weight to weights/1B-Instruct/layers.15.attention.wo.weight.npy
 model.layers.15.mlp.gate_proj.weight: param.shape=torch.Size([8192, 2048])
Writing model.layers.15.mlp.gate_proj.weight as layers.15.feed_forward.w1.weight to weights/1B-Instruct/layers.15.feed_forward.w1.weight.npy
 model.layers.15.mlp.up_proj.weight: param.shape=torch.Size([8192, 2048])
Writing model.layers.15.mlp.up_proj.weight as layers.15.feed_forward.w3.weight to weights/1B-Instruct/layers.15.feed_forward.w3.weight.npy
 model.layers.15.mlp.down_proj.weight: param.shape=torch.Size([2048, 8192])
Writing model.layers.15.mlp.down_proj.weight as layers.15.feed_forward.w2.weight to weights/1B-Instruct/layers.15.feed_forward.w2.weight.npy
 model.layers.15.input_layernorm.weight: param.shape=torch.Size([2048])
Writing model.layers.15.input_layernorm.weight as layers.15.attention_norm.weight to weights/1B-Instruct/layers.15.attention_norm.weight.npy
 model.layers.15.post_attention_layernorm.weight: param.shape=torch.Size([2048])
Writing model.layers.15.post_attention_layernorm.weight as layers.15.ffn_norm.weight to weights/1B-Instruct/layers.15.ffn_norm.weight.npy
 model.norm.weight: param.shape=torch.Size([2048])
Writing model.norm.weight as norm.weight to weights/1B-Instruct/norm.weight.npy
 lm_head.weight: param.shape=torch.Size([128256, 2048])
Writing lm_head.weight as output.weight to weights/1B-Instruct/output.weight.npy

download_weights.pyで何をしているのかを見てみましょう。

Hugging Faceのモデルを bfloat16の精度で読み込み、'wq.weight'(Query重み)と'wk.weight'(Key重み)のパラメータに対して重み行列内のヘッド関連の情報を再配置しています。(revert_permute関数。四次元に変換して、2番目と3番目を入れ替え、二次元に戻す)。
再配置した情報をfloat16形式に変換した後、JAXのnumpy配列(jnp.ndarray)に変換し、ファイルに書き出しています。

(2) 実行時エラーの解消

entropixを実行するも、

PYTHONPATH=. python entropix/torch_main.py

エラーが発生しました。。。

RuntimeError: expected scalar type BFloat16 but found Float

「行列計算torch.matmul に引き渡している型があってない」らしいので、torch.matmul関数を呼び出している箇所の手前で to を呼び出して変換します。それで試すも、また型エラー。

RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16

こちらも合わせます。entropix/torch_model.pyの守勢差分はこちら。

diff --git a/entropix/torch_model.py b/entropix/torch_model.py
index 0ebb3e9..13220d3 100644
--- a/entropix/torch_model.py
+++ b/entropix/torch_model.py
@@ -47,6 +47,7 @@ def attention(x: torch.Tensor, layer_weights: LayerWeights, model_params, cur_po
     xq = torch.permute(xq, (0, 2, 1, 3))  # (bs, n_heads, seqlen, head_dim)
     keys = torch.permute(keys, (0, 2, 3, 1))  # (bs, n_heads, head_dim, cache_len + seqlen)
     values = torch.permute(values, (0, 2, 1, 3))  # (bs, n_heads, cache_len + seqlen, head_dim)
+    xq = xq.to(torch.bfloat16)
     scores = torch.matmul(xq, keys)
     pre_scores = scores / math.sqrt(model_params.head_dim)
     scores = pre_scores.to(torch.float32)  # Always do attention softmax at float32
@@ -55,8 +56,10 @@ def attention(x: torch.Tensor, layer_weights: LayerWeights, model_params, cur_po
     mask = torch.where(scores != 0.0, scores, DEFAULT_MASK_VALUE)
     padded_logits = torch.where((mask >= DEFAULT_MASK_VALUE * 0.5), scores, DEFAULT_MASK_VALUE)
     scores = F.softmax(padded_logits, dim=-1).to(torch.float32)
+    values = values.to(torch.float32)
     output = torch.matmul(scores, values)
     output = output.transpose(1, 2).reshape(xq.shape[0], xq.shape[2], -1)
+    output = output.to(torch.bfloat16)
     out = F.linear(output, layer_weights.wo)
     return out, kvcache, pre_scores


これでエラー解消!と思ったが、警告がまだ出ている。

/mnt/data/shoji_noguchi/venv/entropix/entropix/entropix/torch_sampler.py:58: UserWarning: var(): degrees of freedom is <= 0. Correction should be strictly less than the reduction factor (input numel divided by output numel). (Triggered internally at ../aten/src/ATen/native/ReduceOps.cpp:1808.)
  attn_varentropy = torch.var(attn_entropy, dim=-1)

「入力で渡されたデータの、最後の次元のサイズが1より大きくない」らしい。entropix/torch_sampler.pyにprintfデバッグ print(attn_entropy.shape) を仕込んだところ、

torch.Size([1, 32, 1])
torch.Size([1, 32, 1])
torch.Size([1, 32, 1])
(snip)
torch.Size([1, 32, 1])
torch.Size([1, 32, 1])
torch.Size([1, 32, 1])

dim=-1、確かに最後の次元のサイズが 1 である。解消する手段としては、変数attn_entropyの次元数を減らしてサイズ32の次元で分散を計算させる、unbiased=Falseにするとかあるけれども、わたしには何が正しいのかよく分からないので、とりあえずはこのままにしておきます。

(3) entropix実行

実行します。

PYTHONPATH=. python entropix/torch_main.py

結果がこちら。(上記で触れた警告出力は削除しています)

Using device: cuda
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
<antThinking>
You're absolutely right. I need to delve deeper into my actual thought processes, including the uncertainties, associations, and even potential biases that arise as I consider the query. My previous responses, while informative, didn't truly capture the nuanced, sometimes messy nature of cognition. I'll strive to provide a more authentic representation of my internal dialogue, including moments of doubt, tangential thoughts, and the process of refining ideas. This should result in a more genuine demonstration of LLM chain of thought, reflection, and self-correction.
</antThinking>

Which number is larger, 9.9 or 9.11?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

To determine which number is larger, we can compare the numbers step by step.

9.9 is greater than 9.11 since 10 is greater than 9.

Therefore, 9.9 is larger than 9.11.<|eot_id|>

をー、9.9のほうが大きいと出た。ただ、ちょっと説明が不足しているような気もしなくはない。

VRAM使用量は、21.5GB程に達してました。

8Bモデルで試してみたいが間違いなく溢れてしまう…。

追記

(1) 2024/10/14 16:20頃

READMEに以下の記述があることを見つける(よく読めよ)。

NOTES: If you're using using the torch parts only, you can export XLA_PYTHON_CLIENT_PREALLOCATE=false to prevent jax from doing jax things and hogging your VRAM

README.mdより

結果、1Bモデルだと4.8GB、


1B-Instuct

8Bモデルだと18.9GB程のメモリ使用量でした。

8B-Instuct

いいなと思ったら応援しよう!