見出し画像

rinnaをCTranslate2から遊ぶメモ

このメモを読むと

・CTranslate2を導入できる
・ローカルLLMの文章生成が爆速になる

検証環境

・Windows11
・VRAM24GB
・ローカル(Anaconda)
・2023/6/M時点

事前準備

Anacondaを使うメモ|おれっち (note.com)
Gitを使うメモ|おれっち (note.com)

CTranslate2とは

transeformerモデルで効率的に推論するライブラリです。
ローカルLLMの文章生成がとにかく速くなります。
ためしてみましょう!

CTranslate2導入

とても簡単です!

1. 仮想環境を作成し、環境切替

conda create -n ct2test python=3.10
activate ct2test

2. 追加パッケージのインストール

pip install ctranslate2 
pip install sentencepiece transformers protobuf==3.20.0
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

完了です!

CTranslate2を使ってみる

用意するものは以下の一つです。
 ・ベースモデル
 こちらをお借りしました:rinna/japanese-gpt-neox-3.6b-instruction-ppo

モデル変換

CTranslate2では既存のモデルを変換する必要があります。
好きな名前で下記スクリプトを作成し実行します。

import subprocess

base_model = "rinna/japanese-gpt-neox-3.6b-instruction-ppo"
output_dir = "rinnna-ppo-int8"
quantization_type = "int8"

command = f"ct2-transformers-converter --model {base_model} --quantization {quantization_type} --output_dir {output_dir}"
subprocess.run(command, shell=True)

推論

好きな名前で下記スクリプトを作成し実行します。

import ctranslate2
import transformers
import time

base_model = "rinna/japanese-gpt-neox-3.6b-instruction-ppo"
quant_model = "rinnna-ppo-int8"

generator = ctranslate2.Generator(quant_model, device="cuda")
tokenizer = transformers.AutoTokenizer.from_pretrained(base_model, use_fast=False)

# プロンプトを作成する
def prompt(msg):
    p = [
        {"speaker": "ユーザー", "text": msg},
    ]
    p = [f"{uttr['speaker']}: {uttr['text']}" for uttr in p]
    p = "<NL>".join(p)
    p = p + "<NL>" + "システム: "
    return p

# 返信を作成する
def reply(msg):
    tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt(msg),add_special_tokens=False,))

    results = generator.generate_batch(
        [tokens],
        max_length=256,
        sampling_topk=10,
        sampling_temperature=0.9,
        include_prompt_in_result=False,
    )

    text = tokenizer.decode(results[0].sequences_ids[0])
    print("A: " + text)
    return text

if __name__ == "__main__":
    while True:
        msg = input("Q: ")
        start_time = time.time()
        reply(msg)
        end_time = time.time()
        elapsed_time = end_time - start_time
        print(f"Elapsed time: {elapsed_time} seconds")

Q: カレーにジャガイモは入れるべき?
A: はい、ジャガイモはカレーに非常に良い付け合わせです。ジャガイモは、カレーに味と食感を与え、カレーがより風味豊かになるため、非常に人気があります。
Elapsed time: 0.31789183616638184 seconds

出力

VRAM使用量は4GB前後でした。

おわり

爆速です。

参考資料

CTranslate2でrinna instructionをquantizeして動かす|if001 (note.com)
RinnaのppoモデルをCTranslate2で高速に動かす (zenn.dev)

この記事が気に入ったらサポートをしてみませんか?