見出し画像

【ローカルLLM】Gradio+CTranslate2で日本語LLMのチャットUIをつくる

夏になって立て続けに日本語LLMが公開されたので、遅ればせながらいくつか試している。

生成AIのColabでおなじみのcamenduruさんがtext-generation-webUIの日本語LLM用Colabをアップしていたので、使わせてもらっている。

ただ、軽い量子化モデルも試したいのと、自分用のシンプルなチャットUIがあったら便利かも、という思いつきで自作してみることにした。

今回は、ウェブUI用のPythonライブラリGradio+言語モデルの量子化・推論用のライブラリCTranslate2を使って、「line-corp-japanese-large-lm-3.6b」をウェブUIで動かしてみる素人DIY。

Gradioのひな形を探す

当初Gradioの公式ガイドに掲載されているチャットUIのサンプルコードを参考にしようとした。が、ここはOpenAIなどのAPIを利用したコードが中心で、そのまま流用できそうなものは見当たらず。

そこでHuggingFace Spaceを見ていると、今年5月に「rinna/japanese-gpt-neox-3.6b-instruction-sft」が公開されたときのGradioデモ(すでに稼働は停止)のファイルが残っていた。同じ日本語LLMだしちょうどいいと思い、これをひな形に使わせてもらうことにした。

CTranslate2で量子化

日本語LLMは3.6B~7Bパラメータくらい(Weblabが10Bで最大)でそこまで大きくない。ただローカルで使うことを考えると、やはりサイズを圧縮したいし、なおかつCPUでも実行できると嬉しい。なので、量子化モデルを使いたいモチベーションがある。

日本語LLMはGPT-NeoX系のモデルが中心で、GGMLで量子化できるものが多い。GGMLモデルをPythonで使う場合、llama-cpp-pythonまたはC Transformersといったライブラリを利用できる。ただ、前者は現時点でLlama系のモデルしか使えなさそうで、後者はGPT-NeoX系モデルだとGPUが使えないっぽい(?)。

そのため、今回はGGMLとは別の、CTranslate2というライブラリで量子化・推論を試してみることにした。

こちらの記事を参考にさせて頂き、CTranslate2で「line-corp-japanese-large-lm-3.6b」を8bit量子化(モデル名を入れ替えて実行しただけ)。

# パッケージのインストール
!pip install ctranslate2 
!pip sentencepiece transformers

# CTranslate2フォーマットに変換
!ct2-transformers-converter \
    --model line-corporation/japanese-large-lm-3.6b-instruction-sft \
    --quantization int8 \
    --output_dir ./line-lm-sft

これで7GB強あったモデルサイズが半分に圧縮された。8bitならモデルの質はあまり低下しないはず。

Google ColabでチャットUIを試す

次に、量子化した「line-corp-japanese-large-lm-3.6b」を使えるように、先ほどのひな形コードを修正する。

「inference_func」関数でTransformersによる推論を使っていた部分をCTranslate2を使うように修正し、不要なコードを削除したくらい。正直よく分かってないが、とりあえず動いたのでよしとする。

単純にChatbotコンポーネントを使うだけならもっとシンプルに済みそうだが、チャット履歴をコンテキストとして保持させるために少し煩雑化してるっぽい。

せっかくなので、Google Colab向けのコードを以下に掲載する。
※ColabのCPUだとまともに動かないので、ランタイムをGPUに切り替えて実行する必要あり。

# パッケージのインストール
!pip install ctranslate2 sentencepiece transformers gradio

# line-corp-japanese-large-lm-3.6b 量子化モデルのダウンロード
# (Colab用に変換済みモデルをHuggingFaceにアップしたもの)
!git clone https://huggingface.co/TFMC/line-corp-japanese-large-lm-3.6b-instruction-sft-ct2

# ウェブUIの起動
import os
import itertools

import torch
from transformers import AutoTokenizer
import ctranslate2
import gradio as gr

device = "cuda" if torch.cuda.is_available() else "cpu"
generator = ctranslate2.Generator("./line-corp-japanese-large-lm-3.6b-instruction-sft-ct2", device=device)
tokenizer = AutoTokenizer.from_pretrained(
    "line-corporation/japanese-large-lm-3.6b-instruction-sft", use_fast=False)

def inference_func(prompt, max_length=128, sampling_temperature=0.7):
    tokens = tokenizer.convert_ids_to_tokens(
        tokenizer.encode(prompt, add_special_tokens=False)
    )
    results = generator.generate_batch(
        [tokens],
        max_length=max_length,
        sampling_topk=10,
        sampling_temperature=sampling_temperature,
        include_prompt_in_result=False,
    )
    output = tokenizer.decode(results[0].sequences_ids[0])
    return output

def make_prompt(message, chat_history, max_context_size: int = 10):
    contexts = chat_history + [[message, ""]]
    contexts = list(itertools.chain.from_iterable(contexts))
    if max_context_size > 0:
        context_size = max_context_size - 1
    else:
        context_size = 100000
    contexts = contexts[-context_size:]
    prompt = []
    for idx, context in enumerate(reversed(contexts)):
        if idx % 2 == 0:
            prompt = [f"システム: {context}"] + prompt
        else:
            prompt = [f"ユーザー: {context}"] + prompt
    prompt = "\n".join(prompt)
    return prompt


def interact_func(message, chat_history, max_context_size, max_length, sampling_temperature):
    prompt = make_prompt(message, chat_history, max_context_size)
    print(f"prompt: {prompt}")
    generated = inference_func(prompt, max_length, sampling_temperature)
    print(f"generated: {generated}")
    chat_history.append((message, generated))
    return "", chat_history


with gr.Blocks() as demo:
    with gr.Accordion("Configs", open=False):
        # max_context_size = the number of turns * 2
        max_context_size = gr.Number(value=10, label="max_context_size", precision=0)
        max_length = gr.Number(value=128, label="max_length", precision=0)
        sampling_temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.1, label="temperature")
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.Button("Clear")
    msg.submit(
        interact_func,
        [msg, chatbot, max_context_size, max_length, sampling_temperature],
        [msg, chatbot],
    )
    clear.click(lambda: None, None, chatbot, queue=False)

if __name__ == "__main__":
    demo.launch(debug=True, share=True)

実行すると、モデルのダウンロードを含め3-5分くらいでGradioのリンクURL(https://*********.gradio.live)が表示される。これをクリックするとウェブUIが起動できる。

なお「line-corp-japanese-large-lm-3.6b」などのinstruction tuningした日本語LLMは、あくまで指示応答モデルという感じ。チャットAIの感覚で挨拶とか雑談をふると変な応答が返ってきやすい。日本語はさすがに流暢でうれしい。


参考

GradioによるChatbot構築の公式ガイドは下記(これからちゃんと読む…)。

CTranslate2による基本的な量子化・推論は下記記事が分かりやすい。

<追記>Llama系の言語モデルを使う場合は(「CTranslate2」でも対応しているが)「llama-cpp-python」を利用することができる。