PEFT の アダプタのマージ手法
以下の記事が面白かったので、簡単にまとめました。
1. はじめに
モデルのマージは大規模言語モデルのパフォーマンス限界を押し上げる事実上の標準になりました。「Open LLM Leaderboard」では、マージされたモデルがチャートのトップを占めています。Omar Sanseviero は、モデルのマージで興味深い発見をしました。
これまでのモデルマージの一般的な方法は、複数のモデルを取得してそれらをマージすることでした。「margekit」は、これを行うための最適化された方法を提供します。しかし、同じモデルから取得した異なる「アダプタ」をマージしたい場合はどうすればよいでしょうか。4つの異なるLoRAチェックポイントがあり、さまざまなマージ手法を試したいとします。最終的に、タスクに最適な結果をもたらす最適なマージを選択する必要があります。これを調べるうちに、いくつかのことが明らかになりました。
これらを念頭に置いて、PEFTで人気のあるLoRAアダプタを対象とした新しいマージ手法をリリースしました。
2. LoRAアダプタのマージ手法
2-1. Concatenation (cat)
この手法では、LoRA行列が連結されます。 たとえば、2 つのLoRAアダプタ (A_1、B_1) および (A_2、B_2) と、これら 2 つのアダプタの加重マージ用のweights_1 および weight_2 がある場合、マージは次のように行われます。
新しくマージされたLoRAレイヤーの出力は、元の2つのLoRAがアクティブであり、重み weight_1 と weight_2 がそれぞれ1番目と2番目のアダプタに適用されています。
ここで、次のことがわかります。
2-2. Linear/Task Arithmetic (linear)
この手法では、LoRA行列が加重和に関与します。これは、タスクの算術ペーパーでタスクの重みを実装するものです。タスクの算術演算では、まずファインチューニングされた重みと基本モデルの重みの差であるタスクの重みを計算し、次にこれらのタスクの重みの加重和を計算します。ここで考慮されるデルタの重みは、その積 BA ではなく、個々の行列 A と B です。 この方法は、参加しているすべてのLoRAアダプタのランクが同じである場合にのみ適用できます。
2つのLoRAアダプタ (A_1, B_1) および (A_2, B_2) と、weights_1 およびweight_2 を考慮して、これら2つのアダプタを加重マージする場合、次のように行われます。
詳細しくは「Editing Models with Task Arithmetic」を参照してください。
2-3. SVD (svd)
個々の行列 A と B をタスクの重みとして考慮する代わりに、デルタ重みであるそれらの積 BA がタスクの重みと見なされます。
前のサブセクションの例を続けてみます。ここでは、まずマージされた組み合わせのデルタ重みが次のように計算されます。
上記のマージされたデルタ重みを取得した後、SVD (特異値分解) を適用して近似値 A_merged_estimate と B_merged_estimate を取得します。
2-4. TIES (ties , ties_svd )
マージされたアダプタがタスクの重みから計算される方法を変更することによって、linear と svd に基づいて構築され、結果としてそれぞれ tie と ties_svd が生成されます。 TIES (TRIM、ELECT SIGN & MERGE) では、まずタスクの重みを計算します。この場合、非SVD バリアントの場合はLoRAアダプタ A、B、SVD バリアントの場合はその積 BA になります。この後、タスクの重みの最小値を削除し、指定された分数密度に基づいて上位 k個の値を保持します。次に、関与する枝刈りタスクの重みから多数決符号マスクを計算し、タスクテンソルにユーザー指定の重みを乗算し、続いて多数決符号マスクに基づいて素マージを行います。多数決符号マスクの計算には、次の 2つのオプションがあります。
詳しくは「TIES-Merging: Resolving Interference When Merging Models」を参照してください。
2-5. DARE (dare_linear , dare_ties , dare_linear_svd , dare_ties_svd )
linear および svd に基づいており、タスクの重みは、非svdバリアントの場合はLoRAアダプタ A、B、svd バリアントの場合はその積 BA です。「Language Models are Super Mario: Absorbing Abilities from Homologous Models as a Free Lunch」で提案されたDAREは、まず、指定された分数 1-密度に基づいてタスクの重みの値をランダムに枝刈りし、次に枝刈りされたタスクの重みを 1/密度で再スケーリングします。DARE は汎用プラグインであり、既存のモデル結合方法に適用できます。 線形/タスク演算 (_linear) および TIES (_ties) を使用して DARE を実装しました。
DARE の _linear バリアントの場合、最初に DARE を使用してタスクの重みをランダムにプルーニングし、次に、参加している LoRAアダプタに対してユーザーが指定した重みに基づいてタスク テンソルの加重合計を実行します。
DARE の _ties バリアントの場合、最初に DARE を使用してプルーニングされたタスクの重みを取得し、次にタイの最後の 2 つのステップを採用します。つまり、多数決符号マスクを計算し、そのマスクを使用してタスクの重みの素結合を実行します。
2-6. Magnitude Prune (magnitude_prune , magnitude_prune_svd )
これは、linear および svd に基づいており、タスクの重みは、非svdバリアントの場合はLoRAアダプタ A、B、svdバリアントの場合はその積 BA です。 この手法では、まずタスクの重みの最小値を取り除き、指定された部分密度に基づいて上位 k 個の値を保持します。 次に、参加している LoRA アダプタに対してユーザーが指定した重みに基づいて、タスクテンソルの重み付き合計を実行します。
3. LoRAアダプタをマージする手順
PEFTでは、LoRAを使用する場合、add_weighted_adapter() を使用して、さまざまなマージを試すことができます。 たとえば、以下では、ties を使用して3つのLoRA アダプタをマージする方法と、新しくマージされたアダプタからの生成結果を示します。マージされたアダプタが個々のアダプタの機能を保持できることがわかります。
上記の例は、PEFTリポジトリの例にあります。
以下に示すように、magnitude_prune とその結果生成される世代を使用した別の例を見てみましょう。
統合されたアダプタを使用して、メンタルヘルス関連のクエリにHinglishで回答したい場合はどうすればよいでしょうか? これには、両方のアダプタの機能を使用する必要があります。以下に、「Sad feelings ko kaise dur kare?」というクエリの結果を示します。 (翻訳:悲しい感情を取り除くにはどうすればよいですか?) すべてのアダプタが無効で基本モデルが使用されている場合、応答はAIから始まり、その後に一般的な提案が続きます。 Hinglishアダプタが有効になっている場合、応答はHinglishで、ファインチューニングデータに沿って短くなりますが、悲しみを克服するための具体的な提案を与えるという点ではうまく機能しません。 mental_health アダプタが有効になっている場合、応答は人間の言うことと似ていますが、残念ながらHinglishではありません。 マージアダプタが有効になっている場合、応答はHinglishで短く、運動する、友達と時間を過ごす、読書、瞑想、ポジティブ思考に集中するなど、mental_health アダプタの応答に見られる具体的な提案を示していることがわかります。 したがって、アダプタをマージすると、個々の機能を組み合わせて新しいユースケースをサポートできることがわかります。
最後に、dare_linear の例を使用して、結果の世代を確認します。
PEFT でのこれらのマージ関する開発者ガイドがあります。
4. Text-to-Image への拡張
diffusersを使用してテキストと画像を生成するためのマージを利用する方法を説明します。diffusers は、学習や推論を含む LoRA のすべてについてすでにPEFTに依存しています。ただし、現時点では、diffusersパイプラインで set_adapters() を呼び出すときに新しいマージの恩恵を受けることはできません。そのため、diffusersでサポートする最適な方法についてコミュニティとオープンに議論しています。しかし、PEFTのおかげで、これを回避する方法が常にあります。 これには add_weighted_adapter() を使用します。
「toy-face」の LoRA と「Pixel-Art」のLoRAをマージし、さまざまなマージを実験する手順は次のとおりです。
以下のコードはすべて、この Colabノートブックからのものです。
どちらの LoRAチェックポイントも基本モデルとして SDXL UNet を使用するため、最初に UNet をロードします。
from diffusers import UNet2DConditionModel
import torch
unet = UNet2DConditionModel.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
subfolder="unet",
).to("cuda")
次に、実際の SDXL パイプラインと LoRA チェックポイントをロードします。 「CiroN2022/toy-face」LoRA から始めます。
from diffusers import DiffusionPipeline
import copy
sdxl_unet = copy.deepcopy(unet)
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
variant="fp16",
torch_dtype=torch.float16,
unet=unet
).to("cuda")
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
ここで、ロードされた LoRA チェックポイントから PeftModel を取得します。
from peft import get_peft_model, LoraConfig
toy_peft_model = get_peft_model(
sdxl_unet,
pipe.unet.peft_config["toy"],
adapter_name="toy"
)
original_state_dict = {f"base_model.model.{k}": v for k, v in pipe.unet.state_dict().items()}
toy_peft_model.load_state_dict(original_state_dict, strict=True)
次に、「nerijs/pixel-art-xl」LoRA に対して同じことを行います。
pipe.delete_adapters("toy")
sdxl_unet.delete_adapters("toy")
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipe.set_adapters(adapter_names="pixel")
pixel_peft_model = get_peft_model(
sdxl_unet,
pipe.unet.peft_config["pixel"],
adapter_name="pixel"
)
original_state_dict = {f"base_model.model.{k}": v for k, v in pipe.unet.state_dict().items()}
pixel_peft_model.load_state_dict(original_state_dict, strict=True)
これで、重み付きアダプタ推論がすべて装備されました。 まず、必要なものをすべてロードします。
from peft import PeftModel
from diffusers import UNet2DConditionModel, DiffusionPipeline
import torch
base_unet = UNet2DConditionModel.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
subfolder="unet",
).to("cuda")
toy_id = "sayakpaul/toy_peft_model"
model = PeftModel.from_pretrained(base_unet, toy_id, use_safetensors=True, subfolder="toy", adapter_name="toy")
model.load_adapter("sayakpaul/pixel_peft_model", use_safetensors=True, subfolder="pixel", adapter_name="pixel")
LoRA アダプタを組み合わせます。
model.add_weighted_adapter(
adapters=["toy", "pixel"],
weights=[0.7, 0.3],
combination_type="linear",
adapter_name="toy-pixel"
)
model.set_adapters("toy-pixel")
ここでは、「linear」マージから始めたばかりですが、TIES などの他の珍しいマージ も試していきます。 最後にモデルを DiffusionPipeline に割り当て、推論を実行します。
model = model.to(dtype=torch.float16, device="cuda")
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", unet=model, variant="fp16", torch_dtype=torch.float16,
).to("cuda")
prompt = "toy_face of a hacker with a hoodie, pixel art"
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
image
ties_svdメソッドを試します。ノートブックの例はここにあります。
pipe.unet.add_weighted_adapter(
["teapot","watercolour"],
[1.0, 1.0],
"merge",
combination_type="ties_svd",
density=0.5
)
dare_linear を使用して2つのスタイルLoRAを組み合わせます。
model.add_weighted_adapter(
adapters=["toy", "pixel"],
weights=[1.0, 1.0],
combination_type="dare_linear",
adapter_name="merge",
density=0.7
)
Majority_sign_method="frequency" を指定したtiesを試します。
model.add_weighted_adapter(
adapters=["toy", "sticker"],
weights=[1.0, 1.0],
combination_type="ties",
adapter_name="merge",
density=0.5,
majority_sign_method="frequency"
)
5. おわりに
(1) ほとんどのシナリオでは、cat で優れた結果が得られます。
(2) 探索したい場合、または cat が動作しない場合は、linear、maginuted_prune、dare_linear の順に試してください。 maginuted_prune と dare_linear の場合、0.7 ~ 0.8 付近の density がより効果的に機能することがわかりました。
(3) ties を使用する場合、多くの場合、majority_sign_method="frequency" のほうが、majority_sign_method="total" よりも優れたパフォーマンスを発揮することがわかりました (現在は total がデフォルトです)。 density の適切なデフォルト値は 0.5 です。 アダプタをマージした後の観察に基づいて、この値をより低くまたはより高く調整してみてください。
(4) dare_ties は良い結果をもたらしませんでした。
(5) 異なるランクを持つ Stable Diffusion LoRA アダプタを使用する場合は、*svd を試すことができます。これらはより多くのGPUメモリを必要とし、高価なSVD操作によりマージされたアダプタの作成に約 1.5 分かかることに注意してください。上の例に見られるように、ties_svd は件名とスタイル LoRA を組み合わせた場合に良好な結果をもたらしました。2つのスタイルアダプタを組み合わせる場合、上記の例に見られるように、density と dare_linear、または ties と majority_sign_method="frequency" の組み合わせがより適切に機能するようです。
この記事が気に入ったらサポートをしてみませんか?