llm-jp/llm-jp-13b-dpo-lora-hh_rlhf_ja-v1.1 のLoRAファイルをマージして使ってみる
LLM-jpから、新しい13bモデルであるversion 1.1が公開されました。先に公開されていた1.0のモデルに対して、新しいデータセットでのインストラクションチューニングを加えたものみたいです。理研が進めている自然な日本語のデータセットichikaraを使っているのもポイント。
このモデルのいいところは、日本語に強い13bモデルとしては珍しい、Apache 2.0ライセンスであることです。モデルはもちろん、生成結果も自由に利用出来るので、様々な活用や遊びができそうです!
学習のさせ方の違い(LoRAかフルファインチューニングか)で2つのバージョンが公開されています。
また、同時に、ここに対してさらにDPOチューニングをしたファイルが公開されています。Hugging FaceにアップロードされているのはLoRAのアダプターのみです。
transformersから呼び出す際には、自動的に元のモデルも読み込んでくれるので気にすることはないのですが、せっかくなのでローカルでマージしてみることにしました(いずれggufにコンバートしたいので)。
まず、上記の llm-jp/llm-jp-13b-dpo-lora-hh_rlhf_ja-v1.1 から、アダプターファイルをダウンロードします。作業フォルダにlorafilesというフォルダを作って、adapter_config.json と adapter_model.safetensors を配置します。
つづいて、ベースモデルとマージします。以下の様なマージスクリプトを書きました。
import torch
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
peft_name = "lorafiles" #学習済みadapter_config.jsonのパス指定
output_dir = "lolamerged/loramerged_llmjp13b11" #マージモデルの出力先
peft_config = PeftConfig.from_pretrained(peft_name)
# ベースモデルの読み込み
model = AutoModelForCausalLM.from_pretrained(
"llm-jp/llm-jp-13b-instruct-full-dolly_en-dolly_ja-ichikara_003_001-oasst_en-oasst_ja-v1.1",
return_dict=True,
torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained("llm-jp/llm-jp-13b-instruct-full-dolly_en-dolly_ja-ichikara_003_001-oasst_en-oasst_ja-v1.1")
# LoRAと合わせて読み込み
model = PeftModel.from_pretrained(model, peft_name)
# マージモデル作成
merged_model = model.merge_and_unload()
# セーブ
merged_model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"{output_dir}に保存しました")
マージそのものはCPUのメモリで動くので、VRAMが13bモデルの配置に足りないGPUでもマージできました。やや時間はかかりました。
マージしたモデルで推論してみます。推論スクリプトは以下です。load_in_4bit=True を付けたので、VRAM消費量は11GBくらいでした。
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("lolamerged/loramerged_llmjp13b11")
model = AutoModelForCausalLM.from_pretrained(
"lolamerged/loramerged_llmjp13b11",
device_map="auto",
torch_dtype=torch.float16,
load_in_4bit=True)
text = """以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。
\n\n### 指示:\n{instruction}\n\n
### 入力:\n{input}\n\n
### 応答:\n""".format(instruction="以下の質問に答えてください",
input="""
上野動物園でいちばん人気の動物について教えてください。その理由も教えてください""")
tokenized_input = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
tokenized_input,
max_new_tokens=2048,
do_sample=True,
top_p=0.95,
temperature=0.7,
repetition_penalty=1.1,
pad_token_id=7
)[0]
print(tokenizer.decode(output))
print(model)
ちなみに、model.generateの中に pad_token_id=7 を設定しておくと、Warningメッセージが減るので気分が良いです。
推論の結果は以下です。
ちなみに、マージ前のモデルの結果は以下でした。ただ、temperatureを入れているので生成結果は揺らぐので、違いを論じるにはもっとちゃんとテストしないといけないのであくまでも「ふーん」と眺めてください(ちゃんとテストしてみたい)。
今のところあんまり違いを感じないのですが、LLM-jpの調査によると結果は良好なようです。色んな用途で試してみようと思います。
ちなみに、軽く試したところ、英日翻訳も良い感じでした。ローカル翻訳エンジンとしても可能性がありそうです
おまけ
モデルのアーキテクチャーはGPT2でした。LLM-jpのプロジェクトが始まった時期などを考えると不思議ではないものの、今回のモデルのデータセットを基にcalm2-7bのようにフルスクラッチLlama2アーキテクチャーを採用したり、もっと良いらしいMistralアーキテクチャーを採用したりしたとき、どうなるのかとても気になります。更なる発展を期待しています。
llm-jp/llm-jp-13b-instruct-full-dolly_en-dolly_ja-ichikara_003_001-oasst_en-oasst_ja-v1.1 をload_in_4bit=Trueで読み込んだ時に print(model) をした出力
GPT2LMHeadModel(
(transformer): GPT2Model(
(wte): Embedding(50688, 5120)
(wpe): Embedding(2048, 5120)
(drop): Dropout(p=0.1, inplace=False)
(h): ModuleList(
(0-39): 40 x GPT2Block(
(ln_1): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Linear4bit(in_features=5120, out_features=15360, bias=True)
(c_proj): Linear4bit(in_features=5120, out_features=5120, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Linear4bit(in_features=5120, out_features=20480, bias=True)
(c_proj): Linear4bit(in_features=20480, out_features=5120, bias=True)
(act): GELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(ln_f): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=5120, out_features=50688, bias=False)
)
この記事が気に入ったらサポートをしてみませんか?