見出し画像

WSL2でTOVAを試してみる...が

「Transformers are Multi-State RNNs」とのことなので、その実装であるTOVAを試してみます。
対応しているモデルは Llama と Mistral とのことなので、今回は、elyza/ELYZA-japanese-Llama-2-7b-instruct を使用してみます。

使用するPCはドスパラさんの「GALLERIA UL9C-R49」。スペックは
・CPU: Intel® Core™ i9-13900HX Processor
・Mem: 64 GB
・GPU: NVIDIA® GeForce RTX™ 4090 Laptop GPU(16GB)
・GPU: NVIDIA® GeForce RTX™ 4090 (24GB)
・OS: Ubuntu22.04 on WSL2(Windows 11)
です。


1. 準備

venvを構築します

python3 -m venv tova
cd $_
source bin/activate

TOVAのリポジトリをクローンします。

git clone https://github.com/schwartz-lab-NLP/TOVA.git

続いて、pip install。

pip install torch transformers==4.36.2 sentencepiece accelerate

pip list

$ pip list
Package                  Version
------------------------ ----------
accelerate               0.26.1
certifi                  2023.11.17
charset-normalizer       3.3.2
filelock                 3.13.1
fsspec                   2023.12.2
huggingface-hub          0.20.2
idna                     3.6
Jinja2                   3.1.3
MarkupSafe               2.1.3
mpmath                   1.3.0
networkx                 3.2.1
numpy                    1.26.3
nvidia-cublas-cu12       12.1.3.1
nvidia-cuda-cupti-cu12   12.1.105
nvidia-cuda-nvrtc-cu12   12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12        8.9.2.26
nvidia-cufft-cu12        11.0.2.54
nvidia-curand-cu12       10.3.2.106
nvidia-cusolver-cu12     11.4.5.107
nvidia-cusparse-cu12     12.1.0.106
nvidia-nccl-cu12         2.18.1
nvidia-nvjitlink-cu12    12.3.101
nvidia-nvtx-cu12         12.1.105
packaging                23.2
pip                      22.0.2
psutil                   5.9.7
PyYAML                   6.0.1
regex                    2023.12.25
requests                 2.31.0
safetensors              0.4.1
sentencepiece            0.1.99
setuptools               59.6.0
sympy                    1.12
tokenizers               0.15.0
torch                    2.1.2
tqdm                     4.66.1
transformers             4.36.2
triton                   2.1.0
typing_extensions        4.9.0
urllib3                  2.1.0

2. コード

流し込むコードはこちら。TOVAを使用するための箇所を見てみましょう。

(1) from TOVA

TOVAのキャッシュを作成するために追加します。

import sys
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
from TOVA import TOVACache, enable_tova_caching
from typing import List, Dict
import time

ここは特にかわらず

# argv
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default=None)
parser.add_argument("--no-instruct", action='store_true')
parser.add_argument("--no-use-system-prompt", action='store_true')

args = parser.parse_args(sys.argv[1:])

model_id = args.model_path
if model_id == None:
    exit

is_instruct = not args.no_instruct
use_system_prompt = not args.no_use_system_prompt

# トークナイザーとモデルの準備
tokenizer = AutoTokenizer.from_pretrained(
    model_id,
    trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    device_map="auto",
    #device_map="cuda",
    #low_cpu_mem_usage=True,
    trust_remote_code=True
)

(2) enable_tova_caching関数とTOVACache関数の追加

enable_tova_chaing関数に対してmodelを引数で呼び出します。
また、multi_state_sizeの大きさのキャッシュを作成(TOVACache)し、変数cacheに格納しておきます。

# use TOVA
enable_tova_caching(model)
multi_state_size = 512
cache = TOVACache(multi_state_size)

model.generateメソッドに引き渡すpast_key_valuesに上記で作成した cache を指定します。

streamer = TextStreamer(
    tokenizer,
    skip_prompt=True,
    skip_special_tokens=True
)

DEFAULT_SYSTEM_PROMPT = "あなたは誠実で優秀な日本人のアシスタントです。"

# generation params
max_new_tokens = 256
generation_params = {
    "do_sample": True,
    "temperature": 0.8,
    "top_p": 0.95,
    "top_k": 40,
    "max_new_tokens": max_new_tokens,
    "repetition_penalty": 1.1,
    "past_key_values": cache,
}

以降は変わらず。

def q(
    user_query: str,
    history: List[Dict[str, str]]=None
):
    start = time.process_time()
    # messages
    messages = ""
    if is_instruct:
        if use_system_prompt:
            messages = [
                {"role": "system", "content": DEFAULT_SYSTEM_PROMPT},
            ]
        user_messages = [
            {"role": "user", "content": user_query}
        ]
    else:
        user_messages = user_query
    if history:
        user_messages = history + user_messages
    messages += user_messages
    # generation prompts
    if is_instruct:
        prompt = tokenizer.apply_chat_template(
            conversation=messages,
            add_generation_prompt=True,
            tokenize=False
        )
    else:
        prompt = messages
    print("--- prompt")
    print(prompt)
    print("--- output")
    #
    input_ids = tokenizer.encode(
        prompt,
        add_special_tokens=False,
        return_tensors="pt"
    )
    # 推論
    output_ids = model.generate(
        input_ids.to(model.device),
        streamer=streamer,
        **generation_params
    )
    output = tokenizer.decode(
        output_ids[0][input_ids.size(1) :],
        skip_special_tokens=True
    )
    if is_instruct:
        user_messages.append(
            {"role": "assistant", "content": output}
        )
    else:
        user_messages += output
    end = time.process_time()
    ##
    input_tokens = len(input_ids[0])
    output_tokens = len(output_ids[0][input_ids.size(1) :])
    total_time = end - start
    tps = output_tokens / total_time
    print(f"prompt tokens = {input_tokens:.7g}")
    print(f"output tokens = {output_tokens:.7g} ({tps:f} [tps])")
    print(f"   total time = {total_time:f} [s]")
    return user_messages

3.試してみる

上記 pythonコードを query.pyとして保存して、以下のように実行します。

python -i ./query.py --model-path elyza/ELYZA-japanese-Llama-2-7b-instruct

聞いてみましょう。

history = q("ドラえもんとはなにか")

これらを組み合わせることで、一早く解きます。そして、何度も繰り返し解いていきましょう。この記事はコラム「パズル・エデュケー ション」の第2回となります。次回は、「数学的思考力を養う ~スキャットル・パズルを通じて~」をお届けします。パズルを解く楽しさを知っ
た上で、それぞれのパズルをクリックしてご覧ください。「ピーク・ア・ボウル」「スカイプ・ザ・サーカス」
【4】 残った1つをマークします。

③ 6つのマークの中から2つのマークを選び、それらを交互に
prompt tokens = 58
output tokens = 256 (17.940468 [tps])
total time = 14.269416 [s]

elyza/ELYZA-japanese-Llama-2-7b-instruct by TOVA

・・・うん? なにかお気に召さないことでもあったのか。

history = q("続きを教えてください", history)

以下が第3回目の記事です。

第3回目 「スキャットル・パズル」で考える 〜数学的思考力を養う〜

今回は、10×10のグリッドに数字を配置し、その数字の合計が指定された数字になるように指示を出します。例えば、グリッドの上に1, 2, 3, ..., 9の数字が配置されていた場合、下記のように指示を出すことができます。

「4, 5, 6, 8」

これは、グリッドの左上隅に1が配置され、次に右上隅に2が配置され、そして次に右下隅に3が配置され、最後に左下隅に9が配置される ように指

elyza/ELYZA-japanese-Llama-2-7b-instruct by TOVA

うーん???

history = q("続きを教えてください", history)

以下が続きの記事です。

第4回目 「スキャットル・パズル」で考える 〜数学的思考力を養う〜

今回は、前回と同様にグリッドに数字を配置するというミッションを与えます。しかし、今回は先程とは反対の動きをします。つまり、 グリッドの上に配置した数字の合計が、指定された数字になるように配置を変更します。例えば、グリッドの上に1, 2, 3, ..., 9の数字が配置されていた場合、次のように配置を変更することができます。

「3, 4, 6, 8」

その他にも様々な方法で配置を変更する

elyza/ELYZA-japanese-Llama-2-7b-instruct by TOVA

「続き」という言葉は解釈できているようだが・・・。

history = q("続きを教えてください", history)

以下が続きの記事です。

第5回目 「スキャットル・パズル」で考える 〜数学的思考力を養う〜

今回は、前回と同様にグリッドに数字を配置するというミッションを与えます。しかし、今回は前回とは反対の動きをします。つまり、 グリッドの上に配置した数字の合計が、指定された数字になるように配置を逆に変更します。例えば、グリッドの上に1, 2, 3, ..., 9の数字が配置されていた場合、次のように配置を変更することができます。

「9, 5, 2, 1」

その他にも様々な方法で配置を

elyza/ELYZA-japanese-Llama-2-7b-instruct by TOVA

ふむ・・・。

3. まとめ

推論結果

TOVAを試してみましたが、明後日の方向の推論結果が出ています。使い方が不味いのかなんなのか。

ちなみに、instructではない elyza/ELYZA-japanese-Llama-2-7b でも似たような結果でした。

サムネイル

サムネイルの画像は、ChatGPT4に「TOVAと聞いて想像する図を描いてください。」として表示された画像です。

こちらが「TOVA」という概念を抽象的に表現した画像です。革新性、技術、進歩を感じさせるデザインで、抽象的な幾何学形状が複雑なネットワークで繋がっていることを示しています。未来的な色彩で、青、銀、ネオンのヒントが含まれています。全体的に動きと進歩を示唆するダイナミックな構成で、鋭角と滑らかな曲線が混ざり合って、視覚的に印象的で思考を促す作品になっています。

ChatGPT4

この記事が気に入ったらサポートをしてみませんか?