見出し画像

LangChainでストリーミングを有効にしつつ、会話やRAGのトークン消費数を計測する方法

はじめに

こんにちは、@_mkazutakaと申します。今日は、LangChainでストリーミングを有効にしつつ、会話やRAGのトークン消費数を計測する方法について紹介します。

LangChainを使用するときのトークン消費量は、以下のドキュメントに記載されているように `get_openai_callback` 関数を利用すれば簡単に取得できます。しかしこれには注意点があり、この関数はストリーミングを有効にした場合には機能しません(Issueでの報告もあります)。

原因は、シンプルで get_openai_callback で使って取得できるトークン消費量の値は、openaiからのレスポンスの値を使っているからです。ストリーミングを有効にした場合、openaiからのレスポンスにトークン消費量の値がが含まれないので、求めている動作にならなくなります。以下が参照コードです。

# https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/callbacks/openai_info.py#L149-L156
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        """Collect token usage."""
				## StreamingをEnableにした場合ここに引っかかる
        if response.llm_output is None:
            return None
        self.successful_requests += 1
        if "token_usage" not in response.llm_output:
            return None
        token_usage = response.llm_output["token_usage"]

なので諦めましょうといいたいのですが、クライアントの要望等でトークン消費数を計測してほしいみたいなケースは多々あります。ということで改めてですが、LangChainでストリーミングを有効にしつつ、会話やRAGのトークン消費数を計測する方法について紹介します。

方針

openaiからのレスポンスに含まれない以上、自分たちでトークン数を計測する必要があります。トークン数の計算は、openaiのexamplesにあるHow to Count tokens with tiktoken にて紹介されています。方針としては、これらを実装した独自のCallbackHandlerを実装して、トークン数を計測していきます。

実装

最終的には、以下のような形になりました。

最初に、examplesを少し簡略化しかつLangChainのMessageの型に合わせた関数 num_tokens_from_messages を作成し、この関数を用いてトークン数を計測します。

Tokenの計算は、Prompt(入力)とCompletion(出力)の両方に対して行わないと行けないので、まず on_chat_model_start 内でPromptのトークン数を計算しています。Completionの計算は、on_llm_new_token で取得できるtokenの値を保存しておき、on_llm_end もしくは on_chain_end 内で計測します。
on_chain_end も用意している理由としては、QARetrieverを使った際は、on_llm_end が呼ばれなかったからです。このあたりのどのタイミングでどのメソッドが呼ばれるかはドキュメントに記載されていない(はず)なので、Try&Errorで頑張る必要があります。

import tiktoken

from typing import Any, Dict, List
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import BaseMessage
from langchain.schema.output import LLMResult


def num_tokens_from_messages(messages: List[BaseMessage], model="gpt-3.5-turbo-0613"):
    encoding = tiktoken.encoding_for_model(model)
    tokens_per_message = 3
    num_tokens = 0
    for message in messages:
        num_tokens += tokens_per_message
        num_tokens += len(encoding.encode(message.type))
        num_tokens += len(encoding.encode(message.content))
    num_tokens += 3
    return num_tokens


class CostCalculateCallbackHandler(BaseCallbackHandler):
    token = ""
    total_tokens: int = 0
    prompt_tokens: int = 0
    completion_tokens: int = 0
    encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")

    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        self.token += token

    def on_llm_end(
            self,
            response: LLMResult,
            **kwargs: Any,
    ) -> Any:
        self.completion_tokens = len(self.encoding.encode(self.token))
        self.total_tokens = self.completion_tokens + self.prompt_tokens

    def on_chain_end(
            self,
            outputs: Dict[str, Any],
            **kwargs: Any,
    ) -> Any:
        self.completion_tokens = len(self.encoding.encode(self.token))
        self.total_tokens = self.completion_tokens + self.prompt_tokens

    def on_chat_model_start(
            self,
            serialized: Dict[str, Any],
            messages: List[List[BaseMessage]],
            **kwargs: Any,
    ) -> Any:
        for m in messages:
            self.prompt_tokens = num_tokens_from_messages(m)

テスト

get_openai_callback を使った結果と、上記のCallbackを使った結果を比較してトークンの計測があっているかを確認しています。手元で動かし、正しく動作するのを確認しましたので、おそらくあっているかと思います。

def test_chat() -> None:
    chat = ChatOpenAI(
        temperature=0,
    )
    messages = [
        SystemMessage(content="You are a helpful assistant that translates English to French."),
        HumanMessage(content="Translate this sentence from English to French. I love programming."),
        AIMessage(content="J'adore la programmation."),
        HumanMessage(content="Translate this sentence from English to French. Hello World")
    ]
    with get_openai_callback() as cb:
        result = chat.predict_messages(messages)
        want = result.content
        want_total_tokens = cb.total_tokens
        want_prompt_tokens = cb.prompt_tokens
        want_completion_tokens = cb.completion_tokens
    assert want_total_tokens != 0

    handler = CostCalculateCallbackHandler()
    chat = ChatOpenAI(
        temperature=0,
        streaming=True,
        callbacks=[handler]
    )
    result = chat.predict_messages(messages)
    got = result.content

    assert want == got
    assert want_prompt_tokens == handler.prompt_tokens
    assert want_completion_tokens == handler.completion_tokens
    assert want_total_tokens == handler.total_tokens


def test_qa_retriever() -> None:
    loader = TextLoader("docs/state_of_the_union.txt")
    documents = loader.load()
    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
    texts = text_splitter.split_documents(documents)
    embeddings = OpenAIEmbeddings()
    docsearch = FAISS.from_documents(texts, embeddings)

    with get_openai_callback() as cb:
        qa_chain = load_qa_chain(
            llm=ChatOpenAI(
                temperature=0,
            ),
            chain_type="stuff",
        )
        qa = RetrievalQA(combine_documents_chain=qa_chain, retriever=docsearch.as_retriever())
        want = qa.run("What did the president say about Ketanji Brown Jackson?")
        want_total_tokens = cb.total_tokens
        want_prompt_tokens = cb.prompt_tokens
        want_completion_tokens = cb.completion_tokens
    assert want_total_tokens != 0

    handler = CostCalculateCallbackHandler()
    qa_chain = load_qa_chain(
        llm=ChatOpenAI(
            temperature=0,
            streaming=True,
            callbacks=[handler],
        ),
        chain_type="stuff",
        callbacks=[handler],
    )
    qa = RetrievalQA(
        combine_documents_chain=qa_chain,
        retriever=docsearch.as_retriever(),
    )
    got = qa.run("What did the president say about Ketanji Brown Jackson?")

    assert want == got
    assert want_prompt_tokens == handler.prompt_tokens
    assert want_completion_tokens == handler.completion_tokens
    assert want_total_tokens == handler.total_tokens

軽くハマったところは load_qa_chain を使う際に、chainとLLMの両方にhandlerを渡さないといけなかった部分です。これをしないと記憶上、on_chain_end on_llm_end も呼ばれないようになります。

    qa_chain = load_qa_chain(
        llm=ChatOpenAI(
            temperature=0,
            streaming=True,
            callbacks=[handler],
        ),
        chain_type="stuff",
        callbacks=[handler],
    )

注意点

streamingを有効にしてない場合、on_llm_new_token が呼ばれないため上記のHandlerは正しく動作しないので気をつけてください。

まとめ

LangChainでストリーミングを有効にしつつ、会話やRAGのトークン消費数を計測する方法に付いて紹介しました。
AgentやToolを使って計測したい場合、これ以上の実装が必要かもしません。

いいなと思ったら応援しよう!