見出し画像

rinnaのllama-3-youko-8bを試す。


やったこと

とりあえず、cliでモデルロード後、続けて入力と出力が出てくるコード。

cudaが使える環境とPyTorchがインストールされていればtransformersをインストールすれば動くはず。
GPUは16G程度必要です。

pip install transformers

user: 入力でプロンプトを入力、何も入力しなければ終了します。
rinna: で応答が出力されます。
生成時間が応答後に表示されます。

コード

import transformers
import torch
import time

model_id = "rinna/llama-3-youko-8b"
pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto"
)

while True:
    try:
        message = input("user:")
    except EOFError:
        message = ""
    if not message:
        print("exit...")
        break
    start_time=time.time()
    output = pipeline(
        message,
        max_new_tokens=256,
        do_sample=True
    )
    #print (output)
    ans=output[0]["generated_text"]
    print("rinna;",ans)
    print("genaraton time=",round(time.time()-start_time,2),"秒")

結果

user:大阪は、
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
rinna; 大阪は、北新地、梅田、なんばとどこでもご紹介します。
genaraton time= 0.67 秒