LangChainのストリーミングレスポンスをFastAPIを介してクライアントに返す

こんにちは、Explazaでエンジニアをしています @_mkazutaka です。

LangChainのストリーミングレスポンスをFastAPIを介してクライアントに返す方法を紹介します。

方法

LangChainのAsyncIteratorCallbackHandlerとFastAPIのStreamingResponseを使います。基本的なコードは、こちらを参考にしています。

(事前準備) パッケージを作成する

mkdir langchain-streaming-response
cd langchain-streaming-response
 
poetry init
poetry shell
poetry add langchain fastapi uvicorn openai

(説明) FastAPIのStreamingResponseクラスについて

公式ドキュメントによるとStreamingResponseクラスは以下のように説明されています。ジェネレータ/イテレータを渡すと、イテレートごとの結果がレスポンスにストリームされます。

Takes an async generator or a normal generator/iterator and streams the response body.
訳: 非同期ジェネレーターまたは通常のジェネレーター/イテレーターを受け取り、レスポンス・ボディをストリームします。

https://fastapi.tiangolo.com/advanced/custom-response/?h=streamingresponse#streamingresponse

(説明) LangChainのAsyncIteratorCallbackHandlerクラスについて

内部的には、asyncio.Queueを持ったクラスになります。LangChain内でトークンを受け取るとQueueに受け取ったトークンをenqueueされます。またaiter関数を呼び出すことでQueueをdequeueしてトークンを取り出す事ができます。

# https://github.com/langchain-ai/langchain/blob/a612800ef0ac8ea851cd96f98611d1d668d0e1b6/libs/langchain/langchain/callbacks/streaming_aiter.py

class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
    """Callback handler that returns an async iterator."""

    queue: asyncio.Queue[str]
    done: asyncio.Event

    ...

    def __init__(self) -> None:
        self.queue = asyncio.Queue()
        self.done = asyncio.Event()

    ....

    async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        if token is not None and token != "":
            self.queue.put_nowait(token)

    ....

    async def aiter(self) -> AsyncIterator[str]:
        while not self.queue.empty() or not self.done.is_set():
            ...

ストリーミングレスポンスインスタンスを返す関数を実装する

コードは以下のようになります。StreamingResponseのコンストラクタで_streamを渡すことで、LangChainのストリーミングレスポンスをクライアントに返すことができます。
エラーが起きた際に処理が停止しないように、_stream内ではwrap_done関数でエラーハンドリング周りの処理をwrapした非同期タスクを作り直し、実行しています。

import asyncio
import json
from typing import AsyncIterator, Coroutine
from fastapi.responses import StreamingResponse
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler


def stream(
    fn: Coroutine,
    callback: AsyncIteratorCallbackHandler,
) -> StreamingResponse:
    return StreamingResponse(
        media_type="text/event-stream",
        content=_stream(fn, callback),
    )


async def _stream(
        fn: Coroutine,
        callback: AsyncIteratorCallbackHandler,
) -> AsyncIterator[str]:
    async def wrap_done(wrapped_fn: Coroutine, event: asyncio.Event):
        try:
            await wrapped_fn
        except Exception as e:
            # TODO: Error Handling
            print(f"Caught exception: {e}")
        finally:
            # これがないとエラーが起きた際にeventが終了しない
            event.set()

    task = asyncio.create_task(
        wrap_done(fn, callback.done),
    )

    # Response Data
    # Match the format of the OpenAI response.
    async for token in callback.aiter():
        data = json.dumps(
            {"choices": [{"delta": {"content": token}}]}
        )
        yield f"data: {data}\n\n"

    # Response Finish
    data = json.dumps({"choices": [{"finish_reason": "stop"}]})
    yield f"data: {data}\n\n"

    await task

エンドポイントを実装する

FastAPIの方法に従って、実装しています。

# app/main.py

from typing import List

from typing import List

from fastapi import FastAPI
from langchain.chat_models import ChatOpenAI
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.callbacks.manager import CallbackManager
from pydantic import BaseModel
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage

from app.stream import stream

app = FastAPI()


class ChatMessage(BaseModel):
    role: str
    content: str

    def to_langchain_message(self) -> BaseMessage:
        if self.role == 'system':
            return SystemMessage(content=self.content)
        if self.role == 'user':
            return HumanMessage(content=self.content)
        if self.role == 'assistant':
            return AIMessage(content=self.content)
        raise ValueError('role is not expected, role:', self.role)


class Chat(BaseModel):
    messages: List[ChatMessage]


@app.post("/v1/chat/completions")
async def v1_chat_completions(
    chat: Chat,
):
    async_callback = AsyncIteratorCallbackHandler()

    lc_messages = [m.to_langchain_message() for m in chat.messages]

    llm = ChatOpenAI(
        streaming=True,
        callback_manager=CallbackManager([
            async_callback
        ])
    )

    return stream(llm.apredict_messages(
        messages=lc_messages,
    ), async_callback)

テストする

アプリケーションを起動します

poetry run uvicorn app.main:app

curlコマンドを送信します。無事受信できてそうです。

$ curl -X POST -H "Accept: text/event-stream" http://0.0.0.0:8000/v1/chat/completions -d '{"message": [{"role": "user", "content": "hello"}]}'
data: {"choices": [{"delta": {"content": "Hi"}}]}

data: {"choices": [{"delta": {"content": " there"}}]}

data: {"choices": [{"delta": {"content": "!"}}]}

data: {"choices": [{"delta": {"content": " How"}}]}

data: {"choices": [{"delta": {"content": " can"}}]}

data: {"choices": [{"delta": {"content": " I"}}]}

data: {"choices": [{"delta": {"content": " assist"}}]}

data: {"choices": [{"delta": {"content": " you"}}]}

data: {"choices": [{"delta": {"content": " today"}}]}

data: {"choices": [{"delta": {"content": "?"}}]}

data: {"choices": [{"finish_reason": "stop"}]}

まとめ

AsyncIteratorCallbackHandlerを使うことで比較的簡単に処理を行うことができました。


この記事が気に入ったらサポートをしてみませんか?