見出し画像

WSL2でChatRWKV (RWKV-5-World-1B5-v2-20231025-ctx4096)を試してみる

RWKV-5-World-1B5-v2-20231025-ctx4096 のモデルを試してみます。
使用するPCは、GALLERIA UL9C-R49(RTX 4090 laptop 16GB)、Windows 11+WSL2です。

事前準備

python3 -m venv rwkv5
cd $_
source bin/activate

続いて、パッケージインストール。

pip install torch pynvml rwkv Ninja gradio

pip listはこんな感じです。

$ pip list
Package                   Version
------------------------- ------------
aiofiles                  23.2.1
aiohttp                   3.8.6
aiosignal                 1.3.1
altair                    5.1.2
annotated-types           0.6.0
anyio                     3.7.1
async-timeout             4.0.3
attrs                     23.1.0
certifi                   2023.7.22
charset-normalizer        3.3.2
click                     8.1.7
colorama                  0.4.6
contourpy                 1.2.0
cycler                    0.12.1
exceptiongroup            1.1.3
fastapi                   0.104.1
ffmpy                     0.3.1
filelock                  3.13.1
fonttools                 4.44.1
frozenlist                1.4.0
fsspec                    2023.10.0
gradio                    4.3.0
gradio_client             0.7.0
h11                       0.14.0
httpcore                  1.0.2
httpx                     0.25.1
huggingface-hub           0.17.3
idna                      3.4
importlib-resources       6.1.1
Jinja2                    3.1.2
jsonschema                4.19.2
jsonschema-specifications 2023.7.1
kiwisolver                1.4.5
linkify-it-py             2.0.2
markdown-it-py            2.2.0
MarkupSafe                2.1.3
matplotlib                3.8.1
mdit-py-plugins           0.3.3
mdurl                     0.1.2
mpmath                    1.3.0
multidict                 6.0.4
networkx                  3.2.1
ninja                     1.11.1.1
numpy                     1.26.2
nvidia-cublas-cu12        12.1.3.1
nvidia-cuda-cupti-cu12    12.1.105
nvidia-cuda-nvrtc-cu12    12.1.105
nvidia-cuda-runtime-cu12  12.1.105
nvidia-cudnn-cu12         8.9.2.26
nvidia-cufft-cu12         11.0.2.54
nvidia-curand-cu12        10.3.2.106
nvidia-cusolver-cu12      11.4.5.107
nvidia-cusparse-cu12      12.1.0.106
nvidia-nccl-cu12          2.18.1
nvidia-nvjitlink-cu12     12.3.52
nvidia-nvtx-cu12          12.1.105
orjson                    3.9.10
packaging                 23.2
pandas                    2.1.3
Pillow                    10.1.0
pip                       22.0.2
pydantic                  2.5.0
pydantic_core             2.14.1
pydub                     0.25.1
Pygments                  2.16.1
pynvml                    11.5.0
pyparsing                 3.1.1
python-dateutil           2.8.2
python-multipart          0.0.6
pytz                      2023.3.post1
PyYAML                    6.0.1
referencing               0.30.2
requests                  2.31.0
rich                      13.6.0
rpds-py                   0.12.0
rwkv                      0.8.20
semantic-version          2.10.0
setuptools                59.6.0
shellingham               1.5.4
six                       1.16.0
sniffio                   1.3.0
starlette                 0.27.0
sympy                     1.12
tokenizers                0.14.1
tomlkit                   0.12.0
toolz                     0.12.0
torch                     2.1.0
tqdm                      4.66.1
triton                    2.1.0
typer                     0.9.0
typing_extensions         4.8.0
tzdata                    2023.3
uc-micro-py               1.0.2
urllib3                   2.1.0
uvicorn                   0.24.0.post1
websockets                11.0.3
yarl                      1.9.2

requirements.txtにgradio==3.28.1 とあったのですが、上手く動かなかったので、最新のgradio (4.3.0) にして、以下で説明するapp.pyを1行修正しています。

ソースの修正と実行 - app.py

こちらのapp.pyをベースに少し修正します。

$ diff -u app.py.orig app.py
--- app.py.orig 2023-11-15 12:25:30.805741073 +0900
+++ app.py      2023-11-15 10:06:31.649338778 +0900
@@ -123,6 +123,6 @@
         clear.click(lambda: None, [], [output])
         data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])

-demo.queue(concurrency_count=1, max_size=10)
+#demo.queue(concurrency_count=1, max_size=10)
+demo.queue()
 demo.launch(share=False)

では、実行しましょう。

python app.py

https://127.0.0.1:7860/ にブラウザでアクセスして試します。

ChatRWKV - RWKV-5 World v2

できました。

CLIでもやってみよう

GUIがアレなわたしとしては、プログラミングしやすいようにしておきたいわけで。

こんな感じで query.pyを作ります。app.pyをベースにしてます。

import os, gc, copy, torch
from datetime import datetime
from huggingface_hub import hf_hub_download
from pynvml import *
nvmlInit()
gpu_h = nvmlDeviceGetHandleByIndex(0)
ctx_limit = 2000
title = "RWKV-5-World-1B5-v2-20231025-ctx4096"

os.environ["RWKV_JIT_ON"] = '1'
os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)

from rwkv.model import RWKV
model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{title}.pth")
model = RWKV(model=model_path, strategy='cuda fp16')
from rwkv.utils import PIPELINE, PIPELINE_ARGS
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")

def generate_prompt(instruction, input=""):
    instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
    input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
    if input:
        return f"""Instruction: {instruction}
Input: {input}
Response:"""
    else:
        return f"""User: hi
Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free
 to ask any question and I will always answer it.
User: {instruction}
Assistant:"""

def evaluate(
    ctx,
    token_count=200,
    temperature=1.0,
    top_p=0.7,
    presencePenalty = 0.1,
    countPenalty = 0.1,
):
    args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
                     alpha_frequency = countPenalty,
                     alpha_presence = presencePenalty,
                     token_ban = [], # ban the generation of some tokens
                     token_stop = [0]) # stop generation whenever you see any token here
    ctx = ctx.strip()
    all_tokens = []
    out_last = 0
    out_str = ''
    occurrence = {}
    state = None
    for i in range(int(token_count)):
        out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
        for n in occurrence:
            out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
        #
        token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
        if token in args.token_stop:
            break
        all_tokens += [token]
        for xxx in occurrence:
            occurrence[xxx] *= 0.996
        if token not in occurrence:
            occurrence[token] = 1
        else:
            occurrence[token] += 1
        #
        tmp = pipeline.decode(all_tokens[out_last:])
        if '\ufffd' not in tmp:
            out_str += tmp
            #yield out_str.strip()
            out_last = i + 1
    #
    gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
    print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
    del out
    del state
    gc.collect()
    torch.cuda.empty_cache()
    #yield out_str.strip()
    return out_str

def query(user_query):
    print(evaluate(generate_prompt(user_query)))

yield部分をコメントアウトせずに、

def query(user_query):
    for out_str evaluate(generate_prompt(user_query)):
        print(out_str)

とすると、FF2っぽくなります。

そんな話は横に置き、

>>> query("ドラえもんの登場人物をJSONで")
vram 17171480576 used 8015347712 free 9156132864
 以下はドラえもんの登場人物をJSONで表示します。
```json
{
  "Alice": "ドラえもんの主人公の妹である『マイマイ』、身長180cm、体重50kg、身体能力を得意としている魔法使い 。小さい頃から人間として育てられ、変身を使用する際はその魔法で自分を変身させることができる。"
}
```
このJSONデータには、ドラえもんの主人公である「Alice」の情報が含まれています。彼女はドラえもんの主人公であり 、身長180cm、体重50kg、身体能力を得意としている魔法使いであり、小さい頃から人間として育てられ、変身を使用す る際はその魔法で
>>>

はい、できました。

メモリの使用量など

メモリの使用量ですが、3.7GBほど。かるい!

RWKV-5-World-1B5-v2-20231025-ctx4096のメモリ使用量

この記事が気に入ったらサポートをしてみませんか?