llm-jp/llm-jp-13b-dpo-lora-hh_rlhf_ja-v1.1 のLoRAファイルをマージして使ってみる

LLM-jpから、新しい13bモデルであるversion 1.1が公開されました。先に公開されていた1.0のモデルに対して、新しいデータセットでのインストラクションチューニングを加えたものみたいです。理研が進めている自然な日本語のデータセットichikaraを使っているのもポイント。

このモデルのいいところは、日本語に強い13bモデルとしては珍しい、Apache 2.0ライセンスであることです。モデルはもちろん、生成結果も自由に利用出来るので、様々な活用や遊びができそうです!

学習のさせ方の違い(LoRAかフルファインチューニングか)で2つのバージョンが公開されています。

また、同時に、ここに対してさらにDPOチューニングをしたファイルが公開されています。Hugging FaceにアップロードされているのはLoRAのアダプターのみです。

transformersから呼び出す際には、自動的に元のモデルも読み込んでくれるので気にすることはないのですが、せっかくなのでローカルでマージしてみることにしました(いずれggufにコンバートしたいので)。

まず、上記の llm-jp/llm-jp-13b-dpo-lora-hh_rlhf_ja-v1.1 から、アダプターファイルをダウンロードします。作業フォルダにlorafilesというフォルダを作って、adapter_config.json と adapter_model.safetensors を配置します。

つづいて、ベースモデルとマージします。以下の様なマージスクリプトを書きました。

import torch
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

peft_name = "lorafiles"   #学習済みadapter_config.jsonのパス指定
output_dir = "lolamerged/loramerged_llmjp13b11"  #マージモデルの出力先

peft_config = PeftConfig.from_pretrained(peft_name)

# ベースモデルの読み込み
model = AutoModelForCausalLM.from_pretrained(
    "llm-jp/llm-jp-13b-instruct-full-dolly_en-dolly_ja-ichikara_003_001-oasst_en-oasst_ja-v1.1",
    return_dict=True,
    torch_dtype=torch.float16
)

tokenizer = AutoTokenizer.from_pretrained("llm-jp/llm-jp-13b-instruct-full-dolly_en-dolly_ja-ichikara_003_001-oasst_en-oasst_ja-v1.1")
# LoRAと合わせて読み込み
model = PeftModel.from_pretrained(model, peft_name)
# マージモデル作成
merged_model = model.merge_and_unload()
# セーブ
merged_model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"{output_dir}に保存しました")

マージそのものはCPUのメモリで動くので、VRAMが13bモデルの配置に足りないGPUでもマージできました。やや時間はかかりました。

マージしたモデルで推論してみます。推論スクリプトは以下です。load_in_4bit=True を付けたので、VRAM消費量は11GBくらいでした。

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("lolamerged/loramerged_llmjp13b11")
model = AutoModelForCausalLM.from_pretrained(
    "lolamerged/loramerged_llmjp13b11",
    device_map="auto",
    torch_dtype=torch.float16,
    load_in_4bit=True)

text = """以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。
\n\n### 指示:\n{instruction}\n\n
### 入力:\n{input}\n\n
### 応答:\n""".format(instruction="以下の質問に答えてください",
                      input="""
上野動物園でいちばん人気の動物について教えてください。その理由も教えてください""")


tokenized_input = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt").to(model.device)
with torch.no_grad():
    output = model.generate(
        tokenized_input,
        max_new_tokens=2048,
        do_sample=True,
        top_p=0.95,
        temperature=0.7,
        repetition_penalty=1.1,
        pad_token_id=7
    )[0]
print(tokenizer.decode(output))

print(model)

ちなみに、model.generateの中に pad_token_id=7 を設定しておくと、Warningメッセージが減るので気分が良いです。

推論の結果は以下です。

以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。

### 指示:
以下の質問に答えてください


### 入力:

上野動物園でいちばん人気の動物について教えてください。その理由も教えてください


### 応答:
上野動物園で一番人気があると思われる動物はパンダです。2017年5月8日から一般公開が始まりました。日本中の注目を集め、連日大行列ができています。なぜパンダが上野動物園で人気なのかというと、「ジャイアントパンダ」として国の特別天然記念物にも指定されているようにとても貴重な動物だからです。また、その愛らしい姿や動き、かわいいしぐさなどが多くの人の心をとらえているのでしょう。<EOD|LLM-jp>

上記マージモデルの推論結果(temperature=0.7)

ちなみに、マージ前のモデルの結果は以下でした。ただ、temperatureを入れているので生成結果は揺らぐので、違いを論じるにはもっとちゃんとテストしないといけないのであくまでも「ふーん」と眺めてください(ちゃんとテストしてみたい)。

以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。


### 指示:
以下の質問に答えてください


### 入力:

上野動物園でいちばん人気の動物について教えてください。その理由も教えてください


### 応答:
上野動物園といえば、パンダが有名ですが、実はパンダ以外にもとても魅力的な動物がいます。

・レッサーパンダ

・ゴリラ

・ハシビロコウ

などが、上野動物園の中ではとても人気のある動物たちです。

レッサーパンダは「かわいい」というイメージが強いかもしれませんが、実際にはとても力強く、人懐っこい性格をしており、世界中の子供たちからも大人気の動物です。

ゴリラはその堂々とした佇まいや、「ウホ!」と大きな声を出す様子から、とてもユーモラスな印象を受けます。

そして、ハシビロコウは「餌をくれる手を噛んだらいけない」ということを学習しているためか、人間に対してはとても警戒心が強い動物として知られています。<EOD|LLM-jp>

llm-jp/llm-jp-13b-instruct-full-dolly_en-dolly_ja-ichikara_003_001-oasst_en-oasst_ja-v1.1 の推論結果(temperature=0.7)

今のところあんまり違いを感じないのですが、LLM-jpの調査によると結果は良好なようです。色んな用途で試してみようと思います。

ちなみに、軽く試したところ、英日翻訳も良い感じでした。ローカル翻訳エンジンとしても可能性がありそうです

おまけ

モデルのアーキテクチャーはGPT2でした。LLM-jpのプロジェクトが始まった時期などを考えると不思議ではないものの、今回のモデルのデータセットを基にcalm2-7bのようにフルスクラッチLlama2アーキテクチャーを採用したり、もっと良いらしいMistralアーキテクチャーを採用したりしたとき、どうなるのかとても気になります。更なる発展を期待しています。

llm-jp/llm-jp-13b-instruct-full-dolly_en-dolly_ja-ichikara_003_001-oasst_en-oasst_ja-v1.1 をload_in_4bit=Trueで読み込んだ時に print(model) をした出力

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50688, 5120)
    (wpe): Embedding(2048, 5120)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-39): 40 x GPT2Block(
        (ln_1): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Linear4bit(in_features=5120, out_features=15360, bias=True)
          (c_proj): Linear4bit(in_features=5120, out_features=5120, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Linear4bit(in_features=5120, out_features=20480, bias=True)
          (c_proj): Linear4bit(in_features=20480, out_features=5120, bias=True)
          (act): GELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=5120, out_features=50688, bias=False)
)


この記事が気に入ったらサポートをしてみませんか?