![見出し画像](https://assets.st-note.com/production/uploads/images/125830532/rectangle_large_type_2_ead00a89a1d9b0e2ebc5044760b6191e.png?width=800)
Google Colab で vLLM を試す
「Google Colab」で「vLLM」を試したので、まとめました。
【注意】Google Colab Pro/Pro+のA100で動作確認しています。
1. vLLM
「vLLM」は、LLMの高速推論のためのライブラリです。
・最先端のサービススループット
・PagedAttendantによるアテンションキーと値のメモリの効率的な管理
・受信リクエストの継続的なバッチ処理
・CUDA/HIP グラフによる高速モデル実行
・量子化 (GPTQ、AWQ、SqueezeLLM)
・最適化されたCUDAカーネル
「vLLM」には柔軟性があり、次の用途に対応しています。
・HuggingFaceモデルとのシームレスな統合
・並列サンプリング、Beam Searchなどを含む様々なデコードアルゴリズムによる高スループットサービス
・分散推論のためのテンソル並列処理のサポート
・ストリーミング出力
・OpenAI対応APIサーバー
・NVIDIA GPU と AMD GPU をサポート
2. サポートモデル
「vLLM」は、次のモデルを含む多くのHuggingFaceモデルをシームレスにサポートします。
・Aquila & Aquila2 (BAAI/AquilaChat2-7B, BAAI/AquilaChat2-34B, BAAI/Aquila-7B, BAAI/AquilaChat-7B, etc.)
・Baichuan & Baichuan2 (baichuan-inc/Baichuan2-13B-Chat, baichuan-inc/Baichuan-7B, etc.)
・BLOOM (bigscience/bloom, bigscience/bloomz, etc.)
・ChatGLM (THUDM/chatglm2-6b, THUDM/chatglm3-6b, etc.)
・DeciLM (Deci/DeciLM-7B, Deci/DeciLM-7B-instruct, etc.)
・Falcon (tiiuae/falcon-7b, tiiuae/falcon-40b, tiiuae/falcon-rw-7b, etc.)
・GPT-2 (gpt2, gpt2-xl, etc.)
・GPT BigCode (bigcode/starcoder, bigcode/gpt_bigcode-santacoder, etc.)
・GPT-J (EleutherAI/gpt-j-6b, nomic-ai/gpt4all-j, etc.)
・GPT-NeoX (EleutherAI/gpt-neox-20b, databricks/dolly-v2-12b, stabilityai/stablelm-tuned-alpha-7b, etc.)
・InternLM (internlm/internlm-7b, internlm/internlm-chat-7b, etc.)
・LLaMA & LLaMA-2 (meta-llama/Llama-2-70b-hf, lmsys/vicuna-13b-v1.3, young-geng/koala, openlm-research/open_llama_13b, etc.)
・Mistral (mistralai/Mistral-7B-v0.1, mistralai/Mistral-7B-Instruct-v0.1, etc.)
・Mixtral (mistralai/Mixtral-8x7B-v0.1, mistralai/Mixtral-8x7B-Instruct-v0.1, etc.)
・MPT (mosaicml/mpt-7b, mosaicml/mpt-30b, etc.)
・OPT (facebook/opt-66b, facebook/opt-iml-max-30b, etc.)
・Phi (microsoft/phi-1_5, microsoft/phi-2, etc.)
・Qwen (Qwen/Qwen-7B, Qwen/Qwen-7B-Chat, etc.)
・Yi (01-ai/Yi-6B, 01-ai/Yi-34B, etc.)
3. Colabでの実行
Colabでの実行手順は、次のとおりです。
(1) Colabのノートブックを開き、メニュー「編集 → ノートブックの設定」で「GPU」の「A100」を選択。
(2) パッケージのインストール。
# パッケージのインストール
!pip install vllm
(3) LLMの準備。
今回は、「elyza/ELYZA-japanese-Llama-2-13b-instruct」を使います。
from vllm import LLM
# LLMの準備
llm = LLM(model="elyza/ELYZA-japanese-Llama-2-13b-instruct")
(4) プロンプトテンプレートの準備。
import string
# プロンプトテンプレートの準備
template = string.Template("""<s>[INST] <<SYS>>
あなたは誠実で優秀な日本人のアシスタントです。
<</SYS>>
${instruct} [/INST] """)
(5) 推論の実行。
%%time
from vllm import SamplingParams
# プロンプトの準備
prompts = [
"まどか☆マギカでは誰が一番かわいい?",
]
for i in range(len(prompts)):
prompts[i] = template.safe_substitute({"instruct": prompts[i]})
# 推論の実行
outputs = llm.generate(
prompts,
sampling_params = SamplingParams(
temperature=0.5,
max_tokens=256,
)
)
for output in outputs:
print("Prompt:", output.prompt, "\n")
print("Response:", output.outputs[0].text, "\n----\n")
Prompt: <s>[INST] <<SYS>>
あなたは誠実で優秀な日本人のアシスタントです。
<</SYS>>
まどか☆マギカでは誰が一番かわいい? [/INST]
Response: まどか☆マギカに登場するキャラクターは、全員魅力的でかわいいと言える要素を持っています。しかし、人によって好みは異なるため、一概に誰が一番かわいいかは言うことができません。
ただし、人気投票などでは、以下のキャラクターが上位にランクインすることが多いです。
1. 杏子
2. 澪
3. まどか
4. ひまわり
5. 輪る
これはあくまで一意見でしかありません。
----
CPU times: user 4.74 s, sys: 0 ns, total: 4.74 s
Wall time: 4.73 s
VRAM使用量は、次のとおりです。
![](https://assets.st-note.com/img/1703737310977-xS2mKWbUYl.png?width=800)
「SamplingParams」のパラメータは、次のとおりです。
・n : プロンプトに対して返される出力シーケンス数
・best_of : プロンプトから生成される出力シーケンスの数。best_ofシーケンスから、上位nシーケンスが返される
・frequency_penalty : 生成されたテキスト内の頻度に基づいて、新しいトークンにペナルティを与える浮動小数点数。値 > 0 の場合は新しいトークンの使用を推奨、値 < 0 の場合はトークン繰り返しを推奨
・repetition_penalty : 新しいトークンがプロンプトおよびこれまでに生成されたテキストに表示されるかどうかに基づいて、新しいトークンにペナルティを与える浮動小数点数。値 > 1 の場合は新しいトークンの使用を推奨、値 < 1 の場合はトークン繰り返しを推奨
・temperature : サンプリングのランダム性を制御する浮動小数点数。値が低いほどより決定的、値が高いほどよりランダム。0は貪欲なサンプリング
・top_p : 考慮する上位トークンの累積確率を制御する浮動小数点数。 (0, 1] でなければならない。すべてのトークンを考慮するには 1 に設定
・top_k : 考慮する上位トークンの数を制御する整数。すべてのトークンを考慮するには、-1 に設定
・min_p : 最も可能性の高いトークンの確率と比較して、考慮されるトークンの最小確率を表す浮動小数点数。[0, 1] になければならない。これを無効にするには 0 に設定
・use_beam_search : サンプリングの代わりにBeam Searchを使用するかどうか
・length_penalty : 長さに基づいてシーケンスにペナルティを与える浮動小数。Beam Searchに使用
・early_stopping : Beam Searchの停止条件を制御
・True : best_of の完全な候補が存在するとすぐに生成が停止
・False : ヒューリスティックが適用され、より適切な候補が見つかる可能性が非常に低い場合に生成が停止
・never : Beam Search手順は、より良い候補が存在しない場合にのみ停止
・stop : 生成時に生成を停止する文字列のリスト。返される出力には停止文字列は含まれない
・stop_token_ids : 生成時に生成を停止するトークンのリスト。返される出力には、ストップトークンがスペシャルトークンでない限り、ストップトークンが含まれます
・include_stop_str_in_output : 出力テキストに停止文字列を含めるかどうか。デフォルトはFalse
・ignore_eos : EOS トークンが生成された後、EOS トークンを無視してトークンの生成を続行するかどうか
・max_tokens : 出力シーケンスごとに生成するトークンの最大数
・logprobs : 出力トークンごとに返されるログの確率の数。 実装は OpenAI API に従っていることに注意。返される結果には、最も可能性の高いlogprobsトークンのログ確率と、選択されたトークンが含まれる。API は常にサンプリングされたトークンの対数確率を返すため、応答には最大 logprobs+1 要素が含まれる可能性がある
・prompt_logprobs : プロンプト トークンごとに返されるログの確率の数。
・skip_special_tokens : 出力内の特別なトークンをスキップするかどうか
・space_between_special_tokens : 出力内の特別なトークンの間にスペースを追加するかどうか。デフォルトはTrue
・logits_processors : 以前に生成されたトークンに基づいてロジットを変更する関数のリスト
(5) 3つのプロンプトの推論の実行。
プロンプトを3つに増やしても、処理時間は3倍にはならないことを確認します。
%%time
from vllm import SamplingParams
# プロンプトの準備
prompts = [
"まどか☆マギカでは誰が一番かわいい?",
"自然言語処理とは?",
"Pythonでtest.txtを読み込むコードは?",
]
for i in range(len(prompts)):
prompts[i] = template.safe_substitute({"instruct": prompts[i]})
# 推論の実行
outputs = llm.generate(
prompts,
sampling_params = SamplingParams(
temperature=0.5,
max_tokens=256,
)
)
for output in outputs:
print("Prompt:", output.prompt, "\n")
print("Response:", output.outputs[0].text, "\n----\n")
Prompt: <s>[INST] <<SYS>>
あなたは誠実で優秀な日本人のアシスタントです。
<</SYS>>
まどか☆マギカでは誰が一番かわいい? [/INST]
Response: まどか☆マギカに登場するキャラクターは、それぞれ魅力的で愛らしいところがあります。
しかし、一般的には、以下のキャラクターが特にかわいいと言われています。
1. 美鈴
2. 杏子
3. 鶴乃
4. 絹子
5. まどか
これらの順位は一意見ではなく、個人の感想によって異なると考えられます。
----
Prompt: <s>[INST] <<SYS>>
あなたは誠実で優秀な日本人のアシスタントです。
<</SYS>>
自然言語処理とは? [/INST]
Response: 自然言語処理 (Natural Language Processing) とは、人間が自然な言語で入力した文章などをコンピュータが理解し、様々な処理を行うための技術のことです。
自然言語は、人間が日常的に使用する言語であり、様々な表現や曖昧さが含まれているため、コンピュータにとって理解しやすい形式のデータとは異なります。そのため、自然言語をコンピュータに理解させるためには、専用のアルゴリズムや技術が必要であり、これを自然言語処理と呼びます。
自然言語処理の具体的な応用例としては、チャットボッ
----
Prompt: <s>[INST] <<SYS>>
あなたは誠実で優秀な日本人のアシスタントです。
<</SYS>>
Pythonでtest.txtを読み込むコードは? [/INST]
Response: test.txtを読み込むには、open関数を使います。以下が一例です。
```python
with open("test.txt", "r") as f:
data = f.read()
```
----
CPU times: user 5.95 s, sys: 0 ns, total: 5.95 s
Wall time: 5.92 s
この記事が気に入ったらサポートをしてみませんか?