rinnaの30億パラメータ日本語会話モデルを試す

rinnaから日本語会話モデルがリリースされていたので試した。

import torch 
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-sft", use_fast=False)
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-sft")

if torch.cuda.is_available():
  model = model.to("cuda")

def rinna(prompt):
  token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
  with torch.no_grad():
     output_ids = model.generate(
         token_ids.to(model.device),
         do_sample=True,
         max_new_tokens=128,
         temperature=0.7,
         pad_token_id=tokenizer.pad_token_id,
         bos_token_id=tokenizer.bos_token_id,
         eos_token_id=tokenizer.eos_token_id
     )
  output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):])
  output = output.replace("<NL>", "\n")
  print(output)


def getp():
  s=input()
  return f"ユーザー: {s}<NL>システム: "

rinnaのトークナイザは特殊で、\nを解釈しないで<NL>になるのでちょっとだけ注意が必要。また、会話モデルもなぜか独自の「ユーザー: {プロンプト}:<NL>システム: 」という感じで指定しなければならないのでちょっとだけ癖がある。

簡単な質問をしてみる

>>> rinna(getp())
富士山について教えて
富士山は活火山です。頂上には噴火口や噴気孔があり、頂上には巨大な溶岩ドームもあります。</s>
>>> rinna(getp())
東京ドームについて
東京ドームは、アジア最大級のドーム球場です。収容人数は55,000人以上で、野球、コンサート、格闘技、フットボールなど、様々なイベントが開催されてきました。また、ドーム内にはレストラン、カフェ、バーもあります。</s>

結構いい。
サイバーエージェントのOpenCALM7Bとは目的が違うのだが、受け答えという点ではしっかりしている。

やはりモデルサイズよりもどんなデータを学習したかということが大事なようだ