WSL2でphi-2を試してみる
師走は忙しいと実感する今日この頃です。
来ました、Microsoftのphi-2です。MicrosoftですからやはりここはWSL2で試さなければなりません。
使用するPCは、GALLERIA UL9C-R49(RTX 4090 laptop 16GB)、メモリは64GB、OSはWindows 11+WSL2です。
準備
打ち込んでいきます。
python3 -m venv phi-2
cd $_
source bin/activate
pip install
pip install torch transformers accelerate
pip intall einops
pip list
$ pip list
Package Version
------------------------ ----------
accelerate 0.25.0
certifi 2023.11.17
charset-normalizer 3.3.2
einops 0.7.0
filelock 3.13.1
fsspec 2023.12.2
huggingface-hub 0.19.4
idna 3.6
Jinja2 3.1.2
MarkupSafe 2.1.3
mpmath 1.3.0
networkx 3.2.1
numpy 1.26.2
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu12 2.18.1
nvidia-nvjitlink-cu12 12.3.101
nvidia-nvtx-cu12 12.1.105
packaging 23.2
pip 22.0.2
psutil 5.9.7
PyYAML 6.0.1
regex 2023.10.3
requests 2.31.0
safetensors 0.4.1
setuptools 59.6.0
sympy 1.12
tokenizers 0.15.0
torch 2.1.2
tqdm 4.66.1
transformers 4.36.2
triton 2.1.0
typing_extensions 4.9.0
urllib3 2.1.0
プロンプトの書式
README.mdを読むに、
QA Format
Chat Format
Code Format
が例示されていますね。Instruction Formatで試してみます。
pythonコードを作る
以下の注意書きがあるので、tokenizer, modelのオブジェクト生成時に引数を追加しておきましょう。
こんな感じです。
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import time
llm = "microsoft/phi-2"
# トークナイザーとモデルの準備
tokenizer = AutoTokenizer.from_pretrained(
llm,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
llm,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True
)
streamer = TextStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
B_INST, E_INST = "Instruction:", "Output:"
def build_prompt(user_query, chat_history):
prompt = "{b_inst} {prompt}\n{e_inst} ".format(
b_inst=B_INST,
prompt=user_query,
e_inst=E_INST
)
if chat_history:
prompt = chat_history + "\n" + prompt
return prompt
def q(user_query, chat_history):
start = time.process_time()
# 推論の実行
prompt = build_prompt(user_query, chat_history)
input_ids = tokenizer.encode(
tokenizer.bos_token + prompt,
add_special_tokens=False,
return_tensors="pt"
)
output_ids = model.generate(
input_ids.to(device=model.device),
max_new_tokens=128,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=0.8,
streamer=streamer,
)
output = tokenizer.decode(
output_ids[0][input_ids.size(1) :],
skip_special_tokens=True
)
# print(output)
chat_history = prompt + output
end = time.process_time()
print(end - start)
return chat_history
聞いてみる
日本語が・・・
とりあえず聞いてみます。
>>> chat_history = ""
>>> chat_history = q("ドラえもんとはなにか", chat_history)
おっおー…。
>>> chat_history = q("続きを教えてください", chat_history)
ダメか・・・。
英語で
聞いてみます。
>>> chat_history = ""
>>> chat_history = q("What is Doraemon?", chat_history)
かなりいい。そして速い。ただ、「どらぼー」ってなんの道具だ?
続きを聞いてみましょう。
>>> chat_history = q("Please continue.", chat_history)
いい、さらに聞く。
>>> chat_history = q("Please continue.", chat_history)
ドラえもん、いいです! あれ?
リソース
7.2GBぐらいですかね。
nvidia-smiも参考に。
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.36 Driver Version: 546.33 CUDA Version: 12.3 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 4090 ... On | 00000000:02:00.0 On | N/A |
| N/A 42C P8 6W / 150W | 7398MiB / 16376MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 4961 C /python3.10 N/A |
+---------------------------------------------------------------------------------------+
この記事が気に入ったらサポートをしてみませんか?