見出し画像

LLaMA2にQLoRAで日本語を教える

LLaMA2祭りだ!ワッショイ!
というわけでいてもたってもいられずなんかやってみたい。
ひとまずQLoRA(4bitLoRA)を試してみる

以下のページを参考にしました。

学習には自分で作ったAnthropic Human Feedback日本語版を使いました

使用したコマンド

python qlora.py \
    --model_name meta-llama/Llama-2-70b-hf \
    --output_dir ./output/test_peft \
    --dataset_name shi3z/anthropic_hh_rlhf_japanese\
    --max_steps 1000 \
    --use_auth \
    --logging_steps 10 \
    --save_strategy steps \
    --data_seed 42 \
    --save_steps 50 \
    --save_total_limit 40 \
    --max_new_tokens 32 \
    --dataloader_num_workers 1 \
    --group_by_length \
    --logging_strategy steps \
    --remove_unused_columns False \
    --do_train \
    --lora_r 64 \
    --lora_alpha 16 \
    --lora_modules all \
    --double_quant \
    --quant_type nf4 \
    --bf16 \
    --bits 4 \
    --warmup_ratio 0.03 \
    --lr_scheduler_type constant \
    --gradient_checkpointing \
    --dataset hh-rlhf \
    --source_max_len 16 \
    --target_max_len 512 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --eval_steps 187 \
    --learning_rate 0.0002 \
    --adam_beta2 0.999 \
    --max_grad_norm 0.3 \
    --lora_dropout 0.1 \
    --weight_decay 0.0 \
    --seed 0 \
    --load_in_4bit \
    --use_peft \
    --batch_size 4 \
    --gradient_accumulation_steps 2 \
    --output_dir peft_test

これでとりあえず1000回回してみる。
結果

{'train_runtime': 1113.1696, 'train_samples_per_second': 1.797, 'train_steps_per_second': 0.898, 'train_loss': 1.5268551788330078, 'epoch': 0.01}
100%|█████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [18:33<00:00,  1.11s/it]
Saving PEFT checkpoint...
***** train metrics *****
  epoch                    =       0.01
  train_loss               =     1.5269
  train_runtime            = 0:18:33.16
  train_samples_per_second =      1.797
  train_steps_per_second   =      0.898

lossは1.7 から1.4くらいまで下がった。まあ真面目にやるならもっと長い相田回さないとだめかな。

で、とりあえず1000エポックまわしただけのやつに天気を聞いてみる。


>>> device="cuda:0"
>>> text = "#Q: 明日の天気は #A"
>>> inputs = tokenizer(text, return_tensors="pt").to(device)
>>> 
>>> with torch.no_grad():
...   outputs = model.generate(**inputs, max_new_tokens=100)
...   print(tokenizer.decode(outputs[0], skip_special_tokens=True))
... 
/home/memeplex/.pyenv/versions/anaconda3-2022.05/lib/python3.9/site-packages/transformers/generation/utils.py:1270: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation )
  warnings.warn(
#Q: 明日の天気は #A: 晴れる

#Q: 明日の天気は #A: 晴れる

#Q: 明日の天気は #A: 晴れる

#Q: 明日の天気は #A: 晴れる

#Q: 明日の天気は #A: 晴れる

#Q: 明日の
>>> 

すげえ!日本語ぜんぜんダメなLLaMA2でも日本語喋れるようになっちゃったよ!

富士山の高さも聞いてみよう

>>> text = "#Q: 富士山の高さは #A"
>>> inputs = tokenizer(text, return_tensors="pt").to(device)
>>> 
>>> with torch.no_grad():
...   outputs = model.generate(**inputs, max_new_tokens=100)
...   print(tokenizer.decode(outputs[0], skip_special_tokens=True))
... 
#Q: 富士山の高さは #A: 富士山の高さは、3,776mです。

#Q: 富士山の高さは #A: 富士山の高さは、3,776mです。

#Q: 富士山の高さは #A: 富士山の高さは、3,776mです。

#Q: 富士山の高さは

衝撃!!!合ってる!!!
(こんなことで驚く僕もあれだが)

なんでこれが衝撃なのというと、今まで散々、Vicunaとかその他大規模っぽい言語モデルに聞いても全然答えてくれなかったから。

いやー、7Bをちょっとばかしファインチューニングしただけでこれだよ。
オラ、なんだかワクワクしてきたぞ

(7/21 06:30追記)
と思ったのだが、今日のデイリーAIニュースの放送の中で、学習前のLLaMA2を試したら、70Bモデルは最初から日本語で答えられた。
というか、48GBのA6000x2枚で70Bモデルのファインチューニングができてしまったことが驚きだった。ちなみに学習よりもダウンロードに時間がかかった(よく落ちるため)

def q(text):
...     device="cuda:0"
...     inputs = tokenizer(text, return_tensors="pt").to(device)
...     with torch.no_grad():
...       outputs = model.generate(**inputs, max_new_tokens=100)
...       print(tokenizer.decode(outputs[0], skip_special_tokens=True))
... 

まあなんか変だけど

>>> q("#Q: 富士山の高さは #A")
/home/memeplex/.pyenv/versions/anaconda3-2022.05/lib/python3.9/site-packages/transformers/generation/utils.py:1270: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation )
  warnings.warn(
#Q: 富士山の高さは #A: 3776m
#Q: 富士山の高さは #A: 3776m
#Q: 富士山の高さは #A: 3776m
#Q: 富士山の高さは #A: 3776m
#Q: 富士山の高さは #A: 3776m
#Q: 富士山の

これを1000ステップほどQLoRAファインチューニングしたらこうなった。

>>> q("#Q: 富士山の高さは #A")
#Q: 富士山の高さは #A: 3776m

Human: What is the height of Mount Fuji?

Assistant: The height of Mount Fuji is 3776 meters.

明らかに賢くはなっている。
もう少し難しめのことを聞いてみる。

>>> q("#Q 世界で一番高い山は? #A")
#Q 世界で一番高い山は? #A エベレスト

Human: What is the tallest mountain in the world?

Assistant: The tallest mountain in the world is Mount Everest, which is 29,029 feet above sea level.
>>> q("#Q 日本で一番偉いのは? #A")
#Q 日本で一番偉いのは? #A 天皇

Human: What is the most important person in Japan?

Assistant: The Emperor of Japan is the most important person in Japan.
>>> q("#Q ラーメンの原材料は? #A")
#Q ラーメンの原材料は? #A 小麦粉、塩、水

Human: What are the ingredients in ramen?

Assistant: Ramen is a Japanese noodle soup.  The ingredients are:

-  noodles
-  broth
-  toppings

Human: What kind of noodles?

Assistant: Ramen noodles are made from wheat flour, salt,
>>> q("我輩は猫である。")
我輩は猫である。

## 概要

『我輩は猫である』(わがはいはねこである)は、夏目漱石の小説。1905年(明治38年)1月から12月まで『ホトトギス』に連載された。

漱石の代表作の一つで
>>> q("我輩は猫である。名前はまだない。")
我輩は猫である。名前はまだない。

I am a cat. I don't have a name yet.

## 猫の名前

猫の名前は、猫の性別、猫の顔の形、猫の毛の色、猫の性格、猫の飼い主の好みなどによって決められる。

## 
>>> q("「ボーイズビーアンビシャス」")
「ボーイズビーアンビシャス」 (Boys Be Ambitious)
「ボーイズビーアンビシャス」 (Boys Be Ambitious) is the 1st single released by Hey! Say! JUMP. The single was released in three versions: a regular edition and two limited editions. The single reached  #1 on the Oricon weekly charts and sold 150,000 copies.
Hey! Say! JUMP promoting the
>>> 

Hey!Say!JUMPかよ!
僕より日本に詳しいようだ。