OpenAI o1のコード理解能力を探る
o1はコードの理解能力が飛躍的に向上したという。
本当なのかどうか確かめたい。
そこでなんかentropixの作者がしょげてるみたいだったので、entropixを題材にコードを紐解いてみることにする。というのもentropixは論文がなくコードが先に出ているという珍しいケースなので、scam(詐欺)とまで言われて作者が怒っていたからだ。
まず、entropixのgithubのreadmeを読ませてみる。
その結果は以下(以下、線と線の間がo1の回答)
以下は、提示された主張・計画内容を要約したものです。
プロジェクト概要: 「entropix」は、エントロピーに基づくサンプリングと並列的なChain-of-Thought(CoT)デコード手法を用いて、推論時の計算リソースのみで高度な文脈対応を実現し、より優れた生成結果を得ようとする研究プロジェクトである。
現状と課題: 現在は開発途中であり、コードベースの一部が不安定・未完成。大型モデル向け実装やマルチGPU、TPU対応など、スケールアップに伴う複雑な課題がある。
今後の方針:
entropix-local:単一GPU(例:4090)やApple Metalで小規模モデルを対象に、研究・プロトタイピング向けの軽量版リポジトリを整備する。JAX、PyTorch、MLX版のコードをメンテナンスし、専用UIも提供予定。
entropix(大規模版):8xH100 / TPU v4-16(70Bモデル)からTPU v4-64(405Bモデル)に対応する大規模推論実装を目指す。アンソロピック風のチャットUIやOpenAI互換のサービングレイヤーを搭載。より高度なサンプラー実装とデプロイ手法が求められる。JAXはTPU向け、PyTorchはGPU向けに最適化。
entropix-trainer(将来的な新リポジトリ):強化学習(RL)など、さらなる手法を試すためのトレーナー環境の提供を検討中。
運用方針: 当面、コードの大きな改修・統合作業を進めるため、Pull Requestへの対応を一時的に制限。安定化後は、entropix-localに対するコミュニティ参加を奨励する。大規模版リポジトリへのコントリビュートは、デプロイ/シャーディングなどの複雑さゆえ、しばらく限定的になる見込み。
総評: 「entropix」は、エントロピーを活用した新しいサンプリング戦略で高品質なテキスト生成を目指す前衛的プロジェクトであり、ローカル版と大規模版の二本立てで開発を進める計画。現時点では不安定であり、開発者は時間を要する整備と評価プロセスを経た上で、ユーザやコントリビュータを受け入れる体制を整える方針である。
次に、engine.pyを読ませてみる。
このコードは、JAX 上で大規模言語モデル推論(生成)を行うためのエンジン(EntropixEngine)の実装例である。以下にコードが行っている主な処理・機能を説明する。
全体概要
役割: このクラスと関数群は、トークナイズ済みの入力系列(トークン列)に対してモデル推論を実行し、KVキャッシュ(Attention用キー・バリュー格納)を管理しながら次のトークンを生成するためのフレームワークを提供する。
対象: 大規模言語モデル(例えばLLM系モデル)のデコード過程で使われる。JAXのShardingやjitなどを駆使し、高速・分散推論に対応するインタフェースを整備している。
コード構造と主要クラス・関数
NamedTupleによる定義
LayerWeights / XfmrWeights: トランスフォーマーモデルの重み(パラメータ)構造を表すクラス。XfmrWeightsはLayerWeightsのリストなどを持ち、モデルのレイヤー重みやトークンエンベッディング、出力層の重みなどを格納する。
Params: モデルの層数やヘッド数、最大シーケンス長、ROPE関連パラメータといったモデル構成を記述するデータクラス。
DecodeState, SlotData, ResultTokens: デコード過程で用いる中間状態や結果を表すクラス。DecodeStateは現在のキャッシュ状態や生成位置を示し、ResultTokensは生成されたトークンや有効性、長さを一括管理するための構造。
EntropixEngineクラス
初期化 (__init__):
引数で受け取ったモデルパラメータ(params)、モデル重み(xfmr_weights)、デバイスメッシュ(mesh)、トークナイザ(tokenizer)、およびモデルのforward関数 (xfmr_fn)、サンプリング関数 (sample_fn)をクラス内部に保持する。
precompute_freqs_cisメソッドで計算されるROPE(Rotary Positional Embedding)用の複素数配列 freqs_cis をデバイス上に配置。
prefillメソッド:
与えられたトークン列について、モデルの前処理("prefill")を行い、対応するKVキャッシュを構築する。
トークン列をモデルに通してロジット(logits)を計算し、そのロジットから次トークン候補(top-kやargmax)が得られる。
結果として、キャッシュ・ロジット・次トークン情報などを含む辞書(Prefix)と、ResultTokensによる結果構造を返す。
generateメソッド:
既存のdecode_state(KVキャッシュや現在のポジションなどの状態)を用いて新たなトークンを生成する。
xfmr_fnで次ステップのロジットを計算し、sample_fnで次トークンをサンプリングし、ResultTokensへ格納。decode_stateも更新する。
insertメソッド:
新しいリクエスト(prefix)の情報を既存のdecode_stateに挿入する処理。並列デコード(複数バッチ)の場合、完了したスロットに新しいデコード要求を割り当てるためのハンドリング。
ROPE関連 (precompute_freqs_cis, apply_scaling):
Rotary Position Embeddingで用いる周波数(freqs)を事前計算する。
apply_scalingでは特定の条件下で周波数をスケーリングし、元の周波数テーブルを補正するカスタムな処理を実装。
build_attn_maskメソッド:
デコード時の因果的アテンションマスクを作成。
上下三角行列で未来トークンへのアテンションを不可にする一般的なマスクを構築。
JAX関連:
@functools.partial(jax.jit, ...)によるJITコンパイルで、高速なGPU/TPU上での実行を可能にしている。
meshやPartitionSpecを用い、モデルパラメータや計算負荷を並列化している。
要約
このコードは、JAXを用いた大規模言語モデル推論エンジンのコア部分であり、以下のことを行う:
モデルパラメータ・重み・トークナイザ・分散計算用メッシュを初期化
ROPEなどの事前計算を実行
prefillで初期入力に対する前向き計算(KVキャッシュ生成)と初回トークン生成
generateで逐次的に1トークンずつ生成していくステップを管理
KVキャッシュやロジット、生成トークンをDecodeStateやResultTokensで整理・保持
並列リクエスト処理のためのinsertなど補助的機能も備える
以上により、コードは大規模トランスフォーマーモデルを用いたテキスト生成を、効率的かつ分散的に行う基盤を提供している。
さらに、sampler.pyを読ませてみる。samplerはentropixの革新部分と言える。
このコードは、モデルの生成トークン分布(logits)におけるエントロピー(entropy)や「変動エントロピー」(varentropy)などの指標に基づいて、サンプリング戦略を動的に切り替えるカスタムなサンプリングアルゴリズムを実装している。
全体の流れ
adaptive_dirichlet_step による初期サンプル取得
adaptive_dirichlet_stepという関数を用いて、与えられたlogitsから一度トークンをサンプリングすると共に、naked_ent("naked"なlogits分布上のエントロピー)
naked_varent("naked"なlogits分布上のvarentropy、変動エントロピー)
scaffold_ent(補助的指標とみられるscaffoldベースのエントロピー)
scaffold_varent(同上scaffoldベースのvarentropy)
などのメトリクスが計算される。
条件分岐による4つのケース判定
サンプリング結果とメトリクスに基づいて、4つのケースに分岐する。LELV (Low Entropy, Low Varentropy): エントロピーと変動エントロピーがともに低い。
→ 現状の選択トークンをそのまま採用し、「流れに乗っている」ような生成と解釈している。HELV (High Entropy, Low Varentropy): エントロピーが高く、変動エントロピーが低い。
→ 意味的には「不確実性は低いが、選択肢は多彩」な状態と解釈し、「明確化の質問トークン」(固定ID=2564)に差し替える戦略を取る。LEHV (Low Entropy, High Varentropy): エントロピーが低く、変動エントロピーが高い。
→ 「枝分かれする道を探索している」ような状態。現時点では単純に元のトークンを採用するが、TODOコメントがあるように将来的には別の戦略(ツリー探索的な再サンプル)を想定している。HEHV (High Entropy, High Varentropy): エントロピー・変動エントロピーの両方が高い。
→ 「混沌とした霧の中で再サンプリングする」状態。ここでは一旦、先ほどサンプルしたトークンに対してlogitsを修正(そのトークンを落とす)し、もう一度 adaptive_dirichlet_step を用いて再サンプリングを行う。
ベクトル化(vmap)によるバッチ処理 sample_one関数をjax.vmapでバッチ方向に適用することで、複数のバッチに対して同時にこのサンプリング戦略を実行する。
ポイント
このアルゴリズムは、通常の「softmax後にランダムサンプリング」や「greedyデコード」等と異なり、分布特性(エントロピー・varentropy)に応じてサンプリングの振る舞いを変える適応的サンプリング手法である。
エントロピー/varentropyを用いて、「生成があまり揺らぎなく進んでいるか」や「生成に不確実性はあるが選択肢が多いか」、「不確実性が大きく次手を再考すべきか」といった状況判断を行い、その状況に応じてトークン再サンプルや特定トークンへの書き換えを行う点が特徴的。
要するに、このサンプリングアルゴリズムは、生成分布の統計的特徴に基づいて動的に戦略を切り替えることで、出力品質や多様性、文脈明確化を試みる先進的な手法の一例と言える。
以下は、適応的サンプリング(adaptive sampling)をエントロピーやヴァリエントロピー(varentropy)といった統計量に基づいて動的にサンプリング手法を切り替える様子を、数式的に簡略化して示した一例である。
確率分布とメトリクス定義
モデルから得られる出力ロジットをソフトマックス関数で確率分布に変換する。語彙集合を $${\mathcal{V}}$$ とし、トークン $${x \in \mathcal{V}}$$ に対する確率を $${p(x)}$$ とする。ソフトマックス変換後、
$${p(x) = \frac{\exp(\text{logit}(x))}{\sum_{x' \in \mathcal{V}}\exp(\text{logit}(x'))}}$$
エントロピー $${H}$$ は以下で定義される(自然対数使用):
$${H(p) = -\sum_{x \in \mathcal{V}} p(x)\ln p(x)}$$
ヴァリエントロピー(varentropy) $${vH}$$ は分布の $${\ln p(x)}$$ の分散と解釈できる。
$${vH(p) = E[(\ln p(x))^2] - (E[\ln p(x)])^2}$$
ここで期待値 $${E[\cdot]}$$ は $${p(x)}$$ に関する期待値として
$${E[f(x)] = \sum_{x \in \mathcal{V}} p(x)f(x)}$$.
適応的閾値による条件分岐
適応的サンプリングでは、あらかじめ以下のような閾値を設定する:$${H_{\text{low}}, H_{\text{medium}}, H_{\text{high}}}$$(エントロピー用閾値)
$${{vH_{\text{low}}, vH_{\text{high}}}}$$(ヴァリエントロピー用閾値)例えば、
LELV状態(Low Entropy, Low Varentropy)条件:
$${H(p) < H_{\text{low}} \quad \wedge \quad vH(p) < vH_{\text{low}}}$$
HELV状態(High Entropy, Low Varentropy)条件:
$${H(p)<vHlowH(p) > H_{\text{high}} \quad \wedge \quad vH(p) < vH_{\text{low}}}$$
LEHV状態(Low Entropy, High Varentropy)条件:
$${H(p) < H_{\text{high}} \quad \wedge \quad vH(p) > vH_{\text{high}}}$$
HEHV状態(High Entropy, High Varentropy)条件:
$${H(p) > H_{\text{medium}} \quad \wedge \quad vH(p) > vH_{\text{high}}}$$
条件に応じたサンプリング手続き
各状態により、サンプリング方法を切り替える。$${x_{\text{new}} = \begin{cases} \displaystyle \arg\max_{x} p(x) & \text{(例) LELVの場合、確定的なトークン選択や最初のサンプルを維持}\\[6pt] t_{\text{clarify}} & \text{(例) HELVの場合、特定の「明確化質問トークン」を代入}\\[6pt] \text{sample from }p(x)& \text{(例) LEHVの場合、元の分布からそのままサンプリング}\\[6pt] \text{re-sample from a modified distribution }p'(x)& \text{(例) HEHVの場合、再サンプリングを行う} \end{cases}}$$
ここで再サンプリングとは、
$${p'(x) \propto \begin{cases} p(x) & x \neq x_{\text{previous}} \\ 0 & x = x_{\text{previous}} \end{cases}}$$
のような修正分布から改めてサンプルする、あるいは他の手法(例:Dirichlet分布ベースの調整)を用いることを示す。
以上のように、適応的サンプリングとは、確率分布 $${p(x)}$$ の統計特性(エントロピーとヴァリエントロピー)に基づき、どのサンプリング戦略を用いるかを切り替えることであり、上記の条件分岐式と再サンプリング方針によって数学的にも示すことができる。
数式をプログラムに変換するのもまあまあ難しいが、プログラムを数式に還元するのはもっと難しいのでかなり高度なことができていると想像できる。
entropixの場合、熱力学に由来するからか、プログラムでいきなり見せられるより数式で説明してもらった方がわかりやすい気さえする。
entropixをQwQとかに適応できるようになったりしないかなー。