見出し画像

Function callingをローカルLLMで動かしてみよう。

この記事では実装が難しいFunction callingをローカルLLMで動かします。今回動かしたサンプルコードの天気取得は100%うまくfunction名を認識しますが、正直なところ、このサンプル以外だとファンクション名を取得できる精度は高くはありません。functionのdescriputionを工夫して書く必要がありそうです。しかし、手段として方法があれば、今後ローカルLLMの性能が上がれば使えるようになると期待しています。

環境

こちらの記事を参考にしてください。今回利用するのはllama-cpp-pythonです。サーバ化しない方法もあるので実際に動かす部分で併記します。

LLMモデル

一番面倒なのがモデルです。Mistral系でしかサンプルのコードはfunction名を正しく取得できませんでした。Mistral系では以下を試しています。
mistral-7b-instruct-v0.1.Q8_0.gguf
japanese-stablelm-instruct-gamma-7b-q8_0.gguf

Q4でもOKです。

llama-cpp-pythonのfunction callingについては以下のissueで記述があり、最後に動くコードへのリンクがあります。

https://github.com/teleprint-me/py.gpt.prompt/blob/main/docs/notebooks/llama_cpp_grammar_api.ipynb

サーバ化されていないコードの例
オリジナルのコードの最後の部分、抽象表現のfunctinからの回答を日本語文で説明するように変えています。ただ、あまり良いとはいえませんね。
このサンプルでは都市の天気を取得するfunctionとbinary_arithmeticという解釈に困るfunctionが定義されています。前者は高い精度で認識出来ます。一方でbinary_arithmeticはローカルLLMではうまく取得してくれません。プロンプトをdescriptionに限りなく近づければ取得できますが、そもそもオリジナルのコードに取得出来たときにfunctionを呼び出すコードが記述されていません。以下のコードは運良くbinary_arithmeticを取得できたときにfunctionを呼び出すための引数を正しく渡せるコードを追加しています。

import json
import os
from pprint import pprint
import requests
from typing import Literal, NotRequired, List, Union
from llama_cpp import ChatCompletionMessage, Llama, LlamaGrammar

MODEL_PATH = "./models/japanese-stablelm-instruct-gamma-7b-q8_0.gguf"

llm = Llama(model_path=MODEL_PATH,   n_gpu_layers=35,n_ctx=2048)
llama_grammar = LlamaGrammar.from_file("./grammars/json.gbnf")

FUNCTIONS = [
    {
        "name": "get_current_weather",
        "description": "指定した場所の現在の天気を取得",
        "parameters": {
            "type": "object",
            "properties": {
                "location": {
                    "type": "string",
                    "description": "都道府県名(例:大阪府)",
                },
                "unit": {"type": "string", "enum": ["metric", "uscs"]},
            },
            "required": ["location"],
        },
    },
    {
        "name": "binary_arithmetic",
        "description": "2 つのオペランドに対して 2 項算術演算を実行します。",
        "parameters": {
            "type": "object",
            "properties": {
                "left_op": {
                    "type": ["integer", "number"],
                    "description": "The left operand.",
                },
                "right_op": {
                    "type": ["integer", "number"],
                    "description": "The right operand.",
                },
                "operator": {
                    "type": "string",
                    "description": "算術演算子。サポートされている演算子は次のとおりです。 '+', '-', '*', '/', '%'.",
                    "enum": ["+", "-", "*", "/", "%"],
                },
            },
            "required": ["left_op", "right_op", "operator"],
        },
    },
]

def get_current_weather(location: str, unit: str = "metric") -> str:
    # Replace spaces with hyphens and commas with underscores for the wttr.in URL
    location = location.replace(" ", "-").replace(",", "_")
    # Determine the unit query parameter
    unit_query = "m" if unit == "metric" else "u"
    # Set the API response formatting
    res_format = "%l+%T+%S+%s+%C+%w+%t"
    # Make a request to the wttr.in service
    response = requests.get(
        f"http://wttr.in/{location}?{unit_query}&format={res_format}"
        )
    # Check if the request was successful
    if response.status_code == 200:
        return response.text
    else:
        return f"Could not get the weather for {location}."

def binary_arithmetic(
        left_op: Union[int, float], right_op: Union[int, float], operator: str
        ) -> Union[int, float]:
    if operator == "+":
        return str(left_op) + "+" +  str(right_op) + "= " + str( left_op + right_op)
    elif operator == "-":
        return str(left_op) + "-" +  str(right_op) + "= " + str(left_op - right_op)
    elif operator == "*":
        return str(left_op) + "*" +  str(right_op) + "= " +  str(left_op * right_op)
    elif operator == "/":
        if right_op == 0:
            raise ValueError("Division by zero is not allowed.")
        return str(left_op) + "/" +  str(right_op) + "= " + (left_op / right_op)
    elif operator == "%":
        return str(left_op) + "%" +  str(right_op) + "= " + (left_op % right_op)
    else:
        raise ValueError(
            f"Unsupported operator '{operator}'. Supported operators are '+', '-', '*', '/', '%'."
        )

function_map = {
    "get_current_weather": get_current_weather,
    "binary_arithmetic": binary_arithmetic,
    }

system_prompt = ChatCompletionMessage(
    role="system",
    content="""My name is Vincent and I am a helpful assistant. I can make function calls to retrieve information such as the current weather in a given location.\n{ "function_call": { "name": "get_current_weather", "arguments": { "location": "New York City, NY" } } }""",
    )

def generate_chat_sequence(
        user_query: str,
        function_def: dict,
        ) -> List[ChatCompletionMessage]:
    messages = [system_prompt]
    user_message = ChatCompletionMessage(role="user", content=user_query)
    function_message = ChatCompletionMessage(
        role="function", content=json.dumps(function_def)
    )
    messages.extend([user_message, function_message])
    return messages

def generate_combined_chat_sequence(
        user_query: str,
        function_list: list,
        ) -> List[ChatCompletionMessage]:
    function_messages = [
        ChatCompletionMessage(role="function", content=json.dumps(func_def))
        for func_def in function_list
    ]
    user_message = ChatCompletionMessage(role="user", content=user_query)
    messages = [system_prompt] + function_messages + [user_message]
    return messages

#messages = generate_chat_sequence("What is the weather like in New York City, New York today?", FUNCTIONS[0]) #original
#messages = generate_chat_sequence("東京都の天気は?", FUNCTIONS)
#user_msg="東京都の天気は?"
#user_msg="What is the weather like in New York City, New York today?"
messages = generate_chat_sequence(user_msg_org, FUNCTIONS)
print("++++++++++ 158")
response = llm.create_chat_completion(messages=messages, grammar=llama_grammar, temperature=0)
assistant_content = response["choices"][0]["message"]["content"]
function_content = json.loads(assistant_content)
function_call = function_content["function_call"]
print(function_call)
if function_call["name"]=="binary_arithmetic":
    callback = function_map[function_call["name"]]
    result = callback(function_call["arguments"]["a"],function_call["arguments"]["b"],"+")
    #result = callback(function_call["arguments"]["location"])
else:    
    callback = None
    for function in FUNCTIONS:
        if function["name"] == function_call["name"]:
            callback = function_map[function_call["name"]]
    result = callback(function_call["arguments"]["location"])
print("++++++++++ 163")
print(result)
function_message = ChatCompletionMessage(role="user", content=result)

messages.append(function_message)

sys_msg= "以下は、タスクを説明する指示と、抽象的な情報の入力の組み合わせです。入力情報を指示に従って適切に満たす応答を書きなさい。"
user_query ="以下の情報を文章で説明しなさい。"
user =result
prompt = sys_msg+"\n\n" + "### 指示: "+"\n" + user_query + "\n\n"  +  "### 入力:" +"\n"+ user + "\n\n"  +  "### 応答:"
output = llm(
        prompt,
        stop=["### 入力","\n\n### 指示"],
        max_tokens=512,
        top_k = 40 ,
        top_p = 0.95 ,
        temperature=0.5,
        repeat_penalty=1.1,
        frequency_penalty  =0.5,
        echo=True,
        )
try:
             ans = ans=output["choices"][0]["text"].split("### 応答:")[1].split("###")[0]
except:
             ans = output["choices"][0]["text"].split("### 応答:")[1]
print("-----------------final ans  ----------------------------------------------------------")
print(ans) 

llama-cpp-pythonのOpenAI互換サーバの問題点

OpenAI互換サーバを立て、上記コードを改修してサーバを呼び出す形式にしたのですが、だめでした。私のやり方が良くないのかもしれませんね。
ChatCompletionMessage
がなく、弾かれます。仕方なく、OpenAI互換は諦め、以下の記事で作成した独自のFastAPIサーバにエンドポイントを追加して対応しています。

サーバ側のコードです。
エンドポイントは4種類定義しています。
@app.post("/generate/")
@app.post("/ChatCompletionMsg/")
@app.post("/create_chat_completion/")
@app.post("/grammar/") →未確認です。

#from llama_cpp import Llama
from llama_cpp import ChatCompletionMessage, Llama, LlamaGrammar
from pprint import pprint
import json
import os
from typing import Literal, NotRequired, List, Union
from fastapi import FastAPI,Form
from fastapi.responses import HTMLResponse
from pydantic import BaseModel

# LLMの準備
"""Load a llama.cpp model from `model_path`.
            model_path: Path to the model.
            seed: Random seed. -1 for random.
            n_ctx: Maximum context size.
            n_batch: Maximum number of prompt tokens to batch together when calling llama_eval.
            n_gpu_layers: Number of layers to offload to GPU (-ngl). If -1, all layers are offloaded.
            main_gpu: Main GPU to use.
            tensor_split: Optional list of floats to split the model across multiple GPUs. If None, the model is not split.
            rope_freq_base: Base frequency for rope sampling.
            rope_freq_scale: Scale factor for rope sampling.
            low_vram: Use low VRAM mode.
            mul_mat_q: if true, use experimental mul_mat_q kernels
            f16_kv: Use half-precision for key/value cache.
            logits_all: Return logits for all tokens, not just the last token.
            vocab_only: Only load the vocabulary no weights.
            use_mmap: Use mmap if possible.
            use_mlock: Force the system to keep the model in RAM.
            embedding: Embedding mode only.
            n_threads: Number of threads to use. If None, the number of threads is automatically determined.
            last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
            lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
            lora_path: Path to a LoRA file to apply to the model.
            numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init)
            verbose: Print verbose output to stderr.
            kwargs: Unused keyword arguments (for additional backwards compatibility).
"""
llm = Llama(model_path="./models/japanese-stablelm-instruct-gamma-7b-q4_K_M.gguf",
               n_gpu_layers=35,
               n_ctx=2048
                 )
llama_grammar = LlamaGrammar.from_file("./grammars/json.gbnf")

app = FastAPI()

class AnswerRequest(BaseModel):
     sys_msg : str
     user_query:str
     user:str
     talk_log_list:list =[[]]
     log_f:bool = False
     log_len :int = 0
     max_token:int = 256
     temperature:float = 0.8
     repeat_penalty:float =  1.1
     top_k:int  = 40
     top_p:float = 0.95
     frequency_penalty:float = 0.0

@app.post("/generate/")
def  genereate(gen_request: AnswerRequest):
    sys_msg         =gen_request.sys_msg
    user_query  =gen_request.user_query
    user                  =gen_request.user
    talk_log_list=gen_request.talk_log_list
    log_f                =gen_request.log_f
    log_len           =gen_request.log_len
    max_token =gen_request.max_token
    top_k              =gen_request.top_k
    top_p              =gen_request.top_p
    get_temperature     =gen_request.temperature
    repeat_penalty         =gen_request.repeat_penalty
    frequency_penalty =gen_request.frequency_penalty
    print("top_k:",top_k,"top_p:",top_p,"get_temperature :",get_temperature ,"repeat_penalty:",repeat_penalty,"frequency_penalty:",frequency_penalty)

    talk_log_list= talk_log_list[0]
     
    prompt = sys_msg+"\n\n" + "### 指示: "+"\n" + user_query + "\n\n"  +  "### 入力:" +"\n"+ user + "\n\n"  +  "### 応答:"
    print("-------------------talk_log_list-----------------------------------------------------")
    print("talk_log_list",talk_log_list)  

    #会話ヒストリ作成。プロンプトに追加する。
    log_len = int(log_len)
    if  log_f==True and log_len >0: # 履歴がTrueでログ数がゼロでなければtalk_log_listを作成
        sys_prompt=prompt.split("### 入力:")[0]
        talk_log_list.append( " \n\n"+ "### 入力:"+ " \n" + user+ " \n" )
        new_prompt=""
        for n in range(len(talk_log_list)):
            new_prompt=new_prompt + talk_log_list[n]
        prompt= sys_prompt + new_prompt+" \n \n"+ "### 応答:"+" \n"
    # 推論の実行
        """Sample a token from the model.
            top_k: The top-k sampling parameter.
            top_p: The top-p sampling parameter.
            temp: The temperature parameter.
            repeat_penalty: The repeat penalty parameter.
        Returns:
            The sampled token.
              # デフォルトパラメータ
               top_k: int = 40,
               top_p: float = 0.95,
               temp: float = 0.80,
               repeat_penalty: float = 1.1,
               frequency_penalty: float = 0.0,
               presence_penalty: float = 0.0,
               tfs_z: float = 1.0,
               mirostat_mode: int = 0,
               mirostat_eta: float = 0.1,
               mirostat_tau: float = 5.0,
               penalize_nl: bool = True,
        """
    print("-----------------prompt---------------------------------------------------------")
    print(prompt)
    output = llm(
        prompt,
        stop=["### 入力","\n\n### 指示"],
        max_tokens=max_token,
        top_k = top_k ,
        top_p = top_p,
        temperature=get_temperature,
        repeat_penalty=repeat_penalty,
        frequency_penalty  =frequency_penalty,
        echo=True,
        )
    print('------------------output["choices"][0]-------------------------------------------------')
    print(output["choices"][0])
    #output の"### 応答:"のあとに、"###"がない場合もあるので、ない場合は最初の"### 応答:"を選択
    try:
             ans = ans=output["choices"][0]["text"].split("### 応答:")[1].split("###")[0]
    except:
             ans = output["choices"][0]["text"].split("### 応答:")[1]
    print("-----------------final ans  ----------------------------------------------------------")
    print(ans)
    if len(talk_log_list)>log_len:
        talk_log_list=talk_log_list[2:] #ヒストリターンが指定回数を超えたら先頭(=一番古い)の会話(入力と応答)を削除
    talk_log_list.append("\n" +"###"+  "応答:"+"\n" + ans .replace("\n" ,""))
    result=200
    return {'message':result, "out":ans,"all_out":output,"prompt":prompt,"talk_log_list":talk_log_list }


class ChatCompletion(BaseModel):
     role: str
     content:str
@app.post("/ChatCompletionMsg/")
def  ChatCompletionMsg(gen_request: ChatCompletion):
    role           = gen_request.role
    content = gen_request.content
    result   = ChatCompletionMessage(role=role, content= content )
    print("ChatCompletionMsg",result)
    return {"result":result}

class Chatgrammar(BaseModel):
     gm_name: str
@app.post("/grammar/")
def  grammar(gen_request: Chatgrammar):
    gm_name               = gen_request.gm_name
    llama_grammar = LlamaGrammar.from_file("./grammars/"+ gm_name)
    return {"llama_grammar ":llama_grammar }

class create_chatRequest(BaseModel):
     messages: list
     temperature: float= 0.5
@app.post("/create_chat_completion/")
def  create_chat_completion(gen_request: create_chatRequest):
    messages            =gen_request.messages
    temperature     =gen_request.temperature
    out = llm.create_chat_completion(messages=messages,  grammar=llama_grammar , temperature=temperature)
    print(out)
    return {"out":out }

     
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)


クライアント側

サーバ化されていないコードを整理して、必要な部分をサーバのAPIを呼び出すコードに変更しています。

import json
import os
from pprint import pprint
import requests
from typing import Literal, NotRequired, List, Union

# FastAPIエンドポイントのURL
url = 'http://127.0.0.1:8000'  # FastAPIサーバーのURLに合わせて変更してください

FUNCTIONS = [
    {
        "name": "get_current_weather",
        "description": "指定した場所の現在の天気を取得",
        "parameters": {
            "type": "object",
            "properties": {
                "location": {
                    "type": "string",
                    "description": "都道府県名(例:大阪府)",
                },
                "unit": {"type": "string", "enum": ["metric", "uscs"]},
            },
            "required": ["location"],
        },
    },
    {
        "name": "binary_arithmetic",
        "description": "2 つのオペランドに対して 2 項算術演算を実行します。",
        "parameters": {
            "type": "object",
            "properties": {
                "left_op": {
                    "type": ["integer", "number"],
                    "description": "The left operand.",
                },
                "right_op": {
                    "type": ["integer", "number"],
                    "description": "The right operand.",
                },
                "operator": {
                    "type": "string",
                    "description": "算術演算子。サポートされている演算子は次のとおりです。 '+', '-', '*', '/', '%'.",
                    "enum": ["+", "-", "*", "/", "%"],
                },
            },
            "required": ["left_op", "right_op", "operator"],
        },
    },
]

 #ChatCompletionMessage def     ChatCompletionMessage(data):
    ChatCompletionMsg = url+"/ChatCompletionMsg/"
    response = requests.post(ChatCompletionMsg , json=data)
    if response.status_code == 200:
        result = response.json()
        output  = result["result"]
    return  output

def get_current_weather(location: str, unit: str = "metric") -> str:
    # Replace spaces with hyphens and commas with underscores for the wttr.in URL
    location = location.replace(" ", "-").replace(",", "_")
    unit_query = "m" if unit == "metric" else "u"
    res_format = "%l+%T+%S+%s+%C+%w+%t"
    response = requests.get(
        f"http://wttr.in/{location}?{unit_query}&format={res_format}"
        )
    if response.status_code == 200:
        return response.text
    else:
        return f"Could not get the weather for {location}."

def binary_arithmetic(
        left_op: Union[int, float], right_op: Union[int, float], operator: str
        ) -> Union[int, float]:
    if operator == "+":
        return str(left_op) + "+" +  str(right_op) + "= " + str( left_op + right_op)
    elif operator == "-":
        return str(left_op) + "-" +  str(right_op) + "= " + str(left_op - right_op)
    elif operator == "*":
        return str(left_op) + "*" +  str(right_op) + "= " +  str(left_op * right_op)
    elif operator == "/":
        if right_op == 0:
            raise ValueError("Division by zero is not allowed.")
        return str(left_op) + "/" +  str(right_op) + "= " + (left_op / right_op)
    else:
        raise ValueError( f"Unsupported operator '{operator}'. Supported operators are '+', '-', '*', '/'.")

def generate_chat_sequence(user_query: str,  function_def: dict,) :
    messages = [system_prompt]
     #ChatCompletionMessage     data = {"role":"user", "content":user_query,}
    user_message  = ChatCompletionMessage(data)
     #ChatCompletionMessage     data = {"role":"function", "content":json.dumps(function_def),}
    function_message   = ChatCompletionMessage(data)
    messages.extend([user_message, function_message])
    return messages

def generate_combined_chat_sequence(user_query: str,  function_list: list, ) :
     #ChatCompletionMessage     data = {"role":"function", "content":json.dumps(func_def)}
    function_message   = ChatCompletionMessage(data)
    function_messages = [func_message     for func_def in function_list]
     #ChatCompletionMessage     data = {"role":"user", "content":user_query,}
    user_message  = ChatCompletionMessage(data)
    messages = [system_prompt] + function_messages + [user_message]
    return messages

#+++++++++++++++++++++++++++ここからスタート+++++++++++++++++++++++++++++++
function_map = {
    "get_current_weather": get_current_weather,
    "binary_arithmetic": binary_arithmetic,
    }
 #ChatCompletionMessage data = {"role":"system",
                "content":"""My name is Vincent and I am a helpful assistant. I can make function calls to retrieve information such as the current weather in a given location.\n { "function_call": { "name": "get_current_weather", "arguments": { "location": "New York City, NY" } } }""",
                }
system_prompt=ChatCompletionMessage(data)

#user_msg= generate_chat_sequence("What is the weather like in New York City, New York today?", FUNCTIONS[0]) #original
#user_msg= generate_chat_sequence("東京都の天気は?", FUNCTIONS)
user_msg="ニューヨークの天気は?"
#user_msg="What is the weather like in New York City, New York today?"
#user_msg="後に続く2 つのオペランドに対して演算子に基づく 2 項算術演算を実行します。2345, 3560,演算子= '-'"
messages = generate_chat_sequence(user_msg, FUNCTIONS)
print("++++++++++ 158")
create_chat_completion= url+"/create_chat_completion/"
data = {"messages":messages,
                " temperature":0.0
                }
response = requests.post(create_chat_completion , json=data)
if response.status_code == 200:
    result = response.json()
    output=result["out"]
    response  = output
assistant_content = response["choices"][0]["message"]["content"]
function_content = json.loads(assistant_content)
function_call = function_content["function_call"]
print(function_call)
if function_call["name"]=="binary_arithmetic":
    callback = function_map[function_call["name"]]
    result = callback(function_call["arguments"]["a"],function_call["arguments"]["b"],"+")
else:    
    callback = None
    for function in FUNCTIONS:
        if function["name"] == function_call["name"]:
            callback = function_map[function_call["name"]]
    result = callback(function_call["arguments"]["location"])
print("++++++++++ 163")
print(result )
 #<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
sys_msg= "以下は、タスクを説明する指示と、抽象的な情報の入力の組み合わせです。入力情報を指示に従って適切に満たす応答を書きなさい。"
user_query ="以下の情報を文章で簡潔に説明しなさい。"
user =result
prompt = sys_msg+"\n\n" + "### 指示: "+"\n" + user_query + "\n\n"  +  "### 入力:" +"\n"+ user + "\n\n"  +  "### 応答:"

genetate = url+"/generate/"
data = {"sys_msg":sys_msg,
                "user_query":user_query,
                "user":user,
                "temperature":0.5,
                "max_token":256,
                }
response = requests.post(genetate  , json=data)
if response.status_code == 200:
    result = response.json()
    ans = result["out"]
print("-----------------final ans  ----------------------------------------------------------")
print(ans)  

まとめ

完全互換では無いですし、精度も上がりませんが、基礎は動くようになりました。ここまでくれば意地ですかね。Langchainを使えば同様の処理はできるので、敢えてこれ以上置きかけるのは控えようとおもいます。ただ、仕組みとしてはうまく出来ているんじゃないかと思いました。