rinnaのllama-3-youko-8bを試す。
やったこと
とりあえず、cliでモデルロード後、続けて入力と出力が出てくるコード。
cudaが使える環境とPyTorchがインストールされていればtransformersをインストールすれば動くはず。
GPUは16G程度必要です。
pip install transformers
user: 入力でプロンプトを入力、何も入力しなければ終了します。
rinna: で応答が出力されます。
生成時間が応答後に表示されます。
コード
import transformers
import torch
import time
model_id = "rinna/llama-3-youko-8b"
pipeline = transformers.pipeline(
"text-generation",
model=model_id,
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto"
)
while True:
try:
message = input("user:")
except EOFError:
message = ""
if not message:
print("exit...")
break
start_time=time.time()
output = pipeline(
message,
max_new_tokens=256,
do_sample=True
)
#print(output)
ans=output[0]["generated_text"]
print("rinna;",ans)
print("genaraton time=",round(time.time()-start_time,2),"秒")
結果
user:大阪は、
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
rinna; 大阪は、北新地、梅田、なんばとどこでもご紹介します。
genaraton time= 0.67 秒