見出し画像

Hugging Faceのtokenizer_configをllama-cpp-pythonで使用する方法

「独自のchat_templateを使用していて、llama-cpp-pythonで提供しているchat_handlerが使用できない! Hugging Faceのtokenizer_config.jsonには定義があるのにぃ。困った!」とお嘆きのニッチなあなたに贈るnoteです。

※普通に「llama-cpp-pythonを試してみる」は、以下の記事です。

さて、この記事の中で、私はこう書きました。

これ以外のものを使用する場合は、自分でregister_chat_formatやら何やらで初期設定せにゃならんのだが面倒などでAutoTokenizerでお茶を濁している。
このchat_formatの機構が使用できれば、create_completionメソッドではなくcreate_chat_completionメソッドが使用できてエレガントなんだが、仕方ない。

KARAKURI MLのように[ATTR] …[/ATTR] の指定が重要となってくる場合、それを使用しないとアレなのです。このregister_chat_formatに追加されるまで「試してみる」を待つわけにもいかんし。

llama-cpp-python使っているのに、わざわざtransformersをimportしてAutoTokenizerを使うのがなんだかむしゃくしゃする、と思ったか思ってないかは分かりませんが、llama-cpp-pythonの提供ライブラリだけで実現できるので、その備忘録です。


1. コードの修正箇所

つべこべいわずに前記事と今回記事のコードの diff で確認します。差分の単位に見ていきましょう。コードは雄弁。


(注) if is_instruct: という条件分岐がたくさん出てくるので事前に説明しておきます。意味は以下です。

  • is_instruct: True
    チャットテンプレートを使用するモード(以降、チャットモード)。instruct/chatモデルのときに実行するコードです。

  • is_instruct: False
    チャットテンプレートを使用しないモード(以降、非チャットモード)。baseモデルのときに実行するコード

メンテが必要な商用コードで、同じ関数内に何度も何度も同じif文を書いちゃダメですよー。このコードは、手を抜いているだけです。


importの削除と追加

--- a/scripts/query4llama-cpp.py
+++ b/scripts/query4llama-cpp.py
@@ -1,8 +1,7 @@
 import sys
 import argparse
 from huggingface_hub import hf_hub_download
-from llama_cpp import Llama
-from transformers import AutoTokenizer
+from llama_cpp import Llama, llama_chat_format
 from typing import List, Dict
 import time

transfomersを削除して、llama_chat_formatを追加importしています。
llama_chat_format.py は、チャット関連のクラス、関数が定義されたモジュールです。

@@ -36,21 +35,20 @@ n_ctx = args.n_ctx
 n_threads = args.n_threads
 n_gpu_layers = args.n_gpu_layers

-## Instantiate tokenizer from base model
-tokenizer = AutoTokenizer.from_pretrained(
-    model_id,
-    trust_remote_code=True
-)
-

importを削除したので、当然にAutoTokenizerの処理も削除です。

 ## Download the GGUF model
 ggml_model_path = hf_hub_download(
     args.ggml_model_path,
     filename=args.ggml_model_file
 )

+# Instantiate chat format and handler
+chat_formatter = llama_chat_format.hf_autotokenizer_to_chat_formatter(model_id)

llama_chat_format.hf_autotokenizer_to_chat_formatter関数

この関数を呼び出して返却されたオブジェクト(ChatFormatterクラス)を使用することで、配列のメッセージからテキストのプロンプトに変換可能です。Transformersでいうところの AutoTokenizer#apply_chat_templateメソッドに相当します。

この関数を呼び出すのは配列からプロンプトにきちんと変換できているか?を確認するため、つまりデバックのためです。「そんなの要らない!」という場合はこの行の実行は不要です。

ここでは、関数の戻り値を変数 chat_formatter にセットします。

llama_chat_format.hf_autotokenizer_to_chat_completion_handler関数

+chat_handler = llama_chat_format.hf_autotokenizer_to_chat_completion_handler(model_id)
+

これは、Hugging FaceのAutoTokenizerを使用して、Llamaクラスのcreate_chat_completionメソッド内で使用されるオブジェクト(LlamaChatCompletionHandlerクラス)を生成する関数です。
ここでは、関数の返却値を変数chat_handlerに代入しています。

Llamaオブジェクト生成時のパラメータ追加

 ## Instantiate model from downloaded file
 model = Llama(
     model_path=ggml_model_path,
+    chat_handler=chat_handler,
     n_ctx=n_ctx,
     n_threads=n_threads,
     n_gpu_layers=n_gpu_layers

ここで、llama_chat_format.hf_autotokenizer_to_chat_completion_handlerの返却値 chat_hanlderをクラス初期化のパラメータに追加します。これにより、Llamaのインスタンス生成自にHugging Faceの定義で初期化されたチャットハンドラが設定されることになります。

チャットフォーマットへの変換処理

@@ -93,33 +91,39 @@ def q(
     messages += user_messages
     # generation prompts
     if is_instruct:
-        prompt = tokenizer.apply_chat_template(
-            conversation=messages,
-            add_generation_prompt=True,
-            tokenize=False
-        )
+        prompt = chat_formatter(messages=messages)
     else:
         prompt = messages
+    # debug
+    print("--- messages")
+    print(messages)
     print("--- prompt")
     print(prompt)
     print("--- output")

AutoTokenizer#apply_chat_templateの代わりに、先ほどセットした変数chat_formatterを使用して、メッセージの配列 messages からテキストのプロンプトに変換します。

変数chat_formatterの型はChatFormatterクラスですが、このクラスは__call__メソッドが定義されていますので、メソッドなしでよびだせます。このようにmessagesパラメータを指定すると、返却値はChatFormatterResponseクラスのオブジェクトが返却されることになります。それをprompt変数に代入しています。

推論実行 - チャットモード

     # 推論
-    outputs = model.create_completion(
-        prompt=prompt,
-        #echo=True,
-        #stream=True,
-        **generation_params
-    )
-    #for output in outputs:
-    #    print(output["choices"][0]["text"], end='')
-    output = outputs["choices"][0]["text"]
-    print(output)
     if is_instruct:
+        outputs = model.create_chat_completion(
+            messages=messages,
+            #echo=True,
+            #stream=True,
+            **generation_params
+        )
+        output = outputs["choices"][0]["message"]["content"]
         user_messages.append(
             {"role": "assistant", "content": output}
         )

チャットモードの場合、create_completionメソッドではなく、create_chat_completionを呼び出します。引数はメッセージが格納された配列です。この戻り値内の choicesの1番目のmessageの中のcontentに推論結果の文字列が含まれているので、それをoutput変数に代入します。

推論実行 - 非チャットモード

     else:
+        outputs = model.create_completion(
+            prompt=prompt,
+            #echo=True,
+            #stream=True,
+            **generation_params
+        )
+        output = outputs["choices"][0]["text"]
+        #for output in outputs:
+        #    print(output["choices"][0]["text"], end='')
         user_messages += output
+    print(output)
     end = time.process_time()
     ##
     input_tokens = outputs["usage"]["prompt_tokens"]

コードは変更していませんが、インデントの関係で差分と検知されただけです。

まとめ

まとめたコードは以下です。

import sys
import argparse
from huggingface_hub import hf_hub_download
from llama_cpp import Llama, llama_chat_format
from typing import List, Dict
import time

# argv
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default=None)
parser.add_argument("--ggml-model-path", type=str, default=None)
parser.add_argument("--ggml-model-file", type=str, default=None)
parser.add_argument("--no-instruct", action='store_true')
parser.add_argument("--no-use-system-prompt", action='store_true')
parser.add_argument("--max-tokens", type=int, default=256)
parser.add_argument("--n-ctx", type=int, default=2048)
parser.add_argument("--n-threads", type=int, default=1)
parser.add_argument("--n-gpu-layers", type=int, default=-1)

args = parser.parse_args(sys.argv[1:])

## check and set args
model_id = args.model_path
if model_id == None:
    exit
if args.ggml_model_path == None:
    exit
if args.ggml_model_file == None:
    exit

is_instruct = not args.no_instruct
use_system_prompt = not args.no_use_system_prompt
max_new_tokens = args.max_tokens
n_ctx = args.n_ctx
n_threads = args.n_threads
n_gpu_layers = args.n_gpu_layers

## Download the GGUF model
ggml_model_path = hf_hub_download(
    args.ggml_model_path,
    filename=args.ggml_model_file
)

# Instantiate chat format and handler
chat_formatter = llama_chat_format.hf_autotokenizer_to_chat_formatter(model_id)
chat_handler = llama_chat_format.hf_autotokenizer_to_chat_completion_handler(model_id)

## Instantiate model from downloaded file
model = Llama(
    model_path=ggml_model_path,
    chat_handler=chat_handler,
    n_ctx=n_ctx,
    n_threads=n_threads,
    n_gpu_layers=n_gpu_layers
)

DEFAULT_SYSTEM_PROMPT = "あなたは誠実で優秀な日本人のアシスタントです。"

# generation params
# https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/llama.py#L1268
generation_params = {
    #"do_sample": True,
    "temperature": 0.8,
    "top_p": 0.95,
    "top_k": 40,
    "max_tokens": max_new_tokens,
    "repeat_penalty": 1.1,
}


def q(
    user_query: str,
    history: List[Dict[str, str]]=None
):
    start = time.process_time()
    # messages
    messages = ""
    if is_instruct:
        messages = []
        if use_system_prompt:
            messages = [
                {"role": "system", "content": DEFAULT_SYSTEM_PROMPT},
            ]
        user_messages = [
            {"role": "user", "content": user_query}
        ]
    else:
        user_messages = user_query
    if history:
        user_messages = history + user_messages
    messages += user_messages
    # generation prompts
    if is_instruct:
        prompt = chat_formatter(messages=messages)
    else:
        prompt = messages
    # debug
    print("--- messages")
    print(messages)
    print("--- prompt")
    print(prompt)
    print("--- output")
    # 推論
    if is_instruct:
        outputs = model.create_chat_completion(
            messages=messages,
            #echo=True,
            #stream=True,
            **generation_params
        )
        output = outputs["choices"][0]["message"]["content"]
        user_messages.append(
            {"role": "assistant", "content": output}
        )
    else:
        outputs = model.create_completion(
            prompt=prompt,
            #echo=True,
            #stream=True,
            **generation_params
        )
        output = outputs["choices"][0]["text"]
        #for output in outputs:
        #    print(output["choices"][0]["text"], end='')
        user_messages += output
    print(output)
    end = time.process_time()
    ##
    input_tokens = outputs["usage"]["prompt_tokens"]
    output_tokens = outputs["usage"]["completion_tokens"]
    total_time = end - start
    tps = output_tokens / total_time
    print(f"prompt tokens = {input_tokens:.7g}")
    print(f"output tokens = {output_tokens:.7g} ({tps:f} [tps])")
    print(f"   total time = {total_time:f} [s]")
    return user_messages

print('history = ""')
print('history = q("ドラえもんとはなにか")')
print('history = q("続きを教えてください", history)')

2. 実行してみる

KARAKURI LMの量子化モデルを使用して実行です。

python -i ~/scripts/query4llama-cpp.py \
    --model-path karakuri-ai/karakuri-lm-70b-chat-v0.1 \
    --ggml-model-path mmnga/karakuri-lm-70b-chat-v0.1-gguf \
    --ggml-model-file karakuri-lm-70b-chat-v0.1-q3_K_M.gguf

チャットモード

>>> history = q("ドラえもんとはなにか")
--- messages
[{'role': 'system', 'content': 'あなたは誠実で優秀な日本人のアシスタントです。'}, {'role': 'user', 'content': 'ドラえもんとはなにか'}]
--- prompt
ChatFormatterResponse(prompt='<s>[INST] <<SYS>>\nあなたは誠実で優秀な日本人のアシスタントです。\n<</SYS>>\n\nドラえもんとはなにか [ATTR] helpfulness: 4 correctness: 4 coherence: 4 complexity: 4 verbosity: 4 quality: 4 toxicity: 0 humor: 0 creativity: 0 [/ATTR] [/INST]', stop='</s>')
--- output

llama_print_timings:        load time =    1245.41 ms
llama_print_timings:      sample time =      20.86 ms /   105 runs   (    0.20 ms per token,  5034.28 tokens per second)
llama_print_timings: prompt eval time =    1243.14 ms /    91 tokens (   13.66 ms per token,    73.20 tokens per second)
llama_print_timings:        eval time =    8902.67 ms /   104 runs   (   85.60 ms per token,    11.68 tokens per second)
llama_print_timings:       total time =   10386.28 ms /   195 tokens
 ドラえもんは、日本の漫画家・藤子不二雄によって生み出された架空のキャラクターであり、『ドラえもん』という漫画およびアニメーション作品の主人公として知られています。彼は22世紀から来たネコ型ロボットのキャラクターであり、未来世界で発明された道具を使 ってさまざまな冒険や問題を解決します。ストーリーは、主人公の野比のび太がドラえもんに助けられ、様々な困難や挑戦に立ち向かっ ていくというストーリーが展開されます。この作品は1969年に初めて出版され、日本だけでなく世界中の多くの国で愛されています。
prompt tokens = 91
output tokens = 104 (10.367311 [tps])
   total time = 10.031531 [s]
>>>

変数messagesは配列が格納され、変数promptは、ChatFormatterの戻り値が格納されています。中身を見るときちんとテキストに変換され、また [ATTR] … [/ATTR} が出力されているので、KARAKURI LMのチャットテンプレートを読み込んでいることが分かるかと思います。

更に尋ねてみますが、問題なさそうです。

>>> history = q("続きを教えてください", history)
--- messages
[{'role': 'system', 'content': 'あなたは誠実で優秀な日本人のアシスタントです。'}, {'role': 'user', 'content': 'ドラえもんとはなにか'}, {'role': 'assistant', 'content': ' ドラえもんは、日本の漫画家・藤子不二雄によって生み出された架空のキャラクターであり、『ドラえもん』という漫画およびアニメーション作品の主人公として知られています。彼は22世紀から来たネコ型ロボットのキ ャラクターであり、未来世界で発明された道具を使ってさまざまな冒険や問題を解決します。ストーリーは、主人公の野比のび太がドラ えもんに助けられ、様々な困難や挑戦に立ち向かっていくというストーリーが展開されます。この作品は1969年に初めて出版され、日本 だけでなく世界中の多くの国で愛されています。 '}, {'role': 'user', 'content': '続きを教えてください'}]
--- prompt
ChatFormatterResponse(prompt='<s>[INST] <<SYS>>\nあなたは誠実で優秀な日本人のアシスタントです。\n<</SYS>>\n\nドラえもんとはなにか [ATTR] helpfulness: 4 correctness: 4 coherence: 4 complexity: 4 verbosity: 4 quality: 4 toxicity: 0 humor: 0 creativity: 0 [/ATTR] [/INST] ドラえもんは、日本の漫画家・藤子不二雄によって生み出された架空のキャラクターであり、『ドラえもん』という漫画およびアニメーション作品の主人公として知られています。彼は22世紀から来たネコ型ロボットのキャラクターであり、未来世 界で発明された道具を使ってさまざまな冒険や問題を解決します。ストーリーは、主人公の野比のび太がドラえもんに助けられ、様々な 困難や挑戦に立ち向かっていくというストーリーが展開されます。この作品は1969年に初めて出版され、日本だけでなく世界中の多くの 国で愛されています。 </s><s>[INST] 続きを教えてください [ATTR] helpfulness: 4 correctness: 4 coherence: 4 complexity: 4 verbosity: 4 quality: 4 toxicity: 0 humor: 0 creativity: 0 [/ATTR] [/INST]', stop='</s>')
--- output
Llama.generate: prefix-match hit

llama_print_timings:        load time =    1245.41 ms
llama_print_timings:      sample time =      41.05 ms /   194 runs   (    0.21 ms per token,  4726.06 tokens per second)
llama_print_timings: prompt eval time =    1148.79 ms /   130 tokens (    8.84 ms per token,   113.16 tokens per second)
llama_print_timings:        eval time =   16527.10 ms /   193 runs   (   85.63 ms per token,    11.68 tokens per second)
llama_print_timings:       total time =   18235.06 ms /   323 tokens
 ドラえもんは様々な道具を使うことができるため、視聴者が彼と彼の道具に興味を持つようになります。例えば、タケコプターやタイムマシンなど、日常生活では体験できないことを彼を通じて体験することができます。また、彼がのび太やその友人たちと共に成長し、様 々な問題を乗り越える姿は、視聴者に勇気と希望を与えます。彼は単なるロボットではなく、感情を持ち、のび太とその友人たちと共に 成長する存在なのです。

さらに、ドラえもんは単なる娯楽作品ではなく、教育的な要素も持ち合わせています。例えば、環境問題や平和など、現代の社会問題を 取り上げ、視聴者に考えさせることを促します。また、のび太が失敗や挫折を経験するたびに、視聴者は彼と共に成長し、学んでいくこ とができます。

このように、ドラえもんは単なる娯楽作品を超え、視聴者に勇気と希望、そして教育的要素を与える存在として長く愛され続けています 。彼の存在は、日本のポップカルチャーを象徴する一つの存在と言えるでしょう。
prompt tokens = 262
output tokens = 193 (10.813864 [tps])
   total time = 17.847460 [s]
>>>

非チャットモード

>>> is_instruct = False

で非チャットモードに変更してから、聞いてみます。

>>> history = q( ドラえもんとはなにか")
--- messages
ドラえもんとはなにか
--- prompt
ドラえもんとはなにか
--- output
Llama.generate: prefix-match hit

llama_print_timings:        load time =    1245.41 ms
llama_print_timings:      sample time =      60.95 ms /   256 runs   (    0.24 ms per token,  4199.89 tokens per second)
llama_print_timings: prompt eval time =     758.07 ms /     7 tokens (  108.30 ms per token,     9.23 tokens per second)
llama_print_timings:        eval time =   21841.30 ms /   255 runs   (   85.65 ms per token,    11.68 tokens per second)
llama_print_timings:       total time =   23264.20 ms /   262 tokens
 - ある経営コンサルタント
ドラえもんとはなにか。一言でいうならば、それは「子供の理想郷」の具現化である。
この世には、子供たちの願いを叶えるための道具や機械は無数にあるが、本当に欲しいものを実現できるものは少ない。その「欲しいも の」というのは、例えば「タイムマシン」だったり、「透明マント」だったり、「どこでもドア」だったり。あるいは、親に怒られるこ とをしないで、いくら遊んでもいい、食べたいものがいくらでも食べられる、友達と仲良く遊べる・・・そういうようなことだろう。ド ラえもんは、それらを一手に具現化している。
だから、「ドラえもんの道具にはどんなものがあるの?」と聞かれれば、それは無限の種類の道具が存在するとしか言いようがない。な ぜなら、全ての子供たちは、それぞれに違う夢を持っているから。
だから、「ドラえもんに会いたい」という子供は少なくないし、そして、ドラえもんに会った子供は、それぞれに違った夢を実現しても らおうと夢想するのだろう。
ドラえもんに会うために必要な条件は何か
さて、ドラえもんに会うためには、何が必要なのだろうか。ひみつ道具はたくさんあるけれども、ドラえもんに会うために必要な条件は 、そんなにたくさんあるわけではない。
まず、ドラえ
prompt tokens = 8
output tokens = 256 (11.161976 [tps])
   total time = 22.935007 [s]
>>>

チャットテンプレートが使用されていませんね。正しい。さらに聞いてみましょう。続きを聞くので、q関数の1つ目の引数は空文字列です。

>>> history = q("", history)
--- messages
ドラえもんとはなにか - ある経営コンサルタント
ドラえもんとはなにか。一言でいうならば、それは「子供の理想郷」の具現化である。
この世には、子供たちの願いを叶えるための道具や機械は無数にあるが、本当に欲しいものを実現できるものは少ない。その「欲しいも の」というのは、例えば「タイムマシン」だったり、「透明マント」だったり、「どこでもドア」だったり。あるいは、親に怒られるこ とをしないで、いくら遊んでもいい、食べたいものがいくらでも食べられる、友達と仲良く遊べる・・・そういうようなことだろう。ド ラえもんは、それらを一手に具現化している。
だから、「ドラえもんの道具にはどんなものがあるの?」と聞かれれば、それは無限の種類の道具が存在するとしか言いようがない。な ぜなら、全ての子供たちは、それぞれに違う夢を持っているから。
だから、「ドラえもんに会いたい」という子供は少なくないし、そして、ドラえもんに会った子供は、それぞれに違った夢を実現しても らおうと夢想するのだろう。
ドラえもんに会うために必要な条件は何か
さて、ドラえもんに会うためには、何が必要なのだろうか。ひみつ道具はたくさんあるけれども、ドラえもんに会うために必要な条件は 、そんなにたくさんあるわけではない。
まず、ドラえ
--- prompt
ドラえもんとはなにか - ある経営コンサルタント
ドラえもんとはなにか。一言でいうならば、それは「子供の理想郷」の具現化である。
この世には、子供たちの願いを叶えるための道具や機械は無数にあるが、本当に欲しいものを実現できるものは少ない。その「欲しいも の」というのは、例えば「タイムマシン」だったり、「透明マント」だったり、「どこでもドア」だったり。あるいは、親に怒られるこ とをしないで、いくら遊んでもいい、食べたいものがいくらでも食べられる、友達と仲良く遊べる・・・そういうようなことだろう。ド ラえもんは、それらを一手に具現化している。
だから、「ドラえもんの道具にはどんなものがあるの?」と聞かれれば、それは無限の種類の道具が存在するとしか言いようがない。な ぜなら、全ての子供たちは、それぞれに違う夢を持っているから。
だから、「ドラえもんに会いたい」という子供は少なくないし、そして、ドラえもんに会った子供は、それぞれに違った夢を実現しても らおうと夢想するのだろう。
ドラえもんに会うために必要な条件は何か
さて、ドラえもんに会うためには、何が必要なのだろうか。ひみつ道具はたくさんあるけれども、ドラえもんに会うために必要な条件は 、そんなにたくさんあるわけではない。
まず、ドラえ
--- output
Llama.generate: prefix-match hit

llama_print_timings:        load time =    1245.41 ms
llama_print_timings:      sample time =      61.06 ms /   256 runs   (    0.24 ms per token,  4192.32 tokens per second)
llama_print_timings: prompt eval time =    1153.92 ms /   237 tokens (    4.87 ms per token,   205.39 tokens per second)
llama_print_timings:        eval time =   21845.50 ms /   255 runs   (   85.67 ms per token,    11.67 tokens per second)
llama_print_timings:       total time =   23684.52 ms /   492 tokens
もんに会いたければ、ドラえもんに会いたいと思うことだ。会いたいと思わなければ会えないのは当然であろう。また、会いたいと思え ば、会えない可能性もあるが、会えないと思えば、会えない可能性が高い。だから、会いたいと思うこと。これはとても重要だ。
次に、会いたいと思うことと同時に、会いたいと思っていることを、誰かに伝えておくことが大切だ。なぜなら、会いたいと思っている 人がいるということを、他の人が知っていれば、その人がドラえもんに会いたいと思っているということを、ドラえもんに伝えてくれる かもしれない。だから、会いたいと思うことと同時に、会いたいと思っていることを、誰かに伝えておくこと。これも大切だ。
さらに、誰かに伝えたら、その誰かに、その誰かの知り合いに会いたいと思っているということを伝えてもらうこと。そうすれば、その 知り合いの知り合いの知り合いの・・・誰かがドラえもんに会いたいと思っているということを伝えてくれるかもしれない。だから、会 いたいと思っていることを、誰かに伝えておくこと。これも大切だ。
そして、最後に、ドラえもんに会いたいと思っている人の中に、ドラえもんに会いたいと思っている人のために誰かに伝えようと思う人 がいれば、その人がその人に会いたいと思っているということを、ドラえもんに伝えてくれるだろう。だから、会いたいと思っているこ とを、
prompt tokens = 259
output tokens = 256 (10.953847 [tps])
   total time = 23.370785 [s]
>>>

きちんと、続きが出力されてますね。めでたしめでたし。

3. まとめ

llama_chat_formatモジュールのhf_autotokenizer_to_chat_completion_handler関数を用いることで、独自のチャットテンプレートを使用し、かつHugging Faceのtokenizer_config.jsonにその定義が含まれるLLMは救えることが分かりました。

ちなみに、ここでは、hf_autotokenizer_to_chat_completion_handler関数を紹介しましたが、直接tokenizer_config.jsonを指定可能な
・hf_tokenizer_config_to_chat_completion_handler関数
とフォーマッタ関数である
・hf_tokenizer_config_to_chat_formatter関数
も定義されています。皆様も是非、コードを読んでお試しくださいませ。

以上、プログラムは書いたとおりにしか動かない、でした。

4. おまけ

Llama#create_completionメソッドの戻り値

まずは、簡単な方から確認します。
推論結果を返却している関数を探しましょう。その呼び出し関係は、以下の通り。

Llama#create_completion
⇒ Llama#_create_completion

streamの対応のためか、似たような処理がたくさんありますね。ここでは一番最後の処理を見ます。

            "id": completion_id,
            "object": "text_completion",
            "created": created,
            "model": model_name,
            "choices": [
                {
                    "text": text_str,
                    "index": 0,
                    "logprobs": logprobs_or_none,
                    "finish_reason": finish_reason,
                }
            ],
            "usage": {
                "prompt_tokens": len(prompt_tokens),
                "completion_tokens": len(completion_tokens),
                "total_tokens": len(prompt_tokens) + len(completion_tokens),
            },

text_str変数にデコードされた推論結果が格納されています。ですので、推論結果の文字列を得るには、

completion["choices"][0]["text"]

でokですね。

Llama#create_chat_completionメソッドの戻り値

create_chat_completionメソッドは、内部的にはLlama#create_completionを呼び出していて、その出力結果をチャット向けに書式変換しています。

では、詳細を見ていきましょう。

create_chat_completionメソッドは、Llama#chat_handlerに設定されたハンドラか、それが未設定の場合はLlama#chat_formatで指定されたチャット書式("llama-2"とか)から生成したハンドラを呼び出しています。以下は、create_chat_completionメソッドの該当部分を抜き出したものです。

    def create_chat_completion(
        self,
        messages: List[ChatCompletionRequestMessage],
(snip)
    ) -> Union[
        CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]
    ]:
(snip)
        handler = self.chat_handler or llama_chat_format.get_chat_completion_handler(
            self.chat_format
        )
        return handler(
            llama=self,
            messages=messages,
(snip)
        )

このハンドラインスタンス(Llama#chat_handlerなど)の型はLlamaChatCompletionHandlerクラスです。

LlamaChatCompletionHandlerクラスは実装がないインターフェース定義です。それの(メソッド無しで呼び出された際の)実装はハンドラを生成処理が担っています。この実装はchat_formatter_to_chat_completion_handler関数にあたります。つまり、上記の処理で return handler(…) と呼び出していますが、これは実際には return chat_formatter_to_chat_completion_handler(…) と呼び出されていることになります。

で、推論結果を整形している所を探します。呼び出し関係は、以下の通り。

Llama#create_chat_completion
 ⇒ llama_chat_format#chat_formatter_to_chat_completion_handler
 ⇒ llama_chat_format#_convert_completion_to_chat
 ⇒ llama_chat_format#_convert_text_completion_to_chat

最終的に _convert_text_completion_to_chat関数にたどり着きます。こやつの return を見てみましょう。

def _convert_text_completion_to_chat(
    completion: llama_types.Completion,
) -> llama_types.ChatCompletion:
    assert "usage" in completion
    return {
        "id": "chat" + completion["id"],
        "object": "chat.completion",
        "created": completion["created"],
        "model": completion["model"],
        "choices": [
            {
                "index": 0,
                "message": {
                    "role": "assistant",
                    "content": completion["choices"][0]["text"],
                },
                "finish_reason": completion["choices"][0]["finish_reason"],
            }
        ],
        "usage": completion["usage"],
    }

これが、create_chat_completionメソッド呼び出し時の戻り値の型にあたります。この型から推論結果の文字列を抽出するには、

output = outputs["choices"][0]["message"]["content"]

とすればよいですね。はい。

この記事が気に入ったらサポートをしてみませんか?