見出し画像

HuggingFace Transformers の チャットモデルテンプレート を試す

「HuggingFace Transformers」の新機能「チャットモデルテンプレート」を試したので、まとめました。

1. チャットモデルテンプレート

チャットモデルのメッセージの書式は、モデルによって様々です。間違った書式を使用すると、モデルが混乱し、パフォーマンスが低下してしまうこともあります。「チャットモデルテンプレート」はこの問題の解決します。

2. チャットモデルのメッセージの書式

例として「Llama 2」のメッセージの書式を確認します。

・システムメッセージあり

[INST] <<SYS>>
{システムメッセージ}
<</SYS>>

{ユーザーメッセージ} [/INST]

・システムメッセージなし

[INST] {ユーザーメッセージ} [/INST]


チャットモデルテンプレート以前では、テスト生成のプロンプトにメッセージの書式を直接記述していました。

# プロンプトの準備
prompt = "[INST] Who is the cutest in Madoka Magica? [/INST]"

3. チャットモデルテンプレートの使い方

「チャットモデルテンプレート」の使い方は、次のとおりです。

3-1. メッセージリストからのプロンプトの作成

tokenizer.apply_chat_template()にメッセージリストを渡すことで、チャットモデルテンプレートをもとにプロンプトが作成されます。メッセージは、ロール(system、user、assistantなど)とテキストのペアになります。
今回は tokenize=False を指定し、トークナイズせずテキストのまま返して結果を確認しています。

chat = [
    {"role": "user", "content": "Who is the cutest in Madoka Magica?"},
]
tokenizer.apply_chat_template(chat, tokenize=False)

<s>[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\nWho is the cutest in Madoka Magica? [/INST]

デフォルトのシステムメッセージを無効化したい場合は、tokenizer.use_default_system_prompt = False を指定します。

chat = [
    {"role": "user", "content": "Who is the cutest in Madoka Magica?"},
]
tokenizer.use_default_system_prompt = False
tokenizer.apply_chat_template(chat, tokenize=False)

<s> [INST] Who is the cutest in Madoka Magica? [/INST]

3-2. メッセージリストからのトークン群の作成

テキスト生成部分をチャットモデルテンプレートで記述すると、次のようになります。tokenizer.apply_chat_template() に メッセージリストとreturn_tensors="pt" を指定しています。

# チャットメッセージの準備
chat = [
    {"role": "user", "content": "Who is the cutest in Madoka Magica?"},
]

# 推論の実行
with torch.no_grad():
    tokenizer.use_default_system_prompt = False
    token_ids = tokenizer.apply_chat_template(chat, return_tensors="pt")
    output_ids = model.generate(
        token_ids.to(model.device),
        temperature=0.1, 
        do_sample=True, 
        top_p=0.95, 
        top_k=40, 
        max_new_tokens=256,
    )
output = tokenizer.decode(output_ids[0][token_ids.size(1) :])
print(output)

3-4. チャットモデルテンプレートの確認

チャットモデルテンプレートは tokenizer.default_chat_template で確認できます。

tokenizer.default_chat_template

{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\n' + system_message + '\n<</SYS>>\n\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\n' + content.strip() + '\n<</SYS>>\n\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}

Jinjaテンプレート」で、メッセージの書式を指定しています。これを書き換えることでテンプレートを変更でき、トークナイザーを保存すると、tokenizer_config.json に記述され、HuggingFace Hub経由で配布できます。



この記事が気に入ったらサポートをしてみませんか?