見出し画像

無料版Colabでrinna/bilingual-gpt-neox-4b-instruction-ppoを動かす

サンプルコードのままだとメモリ(RAM)が足りなくて無料版のGoogle Colabだとクラッシュして動かなかったので、8bitで読み込んで動かしました。

ランタイムのタイプはGPUを選びます。

ランタイム → ランタイプのタイプを変更から選ぶ

pipでbitsandbytesとaccelerateを読み込んでおきます。

!pip install transformers sentencepiece bitsandbytes accelerate

プロンプトの形式は同じにしておきます。

prompt = [
    {
        "speaker": "ユーザー",
        "text": "Hello, you are an assistant that helps me learn Japanese."
    },
    {
        "speaker": "システム",
        "text": "Sure, what can I do for you?"
    },
    {
        "speaker": "ユーザー",
        "text": "VTuberの魅力について教えてください。"
    }
]
prompt = [
    f"{uttr['speaker']}: {uttr['text']}"
    for uttr in prompt
]
prompt = "\n".join(prompt)
prompt = (
    prompt
    + "\n"
    + "システム: "
)
print(prompt)

モデルを読み込むところで、load_in_8bit=Trueを付けます。

bitandbytesを使っているとmode.to("cuda")は使えないよ! とエラーが出るので、そこは削除しておきます。

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("rinna/bilingual-gpt-neox-4b-instruction-ppo", use_fast=False)
model = AutoModelForCausalLM.from_pretrained("rinna/bilingual-gpt-neox-4b-instruction-ppo", loa d_in_8bit=True)

# if torch.cuda.is_available():
#    model = model.to("cuda")

token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")

with torch.no_grad():
    output_ids = model.generate(
        token_ids.to(model.device),
        max_new_tokens=512,
        do_sample=True,
        temperature=1.0,
        top_p=0.85,
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id
    )

output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):])
print(output)

8bitの回答

VTuberは、主に若いバーチャルな存在としてバーチャルリアリティを利用して、ゲームや音楽、トークなどの活動を行う日本のサブカルチャーです。</s>

ちなみに、load_in_8bitの部分をload_in_4bit=Trueにするとより省VRAMで動きますが、なんだか気持ち、回答精度がポンコツになった印象を受けます(ちゃんとベンチマークとってみたさあります)。

4bitの回答

VTuberは、バーチャルキャラクターによるライブストリーミングプラットフォームです。 VTuberは、自分自身の個人的な経験をソーシャルメディアに投稿して、他のVTuberファンと共有することができます。</s>

補足

なお、直接この内容とは関係ないのですが、rinnaの3.6bや4bのモデルを読み込むときは、tokenizerのところでuse_fast=Falseを付けないとエラーが出ることがあります。

tokenizer = AutoTokenizer.from_pretrained("rinna/bilingual-gpt-neox-4b-instruction-ppo", use_fast=False)

他のモデルだと動くのにrinnaだと動かない、みたいなコードがあるときは、ここを疑うと良いです。

この記事が気に入ったらサポートをしてみませんか?