見出し画像

Jambaの紹介: Mamba, トランスフォーマー, MoEを組み合わせた、進化したLLM

トランスフォーマーアーキテクチャの根本的な課題

これまでのところ、LLMの作成は主に伝統的なトランスフォーマーアーキテクチャの使用によるものであり、その堅牢な能力で知られています。しかし、これらの構造には2つの重要な制限があります。

演算処理とメモリが二乗に必要

ChatGPT、Gemini、またはClaudeのようなモデルはすべて、Transformerブロックの連結に基づいています。

各ブロックには2つの要素が含まれています:

  1. アテンション層

  2. フィードフォワード層(画像中のMLP)

アテンションメカニズムは、入力シーケンスを処理し、各単語が重要な単語に注意を払うのを支援する混合操作を強制します。これにより、動詞が名詞に、副詞が動詞に、代名詞が名詞に変換されます。このメカニズムは、入力シーケンスを効果的に処理するために不可欠です。それにより、シーケンス内の各単語が重要な単語に焦点を当てることが保証されます。アテンションメカニズムは、入力シーケンスの理解を向上させる上で重要な役割を果たします。
一方、フィードフォワード層は、データから主要な特徴と関係性を抽出するために重要です。これにより、全体モデルに非線形性が加わり、非線形関数の近似に不可欠です。この非線形性は、テキストなどの複雑なデータを扱う際に基本的です。フィードフォワードニューラルネットワークは、プロセスにおいて重要な役割を果たします。LLMが複雑なデータを効果的に処理するために不可欠です。
RMSNorm は、トランスフォーマーのトレーニングにおいて重要なトレーニングの安定化を行うために、層の重みを正規化しますが、マンバやハイエナのようなより不安定な演算子においては基本的です。
アテンションは二次コストの複雑さを持っています。入力シーケンスが倍になると、処理コストは4倍になります。短いシーケンスでは問題はそれほど重要ではありませんが、長いシーケンスでは計算とメモリの要件が急増します。これは特にKVキャッシュによるものです。問題は、長いシーケンスに対する計算とメモリの要求の増加にあります。

KV Cacheによるメモリ消費ボトルネック

ChatGPTなどのモデルのKVキャッシュは、注意力(Attention)スコアを保存することで、著しい計算の節約をもたらすことで、一種のメモリとして機能します。KVキャッシュの問題は、トランスフォーマーが圧縮された状態を持たないため、より長いシーケンスではより大きなキャッシュサイズが生じることです。この制限は、シーケンスの長さとともにキャッシュサイズが増加するため、より長いシーケンスに対してより重要になります。最初は、モデルサイズの制限により、短いシーケンスではKVキャッシュの影響は最小限です。しかし、シーケンスが長くなるにつれて、KVキャッシュの影響がより顕著になり、効率的なメモリ管理に課題をもたらします。

Source: UC Berkeley
人気のあるLLaMa-2 7BやMistral-7Bなどのモデルは、その小ささにもかかわらず、256kトークンのシーケンスを処理するために大量のメモリが必要です。このタスクでは、LLaMa-2 7Bは128GBのメモリを、Mistral-7Bは32GBのメモリを必要とします。なお、Mistral-7BはGrouped-Query Attentionを利用しています。これらのメモリ要件は、約192kワードの長さの256kトークンのシーケンスに適用されることに注目すると驚くべきことです。
Transformerは、高いコストのために大きなシーケンス長にスケールするのに苦労しています。Ring AttentionやGPU間の通信を削減するイノベーションも問題を解決していません。二次の障壁は、それに対処しようとする努力にもかかわらず、依然として存在しています。高いコストがTransformerのスケーラビリティを数年間にわたって妨げてきました。Ring Attentionなどの新しいアプローチやGPU間の通信オーバーヘッドを削減することも、二次コストの問題を解決するようには見えません。

上記に課題を解決する試み

Mambaの登場

Mambaは、Transformerに代わるアーキテクチャであり、その主な違いは、それが状態を持つアーキテクチャであることです。言い換えると、Transformerが常に全文脈を考慮する必要があるのに対し、前述したように、Mambaは固定サイズのメモリまたは状態を持っています。
本を書いていると想像してみてください。今、101ページを書いています。しかし、新しい単語を書くたびに、毎回前のページ全てを再度読み直して文脈を把握しなければなりません。これがTransformerの動作方法であり、彼らは前の文脈の更新された状態を保持する能力がないためです。
一方で、Mambaのアーキテクチャは、すべての前の文脈の更新可能な圧縮されたメモリを保持します。この状態はサイズが固定されているため(シーケンスの長さに比例して増加するKVキャッシュとは異なり、計算効率がはるかに高い)、新しい入力ごとに、Mambaブロックはその文脈に適切かどうかを判断しなければなりません。 たとえば、次の単語が「um」であれば、メモリを更新したくないでしょう。
現実は、Mambaは品質面でTransformerに劣っていますので、研究者はなんらかの形でTransformerを保持する必要があります。完全に証明されていませんが、Mambaが品質面で劣る理由は、帰納ヘッドを作成できないこと、つまり、データからのパターンの「コピー&ペースト」能力が、インコンテキスト学習、LLMの主要なスーパーパワーであると考えられているからです。

MoE(Mixture of Experts)

単純に言えば、各予測に対してモデルの一部のみが実行され、研究者は計算をかなり削減しながらも巨大なサイズにスケーリングできます。Jambaの場合、520億のパラメータがあるにもかかわらず、単一の予測ごとに120億しか活性化されません。
これに加えて、JambaがTransformerとMambaブロックの比率が1:7であることから、必要なキャッシュは非常に小さいということです。

また、MambaとMoEを併用することで、LLaMa2 13Bのようなはるかに小さなモデルよりもはるかに大きな改善されたスループット(1秒あたりのトークン予測数)を持つことができます。

モデルは、計算とメモリの要件がかなり削減されているにもかかわらず、標準のアテンションベースのモデルに対して十分に競争力があります。


Jambaについて

トランスフォーマーアーキテクチャは、LLM(Large Language Models)において卓越した進歩を遂げ、AI研究開発に最も使われる技術となっています。トランスフォーマーがAGI(Artificial General Intelligence)に到達する最終的なアーキテクチャとなるか、それとも新しいアーキテクチャパラダイムが実現可能性があるかという問いは、AIコミュニティにおいて熱い議論の的となっています。最近、プリンストン大学とカーネギーメロン大学の研究者たちは、状態空間モデル(State Space Model: SSMs)に基づくMambaアーキテクチャを提案し、これがトランスフォーマーに代わる最も実現可能な選択肢となっています。
SSMsとトランスフォーマーの対立ではなく、両者を組み合わせることはできないか?これがAI21 Labsによって提案された新しいモデルJambaのテーマです。Jambaは、トランスフォーマーとSSMsを組み合わせた単一のアーキテクチャであり、LLMの未来に新たな可能性を切り開くことができるかもしれません。
最近、Mambaなどが採用してる状態空間モデル(SSM)と呼ばれる新しいタイプのモデルが、トランスフォーマーモデルには及ばないものの、訓練効率が高く、テキストの長距離関係に対処する能力が向上していることが示されています。

Jamba アーキテクチャの特徴

Jambaの主要な革新は、トランスフォーマーとMambaレイヤーをMoEコンポーネントと組み合わせたハイブリッド設計です。このJambaブロックと呼ばれるユニークなブレンドにより、低メモリ使用、高処理速度、および高品質の出力を管理する柔軟なアプローチが可能となります。一般的な誤解である大きなモデルほど多くのメモリが必要とされることに対し、MoEの使用により、モデルのパラメータのほんの一部しか常にアクティブであるため、メモリ要求が大幅に削減されます。さらに、一部のTransformerレイヤーをMambaレイヤーに置き換えることで、Jambaは伝統的なTransformerと比較して、処理に必要なキー値(KV)キャッシュのサイズを大幅に削減し、最大8倍の減少を実現します。最近のモデルとの比較により、Jambaは256,000トークンまで処理する際にも、より小さなKVキャッシュを維持する効率性を示しています。
Jambaの中核をなすのは、Mambaとアテンションメカニズムを組み合わせた層のシーケンスであるJambaブロックです。各ブロックには、それぞれの後にマルチレイヤーパーセプトロン(MLP)が続きます。これらのブロック内では、アテンションとMambaレイヤーの比率を調整して、特に長いシーケンスに対してメモリ使用量と計算速度の適切なバランスを取ることができます。一部のMLPはMoEレイヤーと交換することができ、計算オーバーヘッドを低く保ちながらモデルの容量を向上させます。このモジュラーデザインにより、Jambaはコアコンポーネントのミックスを調整することで、計算効率とメモリ使用量の優先順位を柔軟に設定することができます。

Jambaが特に革新的なところ

Jambaの主要な革新点は、トランスフォーマーとMambaレイヤーをMoEコンポーネントと組み合わせたハイブリッド設計です。このJambaブロックと呼ばれるユニークなブレンドにより、低メモリ使用、高処理速度、および高品質の出力を管理する柔軟なアプローチが可能となります。一般的な誤解である大きなモデルほど多くのメモリが必要とされるという点にもかかわらず、MoEの使用により、モデルのパラメータのほんの一部しか常にアクティブでないため、メモリ要求が大幅に削減されます。さらに、一部のTransformerレイヤーをMambaレイヤーに置き換えることで、Jambaは処理に必要なキー値(KV)キャッシュのサイズを大幅に削減し、従来のTransformerと比較して最大8倍の減少を実現します。最近のモデルとの比較により、Jambaは256,000トークンまで処理する際にも、より小さなKVキャッシュを維持する効率性を示しています。
Jambaの中核をなすのは、Mambaとアテンションメカニズムを組み合わせた層のシーケンスであるJambaブロックです。これらのブロック内では、メモリ使用量と計算速度の適切なバランスを取るために、アテンションとMambaレイヤーの比率を調整することができます。一部のMLPはMoEレイヤーに置き換えることができ、計算オーバーヘッドを低く保ちながらモデルの容量を向上させます。このモジュラーデザインにより、Jambaは異なるタスクに優先順位を付ける柔軟性を持っています。

Jambaのパフォーマンス改善は驚異的

Jambaの初期のパフォーマンスは、次の図で示されているように、さまざまなベンチマークで非常に注目に値します。
さらに興味深いのは、このパフォーマンスの向上が異なる側面で現れているという事実です。

効率性

Jambaの設計により、単一の80GB GPUで動作することが可能であり、高品質と高速処理のバランスを提供しています。4つのJambaブロックをセットアップすると、Mixtralの2倍のコンテキスト長をサポートし、Llama-2-70Bの7倍のコンテキスト長をサポートします。

スループット

モデルのスループット、つまりテキストを処理する速さは、少量のテキストと大量のテキストの両方を利用する設定で著しい利点を示しています。たとえば、A100 80 GB GPUを1つ使用する場合、Jambaは大規模なバッチに対してMixtralの3倍のスループットを達成します。さらに、複数のGPUを使用する長いテキストのシナリオでは、Jambaはその優れたパフォーマンスを維持し、特に128,000トークンまでの処理において顕著です。

コスト

Jambaの効率性により、1つのGPUで最大140,000トークンを処理できます。これにより、高度なテキスト処理モデルを広範囲のアプリケーションで利用するために大規模なハードウェアが必要なくなり、より多くの用途で利用できるようになります。

ミストラルが最近、MoEsとLLMsを組み合わせることで大きな改善を見せたように、Jambaは生成AI領域において、大きな改善を見せています。トランスフォーマー、SSMs、MoEsの組み合わせは新しいLLMsの標準を設定することができます。


参照文献:

この記事が気に入ったらサポートをしてみませんか?