MacBook Proでcalm2-7b-chatを試してみる
以前、WSL2で試してみたのですが、今回はMacBook Proで試してみましょう。
試したマシンは、MacBook Pro M3 Proチップ、メモリ18GBです。
準備
python3 -m venv calm2
cd $_
source bin/activate
pip install。
pip install -U pip
pip install torch
pip install transformers
pip install accelerate
pip list
% pip list
Package Version
------------------ ----------
accelerate 0.24.1
certifi 2023.11.17
charset-normalizer 3.3.2
filelock 3.13.1
fsspec 2023.10.0
huggingface-hub 0.19.4
idna 3.4
Jinja2 3.1.2
MarkupSafe 2.1.3
mpmath 1.3.0
networkx 3.2.1
numpy 1.26.2
packaging 23.2
pip 23.3.1
psutil 5.9.6
PyYAML 6.0.1
regex 2023.10.3
requests 2.31.0
safetensors 0.4.0
setuptools 58.0.4
sympy 1.12
tokenizers 0.15.0
torch 2.1.1
tqdm 4.66.1
transformers 4.35.2
typing_extensions 4.8.0
urllib3 2.1.0
試してみよう
前回とほぼ同じ。M3 Proなので、cudaとあるコードは実行してもアレなので削除しました。
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import time
llm = "cyberagent/calm2-7b-chat"
# トークナイザーとモデルの準備
tokenizer = AutoTokenizer.from_pretrained(llm)
model = AutoModelForCausalLM.from_pretrained(
llm,
# device_map="auto",
torch_dtype="auto"
)
# if torch.cuda.is_available():
# model = model.to("cuda")
RTX 4090と比較するとちょっと引っかかりますね。
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:24<00:00, 12.44s/it]
>>>
どんどん流し込みます。
streamer = TextStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
def build_prompt(user_query, chat_history):
prompt = "USER: " + user_query + "\n"
prompt += "ASSISTANT: "
if chat_history:
prompt = chat_history + "<|endoftext|>\n" + prompt
return prompt
def q(user_query, chat_history):
start = time.process_time()
# 推論の実行
prompt = build_prompt(user_query, chat_history)
input_ids = tokenizer.encode(
prompt,
return_tensors="pt"
)
output_ids = model.generate(
input_ids.to(device=model.device),
max_new_tokens=12000,
do_sample=True,
temperature=0.8,
streamer=streamer,
)
output = tokenizer.decode(
output_ids[0][input_ids.size(1) :],
skip_special_tokens=True
)
# print(output)
chat_history = prompt + output
end = time.process_time()
print(end - start)
return chat_history
では、チャットしましょう。
chat_history = ""
chat_history = q("小学生にでもわかる言葉で教えてください。ドラえもんとはなにか", chat_history)
約2分…。ちょっと、いや、かなりゆっくりです。
続きを聞いてみます。
chat_history = q("続きを教えてください", chat_history)
ちょっと遅い…。
リソース
メモリの使用量は、14GB前後でした。
CPUは100%越えとなっていました。