見出し画像

FastChat への 新モデルの対応手順

「FastChat」への新モデルの対応手順をまとめました。

【注意】Google Colab Pro/Pro+ の T4のハイメモリで動作確認しています。


1. FastChat

「FastChat」は、LLMベースのチャットボットを提供、学習、評価するためのオープンプラットフォームです。
FastChat の主な機能は次のとおりです。

・LLMの推論、学習、評価。
・Web UI と OpenAI互換のREST APIを備えた分散マルチモデル提供システム。

2. FastChat対応済みのモデル

「FastChat」が公式で対応済みのモデルは、次のとおりです。

FastChat - Model Support

3. FastChatでの推論の実行

FastChatでの推論の実行手順は、次のとおりです。

(1) Colabのノートブックを開き、メニュー「編集 → ノートブックの設定」で「T4」の「ハイメモリ」を選択。

(2) パッケージのインストール。

# パッケージのインストール
!git clone https://github.com/lm-sys/FastChat
%cd FastChat
!pip install -e ".[model_worker,webui]"

(3) CUI版の推論の実行。
今回は、「FastChat」と開発元が同じ「Vicuna」(lmsys/vicuna-7b-v1.5)を使います。「--debug」のデバッグ出力で内部プロンプトを確認できます。

!python -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --debug

(4) 「USER: 」と表示されたらメッセージ入力。
今回は、「日本一高い山は?」と入力しました。

USER: 日本一高い山は?
ASSISTANT: 日本で最も高い山は、立ちはだかる「富士山」です。富士山は、日本の山岳地帯に位置し、標高は1億5千メートルです。富士山は、古代から信仰され、日本の国の神である「大原神」とされ、多くの登山者が訪れます。また、富士山は、日本の自然の美しさを代表する山であり、日本の山岳地帯において最も有名な山の一つです。

デバッグ出力は、次のとおりです。

{
    'conv_template': 'vicuna_v1.1',
    'prompt': "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.  USER: 日本一高い山は? ASSISTANT:",
    'outputs': '日本で最も高い山は、立ちはだかる「富士山」です。富士山は、日本の山岳地帯に位置し、標高は1億5千メートルです。富士山は、古代から信仰され、日本の国の神である「大原神」とされ、多くの登山者が訪れます。また、富士山は、日本の自然の美しさを代表する山であり、日本の山岳地帯において最も有名な山の一つです。',
    'speed (token/s)': 12.66
}

(5) 再度「USER: 」と表示されたらメッセージ入力。
今回は、「日本一高い山は?」と入力しました。


User: その山は何県にある?
Assistant: 富士山は、日本の山岳地帯に位置し、静岡県と山梨県の境にあります。富士山は、日本の最高峰であり、日本の山岳地帯において最も有名な山の一つです。また、富士山は、日本の文化的な遺産としても重要であり、多くの観光客が訪れます。

デバッグ出力は、次のとおりです。


{
    'conv_template': 'vicuna_v1.1',
    'prompt': "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.  USER: 日本一高い山は? ASSISTANT: 日本で最も高い山は、立ちはだかる「富士山」です。富士山は、日本の山岳地帯に位置し、標高は1億5千メートルです。富士山は、古代から信仰され、日本の国の神である「大原神」とされ、多くの登山者が訪れます。また、富士山は、日本の自然の美しさを代表する山であり、日本の山岳地帯において最も有名な山の一つです。</s>
USER: その山は何県にある?  ASSISTANT:",
    'outputs': '富士山は、日本の山岳地帯に位置し、静岡県と山梨県の境にあります。富士山は、日本の最高峰であり、日本の山岳地帯において最も有名な山の一つです。また、富士山は、日本の文化的な遺産としても重要であり、多くの観光客が訪れます。',
    'speed (token/s)': 15.63
}

会話履歴が含まれていることがわかります。

4. 新モデルの動作確認

対応前だと正しい会話テンプレートが設定されてないため、精度が落ちます。今回は、「Japanese StableLM Gamma 7B」を実行します。

(1) 「--debug」付きで推論の実行。

!python -m fastchat.serve.cli --model-path stabilityai/japanese-stablelm-instruct-gamma-7b --debug

(2) 推論を試す。

<|USER|>: 日本一高い山は?
<|ASSISTANT|>: 富士山。<|SYSTEM|># StableLM Tuned (Alpha version)…不要な長文…

デバッグ出力は、次のとおりです。

{
    'conv_template': 'stablelm',
    'prompt': '<|SYSTEM|># StableLM Tuned (Alpha version)\n- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.\n- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.\n- StableLM will refuse to participate in anything that could harm a human.\n<|USER|>日本一高い山は?<|ASSISTANT|>',
    'outputs': '富士山。<|SYSTEM|># StableLM Tuned (Alpha version)\n- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI...不要な長文...',
    'speed (token/s)': 13.4
}

正しい会話テンプレートが設定されてないため、応答が書式が壊れてることがわかります。

5. 新モデルの対応

「FastChat」で「Japanese StableLM Gamma 7B」が動くように設定します。


新モデルの対応手順は、次のとおりです。

5-1. 会話テンプレートの実装

fastchat/conversation.py」を以下のように編集します。

(1) 「SeparatorStyle」に定数「JSLM_GAMMA」を追加。
新モデルのスタイルを表す定数です。

class SeparatorStyle(IntEnum):
    """Separator styles."""

    ADD_COLON_SINGLE = auto()
    ADD_COLON_TWO = auto()
        :
    FALCON_CHAT = auto()
    JSLM_GAMMA = auto()    # ←追加

(2) Conversationクラスのget_prompt()にプロンプト生成を実装。
他のモデルを参考にして、新モデルのプロンプト生成を実装します。

        elif self.sep_style == SeparatorStyle.FALCON_CHAT:
            ret = ""
            if self.system_message:
                ret += system_prompt + self.sep
            for role, message in self.messages:
                if message:
                    ret += role + ": " + message + self.sep
                else:
                    ret += role + ":"

            return ret
        elif self.sep_style == SeparatorStyle.JSLM_GAMMA:  # ←追加
            ret = self.system_message + self.sep
            for role, message in self.messages:
                if message:
                    ret += role + ": \n" + message + self.sep
                else:
                    ret += role + ": \n"
            return ret

(2) Conversationクラス生成の追加。
他のモデルを参考にして、新モデルのConversationクラス生成を実装します。

# conv template for JSLM Gamma tokenizer
# source: https://huggingface.co/stabilityai/japanese-stablelm-base-gamma-7b
register_conv_template(
    Conversation(
        name="jslm_gamma",
        system_message="以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。",
        roles=("指示", "応答"),
        sep_style=SeparatorStyle.JSLM_GAMMA,
        sep="\n\n### ",
        stop_token_ids=[2],
        stop_str="###",
    )
)

5-2. モデルアダプタの実装

fastchat/model/model_adapter.py」 を以下のように編集します。

(1) ModelAdapterアダプタクラスの定義。
他のモデルを参考にして、新モデルのModelAdapterクラス生成を実装します。

class JSLMGammaAdapter(BaseModelAdapter):  # ←追加
    """
    Model adapter for Japanese StableLM Gamma instruct model
    https://huggingface.co/stabilityai/japanese-stablelm-base-gamma-7b
    """
    model_variation = None

    def match(self, model_path: str):
        model_path = model_path.lower()
        if model_path == "japanese-stablelm-instruct-gamma-7b":
            self.model_variation = "gamma"

        return True if self.model_variation else False

    def load_model(self, model_path: str, from_pretrained_kwargs: dict):
        revision = from_pretrained_kwargs.get("revision", "main")

        tokenizer = LlamaTokenizer.from_pretrained(
            "stabilityai/japanese-stablelm-instruct-gamma-7b"
        )
        from_pretrained_kwargs.pop("trust_remote_code", None)

        clm_cls = AutoModelForCausalLM

        model = clm_cls.from_pretrained(
            model_path,
            low_cpu_mem_usage=True,
            trust_remote_code=True,
            **from_pretrained_kwargs
        )
        return model, tokenizer

    def get_default_conv_template(self, model_path:str):
        return get_conv_template("jslm_gamma")

(2) ModelAdapterアダプタクラスの追加。
register_model_adapter()が並んでいる部分に追加します。

# Note: the registration order matters.
# The one registered earlier has a higher matching priority.
register_model_adapter(JSLMGammaAdapter)  # ←追加
register_model_adapter(PeftModelAdapter)
register_model_adapter(VicunaAdapter)
register_model_adapter(AiroborosAdapter)

6. 新モデルの対応の動作確認

新モデルの対応の動作確認結果は、次のとおりです。

指示: 日本一高い山は?
応答: 日本最高峰の山は富士山です。標高は3,776mです。 富士山は、日本のみならず、世界的にも有名な火山です。 山頂からは、天気がよければ、伊豆半島から房総半島まで見ることができます。 毎年たくさんの人が富士山に登りますが、標高3,000mを超えると、登山者の半分は頂上に到達できないと言われています。 富士山は、日本一の高さを誇るだけでなく、日本一美しい山としても知られています。

デバッグ出力は、次のとおりです。

{
    'conv_template': 'jslm_gamma',
    'prompt': '以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。\n\n### 指示: \n日本一高い山は?\n\n### 応答: \n',
    'outputs': '日本最高峰の山は富士山です。標高は3,776mです。 富士山は、日本のみならず、世界的にも有名な火山です。 山頂からは、天気がよければ、伊豆半島から房総半島まで見ることができます。 毎年たくさんの人が富士山に登りますが、標高3,000mを超えると、登山者の半分は頂上に到達できないと言われています。 富士山は、日本一の高さを誇るだけでなく、日本一美しい山としても知られています。',
    'speed (token/s)': 13.66
}

マルチターンの会話も確認してみます。

指示: その山は何県にある?
応答: 静岡県にある。

{
  'conv_template':  'jslm_gamma',
  'prompt': '以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。\n\n### 指示: \n日本一高い山は?\n\n### 応答: \n日本最高峰の山は富士山です。標高は3,776mです。 富士山は、日本のみならず、世界的にも有名な火山です。 山頂からは、天気がよければ、伊豆半島から房総半島まで見ることができます。 毎年たくさんの人が富士山に登りますが、標高3,000mを超えると、登山者の半分は頂上に到達できないと言われています。 富士山は、日本一の高さを誇るだけでなく、日本一美しい山としても知られています。 富士山の周りには、樹木がほとんどなく、山の形が明瞭である。 富士山は、日本の国花である桜を見ることもできる、世界で最も美しい山のひとつとされています。 富士山の溶岩ドームは、日本最高峰の火山である。 富士山の山頂は、標高3,776mで、日本最高峰の山である。 富士山は、世界最高峰の山として、多くの人々を惹きつけています。 富士山の山頂には、標高3,776mの山頂があります。 富士山は、古来より日本の人々にとって神聖な山とされ、多くの文学作品や絵画の題材となっています。 富士山は、世界遺産に登録されています。 富士山は世界で最も美しい山とされています。 富士山は、世界\n\n### 指示: \nその山は何県にある?\n\n### 応答: \n',
  'outputs': '静岡県にある。',
  'speed (token/s)': 6.6
}

今回は、以下のような会話テンプレートで設定しています。

{システムメッセージ}

### 指示:
{メッセージ1}

### 応答:
{レスポンス1}

### 指示:
{メッセージ2}

### 応答:

「StableLM」は、「LLaMa-2」や「Rinna-4B」のような、マルチターン用の指示の書式が明示されてなかったので、仮の会話テンプレートになります。

【おまけ】ELYZA-japanese-Llama-2-7b

ELYZA-japanese-Llama-2-7b」は「LLaMa 2」会話テンプレート準拠かつモデルIDに「LLaMa-2」の記述があるため、自動的に正しい会話テンプレートが選択されました。

!python -m fastchat.serve.cli --model-path elyza/ELYZA-japanese-Llama-2-7b-fast-instruct --debug



この記事が気に入ったらサポートをしてみませんか?