
Google Colab + CTranslate2 による Rinnaの高速推論を試す
「Google Colab」と「CTranslate2」による Rinnaの高速推論を試したのでまとめました。
【注意】「Google Colab」でモデル変換するのにハイメモリが必要でした。T4のノーマルのメモリで大丈夫そうです。
1. CTranslate2
「CTranslate2」は、Transformerモデルを効率的に推論するためのC++ および Python ライブラリです。
2. モデルの変換
「Google Colab」でのモデル変換の手順は、次のとおりです。
(1) Googleドライブのマウント。
変換したモデルを保存したいのでGoogleドライブをマウントします。
# Googleドライブのマウント
from google.colab import drive
drive.mount("/content/drive")
(2) 作業フォルダへの移動
# 作業フォルダへの移動
import os
os.makedirs("/content/drive/My Drive/work", exist_ok=True)
%cd "/content/drive/My Drive/work"
(3) インストール。
# パッケージのインストール
!pip install ctranslate2
!pip install sentencepiece transformers
(4) モデルの変換。
CTranslate2フォーマットに変換します。
!ct2-transformers-converter \
--model rinna/japanese-gpt-neox-3.6b-instruction-ppo \
--quantization int8 \
--output_dir ./rinna_ppo
パラメータは、次のとおりです。
・-h, --help : ヘルプ
・--model_path <MODEL_PATH> : モデルパス
・--output_dir <OUTPUT_DIR> :出力ディレクトリ
・--vocab_mapping <VOCAB_MAPPING> : 語彙マッピングファイル
・--quantization {int8,int8_float16,int16,float16} : 重みの量子化種別
・--force : 出力ディレクトリが存在する場合に上書き
ハイメモリ必要で、2分ほどかかりました。
出力フォルダに次の3つのファイルが生成されます。
・config.json
・model.bin
・vocabulary.txt
3. 推論
「Google Colab」での推論の手順は、次のとおりです。
(1) ジェネレーターとトークナイザーの準備
import ctranslate2
import transformers
import torch
# ジェネレーターとトークナイザーの準備
device = "cuda" if torch.cuda.is_available() else "cpu"
generator = ctranslate2.Generator("./rinna_ppo", device=device)
tokenizer = transformers.AutoTokenizer.from_pretrained(
"rinna/japanese-gpt-neox-3.6b-instruction-ppo", use_fast=False)
(2) 推論。
# プロンプトの準備
prompt = "ユーザー: まどか☆マギカでは誰が一番かわいい?<NL>システム: "
# 推論の実行
tokens = tokenizer.convert_ids_to_tokens(
tokenizer.encode(prompt, add_special_tokens=False)
)
results = generator.generate_batch(
[tokens],
max_length=64,
sampling_topk=10,
sampling_temperature=0.7,
include_prompt_in_result=False,
)
text = tokenizer.decode(results[0].sequences_ids[0])
print(text)
私は鹿目まどかが一番かわいいと思います。彼女はとてもかわいくて、優しくて、思いやりがあります。また、まどかはとても強いです。彼女はとても強い力を持っており、とても強いです。彼女は自分の運命に逆らって、自分の力を使って立ち向かっていきます。彼女はとても強いです
generator.generate_batch()のパラメータは、次のとおりです。
・start_tokens: List[List[str]] : 開始トークンのバッチ。デコーダが<s>のような特別な開始トークンから開始する場合、このトークンをこの入力に追加する必要があります。
・max_batch_size: int = 0 : 最大バッチサイズ
・batch_type: str = 'examples' : バッチ種別
・asynchronous: bool = False : 非同期実行
・beam_size: int = 1 : ビームサイズ (1:貪欲検索)
・patience: float = 1 : ビーム検索忍耐係数
・num_hypotheses: int = 1 : 返される仮説数
・length_penalty: float = 1 : 長さペナルティ
・repetition_penalty: float = 1 : 反復ペナルティ
・no_repeat_ngram_size: int = 0 : このサイズの ngram の繰り返しを防止(0:無効)
・disable_unk: bool = False : 不明トークンの生成を無効
・suppress_sequences: Optional[List[List[str]]] = None : トークンの一部のシーケンスの生成を無効
・end_token: Optional[Union[str, List[str], List[int]]] = None : 停止トークン(デフォルト:EOSトークン)
・return_end_token: bool = False : 結果に終了トークンを含める
・max_length: int = 512 : 最大長
・min_length: int = 0 : 最小長
・static_prompt: Optional[List[str]] = None : システムプロンプト
・cache_static_prompt: bool = True : 同じプロンプトを使用した時に再利用
・include_prompt_in_result: bool = True : 結果にstart_tokensを含める
・return_scores: bool = False : 出力にスコアを含める
・return_alternatives: bool = False : 最初の制約のないデコード位置で代替案を返す
・min_alternative_expansion_prob: float = 0 : 代替案を展開する最小の初期確率
・sampling_topk: int = 1 : 上位K個の候補から予測をランダムサンプリング
・sampling_temperature: float = 1 : サンプリング温度
・callback: Callable[[GenerationStepResult], bool] = None : コールバック