見出し画像

Google Colab + CTranslate2 による Rinnaの高速推論を試す

「Google Colab」と「CTranslate2」による Rinnaの高速推論を試したのでまとめました。

【注意】「Google Colab」でモデル変換するのにハイメモリが必要でした。T4のノーマルのメモリで大丈夫そうです。

1. CTranslate2

「CTranslate2」は、Transformerモデルを効率的に推論するためのC++ および Python ライブラリです。

2. モデルの変換

「Google Colab」でのモデル変換の手順は、次のとおりです。

(1) Googleドライブのマウント。
変換したモデルを保存したいのでGoogleドライブをマウントします。

# Googleドライブのマウント
from google.colab import drive
drive.mount("/content/drive")

(2) 作業フォルダへの移動

# 作業フォルダへの移動
import os
os.makedirs("/content/drive/My Drive/work", exist_ok=True)
%cd "/content/drive/My Drive/work"

(3) インストール。

# パッケージのインストール
!pip install ctranslate2
!pip install sentencepiece transformers

(4) モデルの変換。
CTranslate2フォーマットに変換します。

!ct2-transformers-converter \
    --model rinna/japanese-gpt-neox-3.6b-instruction-ppo \
    --quantization int8 \
    --output_dir ./rinna_ppo

パラメータは、次のとおりです。

・-h, --help : ヘルプ
・--model_path <MODEL_PATH> : モデルパス
・--output_dir <OUTPUT_DIR> :出力ディレクトリ
・--vocab_mapping <VOCAB_MAPPING> : 語彙マッピングファイル
・--quantization {int8,int8_float16,int16,float16} : 重みの量子化種別
・--force : 出力ディレクトリが存在する場合に上書き

ハイメモリ必要で、2分ほどかかりました。
出力フォルダに次の3つのファイルが生成されます。

・config.json
・model.bin
・vocabulary.txt

3. 推論

「Google Colab」での推論の手順は、次のとおりです。

(1) ジェネレーターとトークナイザーの準備

import ctranslate2
import transformers
import torch

# ジェネレーターとトークナイザーの準備
device = "cuda" if torch.cuda.is_available() else "cpu"
generator = ctranslate2.Generator("./rinna_ppo", device=device)
tokenizer = transformers.AutoTokenizer.from_pretrained(
    "rinna/japanese-gpt-neox-3.6b-instruction-ppo", use_fast=False)

(2) 推論。

# プロンプトの準備
prompt = "ユーザー: まどか☆マギカでは誰が一番かわいい?<NL>システム: "

# 推論の実行
tokens = tokenizer.convert_ids_to_tokens(
    tokenizer.encode(prompt, add_special_tokens=False)
)
results = generator.generate_batch(
    [tokens],
    max_length=64,
    sampling_topk=10,
    sampling_temperature=0.7,
    include_prompt_in_result=False,
)
text = tokenizer.decode(results[0].sequences_ids[0])
print(text)
私は鹿目まどかが一番かわいいと思います。彼女はとてもかわいくて、優しくて、思いやりがあります。また、まどかはとても強いです。彼女はとても強い力を持っており、とても強いです。彼女は自分の運命に逆らって、自分の力を使って立ち向かっていきます。彼女はとても強いです

generator.generate_batch()のパラメータは、次のとおりです。

・start_tokens: List[List[str]] : 開始トークンのバッチ。デコーダが<s>のような特別な開始トークンから開始する場合、このトークンをこの入力に追加する必要があります。
・max_batch_size: int = 0 : 最大バッチサイズ
・batch_type: str = 'examples' : バッチ種別
・asynchronous: bool = False : 非同期実行
・beam_size: int = 1 : ビームサイズ (1:貪欲検索)
 ・patience: float = 1 : ビーム検索忍耐係数
・num_hypotheses: int = 1 : 返される仮説数
・length_penalty: float = 1 : 長さペナルティ
・repetition_penalty: float = 1 : 反復ペナルティ
・no_repeat_ngram_size: int = 0 : このサイズの ngram の繰り返しを防止(0:無効)
・disable_unk: bool = False : 不明トークンの生成を無効
・suppress_sequences: Optional[List[List[str]]] = None : トークンの一部のシーケンスの生成を無効
・end_token: Optional[Union[str, List[str], List[int]]] = None : 停止トークン(デフォルト:EOSトークン)
・return_end_token: bool = False : 結果に終了トークンを含める
・max_length: int = 512 : 最大長
・min_length: int = 0 : 最小長
・static_prompt: Optional[List[str]] = None : システムプロンプト
・cache_static_prompt: bool = True : 同じプロンプトを使用した時に再利用
・include_prompt_in_result: bool = True : 結果にstart_tokensを含める
・return_scores: bool = False : 出力にスコアを含める
・return_alternatives: bool = False : 最初の制約のないデコード位置で代替案を返す
・min_alternative_expansion_prob: float = 0 : 代替案を展開する最小の初期確率
・sampling_topk: int = 1 : 上位K個の候補から予測をランダムサンプリング
・sampling_temperature: float = 1 : サンプリング温度
・callback: Callable[[GenerationStepResult], bool] = None : コールバック

参考



この記事が気に入ったらサポートをしてみませんか?