【論文要約】Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models【メモ】

イントロダクション

今回は『Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models』という以下の論文を要約する。論文のpdfをClaude 3 Opusに渡して要約させた。


研究の目的と背景

<purpose>

本研究の目的は、より効率的で高性能な言語モデルを開発することである。具体的には、以下の2つの新しいモデルアーキテクチャを提案している。

  1. Hawk: ゲート付き線形リカレントレイヤー(RG-LRU)を用いた純粋なリカレントニューラルネットワーク(RNN)モデル

  2. Griffin: RG-LRUとローカルアテンションを組み合わせたハイブリッドモデル

Transformerは自然言語処理のデファクトスタンダードとなっているが、長い文脈を扱う際の計算効率や推論速度に課題がある。一方、RNNは長い文脈を効率的に扱えるが、学習が難しく大規模化が困難という問題がある。本研究ではこれらの課題に取り組み、Transformerに匹敵する性能を持ちつつ、より効率的で高速な言語モデルの実現を目指している。

提案手法の意義は、大規模言語モデルの学習と推論の効率を大幅に改善できる点にある。これにより、より長い文脈を扱うことが可能となり、言語生成や機械翻訳など様々なタスクの性能向上が期待できる。また、計算リソースの削減にもつながり、言語モデルの幅広い活用に貢献すると考えられる。

本研究の新規性は、RG-LRUという新しいリカレントレイヤーを提案し、それをTransformerと組み合わせたGriffinというハイブリッドモデルを考案した点にある。RNNとTransformerの長所を組み合わせることで、これまでにない効率的かつ高性能な言語モデルを実現している。

<background>

大規模言語モデルは自然言語処理の様々なタスクで成功を収めてきたが、主にTransformerアーキテクチャに基づいている。Transformerはアテンション機構により並列化が可能で、RNNに比べて学習が容易だが、文脈長に対して二次の計算量を要するため、長い文脈を扱う際の効率が悪い。また、系列長に比例してキャッシュサイズが増大するため、推論速度が低下するという問題もある。

一方、RNNは固定サイズの隠れ状態に系列情報を圧縮するため、長い文脈に対しても効率的だが、勾配消失の問題があり学習が難しい。また、並列化が困難なため、大規模化が容易ではない。Mambaなどの最近の研究では、RNNの学習と推論の効率を改善する試みがなされているが、まだTransformerには及ばない。

本研究は、RNNとTransformerの長所を組み合わせることで、両者の問題点を克服することを目指している。提案するRG-LRUレイヤーは、ゲート機構を備えた新しい線形リカレントレイヤーであり、効率的な学習と推論を可能にする。さらにGriffinモデルでは、RG-LRUとローカルアテンションを組み合わせることで、グローバルアテンションを用いずにTransformerと同等以上の性能を達成している。

使用した手法の概要

<methods>

本研究では、以下の2つの新しい言語モデルアーキテクチャを提案している。

  1. Hawk: ゲート付き線形リカレントレイヤー(RG-LRU)を用いた純粋なRNNモデル

  2. Griffin: RG-LRUとローカルアテンションを組み合わせたハイブリッドモデル

RG-LRUは、Linear Recurrent Unit (LRU) (Orvieto et al., 2023b)に着想を得た新しいリカレントレイヤーである。LRUの線形リカレンスに、LSTMやGRUなどの非線形RNNにヒントを得たゲート機構を組み込んでいる。RG-LRUの更新式は以下の通りである。

$${r_t = \sigma(W_a x_t + b_a)}$$ (recurrence gate)
$${i_t = \sigma(W_x x_t + b_x)}$$ (input gate)
$${a_t = a^{cr_t}}$$
$${h_t = a_t \odot h_{t-1} + \sqrt{1-a_t^2} \odot (i_t \odot x_t)}$$

ここで、$${\sigma}$$はシグモイド関数、$${\odot}$$は要素ごとの積、$${c}$$はスカラー定数(8に設定)である。recurrence gate $${r_t}$$はRG-LRU独自のゲートで、uninformativeな入力の影響を減らすことで指数オーダー以上のメモリを実現すると考えられる。対角行列$${a}$$はパラメータ$${\Lambda}$$を用いて$${a=\sigma(\Lambda)}$$と表現され、リカレンスの安定性を保証する。

RG-LRUを用いたrecurrent blockは、2つの線形層とConv1Dを組み合わせた構造になっている(論文Figure 2(c))。HawkモデルはMLPとrecurrent blockを交互に積層することで構成される。

Griffinモデルは、recurrent blockとローカルアテンション(Beltagy et al., 2020)を組み合わせたハイブリッドモデルである。論文には、「Griffinは、2つのrecurrent blockの後に1つのローカルアテンションブロックが続く、階層的な構造を採用している」と記載されている。ローカルアテンションは、各位置が過去の一定トークンにのみアテンドするため、系列長によらず一定のキャッシュサイズで計算が可能である。

HawkとGriffinは、共通の残差ブロック(論文Figure 2(a))とゲート付きMLPブロック(論文Figure 2(b))を用いる。これにより、様々なスケールでTransformerと同等の性能を実現している。

全てのモデルをAdamWオプティマイザで学習し、小さいモデルでハイパーパラメータをチューニングした。長さ2048のシーケンスを用い、7Bパラメータまでスケールした。Griffin-14Bを含む一部の大規模モデルでは、最大8192トークンでの学習も行った。

<comparison>

本研究のRG-LRUは、ゲート機構の導入によりLRUを拡張したものである。$${r_t}$$ゲートはLSTMのforget gateに似ているが、隠れ状態に依存しない点が特徴的である。これによりデバイス上で効率的な計算が可能になっている。

Griffinは、RNNとTransformerの長所を組み合わせたハイブリッドモデルである。RNNブロックが長期の情報を、ローカルアテンションが短期の情報を捉えると期待される。類似のアプローチとしてgated-SSMブロック(GSS)(Mehta et al., 2022)やMambaのブロック(Gu and Dao, 2023)があるが、アテンションを組み込んでいない点が異なる。Griffinはアテンション使用量を抑えつつTransformerと同等の性能を達成しており、効率と性能のバランスに優れている。

提案手法は工夫により、学習時にTransformerと同等の効率を実現している。推論時は、固定サイズの隠れ状態により、Transformerを大きく上回る効率と高速性を達成した。さらに訓練シーケンス長を超える長さでの外挿性能も示しており、柔軟性の高さがうかがえる。

得られた主な結果

<main_results>

本研究では、提案するHawkとGriffinモデルについて、言語モデリングタスクでの性能を様々な角度から評価し、以下の主要な結果を得た。

  1. HawkとGriffinは、学習に使用するコンピュート量(FLOPs)と検証用データでの損失の間にベき乗則のスケーリング関係を示した(論文Figure 1(a))。これは、Transformerで観測されてきた関係と同様であり、提案モデルが大規模化に適していることを示唆している。

  2. Hawk、Griffin、MQA Transformerを300Bトークンで学習し、ダウンストリームタスクで評価したところ(論文Table 1)、Hawk-3BはMamba-3Bを、Griffin-7BとGriffin-14BはLlama-2-7B,Llama-2-13Bを上回る性能を示した。特筆すべきは、MambaやLlamaが本研究の6〜7倍のトークン数で学習されているのに対し、HawkとGriffinは300Bトークンのみで同等以上の性能を達成したことである。

  3. TPU-v3上での学習時、HawkとGriffinはTransformerと同等の効率を実現した。Pallas (Bradbury et al., 2018)で実装したRG-LRUレイヤーのカーネルにより、メモリ転送を最小化したことが効率化に寄与している。

  4. 推論時、HawkとGriffinはMQA Transformerを大幅に上回るスループットを達成し(論文Figure 1(b))、長いシーケンスのサンプリング時には低レイテンシも実現した。

  5. Griffinは訓練時より長いシーケンスでTransformerを上回る性能を示し、訓練データからコピーやリトリーバルのタスクを効率的に学習することができた。ただし、ファインチューニングを行わずにコピーと正確なリトリーバルを評価した場合、事前学習済みのHawkとGriffinはTransformerほど高い性能を示さなかった。

<details>

HawkとGriffinのスケーリング則に関する重要な知見は、損失が学習のFLOPsに対してベキ乗則に従って減少するというパターンである。論文Figure 1(a)が示すように、このスケーリング関係は、Transformerで過去に観測されてものと一致している(Kaplan et al., 2020)。この結果は、HawkとGriffinが大規模モデルへのスケーリングに適していることを示唆している。ただし、本研究ではパラメータ数が最大14B(Griffin-14B)までの検証にとどまっており、さらなる大規模化の効果は未知数である。

ダウンストリームタスクの評価では、Hawk-3BがMamba-3Bを、Griffin-7BとGriffin-14BがLlama-2を上回った(論文Table 1)。このことから、提案モデルの実践的な有効性が示された。ただし、Mambaは600Bトークン、Llama-2は2Tトークンで学習されているのに対し、HawkとGriffinの学習トークン数は300Bにとどまっている。学習データの違いを考慮する必要があるが、提案モデルが少ないデータで効率的に学習できる可能性が示唆される。

推論速度に関しては、論文Figure 1(b)に示されるように、1Bパラメータのモデルを用いてバッチサイズ1で推論した場合、文長が長くなるほどHawkとGriffinの優位性が顕著になった。Transformerのキャッシュサイズが文長に比例して増大するのに対し、HawkとGriffinは固定サイズの隠れ状態を用いるためである。ただし、この優位性はモデルサイズにも依存すると考えられる。

論文内での結果の解釈や考察

<structure>

本論文では、結果の解釈や考察は主に4章から7章にかけて行われている。4章ではHawkとGriffinのスケーリング特性と学習効率について、5章では推論速度について、6章では長文脈の取り扱いとコピー・リトリーバル能力について、7章では関連研究との比較について議論されている。

著者は提案モデルのスケーリング特性とTransformerに匹敵する性能を重視しており、4章と6章で多くの紙幅を割いている。また、5章では推論時の優位性を強調している。7章では提案モデルを関連研究の文脈に位置づけ、新規性と有効性を主張している。全体として、提案モデルの優位性を多角的に示す構成になっている。

<interpretations>

4章で示されたHawkとGriffinのスケーリング則は、提案モデルが大規模化に適していることを示唆している。著者は、このスケーリング則がTransformerと同様のものであると指摘し、「Hawk and Griffin exhibit power law scaling between held-out loss and training FLOPs, up to and beyond 7B parameters, as previously observed for Transformers (Kaplan et al., 2020).」と述べている。この結果は、RNNベースのモデルが大規模言語モデルとして有望であるという著者の主張を裏付けるものである。

また、4章ではGriffinがTransformerを上回る性能を示したことが強調されている。「Griffin achieves slightly lower held-out loss than strong Transformer baselines at all model scales.」という記述から、著者はGriffinの優位性を示唆していると解釈できる。この結果は、RNNとアテンションの組み合わせが有効であるという著者の仮説を支持するものである。

6章では、HawkとGriffinが長文脈からの学習とコピー・リトリーバルにおいて優れた性能を示したことが報告されている。著者は「Griffin performs better than Transformers when evaluated on sequences longer than those seen during training, and can also efficiently learn copying and retrieval tasks from training data」と述べ、提案モデルの長文脈処理能力の高さを主張している。この結果は、RNNが長期依存関係のモデリングに適しているという先行研究の知見と整合するものである。

一方で、著者は事前学習済みモデルのコピー・リトリーバル能力の限界についても指摘している。「Hawk and Griffin perform less well than Transformers when we evaluate pre-trained models on copying and exact-retrieval tasks without fine-tuning.」という記述から、提案モデルがこれらのタスクに必要な能力を事前学習のみで獲得することの難しさが示唆されている。この結果は、事前学習済みモデルの能力とタスク固有の適応の必要性についての議論に関連するものである。

<arguments>

本論文の主要な主張は、提案するHawkとGriffinが大規模言語モデルとして有望であり、Transformerに匹敵する性能を達成できるというものである。4章のスケーリング則の結果から、著者は提案モデルが大規模化に適していると結論づけている。また、ダウンストリームタスクでの評価結果から、HawkとGriffinがTransformerと同等以上の性能を発揮できると主張している。

さらに著者は、HawkとGriffinが推論時の効率と速度においてTransformerを上回ることを強調している。5章の結果から、提案モデルが少ないメモリで高速な推論を実現できると論じている。この主張は、大規模言語モデルの実用化に向けた重要な知見であると位置づけられる。

また、6章の結果から、著者はHawkとGriffinが長文脈からの学習に優れていることを示唆している。この知見は、提案モデルが長期依存関係のモデリングに適していることを示唆するものであり、言語モデルの文脈処理能力の向上に寄与すると考えられる。

著者は、これらの結果が大規模言語モデルの発展に向けた新しい方向性を示すものであると主張している。Transformerの問題点を克服しつつ、その利点を取り入れた新しいアーキテクチャの可能性を示したと言える。今後は、より大規模なモデルの構築と、様々なタスクへの応用が期待される。

JambaとGriffinの比較

Jambaと本研究のGriffinは、どちらもTransformerとSSMを組み合わせたハイブリッドアーキテクチャを採用している点で共通しているが、モデルの具体的な構成や規模、評価実験などにはいくつかの違いが見られる。以下、それぞれの特徴を比較しながら、両者の類似点と相違点を論じる。

<アーキテクチャの比較>

Jambaは、TransformerレイヤーとMambaレイヤー(SSMの一種)を交互に積層し、さらにMoEを組み込んだアーキテクチャを採用している。各Jambaブロックは、アテンションレイヤーとMambaレイヤーを1:7の比率で混合し、2レイヤーごとにMoEを適用している。一方、Griffinは2つのRG-LRUブロック(SSM)の後に1つのローカルアテンションブロックを配置する階層的な構造を採用しており、MoEは用いていない。

両者に共通するのは、TransformerとSSMを組み合わせることで、メモリ使用量、スループット、性能のバランスを取ろうとしている点である。アテンションレイヤーの割合を減らすことでメモリ使用量を抑え、長い文脈の処理を可能にしている。一方、Mambaレイヤーを増やすことで、特に長いシーケンスにおけるスループットを改善している。

ただし、JambaはMoEを導入することでさらにモデル容量を増やしつつ、アクティブパラメータ数を抑えている点が特徴的である。Griffinはより単純なアーキテクチャを採用しているが、RG-LRUとローカルアテンションの組み合わせが有効であることを示している。

<モデルスケールの比較>

Jambaの最大構成は、アクティブパラメータ数が12B、利用可能な総パラメータ数が52Bであり、80GBのGPU1基に収まるように設計されている。一方、Griffinの最大構成は14Bパラメータであり、Jambaと比べるとやや小規模である。

ただし、JambaはMoEを用いることで、アクティブパラメータ数を抑えつつ大規模なモデルを構築できるというメリットがある。Griffinは純粋なTransformerやMambaと比べて少ないパラメータ数で同等以上の性能を達成しており、パラメータ効率の高さが特徴と言える。

<評価実験の比較>

Jambaは、標準的な言語モデルのベンチマークや長文脈の評価において、同規模のMixtral-8x7BやLlama-2 70Bと同等の性能を達成している。また、256Kトークンまでの長さの文脈を扱うことができ、現在公開されている大規模言語モデルの中で最長の文脈長をサポートしている。

Griffinも同様に、様々なタスクにおいてTransformerと同等以上の性能を示している。特に、学習したシーケンス長を超える長さでの外挿性能に優れ、訓練データからのコピーやリトリーバルタスクを効率的に学習できることが示されている。

両者とも、ハイブリッドアーキテクチャによってSSMの長所を活かしつつ、Transformerに匹敵する性能を達成できることを実証している点で共通している。一方、Jambaは特に推論時の効率とスループットの高さを重視しているのに対し、Griffinは少ないパラメータ数での高性能化に重点を置いている印象がある。

この記事が気に入ったらサポートをしてみませんか?