見出し画像

WSL2でMixtral-8x7B-Instructを試してみる

自称「Llama 2 70Bの6倍速い」というMixtral-8x7Bの4bit量子化を、無謀と分かりつつ試してみます。
やってみなくちゃ分からない。いや、やらなくても分かっている。

使用するPCは、GALLERIA UL9C-R49(RTX 4090 laptop 16GB)、メモリは64GB、OSはWindows 11+WSL2です。


準備

python3 -m venv mixtral-8x7b
cd $_
source bin/activate

pip install

pip install torch transformers accelerate
pip install bitsandbytes scipy

pip list

$ pip list
Package                  Version
------------------------ ------------
accelerate               0.25.0
bitsandbytes             0.41.3.post2
certifi                  2023.11.17
charset-normalizer       3.3.2
filelock                 3.13.1
fsspec                   2023.12.2
huggingface-hub          0.19.4
idna                     3.6
Jinja2                   3.1.2
MarkupSafe               2.1.3
mpmath                   1.3.0
networkx                 3.2.1
numpy                    1.26.2
nvidia-cublas-cu12       12.1.3.1
nvidia-cuda-cupti-cu12   12.1.105
nvidia-cuda-nvrtc-cu12   12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12        8.9.2.26
nvidia-cufft-cu12        11.0.2.54
nvidia-curand-cu12       10.3.2.106
nvidia-cusolver-cu12     11.4.5.107
nvidia-cusparse-cu12     12.1.0.106
nvidia-nccl-cu12         2.18.1
nvidia-nvjitlink-cu12    12.3.101
nvidia-nvtx-cu12         12.1.105
packaging                23.2
pip                      22.0.2
psutil                   5.9.6
PyYAML                   6.0.1
regex                    2023.10.3
requests                 2.31.0
safetensors              0.4.1
scipy                    1.11.4
setuptools               59.6.0
sympy                    1.12
tokenizers               0.15.0
torch                    2.1.1
tqdm                     4.66.1
transformers             4.36.0
triton                   2.1.0
typing_extensions        4.9.0
urllib3                  2.1.0

流し込むコード

load_in_4bit = True です。

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import time

if torch.cuda.is_available():
    torch.device('cuda')
else:
    torch.device('cpu')

llm = "mistralai/Mixtral-8x7B-Instruct-v0.1"

# トークナイザーとモデルの準備
tokenizer = AutoTokenizer.from_pretrained(llm)
model = AutoModelForCausalLM.from_pretrained(
    llm,
    torch_dtype="auto",
    load_in_4bit=True
)

streamer = TextStreamer(
    tokenizer,
    skip_prompt=True,
    skip_special_tokens=True
)

# Llama 2 based
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
DEFAULT_SYSTEM_PROMPT = "あなたは誠実で優秀な日本人のアシスタントです。"

def build_prompt(user_query, chat_history):
    prompt = "{chat_history}{b_inst} {system}{prompt} {e_inst} ".format(
        chat_history=chat_history,
        b_inst=B_INST,
        system=f"{B_SYS}{DEFAULT_SYSTEM_PROMPT}{E_SYS}",
        prompt=user_query,
        e_inst=E_INST,
    )
    return prompt

def q(user_query, chat_history):
    start = time.process_time()
    # 推論の実行
    prompt = build_prompt(user_query, chat_history)
    input_ids = tokenizer.encode(
        tokenizer.bos_token + prompt,
        add_special_tokens=False,
        return_tensors="pt"
    )
    output_ids = model.generate(
        input_ids.to(device=model.device),
        max_new_tokens=1024,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        do_sample=True,
        temperature=0.8,
        streamer=streamer,
    )
    output = tokenizer.decode(
        output_ids[0][input_ids.size(1) :],
        skip_special_tokens=True
    )
    # print(output)
    chat_history = prompt + output
    end = time.process_time()
    print(end - start)
    return chat_history

試してみる

溢れるメモリ

流し込みます。でかい。。。

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████| 19/19 [00:50<00:00,  2.66s/it]

想定の範囲内ですが、GPUからは溢れてしまいました。ただ、ロード自体は正常に終了。要したGPUメモリは26.8GBでした。

Mixtral-8x7B-Instruct-v0.1 4bit量子化をロードした際のGPUメモリ

聞いてみる

めげずに推論します。

chat_history = ""
chat_history = q("ドラえもんとはなにか", chat_history)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
ドラえもんは、日本の漫画とアニME série で、小さな💎 boy で知られる yellow 色のCharacter 。彼の名前は、「ドラえもん」であり 、それは「дorayaki」(デspringframework リ薩) に由来し、彼はそのデсерトを大好kind of dislike しています。彼は、素晴らしいRobot であり、そして、主人公の小学校生A boy named Nobita と、彼の friends のAvventure に出演します。ドラえもんは、fly できるようになる能力を持っています、そして、時々、彼のRobot のようなFeature を使って、Nobita のアイデアを実現して助けます。
198.19729039999999

Mixtral-8x7B-Instruct-v0.1 より

GPUメモリが溢れてしまっているので、当然に時間も溢れています。その点は仕方が無いですね。
回答がルー大柴気味なのは、日本語対応が未だからか。

しかし、spring frameworkなる文字列が紛れ込んでいるのは、ドラえもんがJavaでできているということを示唆しているのでしょうか(そんなわけない。

めげずに追加質問です。

chat_history = q("続きを教えてください", chat_history)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
ドラえもん series は、1970年代にデブー Debut し、世界中で愛されるPopular 漫画とアニメ serie になりました。彼らのアバンチュ ール Adventures は、小学校生の日常生活の中で起こるEvent で、科学的na scientific なDevice、Space travel、時空 continuum にはたらくFantasy な要素が入ります。

ドラえもん series は、幼児から子供までのChildren の教育の一環としても認められ、諸Genre のTheme 、如 Stories of friendship, honesty, courage, and the struggles of growing up が取り上げられています。同時に、ドラえもん series は、日本の文化のCulture
と歴史のAspect を示しています、彼らは、日本の歴史的Site 、Traditional clothing and festivals を描いたり、日本のMythology と Folklore からのReference を取ったりします。

ドラえもんは、また、日本のポップカルチャーPop culture のIcon として、登場人物Character や、系列のMusic theme など、多様なProduct とMedia を生み出しました。彼の Popularity は、かつての日本のPrime minister までにも及び、彼は世界中で愛されています。
359.5565687

Mixtral-8x7B-Instruct-v0.1 より

やはり回答はルー大柴気味。

この記事が気に入ったらサポートをしてみませんか?