見出し画像

RWKVの学習の安定性について



はじめに

RWKVに関して二つのブログを書いてきました(その1その2)が,追加で学習(勾配)の安定性についても日本語で書いておこうと思います.というのも,RWKV-4の学習では以下の画像のように,LLMでありがちなLoss Spikeが確認されなかったという実験結果があります.

screenshot from: "no loss spikes!" - BlinkDL/RWKV-LM GitHub 2024/02/21

これがなぜそうなのか,という根拠を若干不十分な説明でありますが,Appendix Hに書いてあったので,文脈を補いながら解説します.

前準備:勾配と学習の安定性

RWKVにおいて,TransformerのAttentionに相当する機構であるwkvは以下の数式でした.

$$
wkv_{t} = \frac{\sum\limits_{i=1}^{t-1} e^{-(t-1-i)w+k_{i}}v_{i} + e^{u+k_{t}}v_{t}}{\sum\limits_{i=1}^{t-1} e^{-(t-1-i)w+k_{i}} + e^{u+k_{t}}}
             (1)
$$

ここで,この式を簡略化するために,valueは$${v_t = W_v x_t}$$,指数関数部分は$${K_t^e = e^{W_k x_t + w_{T,t}}}$$と置き換えます.また,トークンは最終トークン$${T}$$を見ることとします.すると,wkvは次のようになります.Eは平均,Sは重み付きの総和を表す関数です.wkvは$${v_t}$$の平均とも見れ,その分母と分子を,重み$${K_t^e}$$の総和と,重みつき$${v_t}$$の総和とも見ることができます.

$$
wkv_T =\frac{\sum_{t=1}^{T} e^{W_k x_t + w_{T,t}}  v_t}{\sum_{t=1}^{T} e^{W_k x_t + w_{T,t}}}= \frac{\sum_{t=1}^{T} K_t^e v_t}{\sum_{t=1}^{T} K_t^e} = E(v_t) = \frac{S(v_t)}{S(1)}             (2)
$$

wkv層を出ると,以降の層$${f(wkv_t)}$$に入り,それが正解ラベル$${y_t}$$との誤差が評価されます.今回,最終トークン$${T}$$の損失を見るとすると,以下の式が得られます.

$$
L_T=l(f(wkv_T),y_T)            (3)
$$

機械学習における学習は,勾配の逆向きの更新により行われます(勾配降下法).従って,RWKVにおいて損失を減らす向きに,ある隠れ層$${a}$$のパラメータ$${(W_a)_{i,j}}$$を更新をすると,以下のようになります.

$$
{ (W_a)_{i,j}} ← { (W_a)_{i,j}} - γ\frac{\partial L_T}{\partial (W_a)_{i,j}}   (4)
$$

発散しない勾配

式(4)を元に,$${W_v}$$の更新がどのような上限が加わるかを見ます.

wkv層に対して合成関数の偏微分の連鎖律を考えると,以下のように損失は損失→wkv層→valueと,経由できます.

$$
\frac{\partial L_T}{\partial (W_v)_{i,j}} = \frac{\partial L_T}{\partial (wk v_T)_i} \cdot \frac{\partial (wk v_T)_i}{\partial (W_v)_{i,j}}  (5)
$$

連鎖律はRWKVのアーキテクチャ図を見ると理解しやすいと思います.

RWKV: Reinventing RNNs for the Transformer Era  Fig2より.ピンク色は著者が追加

勾配の発散は正負関係無しに発生するため,式(5)のwkv層から伝わる勾配に絶対値をとります.

$$
\frac{\partial L_T}{\partial (W_v)_{i,j}} = \frac{\partial L_T}{\partial (wk v_T)_i} \cdot \frac{\partial (wk v_T)_i}{\partial (W_v)_{i,j}}→\frac{\partial L_T}{\partial (wk v_T)_i} \cdot \left|\frac{\partial (wkv_T)_i}{\partial (W_v)_{i,j}}\right|  (6)
$$

この絶対値をさらに見ていくと,式(2)を用いて

$$
\left|\frac{\partial (wkv_T)_i}{\partial (W_v)_{i,j}}\right|=\left|\frac{\partial \left( \frac{\sum_{t=1}^{T} K_t^e v_t}{\sum_{t=1}^{T} K_t^e} \right)_i}{\partial (W_v)_{i,j}}\right|  (7)
$$

式(7)の$${wkv_T}$$に式(2)を代入し,最初に書いたように$${v_t=W_vx_t}$$であることを踏まえると

$$
\left|\frac{\partial (wkv_T)_i}{\partial (W_v)_{i,j}}\right|=\left| \frac{\partial E_i[(v_t)_i]}{\partial (W_v)_{i,j}} \right| = \left| \frac{\partial E_i[( W_v x_t)_i]}{\partial (W_v)_{i,j}} \right| = \left| \frac{ E_i[(\partial W_v x_t)_i]}{\partial (W_v)_{i,j}} \right|=\left| E[(x_t)_j]\right|    (8)
$$

式(8)を見たらわかる通り,入力$${x_t}$$の平均の絶対値となっています.平均の絶対値は,その要素の最大値の絶対値を超えない ことから,以下の不等式が成り立ちます.

$$
\left|\frac{\partial (wkv_T)_i}{\partial (W_v)_{i,j}}\right|=\left| E[(x_t)_j]\right|  = \left| \frac{x_{1j} + \ldots + x_{tj} + \ldots + x_{Tj}}{T} \right| \leq \max_t \left| (x_t)_j \right|     (9)
$$

つまり,wkv層における$${W_v}$$を更新する勾配は,入力列の最大要素を超えない,という上限がつくわけです.

加えて,式(9)の上限に$${T}$$が出てこないため,系列の長さによって勾配の上限は制限されない,系列長により勾配が消失がしずらいとわかります.

消失しない勾配

式(4)から,同様に$${W_k}$$の更新にどのような下限が加わるかを見ます.

$$
\frac{\partial L_T}{\partial (W_k)_{i,j}} = \frac{\partial L_T}{\partial (wk v_T)_i} \cdot \frac{\partial (wk v_T)_i}{\partial (W_k)_{i,j}}  (10)
$$

式(10)のwkvに式(2)を代入すると

$$
\frac{\partial (wk v_T)_i}{\partial (W_k)_{i,j}}= \frac{\partial S_i[(v_t)_i]}{\partial (W_k){i,j}} \cdot \frac{S_i(1)}{\partial (W_k){i,j}}=\frac{\partial}{\partial (W_k){i,j}} \left( \frac{\sum_{t=1}^T (K^t)_i(v_t)_i}{S_i(1)} \right)     (11)
$$

商の偏微分公式$${\frac{∂}{∂x} \left( \frac{f(x)}{g(x)} \right) = \frac{f'(x) \cdot g(x) - f(x) \cdot g'(x)}{g(x)^2}}$$から

$$
= \frac{\left( \frac{\partial}{\partial (W_k)_{i,j}} \sum_{t=1}^T (K^t)_i(v_t)_i \right) \cdot S_i(1) - \left( \sum_{t=1}^T (K^t)_i(v_t)_i \right) \cdot \frac{\partial}{\partial (W_k)_{i,j}} S_i(1)}{S_i(1)^2}   (12)
$$

部分ごとの偏微分を計算すると,最初に簡素化した式を参照して,

  • $${\frac{\partial (v_t)_i}{\partial (W_v)_{i,j}} = (x_t)_j}$$

  • $${\frac{\partial ({K}_t^e)_i}{\partial ({W}k){i,j}} = (x_t)_j ({K}_t^e)_i}$$.

これらを式(12)に戻すと

$${= \frac{(\sum_{t=1}^T (x_t)_j (K^t)_i(v_t)_i) \cdot (\sum_{t=1}^T (K^t)_i(v_t)_i) - (\sum_{t=1}^T (x_t)_j (K^t)_i)}{S_i(1)^2} \\= \frac{S_i[(x_t)_j(v_t)_i]}{S_i(1)} - \frac{S_i[(x_t)_j]S_i[(v_t)_i]}{S_i(1)^2} \\= E_i[(x_t)_j(v_t)_i] - E_i[(x_t)_j]E_i[(v_t)_i] }$$

これは共分散の式であるため,

$${= \text{cov}_i((x_t)_j, (v_t)_i)    (13)}$$

(2)式よりwkv層の計算において,$${x_t}$$と$${v_t}$$の共分散は0にはならないことから,

$$
\frac{\partial (wk v_T)_i}{\partial (W_k)_{i,j}} = \text{cov}_i((x_t)_j, (v_t)_i) ≠0  (14)
$$

よって,wkv層における$${W_k}$$を更新する勾配は0にならない,消失しないとなります.

まとめ

以上,RWKVを構成するwkvのパラメータは勾配と発散,両者の対策がなされていることが数式からわかります.故に最初に示した画像のように,loss spikeが発生しなかったことが説明できる,とされています.

誤り等あれば,コメント or @gojiteji まで連絡お願いします🙇.

参考文献

関連ブログ


この記事が気に入ったらサポートをしてみませんか?