見出し画像

re: WSL2でMedusaを試してみる

こちらの記事は以下の記事の続きで、

Xに書いた次の件(stage2だと速い!)をまとめた記事です。

(注)まとめに際して、使用するデータセットを shi3z/Japanese_Wikipedia_Conversationから、shi3z/ja_conv_wikipedia_orion14B_100Kに変更しています。


1. 学習の前に

使用するモデル

ベースとするモデルは、elyza/ELYZA-japanese-Llama-2-7b-instruct です。ありがとうございます!

使用するデータセット

MedusaのREADMEにあったShareGPT_Vicuna_unfilteredをデータセットとして学習すると、日本語能力が奪われてしまいました。システムプロンプトに何を与えてもすべて英語で回答してしまいます。そりゃそうか…。

「これはとても困った」ので、shi3z さんが公開されている日本語マルチターンデータセット(10万会話)を入力にして試します。貴重なデータセットの公開、ありがとうございます!

ダウンロード

データセットのダウンロードです。

git clone https://huggingface.co/datasets/shi3z/ja_conv_wikipedia_orion14B_10K

設定ファイル

stage1はこちら。datasets.pathを適切に修正します。num_epochsは1にしています。

base_model: elyza/ELYZA-japanese-Llama-2-7b-instruct
base_model_config: elyza/ELYZA-japanese-Llama-2-7b-instruct
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true

load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
  - path: ja_conv_wikipedia_orion14B_100K/ja_conv_orion14b_100K.jsonl
    type: sharegpt
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage1

adapter: qlora
lora_model_dir:

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
  - gate_proj
  - down_proj
  - up_proj
  - q_proj
  - v_proj
  - k_proj
  - o_proj
  - lm_head
lora_target_linear:
lora_fan_in_fan_out:
lora_modules_to_save:

sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0005

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 40
eval_steps: 40
save_steps:
save_total_limit: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  bos_token: "<s>"
  eos_token: "</s>"
  unk_token: "<unk>"

medusa_num_heads: 5
medusa_num_layers: 1
medusa_heads_coefficient: 0.2
medusa_decay_coefficient: 0.8
medusa_logging: true
medusa_scheduler: constant
medusa_lr_multiplier: 4.0
medusa_only_heads: true
ddp_find_unused_parameters: true
# Stage 1: only train the medusa heads
# Stage 2: train the whole model

stage2は、こちら。

base_model: elyza/ELYZA-japanese-Llama-2-7b-instruct
base_model_config: elyza/ELYZA-japanese-Llama-2-7b-instruct
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true

load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
  - path: ja_conv_wikipedia_orion14B_100K/ja_conv_orion14b_100K.jsonl
    type: sharegpt
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage2

adapter: qlora
lora_model_dir:

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
  - gate_proj
  - down_proj
  - up_proj
  - q_proj
  - v_proj
  - k_proj
  - o_proj
  - lm_head
lora_target_linear:
lora_fan_in_fan_out:
lora_modules_to_save:
lora_model_dir: ./Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage1

sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0005

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 40
eval_steps: 40
save_steps:
save_total_limit: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  bos_token: "<s>"
  eos_token: "</s>"
  unk_token: "<unk>"

medusa_num_heads: 5
medusa_num_layers: 1
medusa_heads_coefficient: 0.2
medusa_decay_coefficient: 0.8
medusa_logging: true
medusa_scheduler: constant
medusa_lr_multiplier: 4.0
# medusa_only_heads: true
# ddp_find_unused_parameters: true
# Stage 1: only train the medusa heads
# Stage 2: train the whole model

参考までに、stage1とstage2の差分はこちらです。

$ diff -u axolotl/examples/medusa/elyza_7b_qlora_stage[12]-01.yml
--- axolotl/examples/medusa/elyza_7b_qlora_stage1-01.yml        2024-01-30 11:20:20.591707705 +0900
+++ axolotl/examples/medusa/elyza_7b_qlora_stage2-01.yml        2024-01-30 11:20:38.939502607 +0900
@@ -13,7 +13,7 @@
     type: sharegpt
 dataset_prepared_path:
 val_set_size: 0.01
-output_dir: ./Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage1
+output_dir: ./Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage2

 adapter: qlora
 lora_model_dir:
@@ -33,6 +33,7 @@
 lora_target_linear:
 lora_fan_in_fan_out:
 lora_modules_to_save:
+lora_model_dir: ./Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage1

 sequence_len: 4096
 sample_packing: true
@@ -86,7 +87,7 @@
 medusa_logging: true
 medusa_scheduler: constant
 medusa_lr_multiplier: 4.0
-medusa_only_heads: true
-ddp_find_unused_parameters: true
+# medusa_only_heads: true
+# ddp_find_unused_parameters: true
 # Stage 1: only train the medusa heads
 # Stage 2: train the whole model
$

2. 学習

stage2に進むためには、stage1の学習結果が必要となるようですので、順番に実行していきます。

途中でターミナルへの接続がタイムアウトしてプロセスがkillされたらとても悲しいので、nohupをかましてstdoutの内容はlogファイルに書き出すようにしています。

# stage1
CUDA_VISIBLE_DEVICES=0 nohup accelerate launch -m axolotl.cli.train \
	./axolotl/examples/medusa/elyza_7b_qlora_stage1.yml \
	> ./Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage1.log &
# stage2
CUDA_VISIBLE_DEVICES=0 nohup accelerate launch -m axolotl.cli.train \
	./axolotl/examples/medusa/elyza_7b_qlora_stage2.yml \
	> ./Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage2.log &

そんなこんなでウン十時間経過。はい、できました。

Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage1:

-rw-rw-r-- 1 user user        812  1月 30 14:17 adapter_config.json
-rw-rw-r-- 1 user user 2062971154  1月 30 14:17 adapter_model.bin
-rw-rw-r-- 1 user user       1132  1月 30 11:21 config.json
-rw-rw-r-- 1 user user       2545  1月 30 14:17 README.md
-rw-rw-r-- 1 user user        551  1月 30 11:21 special_tokens_map.json
-rw-rw-r-- 1 user user       1011  1月 30 11:21 tokenizer_config.json
-rw-rw-r-- 1 user user     499723  1月 30 11:21 tokenizer.model

Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage2:

-rw-rw-r-- 1 user user        812  1月 30 20:38 adapter_config.json
-rw-rw-r-- 1 user user 2062971154  1月 30 20:38 adapter_model.bin
-rw-rw-r-- 1 user user       1132  1月 30 14:26 config.json
-rw-rw-r-- 1 user user       2545  1月 30 20:38 README.md
-rw-rw-r-- 1 user user        551  1月 30 14:26 special_tokens_map.json
-rw-rw-r-- 1 user user       1011  1月 30 14:26 tokenizer_config.json
-rw-rw-r-- 1 user user     499723  1月 30 14:26 tokenizer.model

3. 試してみる

stage1を試す

まずは、stage1から。

CUDA_VISIBLE_DEVICES=0 python -m medusa.inference.cli \ 
	--model ./Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage1 \
	--conv-system-msg "あなたは誠実で優秀な日本人のアシスタントです。" \
	--max-steps 256

5回聞きます。平均 秒あたり 39.0 トークンでした。
※「 !!reset 」は会話履歴をクリアする、「 !!exit 」は処理終了のそれぞれコマンドです。

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [01:57<00:00, 58.64s/it]
Some weights of MedusaModelLlama were not initialized from the model checkpoint at elyza/ELYZA-japanese-Llama-2-7b-instruct and are newly initialized: ['medusa_head.3.1.weight', 'medusa_head.2.1.weight', 'medusa_head.1.1.weight', 'medusa_head.2.0.linear.bias', 'medusa_head.4.0.linear.weight', 'medusa_head.2.0.linear.weight', 'medusa_head.3.0.linear.bias', 'medusa_head.1.0.linear.weight', 'medusa_head.1.0.linear.bias', 'medusa_head.0.0.linear.bias', 'medusa_head.0.0.linear.weight', 'medusa_head.3.0.linear.weight', 'medusa_head.0.1.weight', 'medusa_head.4.0.linear.bias', 'medusa_head.4.1.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[INST]: ドラえもんとはなにか
[/INST]: /mnt/data/shoji_noguchi/venv/medusa-llm/Medusa/medusa/model/medusa_model.py:232: UserWarning: Please specify medusa choice configuration!
  warnings.warn('Please specify medusa choice configuration!')
承知しました。
ドラえもんは、藤子・F・不二雄によって生み出された漫画の主人公です。

ドラえもんは、小学生の男の子で、大のおっかけであるのが特徴です。彼は、「ひみつ道具」を持っていて、様々な場面で役に立つ道具を使います。彼の生活の中心は、のび太とのび太の家族である、のび太 の父、のび太の母、ジャイアン、スネ夫、しずか、どこでもドアの8人です。

ドラえもんは、1970年から始まったテレビアニメでも人気が高く、様々なメディアでの展開が行われています。れ
---
prompt tokens = 19
output tokens = 251 (35.315387 [tps])
   total time = 7.107384 [s]
[INST]: !!reset
resetting...
[INST]: ドラえもんとはなにか
[/INST]: 承知しました。
ドラえもんは、藤子・F・不二雄によって生み出された漫画の主人公です。

ドラえもんは、小学生の男の子で、大のおっかけであるのが特徴です。彼は、「ひみつ道具」を持っていて、様々な場面で役に立つ道具を使います。彼の生活の中心は、のび太とのび太の家族である、のび太 の父、のび太の母、ジャイアン、スネ夫、しずか、どこでもドアの8人です。

ドラえもんは、1970年から始まったテレビアニメでも人気が高く、様々なメディアでの展開が行われています。れ
---
prompt tokens = 19
output tokens = 251 (39.240447 [tps])
   total time = 6.396461 [s]
[INST]: !!reset
resetting...
[INST]: ドラえもんとはなにか
[/INST]: 承知しました。
ドラえもんは、藤子・F・不二雄によって生み出された漫画の主人公です。

ドラえもんは、小学生の男の子で、大のおっかけであるのが特徴です。彼は、「ひみつ道具」を持っていて、様々な場面で役に立つ道具を使います。彼の生活の中心は、のび太とのび太の家族である、のび太 の父、のび太の母、ジャイアン、スネ夫、しずか、どこでもドアの8人です。

ドラえもんは、1970年から始まったテレビアニメでも人気が高く、様々なメディアでの展開が行われています。れ
---
prompt tokens = 19
output tokens = 251 (40.192272 [tps])
   total time = 6.244982 [s]
[INST]: !!reset
resetting...
[INST]: ドラえもんとはなにか
[/INST]: 承知しました。
ドラえもんは、藤子・F・不二雄によって生み出された漫画の主人公です。

ドラえもんは、小学生の男の子で、大のおっかけであるのが特徴です。彼は、「ひみつ道具」を持っていて、様々な場面で役に立つ道具を使います。彼の生活の中心は、のび太とのび太の家族である、のび太 の父、のび太の母、ジャイアン、スネ夫、しずか、どこでもドアの8人です。

ドラえもんは、1970年から始まったテレビアニメでも人気が高く、様々なメディアでの展開が行われています。れ
---
prompt tokens = 19
output tokens = 251 (40.316779 [tps])
   total time = 6.225696 [s]
[INST]: !!reset
resetting...
[INST]: ドラえもんとはなにか
[/INST]: 承知しました。
ドラえもんは、藤子・F・不二雄によって生み出された漫画の主人公です。

ドラえもんは、小学生の男の子で、大のおっかけであるのが特徴です。彼は、「ひみつ道具」を持っていて、様々な場面で役に立つ道具を使います。彼の生活の中心は、のび太とのび太の家族である、のび太 の父、のび太の母、ジャイアン、スネ夫、しずか、どこでもドアの8人です。

ドラえもんは、1970年から始まったテレビアニメでも人気が高く、様々なメディアでの展開が行われています。れ
---
prompt tokens = 19
output tokens = 251 (40.712306 [tps])
   total time = 6.165212 [s]
[INST]: !!exit
exit...

stage2を試す

続いて、stage2。こちらも5回聞きます。

CUDA_VISIBLE_DEVICES=0 python -m medusa.inference.cli \
	--model ./Elyza-japanese-Llama-2-7b-instruct_qlora_ja_conv_stage2 \
	--conv-system-msg "あなたは誠実で優秀な日本人のアシスタントです。" \
	--max-steps 256


5回聞きます。平均 秒あたり 50.9 トークンでした。
ただ、推論結果がstage1よりも短くなる傾向にあるのですよね。

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████| 2/2 [00:31<00:00, 15.73s/it]
Some weights of MedusaModelLlama were not initialized from the model checkpoint at elyza/ELYZA-japanese-Llama-2-7b-instruct and are newly initialized: ['medusa_head.1.0.linear.weight', 'medusa_head.0.1.weight', 'medusa_head.0.0.linear.weight', 'medusa_head.3.0.linear.bias', 'medusa_head.2.0.linear.bias', 'medusa_head.2.0.linear.weight', 'medusa_head.1.0.linear.bias', 'medusa_head.1.1.weight', 'medusa_head.4.1.weight', 'medusa_head.3.1.weight', 'medusa_head.4.0.linear.bias', 'medusa_head.3.0.linear.weight', 'medusa_head.2.1.weight', 'medusa_head.0.0.linear.bias', 'medusa_head.4.0.linear.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[INST]: ドラえもんとはなにか
[/INST]: /mnt/data/shoji_noguchi/venv/medusa-llm/Medusa/medusa/model/medusa_model.py:232: UserWarning: Please specify medusa choice configuration!
  warnings.warn('Please specify medusa choice configuration!')
ドラえもんは、藤子・F・不二雄の漫画作品であり、日本の漫画家である藤子・F・不二雄の代表作です。
1970年から1977年まで週刊少年サンデーに連載され、1979年から1996年までテレビアニメ化されました。また、1990年代からは映画化も されています。9
---
prompt tokens = 19
output tokens = 142 (45.777373 [tps])
   total time = 3.101969 [s]
[INST]: !!reset
resetting...
[INST]: ドラえもんとはなにか
[/INST]: ドラえもんは、藤子・F・不二雄の漫画作品であり、日本の漫画家である藤子・F・不二雄の代表作です。
1970年から1977年まで週刊少年サンデーに連載され、1979年から1980年まではテレビアニメ化されました。
---
prompt tokens = 19
output tokens = 117 (52.974490 [tps])
   total time = 2.208610 [s]
[INST]: !!reset
resetting...
[INST]: ドラえもんとはなにか
[/INST]: ドラえもんは、藤子・F・不二雄の漫画作品であり、日本の漫画家である藤子・F・不二雄の代表作です。
1970年から1977年まで週刊少年サンデーに連載され、1979年から1980年まではテレビアニメ化されました。
---
prompt tokens = 19
output tokens = 117 (52.957893 [tps])
   total time = 2.209302 [s]
[INST]: !!reset
resetting...
[INST]: ドラえもんとはなにか
[/INST]: ドラえもんは、藤子・F・不二雄の漫画作品であり、日本の漫画家である藤子・F・不二雄の代表作です。
1970年から1977年まで週刊少年サンデーに連載され、1979年から1980年まではテレビアニメ化されました。
---
prompt tokens = 19
output tokens = 117 (51.590991 [tps])
   total time = 2.267838 [s]
[INST]: !!reset
resetting...
[INST]: ドラえもんとはなにか
[/INST]: ドラえもんは、藤子・F・不二雄の漫画作品であり、日本の漫画家である藤子・F・不二雄の代表作です。
1970年から1977年まで週刊少年サンデーに連載され、1979年から1980年まではテレビアニメ化されました。
---
prompt tokens = 19
output tokens = 117 (53.476681 [tps])
   total time = 2.187869 [s]
[INST]: !!exit
exit...

4. まとめ

これまでの弊環境における 秒あたりトークン数をまとめると以下です。

# Medusa stage1 by Legacy: datasets = Aeala/ShareGPT_Vicuna_unfiltered
29.5

# Medusa stage1 by axolotl: datasets = Aeala/ShareGPT_Vicuna_unfiltered
23.4

# Medusa stage1 by axolotl: datasets = shi3z/ja_conv_wikipedia_orion14B_100K
39.0

# Medusa stage2 by axolotl: datasets = shi3z/ja_conv_wikipedia_orion14B_100K
50.9

# transfomers
11.9
 
# vLLM
54.1

データセットの違いで速度が変わるように見えます。頭が混乱しないからかしら。

vLLMより遅いように見えますが、普通に使用していると秒あたり 55とか59 トークンとvLLMを超えるときもあります。

[INST]: しずかちゃんについて詳しく教えてください
[/INST]: しずかちゃんは、ドラえもんの妹であり、藤子不二雄の漫画作品『ドラえもん』に登場する架空の人物です。彼女は小学生の女の子で、ドラえもんの妹であり、彼女の妹であるしずかちゃんも登場します。彼女は兄のドラえもんと同じく、冷凍庫に入っていることがあります。し
---
prompt tokens = 188
output tokens = 151 (55.873504 [tps])
   total time = 2.702533 [s]
[INST]: もっと詳しく教えてください
[/INST]: もっと詳しく教えると、しずかちゃんはドラえもんの妹であり、彼女は小学生の女の子です。彼女はドラえもんと同じく、冷凍庫に入っていることがあります。彼女はドラえもんの妹であり、彼女は小学生の女の子です。
---
prompt tokens = 364
output tokens = 121 (59.690177 [tps])
   total time = 2.027134 [s]

だだ、全体的に問いに対する回答が微妙なのですよね…。

追記 - 2024/2/2

nums_epoch:2 の結果

Medusaのstage1, 2の学習をepoch 2で回してlossを減らしたら、変わるかしら?と思い、試しました。

確かに、epoch 1と比較してlossは僅かですが小さくなりました。
・stage1: 1.9101 ⇒ 1.8838
・stage2: 2.3897 ⇒ 2.3564

ただ、推論結果の傾向は大きくは変わらずでした。
・stage1: ベースモデルに近しい内容
・stage2: 箇条書き、同一単語の利用も多くなり、応答が短くなる傾向あり

先読みが影響しているのですかね。

生成したMedusaの頭

stage1とstage2とも、Hugging Faceにアップロードしています。ぜひ、ご賞味ください。

・stage1

・stage2

P.S.
サムネイルはCopilotに「ElyzaとMedusaと聞いて想像するイメージ」として作成されたものです。

この記事が気に入ったらサポートをしてみませんか?