
WSL2でSakana AIを試してみる
「進化的モデルマージにより日本語数学LLMとして構築したEvoLLM-JPは、数学のみならず、日本語の全般的な能力に長けている」らしいEvoLLM-JPを試してみます。
追記(2024/3/22)
10Bのモデルですが、torch_dtypeを"auto"からtorch.bfloat16に変更すると、推論のスピードが改善しました。
モデルEvoLLM-JPは、3種提供されています。今回は3つとも試します。
- 大規模言語モデル
SakanaAI/EvoLLM-JP-v1-10B : MICROSOFT RESEARCH LICENSE TERMS
SakanaAI/EvoLLM-JP-v1-7B : MICROSOFT RESEARCH LICENSE TERMS
SakanaAI/EvoLLM-JP-A-v1-7B : Apache License, Version 2.0
※画像言語モデルは、別の機会に。
SakanaAI/EvoVLM-JP-v1-7B : Apache License, Version 2.0
使用するPCはドスパラさんの「GALLERIA UL9C-R49」。スペックは
・CPU: Intel® Core™ i9-13900HX Processor
・Mem: 64 GB
・GPU: NVIDIA® GeForce RTX™ 4090 Laptop GPU(16GB)・GPU: NVIDIA® GeForce RTX™ 4090 (24GB)
・OS: Ubuntu22.04 on WSL2(Windows 11)
です。
1. 準備
venv構築
python3 -m venv sakanaai
cd $_
source bin/activate
パッケージのインストール。
pip install torch transformers accelerate flash_attn
2. 流し込むコード
いつものコードです。query.pyという名前で保存します。
import sys
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
from typing import List, Dict
import time
# argv
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default=None)
parser.add_argument("--no-chat", action='store_true')
parser.add_argument("--no-use-system-prompt", action='store_true')
parser.add_argument("--max-tokens", type=int, default=256)
args = parser.parse_args(sys.argv[1:])
model_id = args.model_path
if model_id == None:
exit
is_chat = not args.no_chat
use_system_prompt = not args.no_use_system_prompt
max_new_tokens = args.max_tokens
# トークナイザーとモデルの準備
tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype="auto",
device_map="auto",
#device_map="cuda",
low_cpu_mem_usage=True,
trust_remote_code=True
)
#if torch.cuda.is_available():
# model = model.to("cuda")
streamer = TextStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
DEFAULT_SYSTEM_PROMPT = "あなたは誠実で優秀な日本人のアシスタントです。"
# generation params
generation_params = {
"do_sample": True,
"temperature": 0.8,
"top_p": 0.95,
"top_k": 40,
"max_new_tokens": max_new_tokens,
"repetition_penalty": 1.1,
}
def q(
user_query: str,
history: List[Dict[str, str]]=None
) -> List[Dict[str, str]]:
start = time.process_time()
# messages
messages = ""
if is_chat:
messages = []
if use_system_prompt:
messages = [
{"role": "system", "content": DEFAULT_SYSTEM_PROMPT},
]
user_messages = [
{"role": "user", "content": user_query}
]
else:
user_messages = user_query
if history:
user_messages = history + user_messages
messages += user_messages
# generation prompts
if is_chat:
prompt = tokenizer.apply_chat_template(
conversation=messages,
add_generation_prompt=True,
tokenize=False
)
else:
prompt = messages
input_ids = tokenizer.encode(
prompt,
add_special_tokens=True,
return_tensors="pt"
)
print("--- prompt")
print(prompt)
print("--- output")
# 推論
output_ids = model.generate(
input_ids.to(model.device),
streamer=streamer,
**generation_params
)
output = tokenizer.decode(
output_ids[0][input_ids.size(1) :],
skip_special_tokens=True
)
if is_chat:
user_messages.append(
{"role": "assistant", "content": output}
)
else:
user_messages += output
end = time.process_time()
##
input_tokens = len(input_ids[0])
output_tokens = len(output_ids[0][input_ids.size(1) :])
total_time = end - start
tps = output_tokens / total_time
print(f"prompt tokens = {input_tokens:.7g}")
print(f"output tokens = {output_tokens:.7g} ({tps:f} [tps])")
print(f" total time = {total_time:f} [s]")
return user_messages
print('history = ""')
print('history = q("ドラえもんとはなにか")')
print('history = q("続きを教えてください", history)')
3. 試してみる
(1) SakanaAI/EvoLLM-JP-v1-10B
pythonコマンドで起動して、
python -i query.py --model-path SakanaAI/EvoLLM-JP-v1-10B
と実行して、聞いてみます。
>>> history = q("ドラえもんとはなにか")
--- output
「ドラえもん」は、藤子・F・ガMOの漫画「ドラえもん」の主人公の小さなロボットです。ドラえもんは、冒険を好み、お客様を助けることを喜びます。彼はしゃらくま(コアラ)の胃の袋内に身を隠して、困難に直面した時に出現し、解決策を提供します。
prompt tokens = 55
output tokens = 136 (0.087481 [tps])
total time = 1554.620631 [s]
26分かかりました。少し化けてもいます。
追記(2024/3/22)
XにてSakana AIの中の方から「bfloat16にしてみて」とのコメントをいただきました。
Thanks for trying our models! For the 10B model, can you try loading it with bfloat16? I believe that's what caused the inference speed difference.
— Yujin Tang (@yujin_tang) March 21, 2024
で、試したところ、確かに大幅改善しました。
ドラえもんは、鳥居の下から生まれたしずくを養HANDYを助けた温泉忍者の勢into子がこNAMESCさんたらどんな名前になる?
prompt tokens = 55
output tokens = 63 (8.515899 [tps])
total time = 7.397927 [s]
(2) SakanaAI/EvoLLM-JP-v1-7B
>>> history = q("ドラえもんとはなにか")
ドラえもんは、安孫子の世界で最も有名な漫画キャラクターである。彼は小学生の主人公・しんちゃんの友人であり、未来から来た時空 を超える猫型ロボットである。彼の名前は「どらえもん」と読み、これ は「ドラ」(ドラム) + 「えもん」(英語の「emon」)という 意味で、「驚きの源泉」を表しています。
prompt tokens = 55
output tokens = 152 (26.959319 [tps])
total time = 5.638125 [s]
(3) SakanaAI/EvoLLM-JP-A-v1-7B
>>> history = q("ドラえもんとはなにか")
ドラえもんは、鳥のような姿をした恐竜の子供の幼体。ドラえもんの外見は鳥のような羽根と長い尾、小さな手足を持っている。ドラえもんは、かわいらしい外見と温和な性格で知られています。ドラえもん の名前は「小さなドラゴン」という意味です。ドラえもんは、地球で生息している胴龙の子供の幼体であり、その大きな龍の名前は「フー」です。
prompt tokens = 55
output tokens = 177 (4.062542 [tps])
total time = 43.568775 [s]
追記(2024/3/22)
こちらも bfloat16 で試したところ、、、
ドラえもんは、小さい恐竜のような外見をしているけど、心は大きく優しいマンガとテレビアニメのキャラクターです。彼はサtoshiと共に様々な冒険を体験します。
prompt tokens = 55
output tokens = 90 (24.724166 [tps])
total time = 3.640163 [s]
4. まとめ
ソースとしているモデルが異なるから当然と言えば当然ですが、応答結果と速度がかなり違いますね。
・EvoLLM-JP-v1-10B : 0.087481 [トークン/秒]
・EvoLLM-JP-v1-10B : 8.515899 [トークン/秒]
・EvoLLM-JP-v1-7B : 26.959319 [トークン/秒]・EvoLLM-JP-A-v1-7B : 4.062542 [トークン/秒]
・EvoLLM-JP-A-v1-7B : 24.724166 [トークン/秒]
「うん?推論、遅いなぁ」と思ったら、torch_dtype、"auto"ではなくtorch.bfloat16みたく指定してみるのがよいですね。