rinnaの30億パラメータ日本語会話モデルを試す
rinnaから日本語会話モデルがリリースされていたので試した。
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-sft", use_fast=False)
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-sft")
if torch.cuda.is_available():
model = model.to("cuda")
def rinna(prompt):
token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
with torch.no_grad():
output_ids = model.generate(
token_ids.to(model.device),
do_sample=True,
max_new_tokens=128,
temperature=0.7,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):])
output = output.replace("<NL>", "\n")
print(output)
def getp():
s=input()
return f"ユーザー: {s}<NL>システム: "
rinnaのトークナイザは特殊で、\nを解釈しないで<NL>になるのでちょっとだけ注意が必要。また、会話モデルもなぜか独自の「ユーザー: {プロンプト}:<NL>システム: 」という感じで指定しなければならないのでちょっとだけ癖がある。
簡単な質問をしてみる
>>> rinna(getp())
富士山について教えて
富士山は活火山です。頂上には噴火口や噴気孔があり、頂上には巨大な溶岩ドームもあります。</s>
>>> rinna(getp())
東京ドームについて
東京ドームは、アジア最大級のドーム球場です。収容人数は55,000人以上で、野球、コンサート、格闘技、フットボールなど、様々なイベントが開催されてきました。また、ドーム内にはレストラン、カフェ、バーもあります。</s>
結構いい。
サイバーエージェントのOpenCALM7Bとは目的が違うのだが、受け答えという点ではしっかりしている。
やはりモデルサイズよりもどんなデータを学習したかということが大事なようだ