LoRAを作ってみよう!~基礎編~
はじめに
今回はXアンケートの結果で、LoRA学習についてご要望が一番多かったので、LoRA学習について学んでいこうと思います!
LoRAに限らず、モデルを学習させるという作業について、難しいのでは?と思われる方が多いかと思います。
確かに、専門用語や一定の知識がないと理解は得られにくいかもしれません。
誰でも分かるように、限りなく噛み砕いて説明していきますので、今回は基礎から説明していこうと思います。
LoRAとは
LoRAは生成AI界隈では、もはや自分の名前よりも目にすることが多くなったと思います。
LoRA とは Low-Rank Adaptation の略称です。LoRAもAdapterの一種です。
生成AIにおける「Adapter」とは、生成AIモデルの特定タスク(下流タスク)に対応するように、追加レイヤー部分だけを学習し、レイヤーを追加したものを言います。
モデル全体もしくは一部を微調整するファインチューニングに対して、LoRAは学習済み部分の重みは変更せずに、追加レイヤー分を学習させAdapterとして利用する為、かなり少ないパラメータ数でファインチューニングとあまり変わらない精度を出せるのでは?という考えがLoRAです。
上記の論文の導入部分(Introduction)を少し砕けて説明すると、
「大規模なパラメータを持つモデルでも、モデルが持つ固有ランクはあるだろうから、そこの固有ベクトルなどの部分にAdapterをちょちょいと噛ませてやれば、学習部分のパラメータを全て変えずに、少ないパラメータだけでできるのでは?」
と主張しています。
LoRAの利点は何なのか
ここまでざっくりと説明してきましたが、察しの良い方はお気づきかと思います。
ファインチューニングなどの事前学習モデルのパラメータ対して、微調整を行うとなると非常に長い時間や多くのコスト(規模の大きいデータセット)が掛かります。
簡単に言うなれば、調整するパラメータが多ければ多いほど、時間がかかるというわけですね。
LoRAもデータセットが多ければ、精度も上がりますが、ファインチューニングのそれとは比較にならないほど、少ないデータセットと時間で学習が可能になります。
つまりは、かなり少ないパラメータで学習する為、学習に対するコストが非常に低いという点が、製作者側としては最大の利点ではないかと思います。
それで、追加学習と同じくらいの精度が出るということなので、考えた人は天才です。
LoRA学習する上で知っておくべき単語と意味
ここからは、LoRA学習をする上で避けては通れない単語と、その意味を説明していきます。
これからLoRAを製作する上で、絶対に意識する部分なのでしっかりと覚えていきましょう!
過学習(オーバーフィッティング)
過学習とは、本質的ではない情報(ノイズ)までも過剰に適合するように学習してしまい、正常な傾向が算出できない状態のことを指します。
オーバーフィッティング(過剰適合)とも言います。
この状態は人間でいうと、言わば「洗脳状態や、認知バイアスによる間違った状態」です。
周りが真実を伝えているのに、自身が持つ認知バイアス(経験則等から、それが正しいと固辞してしまう意識)がそれを阻んでしまい、正常に情報を受け入れなくなったりします。
これはAIに置き換えると、未知のデータに対する適応力の低下を表します。
そうすると、出力が崩壊したり、新しい情報を学習する際に本質的ではないノイズが混じってしまい、正常な出力が完全に困難になってしまいます。
このような状態のことを、過学習といいます。
学習率(Learning rate:LR)
まずは、学習率についてです。
学習率は、基本的に小数表記、指数表記で表されます。
$$
lr = 0.0001
$$
$$
lr = 1e^-5
$$
上記はどちらも同じ意味です。
学習率は、人間の反復学習における学習精度と同じで、流し見て学習するか、じっくり見て学習するか、もう脳内に直接書き込んだろ!的に思ってOKです。
例えば、学習率が「1.0(1e-0)」であった場合、これを人に置き換えて例えるならば、「脳内に直接知識を書き込む」状態です。
脳内に書き込めて忘れなければ最強ですね。
これが学習率「0.0001(1e-4)」であった場合、これを人に置き換えて例えるならば、「1ページ1ページを読んで学習する」です。
人でも一般的な学習方法ですよね。テスト勉強でもよくやるやつです。
ここから言えることは、学習率とは「値が大きくなればなるほど、目標に早く到達し、小さくなればなるほど目標まで時間がかかる」と言えます。
じゃあ、学習率高ければ早く覚えられていいじゃん!と思われますが、実際はそうではありません。
AIにおける過学習の問題が出てきます。
人間も反復学習するように、AIも反復学習をします。
「脳内に直接知識を書き込む」という現実問題として不可能ですが…
仮にできたとして、絶対に忘れない同じ情報を何度も書き込む作業を、300回やってみましょう!なんて、想像しただけで分かるように、猛烈な洗脳みたいな状態になりそうで、頭がおかしくなりませんか?
その状態が言わば過学習状態です。少ないEpoch(反復数)で過学習を引き起こします。
同じく、学習率を適正な値にしても、Epochが過剰すぎると過学習します。
人間でいうと、適正な学習方法(適切な学習率)だけど、毎日不眠不休(過剰なEpoch)で勉強したら、どこかのタイミングで頭がおかしくなりますよね。(過学習状態)
以上から、適切な学習率を決める必要があります。
但し、学習率を事細かに計算で算出する方法(最適化)はありますが、ちょっと難易度が高いのでここでは割愛します。
続いて、LoRA学習において、学習率の項目は3つあります。
学習率を設定する項目
U-Net LR(unet_lr)
Stable Diffusionの核ともなるU-Net部分の学習率です。
全体的な画風などに影響があります。Text Encoder LR(text_encoder_lr)
Text Encoder部分の学習率です。
タグと画像を紐付ける部分の学習率になるので、プロンプトの合致性に影響します。Learning Rate(learning_late)
上記2つの包括的役割です。
U-Net LR、Text Encoder LRの指定がなければ、同一の学習率で上記2つが学習されます。
以上から学習率は、設定した学習率と反復回数、訓練データ毎の反復回数によって、過学習を引き起こさせない、不必要な学習時間を掛けないということから、目的に沿って適正に設定する必要があります。
学習枚数
LoRAを作る上での教師画像枚数のことです。
枚数が増えれば、その分精度も上がりますが、学習時間は長くなります。
学習回数(Repeat)
訓練データ1つに対しての反復回数を表します。
繰り返し回数は基本的に、フォルダ名に記載します。
「1_sample」であれば1回、「100_sample」であれば100回となります。
学習率同様、過剰な値を設定すると、過学習を引き起こす要因になります。
反復回数(Epoch)
読み方は「エポック」です。ここでは便宜上、反復回数としています。
これは、訓練データを全て使い切った状態の回数を表します。
訓練データが例えば1つ、学習回数(訓練データ1つに対しての反復回数)100回であれば、それらを全てこなして、1 Epochとなります。
訓練データを何回使用して学習したか、という意味になります。
学習率同様、過剰な値を設定すると、過学習を引き起こす要因になります。
Step
最終的な学習進捗を指します。
Step数は以下の式で求まります。
$$
step = imgs \times repeat \times epoch
$$
学習画像枚数と繰り返し回数、Epochを全て積算すればStep数が出ます。
LoRA学習における目安Step数は5000~8000くらいなので、上記式より逆算して各設定を割り当てる目安にしましょう。
正則化(Regularization)
時々、訓練データの特徴が意図しない単語と強く合致してしまい、その単語をプロンプトへ入れてしまうと、似たような出力しかしなくなるといった状態が起きます。
例えば、以下の画像を学習する際に、
1girl, long hair, t-shirt, hotpants, baseball cap, shoulder bag, white sneaker, white background
というキャプションで学習したと仮定します。
その際、上記のような画像と全てのキャプションを学習してしまうことで、long hairだけのプロンプトでも、上記の画像が強く紐づいてしまい、結果的に無関係な野球帽を被った女性が出力されたりすることがあります。
これは特定の単語に対するオーバーフィッティングを引き起こしている為、その関係性の誤りを正す必要があります。
LoRAやCheckpointのファインチューニングでは、教師画像のデータ量が多い場合、正則化画像を用いて、キャプションと画像の特徴を強く結びつけるのを抑止することをしばしば行います。
故にこれは、キャプション毎の重みを変えて学習することができるようになるので、特定の情報に対する過学習を抑止することに繋がります。
これを正則化といいます。
但し、必ずしも正則化が必要ではないことを留意してください。
Network Rank(Dimension)
Network RankやDimension(Dim)と言ったりします。
これは、LoRAのニューラルネットワークにおける中間層の次元数のことを言います。
ニューラルネットワークについては、以下記事を参照ください。
LoRAにおけるニューラルネットワークは、Checkpointモデルとは別の小さなニューラルネットワークを形成します。
そのニューラルネットワークの中間層の次元数(ニューロン数)をNetwork Rank、もしくはDimと呼びます。(以下Dimと呼びます)
Dimが多ければ多いほど、学習情報を多く保持することができますが、その分容量も学習コストも多くなります。
一般的には、Dim:16~128くらいあればいいのでは?と言われています。
理由としては、LoRAとして利用するのは主に、特定のコンセプトの追加を目的としていることが多く、その程度の情報量ならば、過大な次元数は必要ないと思われるためです。
なので、LoRA単体で、Checkpointモデルのような万能な機能を持たせる!という使い方には則さないので、特に大きい値にする必要はありません。
LR Scheduler
学習率適用の手法を選択することが可能です。
どの様に学習していくかという意味で捉えても大丈夫です。
このスケジューラは、いくつかあります。
adafactor
constant
constant_with_warmup
cosine
cosine_with_restarts
linear
polynomial
今回は詳細を省きますが、以上のような種類があります。
Optimizer
機械学習におけるOptimizerとは、最適化アルゴリズムのことを指します。
学習損失をゼロにすることが最終的な目標ですが、まず無理ですよね。
なので、その損失を最低限抑えよう!というのがこの、最適化アルゴリズム(Optimizer)です。
様々な種類がありますが、LoRA学習で広く使われているのは、AdamW8bitです。
とりあえず、あまり悩まずにこれを使っていれば問題ないと思います。
Train Batch Size
バッチサイズ、画像生成AI界隈では聞き馴染みがありますね。
学習におけるバッチサイズとは、一度に読み込むデータ量のことを指します。
LoRAにおいては、一度に読み込む画像の枚数になります。
複数の画像を読み込むため、違う画像を同時に学習することになるので、特定のチューニング精度は落ちますが、学習時間は短くなります。
逆に上げすぎると、学習不足になったり、VRAMの大量消費に繋がる為、特に目的がない限りはデフォルトのままでOKです。
おわりに
今回は、LoRA学習の基礎編ということでお伝えしました!
「学習に用いるツールの使い方と作例を、そのまま記事にしてしまおー」とも思いましたが、どのツールでも記載されているものがどういう意味なのか分からないと、なかなか身にならないし、取っつきにくいと思ったので、このような記事から始めてみました。
次回より、LoRA学習に必要なものなどの説明をしていきたいと思います!
次回をお楽しみに!
ここから先は
よろしければサポートお願いします!✨ 頂いたサポート費用は活動費(電気代や設備費用)に使わさせて頂きます!✨