Continual backpropagation:AIは刻一刻と変化する世界で学び続けられるのか?
2020年、世界はコロナウイルスによるパンデミックに襲われた。約1年の期間、多くの人が自宅から必要最低限な外出しかできなかった。この生活への影響は甚大であったが、一方で、このパンデミックは、環境が生まれ育ったものから大きく変化してもその環境に行動を順応し、生活を継続していくというヒトの柔軟性の高さを浮き彫りにする機会でもあった。
AIは言語処理、生物学、ゲーム、ロボットなど、様々な分野で成功した。これらの成功の技術的カギはディープラーニングであり、ニューロン間の重みの学習則として確率的勾配降下法(SGD)と誤差逆伝搬アルゴリズムが用いられる。ほとんどの適用例において、ディープラーニングは重みを学習するトレーニングフェーズと重みを固定してモデルを評価したり実問題にモデルを適用したりするテストフェーズに分けて使用される。
では、ディープラーニングに基づくAIは刻一刻と変化する環境で新しいデータを次々と得た時に、古いデータと新しいデータの重要性を評価し、継続的に学習することができるのだろうか?
コロナ禍でAIは既存の尺度での幸福度を最大化することをあきらめつつも、生活水準を一定に保つような、または新たな楽しみを見つけるような意思決定を行うことができるのだろうか?
本研究では、既存のディープラーニング手法は継続的な学習が必要な課題で徐々に可塑性を失い、学習ができなくなることを示した。AIの継続的な学習を達成するために、継続的に多様性を維持する、継続誤差逆伝搬アルゴリズム(Continual backpropagation)を開発し、教師あり学習、強化学習の課題で使用可能なことを示した。
Dohare, S., Hernandez-Garcia, J.F., Lan, Q. et al. Loss of plasticity in deep continual learning. Nature 632, 768–774 (2024). https://doi.org/10.1038/s41586-024-07711-7
Methods
後のResultsセクションで出てくる手法と開発手法の説明を行う。継続学習に効果のある既存の手法としては、L2正則化とShrink and Perturbがある。L2正則化は各ステップで重みが大きくなりすぎないように0に近づける手法である。Shrink and Perturbでは、L2正則化に加え、結合の重みに小さいランダムな値を加える。
提案手法の継続的誤差逆伝搬アルゴリズムでは、トレーニングの過程で最も使われていない少数のニューロンの重みをランダムに初期化するという操作を行い、人工的に重みのばらつきを大きくし、いくつかの重みを小さく保つ。ChatGPTに出力させたこのアルゴリズムのPython風の疑似コードは以下のようである。
import numpy as np
# ハイパーパラメータを設定:入れ替え率 (ρ)、減衰率 (η)、成熟閾値 (m)
rho = 0.1 # 例として入れ替え率を設定
eta = 0.01 # 例として減衰率(学習率)を設定
maturity_threshold = 10 # 例として成熟閾値を設定
# 各層の重みを初期化。Lは層の数と仮定
L = 5 # 例として層の数を設定
weights = [np.random.normal(size=(10, 10)) for _ in range(L - 1)] # 各層の重みを正規分布からランダムに初期化
# 各層のユーティリティ、年齢、置換対象ユニットの数を初期化
utilities = [np.zeros(10) for _ in range(L - 1)] # 各ユニットのユーティリティを0で初期化
ages = [np.zeros(10) for _ in range(L - 1)] # 各ユニットの年齢を0で初期化
units_to_replace = [0 for _ in range(L - 1)] # 各層で置換されるユニットの数を0で初期化
# データストリームからの入力ごとにループ処理を行う
for xt in data_stream: # data_streamは事前に定義されているものと仮定
# フォワードパス:入力 xt をネットワークに通して予測を取得
activations = xt
for l in range(L - 1):
activations = np.dot(activations, weights[l]) # 各層での線形変換
# 必要に応じて、ReLUやシグモイドなどの活性化関数を適用
# 順伝播後に損失を計算すると仮定
loss = compute_loss(activations, target) # 予測に基づく損失関数を計算
# バックプロパゲーション:勾配を計算し、SGDまたはそのバリエーションで重みを更新
for l in reversed(range(L - 1)):
grad = compute_gradient(loss, weights[l]) # この層の勾配を計算
weights[l] -= eta * grad # SGDによる重み更新
# 出力層を除く各層でユーティリティ、年齢、置換ロジックを更新
for l in range(1, L - 1):
# 層lの各ユニットの年齢を更新
ages[l] += 1
# ユーティリティを式(1)に基づいて更新
utilities[l] = update_utility(utilities[l], activations) # ユーティリティを更新する関数
# 成熟閾値を超えるユニットを特定(年齢が閾値を超えるユニットの数)
eligible_units = np.where(ages[l] > maturity_threshold)[0] # 閾値を超えたユニット
n_eligible = len(eligible_units)
# 層lでの置換対象ユニットの数を更新
units_to_replace[l] += n_eligible * rho
# 置換対象ユニットが存在する場合
while units_to_replace[l] > 1:
# 最もユーティリティが小さいユニットを見つけ、そのインデックスを取得
r = np.argmin(utilities[l])
# 入力重みを再初期化:このユニットの重みを再サンプリング
weights[l - 1][:, r] = np.random.normal(size=weights[l - 1][:, r].shape)
# 出力重みも再初期化:該当ユニットの出力側の重みを0にする
weights[l][r, :] = 0
# 該当ユニットのユーティリティと年齢をリセット
utilities[l][r] = 0
ages[l][r] = 0
# 置換対象ユニットの数を減少させる
units_to_replace[l] -= 1
Results
図1は、継続的な教師あり学習でディープラーニングの性能が低下していくことを示している。ImageNetデータセットで2値分類タスクを畳み込みニューラルネットワーク(CNN)で継続学習させると(図1A)、古典的なディープニューラルネットワーク(DNN)は85%程度の高性能を達成することができるが、その後タスクが増えるにつれて徐々に性能が線形のネットワークと同程度まで下がることが分かった(図1B)。一方で、提案手法の継続的誤差逆伝搬法やShrink and Perturb手法、L2正則化を用いると、タスクを増やし続けても性能が落ちることがなかった(図1C)。
図3は、環境が変化する強化学習課題でディープラーニングの性能が低下していくことを示している。シミュレーター上のAntと呼ばれるエージェントの関節にかかるトークを制御し、前に進ませる強化学習課題を用いた(図3A)。この際にゆかの抵抗を途中で変化させ(図3B)、その攪乱にエージェントが順応できるか評価すると、一般的なPPOと呼ばれる強化学習アルゴリズムは性能が低下していくが、継続的誤差逆伝搬に加え、L2正則化と調整を行ったPPOアルゴリズムを用いると、モデルの性能を攪乱後も維持することができた(図3C)。
AdamやDropout、正規化は可塑性の減少に効果がなかったか悪化させた(Extended Data 図4A)。さらに、古典的ディープラーニングアルゴリズムが継続学習を行えない要因として、ほとんど活動しないニューロンが継続学習で増加していくこと(図2C、図4C)、ステーブルランクが学習に伴い下がっていく(ニューロンの多様性が下がっていく)(図2D、図4B)ことが挙げられている。
Discussion
様々なモデル構造、学習アルゴリズム、パラメーター空間において、ディープラーニングがトレーニングとテストに分けられる課題設定では効果的であるものの、継続的に学習が必要な課題では最終的に浅いニューラルネットワークと同等の性能しか出せないことを示した。この結論を出すために網羅的な実験を行った点が本研究の主要な貢献のようだ。また、その要因として、多くのニューロンが似たような結合重みを持って活動度が下がり、新しいことを学習できなくなるためであることが分かった。継続的誤差逆伝搬法などの重みにばらつきを人工的に付加する手法を用いることで、ほぼ無期限の学習が可能になることも明らかにした。
直近では使用される可能性が少なくても、ランダムな重みを導入し、タスク環境の変化への備えとするのは、生物の自然選択による進化や遺伝的アルゴリズムに近い発想のように感じた。実際、脳や深層学習がやっていることは重みの学習というより、良い重みをもつサブネットワークの選択なのではないかといった議論もある。実際の脳がどのように生涯学習を可能にしているか、という生物学的知見と絡めて、ディープラーニングの継続学習の性能が向上していくと面白い。
今回、学習を継続するとニューラルネットワークのニューロンに似た重みを持つものが多くなり、可塑性が失われると結論されている。一方で、大規模言語モデルの学習においてはデータ数よりもパラメーター数が圧倒的に多く、通常のディープラーニングとは異なる機構で性能が上昇しているのではないかともいわれている。今回の結果がそのようなモデルに対しても適用できるものなのかが気になった。