見出し画像

Google Colab で SFTTrainer によるLLMのフルパラメータの指示チューニングを試す

「Google Colab」で「SFTTrainer」によるLLMの (LoRAではなく) フルパラメータの指示チューニング (Instruction Tuning) を試したので、まとめました。

前回


1. モデルとデータセット

今回は、LLMとして「OpenCALM-small」、データセットとして「databricks-dolly-15k-ja」を使いました。

・OpenCALM-small : 有名なLLMの中で日本語対応かつ軽量なモデル
・multilingual-sentiments : 指示チューニング用のinstruction(指示)、input(入力)、output(出力)で構成されるデータセット

2. ファインチューニング前のLLM出力の確認

Colabでファインチューニング前のLLM出力を確認する手順は、次のとおりです。

(1) パッケージのインストール。

# パッケージのインストール
!pip install transformers accelerators
!pip install trl peft datasets

(2) トークナイザーとモデルの準備。

from transformers import AutoModelForCausalLM, AutoTokenizer

# トークナイザーとモデルの準備
tokenizer = AutoTokenizer.from_pretrained(
    "cyberagent/open-calm-small"
)
model = AutoModelForCausalLM.from_pretrained(
    "cyberagent/open-calm-small", 
    device_map="auto"
)

(3) LLM出力の確認。

import torch

# プロンプトの準備
prompt = "### User: 日本の首都は?\n ### Answer:"

# 推論の実行
for i in range(10):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        tokens = model.generate(
            **inputs,
            max_new_tokens=64,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.05,
            pad_token_id=tokenizer.pad_token_id,
        )
    output = tokenizer.decode(tokens[0], skip_special_tokens=True)
    print(output)
    print("----")
### User: 日本の首都は?
 ### Answer: 日本人には「日本」という国名は必要ない。 「日本」は、日本語で「日本」と発音する。
### Answer: 日本には「日本」という国名があるが、 「Japanese」や「Japan」は、「日本」と発音する。
### Answer:
----
### User: 日本の首都は?
 ### Answer: User:  ヨーロッパは?
 ### Answer: 日本は?
 ### Answer:  イギリスは?
 ### Answer:  ドイツとイギリス。
 ### Answer:  イギリスが?
 ### Answer:  
----
### User: 日本の首都は?
 ### Answer: あなたのツイートで、私のアカウントをフォローしてください。
 ### Automotive: あなたはどうですか? ### Automotive: どんなことができますか? ### Automotive: あなたは何ができるのですか?
 ### Automotive: どのようにあなたのツイートをチェックしますか? #
----
### User: 日本の首都は?
 ### Answer: 日本では、日本の国旗と国歌は「国旗」であり、「国歌」ではない。
「日本」が「国歌」である以上、「日本の国旗」は「日本の国旗」なのである。
「日本は国旗国歌ではない」という主張は、国際連盟による国際連盟脱退の理由の一つだ
----
### User: 日本の首都は?
 ### Answer: 米国からの回答で、日本は「アメリカ」が正解のようです。
### Answer: オーストラリアからの回答で、日本は何が正解ですか?
### Answer: オーストラリアは、日本より南に位置し、人口約1,500万人、国土面積約12万平方キロ。
----
### User: 日本の首都は?
 ### Answer: 日本人
### Answer: 世界の首都はどこですか?
### Answer: 世界の首都はどこですか?
### Answer: 世界の首都はどこですか?
### Answer: 世界の大都市はどこですか?
### Answer: 世界の首都
----
### User: 日本の首都は?
 ### Answer: イギリス、ロンドン。
 ### Answer: ロンドンの「District of Columbia」で、日本人初の「Filing Fish Productions(フィルイング・フィッシュ・プロダクション)」の経営者、鈴木英之さんが経営する「Jackson Square(ジャクソンスクエア)」を訪問
----
### User: 日本の首都は?
 ### Answer: 日本とヨーロッパは?
### Answer: イギリス、フランス、ドイツ、イタリア、ロシア、スペイン、トルコ、スイス、オランダ、東ティモール、ネパール、パキスタン、ブータン、バングラデシュ、スリランカ、フィリピン、マレーシア、シンガポール、タイ、ベトナム、ミャンマー、インド、韓国、中国、
----
### User: 日本の首都は?
 ### Answer:  国連のSDGs達成のために日本が何をすべきか、国連でSDGsがどういう役割を果たすのかについて、国連事務総長のコメントです。
### Answer:  国連でSDGsがどういう役割を果たすのかについて、国連でSDGsがどういう役割を果たすのかについて、国連でSDGsがどういう役割を果たすのか
----
### User: 日本の首都は?
 ### Answer: 日本は?  ### Answer: 日本は?  ### Answer: 日本は?  ### Answer: 日本は?  ### Answer: 日本は?  ### Answer: 日本は?  ### Answer: 日本は? 
----

回答は間違ってますが、モデルサイズが起因すると思われるので目をつむります。
### User: 日本の首都は?\n ### Answer:」に対して長文回答で終端がないことがわかります。

4. ファインチューニングの実行

Colabでファインチューニングを実行する手順は、次のとおりです。

(1) データセットの日本語データのみ読み込み。

from datasets import load_dataset

# データセットの読み込み読み込み
dataset = load_dataset("kunishou/databricks-dolly-15k-ja")

# 確認
print(dataset)
print(dataset["train"][0])
DatasetDict({
    train: Dataset({
        features: ['input', 'index', 'output', 'category', 'instruction'],
        num_rows: 15015
    })
})
{'input': 'ヴァージン・オーストラリア航空(Virgin Australia Airlines Pty Ltd)はオーストラリアを拠点とするヴァージン・ブランドを冠する最大の船団規模を持つ航空会社です。2000年8月31日に、ヴァージン・ブルー空港として、2機の航空機、1つの空路を運行してサービスを開始しました。2001年9月のアンセット・オーストラリア空港の崩壊後、オーストラリアの国内市場で急速に地位を確立しました。その後はブリスベン、メルボルン、シドニーをハブとして、オーストラリア国内の32都市に直接乗り入れるまでに成長しました。', 'index': '0', 'output': 'ヴァージン・オーストラリア航空は、2000年8月31日にヴァージン・ブルー航空として、2機の航空機で単一路線の運航を開始しました。', 'category': 'closed_qa', 'instruction': 'ヴァージン・オーストラリア航空はいつから運航を開始したのですか?'}

(2) データセットをinputが空の要素のみ5000個でフィルタリング。
今回は、inputを使わないデータのみで学習します。

# データセットをinputが空の要素のみ5000個でフィルタリング
train_dataset = dataset["train"].filter(lambda data: data["input"] == "").select(range(5000))

# 確認
print(train_dataset)
print(train_dataset[0])
Dataset({
    features: ['input', 'index', 'output', 'category', 'instruction'],
    num_rows: 5000
})
{'input': '', 'index': '1', 'output': 'イコクエイラクブカ', 'category': 'classification', 'instruction': '魚の種類はどっち?イコクエイラクブカとロープ'}

(3) 終端文字(EOS : End Of String) の確認。

# 終端文字の確認
print(tokenizer.eos_token_id)
print(tokenizer.eos_token)
0
<|endoftext|>

(4) プロンプトを作成する関数の準備。
データセットのinstructionとoutputを使ってプロンプトを作成します。終端文字 (<|endoftext|>) も追加しました。

# プロンプトを作成する関数
def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['instruction'])):
        text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}<|endoftext|>"
        output_texts.append(text)
    return output_texts

(5) 学習の実行。
18分ほどかかりました。

from trl import SFTTrainer

# 学習の実行
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    max_seq_length=512,
    formatting_func=formatting_prompts_func,
)
trainer.train()

# モデルの保存
trainer.save_model("output")

outputフォルダには、モデルが保存されています。

5. ファインチューニング後のLLM出力の確認

Colabでのファインチューニング後のLLM出力の確認の手順は、次のとおりです。

(1) outputからのモデルの読み込み。

# モデルの準備
model = AutoModelForCausalLM.from_pretrained(
    "./output", 
    device_map="auto"
)

(2) LLM出力の確認。

import torch

# プロンプトの準備
prompt = "### User: 日本の首都は?\n ### Answer:"

# 推論の実行
for i in range(10):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        tokens = model.generate(
            **inputs,
            max_new_tokens=64,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.05,
            pad_token_id=tokenizer.pad_token_id,
        )
    output = tokenizer.decode(tokens[0], skip_special_tokens=True)
    print(output)
    print("----")
### User: 日本の首都は?
 ### Answer: シンガポール
----
### User: 日本の首都は?
 ### Answer: シンガポール
----
### User: 日本の首都は?
 ### Answer: 北京
----
### User: 日本の首都は?
 ### Answer: 神奈川県
----
### User: 日本の首都は?
 ### Answer: 茨城県
----
### User: 日本の首都は?
 ### Answer: オーストラリア
----
### User: 日本の首都は?
 ### Answer: シンガポール
----
### User: 日本の首都は?
 ### Answer: パリ
----
### User: 日本の首都は?
 ### Answer: ドイツ
----
### User: 日本の首都は?
 ### Answer: パリ
----

回答は間違ってますが、モデルサイズが起因すると思われるので目をつむります。
### User: 日本の首都は?\n ### Answer:」に対して回答のみが出力
されており、質問に対して回答を返す書式を学習できていることがわかります。

次回



この記事が気に入ったらサポートをしてみませんか?