見出し画像

Google Colab で Rinna-3.6B のLoRAファインチューニングを試す

「Google Colab」で「Rinna-3.6B」のLoRAファインチューニングを試したのでまとめました。

【注意】Google Colab Pro/Pro+ の A100で動作確認しています。VRAMは14.0GB必要でした。


1. Rinna-3.6B

OpenCALM-7B」は、「サイバーエージェント」が開発した、日本語LLMです。商用利用可能なライセンスで公開されており、このモデルをベースにチューニングすることで、対話型AI等の開発が可能です。

Rinna-3.6B」は、「Rinna」が開発した、日本語LLMです。商用利用可能なライセンスで公開されており、このモデルをベースにチューニングすることで、対話型AI等の開発が可能です。

2. 学習

「Google Colab」で「Rinna-3.6B」のLoRAファインチューニングを行います。データセットは@kun1em0nさんの「kunishou/databricks-dolly-15k-ja」を使わせてもらいました。
学習手順は、次のとおりです。

(1) メニュー「編集→ノートブックの設定」で、「ハードウェアアクセラレータ」で「GPU」で「A100」を選択。

(2) Googleドライブのマウント。

# Googleドライブのマウント
from google.colab import drive
drive.mount("/content/drive")

(3) 作業フォルダへの移動。

# 作業フォルダへの移動
import os
os.makedirs("/content/drive/My Drive/work", exist_ok=True)
%cd '/content/drive/My Drive/work'

(4) 基本パラメータの定義。

# 基本パラメータ
model_name = "rinna/japanese-gpt-neox-3.6b"
dataset = "kunishou/databricks-dolly-15k-ja"
peft_name = "lora-rinna-3.6b"
output_dir = "lora-rinna-3.6b-results"

以下のパラメータを定義します。

・model_name : モデル名
・dataset : データセット名
・peft_name : PEFTの出力フォルダ
・output_dir : 学習結果の出力フォルダ

(5) PEFTのインストール。
Rinnaのトークナーザーでは、「sentencepiece」も必要になります。

# PEFTのインストール
!pip install -Uqq  git+https://github.com/huggingface/peft.git
!pip install -Uqq transformers datasets accelerate bitsandbytes
!pip install sentencepiece

(6) トークナイザーの準備。
Rinnaのトークナイザーでは、「use_fast=False」も必要になります。

from transformers import AutoTokenizer

# トークナイザーの準備
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

(7) トークナイザーのスペシャルトークンの確認。
EOSは3でした。

# スペシャルトークンの確認
print(tokenizer.special_tokens_map)
print("bos_token :", tokenizer.bos_token, ",", tokenizer.bos_token_id)
print("eos_token :", tokenizer.eos_token, ",", tokenizer.eos_token_id)
print("unk_token :", tokenizer.unk_token, ",", tokenizer.unk_token_id)
print("pad_token :", tokenizer.pad_token, ",", tokenizer.pad_token_id)
{'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}
bos_token : <s> , 2
eos_token : </s> , 3
unk_token : [UNK] , 1
pad_token : [PAD] , 0

(8) トークナイズ関数の定義。
CALMのtokenizer()はEOSが自動追加されなかったので「+"<|endtext|>"」を付加しましたが、Rinnaのtokenizer()は自動追加されてました。

CUTOFF_LEN = 256  # コンテキスト長

# トークナイズ
def tokenize(prompt, tokenizer):
    result = tokenizer(
        prompt,
        truncation=True,
        max_length=CUTOFF_LEN,
        padding=False,
    )
    return {
        "input_ids": result["input_ids"],
        "attention_mask": result["attention_mask"],
    }

(9) トークナイズ関数の確認。
「input_ids」の最後にEOS「3」が追加されてることを確認します。

# トークナイズの動作確認
tokenize("hi there", tokenizer)
{'input_ids': [3201, 634, 1304, 3], 'attention_mask': [1, 1, 1, 1]}

(10) データセットの準備。

from datasets import load_dataset

# データセットの準備
data = load_dataset(dataset)

(11) データセットの確認。

# データセットの確認
data["train"][5]
{'category': 'information_extraction',
 'instruction': 'ステイルメイトの時に、私の方が多くの駒を持っていたら、私の勝ちですか?',
 'output': 'いいえ。\nステイルメイトとは、引き分けた状態のことです。どちらがより多くの駒を捕獲したか、または優勢であるかは関係ない',
 'input': 'ステイルメイトとは、チェスにおいて、手番が回ってきたプレーヤーがチェックされておらず、合法的な手がない状態のことである。ステイルメイトの結果、引き分けとなる。終盤では、ステイルメイトは劣勢にあるプレイヤーが負けるのではなく、ゲームを引き分けることを可能にする戦術である[2]。より複雑なポジションでは、ステイルメイトはより稀で、通常は優勢側が不注意な場合にのみ成功する詐欺の形をとる[引用] ステイルメイトは終盤研究や他のチェスの問題においても共通のテーマである。\n\nステイルメイトが引き分けに統一されたのは19世紀である。それ以前は、ステイルメイトしているプレイヤーの勝利、引き分け、負けとみなされたり、反則となったり、ステイルメイトしているプレイヤーはターンを失うことになったりと、その扱いは様々であった。ステイルメイトのルールは、チェス以外のチャトランガ系ゲームごとに異なる。',
 'index': '5'}

(12) プロンプトテンプレートの準備。
学習用のレスポンス内容あり版になります。「Rinna-instruction」に習って、日本語かつ改行を<NL>にしています。

# プロンプトテンプレートの準備
def generate_prompt(data_point):
    if data_point["input"]:
        result = f"""### 指示:
{data_point["instruction"]}

### 入力:
{data_point["input"]}

### 回答:
{data_point["output"]}"""
    else:
        result = f"""### 指示:
{data_point["instruction"]}

### 回答:
{data_point["output"]}"""

    # 改行→<NL>
    result = result.replace('\n', '<NL>')
    return result

(13) プロンプトテンプレートの確認。

# プロンプトテンプレートの確認
print(generate_prompt(data["train"][5]))
### 指示:<NL>ステイルメイトの時に、私の方が多くの駒を持っていたら、私の勝ちですか?<NL><NL>### 入力:<NL>ステイルメイトとは、チェスにおいて、手番が回ってきたプレーヤーがチェックされておらず、合法的な手がない状態のことである。ステイルメイトの結果、引き分けとなる。終盤では、ステイルメイトは劣勢にあるプレイヤーが負けるのではなく、ゲームを引き分けることを可能にする戦術である[2]。より複雑なポジションでは、ステイルメイトはより稀で、通常は優勢側が不注意な場合にのみ成功する詐欺の形をとる[引用] ステイルメイトは終盤研究や他のチェスの問題においても共通のテーマである。<NL><NL>ステイルメイトが引き分けに統一されたのは19世紀である。それ以前は、ステイルメイトしているプレイヤーの勝利、引き分け、負けとみなされたり、反則となったり、ステイルメイトしているプレイヤーはターンを失うことになったりと、その扱いは様々であった。ステイルメイトのルールは、チェス以外のチャトランガ系ゲームごとに異なる。<NL><NL>### 回答:<NL>いいえ。<NL>ステイルメイトとは、引き分けた状態のことです。どちらがより多くの駒を捕獲したか、または優勢であるかは関係ない

(14) 学習データと検証データの準備。

VAL_SET_SIZE = 2000

# 学習データと検証データの準備
train_val = data["train"].train_test_split(
    test_size=VAL_SET_SIZE, shuffle=True, seed=42
)
train_data = train_val["train"]
val_data = train_val["test"]
train_data = train_data.shuffle().map(lambda x: tokenize(generate_prompt(x), tokenizer))
val_data = val_data.shuffle().map(lambda x: tokenize(generate_prompt(x), tokenizer))

(15) モデルの準備。
「load_in_8bit=True」でint8量子化を有効にしています。

from transformers import AutoModelForCausalLM

# モデルの準備
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_8bit=True,
    device_map="auto",
)

(16) LoRAモデルの準備。

from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType

# LoRAのパラメータ
lora_config = LoraConfig(
    r= 8, 
    lora_alpha=16,
    target_modules=["query_key_value"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

# モデルの前処理
model = prepare_model_for_int8_training(model)

# LoRAモデルの準備
model = get_peft_model(model, lora_config)

# 学習可能パラメータの確認
model.print_trainable_parameters()
trainable params: 3244032 || all params: 3610489344 || trainable%: 0.08985020286491116

「LoraConfig」の主なパラメータは、次のとおりです。

・r (int) : low-rank行列の次元数。この値が小さいと、学習可能なパラメータが少なくなる
・lora_alpha (int) : low-rank行列のスケーリングファクター。大きいほどLoRAが活性化
・lora_dropout (float) : LoRAレイヤーのドロップアウト確率。学習時にランダムに一部のノードを無効にして汎化性能向上
・target_modules (Union[List[str],str])
LoRAを適用するモジュールの名前
・bias (str) : バイアス種別 ("none", "all", "lora_only")
・task_type (Union[str, peft.utils.config.TaskType] = None) : タスク種別 (CAUSAL_LM)

(17) トレーナーの準備。

import transformers
eval_steps = 200
save_steps = 200
logging_steps = 20

# トレーナーの準備
trainer = transformers.Trainer(
    model=model,
    train_dataset=train_data,
    eval_dataset=val_data,
    args=transformers.TrainingArguments(
        num_train_epochs=3,
        learning_rate=3e-4,
        logging_steps=logging_steps,
        evaluation_strategy="steps",
        save_strategy="steps",
        eval_steps=eval_steps,
        save_steps=save_steps,
        output_dir=output_dir,
        report_to="none",
        save_total_limit=3,
        push_to_hub=False,
        auto_find_batch_size=True
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

「TrainingArguments」の主なパラメータは、次のとおりです。

・num_train_epochs (float, optional, defaults to 3.0) : 学習エポック数
・max_steps (int, optional, defaults to -1) : 学習ステップ数
・output_dir (str) : 出力フォルダ (predictionとcheckpointの保存)
・report_to (str or List[str], optional, defaults to "all") : 結果とログのレポート先 ("azure_ml", "comet_ml", "mlflow", "neptune", "tensorboard","clearml", "wandb")
・save_total_limit (int, optional) : 保存するチェックポイントの最大数
・auto_find_batch_size (bool, optional, defaults to False) : バッチサイズの自動設定
・resume_from_checkpoint (str, optional) : 再開チェックポイント

・logging_strategy (str or IntervalStrategy, optional, defaults to "steps")
ログ戦略 ("no", "steps", "epoch")
・evaluation_strategy (str or IntervalStrategy, optional, defaults to "no")
評価戦略 ("no", "steps", "epoch")
・save_strategy (str or IntervalStrategy, optional, defaults to "steps")
保存戦略 ("no", "steps", "epoch")
・logging_steps (int or float, optional, defaults to 500) : ログ間隔 (ステップ)
・eval_steps (int or float, optional) : 評価間隔 (ステップ)
・save_steps (int or float, optional, defaults to 500) : 保存間隔 (ステップ)

(18) 学習の実行。
3時間ほどかかりました。

# 学習の実行
model.config.use_cache = False
trainer.train() 
model.config.use_cache = True

# LoRAモデルの保存
trainer.model.save_pretrained(peft_name)

3. 推論

「Google Colab」で「Rinna-3.6B」のLoRAモデルの推論手順は、次のとおりです。

(1) モデルとトークナイザーの準備。
Rinnaのトークナイザーでは、「use_fast=False」も必要になります。

import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

# モデルの準備
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_8bit=True,
    device_map="auto",
)

# トークナイザーの準備
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

# LoRAモデルの準備
model = PeftModel.from_pretrained(
    model, 
    peft_name, 
    device_map="auto"
)

# 評価モード
model.eval()

(2) プロンプトテンプレートの準備。
推論用のレスポンス内容なし版になります。

# プロンプトテンプレートの準備
def generate_prompt(data_point):
    if data_point["input"]:
        result = f"""### 指示:
{data_point["instruction"]}

### 入力:
{data_point["input"]}

### 回答:
"""
    else:
        result = f"""### 指示:
{data_point["instruction"]}

### 回答:
"""

    # 改行→<NL>
    result = result.replace('\n', '<NL>')
    return result

(3) テキスト生成関数の定義。
CALMのtokenizer()はEOSが自動追加されなかったのですが、Rinnaのtokenizer()は自動追加されてしまうため、add_special_tokens=Falseを指定しました。

# テキスト生成関数の定義
def generate(instruction,input=None,maxTokens=256):
    # 推論
    prompt = generate_prompt({'instruction':instruction,'input':input})
    input_ids = tokenizer(prompt, 
        return_tensors="pt", 
        truncation=True, 
        add_special_tokens=False).input_ids.cuda()
    outputs = model.generate(
        input_ids=input_ids, 
        max_new_tokens=maxTokens, 
        do_sample=True,
        temperature=0.7, 
        top_p=0.75, 
        top_k=40,         
        no_repeat_ngram_size=2,
    )
    outputs = outputs[0].tolist()
    print(tokenizer.decode(outputs))

    # EOSトークンにヒットしたらデコード完了
    if tokenizer.eos_token_id in outputs:
        eos_index = outputs.index(tokenizer.eos_token_id)
        decoded = tokenizer.decode(outputs[:eos_index])

        # レスポンス内容のみ抽出
        sentinel = "### 回答:"
        sentinelLoc = decoded.find(sentinel)
        if sentinelLoc >= 0:
            result = decoded[sentinelLoc+len(sentinel):]
            print(result.replace("<NL>", "\n"))  # <NL>→改行
        else:
            print('Warning: Expected prompt template to be emitted.  Ignoring output.')
    else:
        print('Warning: no <eos> detected ignoring output')

(4) 推論の実行。
質問に対する答えだけ返してEOSしているため、学習できてそうです。

generate("自然言語処理とは?")
人間が日常的に使っている自然な言語を理解し、自然に文章を生成する技術。


generate("日本の首都は?")
東京は、日本の首都です。


generate("まどか☆マギカで一番かわいいのは?")
暁美ほむらは、まどかマミの親友であり、魔女の使い魔であるキュウベエの力で、時間を遡り、自分の命を救ってくれた少女の命を救うために、魔法少女になる。  ほむほーむは「魔法少女まどか★マゥト」のアニメで人気となり、多くのファンに愛されました。

関連



この記事が気に入ったらサポートをしてみませんか?