LangChainのストリーミングレスポンスをFastAPIを介してクライアントに返す
こんにちは、Explazaでエンジニアをしています @_mkazutaka です。
LangChainのストリーミングレスポンスをFastAPIを介してクライアントに返す方法を紹介します。
方法
LangChainのAsyncIteratorCallbackHandlerとFastAPIのStreamingResponseを使います。基本的なコードは、こちらを参考にしています。
(事前準備) パッケージを作成する
mkdir langchain-streaming-response
cd langchain-streaming-response
poetry init
poetry shell
poetry add langchain fastapi uvicorn openai
(説明) FastAPIのStreamingResponseクラスについて
公式ドキュメントによるとStreamingResponseクラスは以下のように説明されています。ジェネレータ/イテレータを渡すと、イテレートごとの結果がレスポンスにストリームされます。
Takes an async generator or a normal generator/iterator and streams the response body.
訳: 非同期ジェネレーターまたは通常のジェネレーター/イテレーターを受け取り、レスポンス・ボディをストリームします。
(説明) LangChainのAsyncIteratorCallbackHandlerクラスについて
内部的には、asyncio.Queueを持ったクラスになります。LangChain内でトークンを受け取るとQueueに受け取ったトークンをenqueueされます。またaiter関数を呼び出すことでQueueをdequeueしてトークンを取り出す事ができます。
# https://github.com/langchain-ai/langchain/blob/a612800ef0ac8ea851cd96f98611d1d668d0e1b6/libs/langchain/langchain/callbacks/streaming_aiter.py
class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
"""Callback handler that returns an async iterator."""
queue: asyncio.Queue[str]
done: asyncio.Event
...
def __init__(self) -> None:
self.queue = asyncio.Queue()
self.done = asyncio.Event()
....
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
if token is not None and token != "":
self.queue.put_nowait(token)
....
async def aiter(self) -> AsyncIterator[str]:
while not self.queue.empty() or not self.done.is_set():
...
ストリーミングレスポンスインスタンスを返す関数を実装する
コードは以下のようになります。StreamingResponseのコンストラクタで_streamを渡すことで、LangChainのストリーミングレスポンスをクライアントに返すことができます。
エラーが起きた際に処理が停止しないように、_stream内ではwrap_done関数でエラーハンドリング周りの処理をwrapした非同期タスクを作り直し、実行しています。
import asyncio
import json
from typing import AsyncIterator, Coroutine
from fastapi.responses import StreamingResponse
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
def stream(
fn: Coroutine,
callback: AsyncIteratorCallbackHandler,
) -> StreamingResponse:
return StreamingResponse(
media_type="text/event-stream",
content=_stream(fn, callback),
)
async def _stream(
fn: Coroutine,
callback: AsyncIteratorCallbackHandler,
) -> AsyncIterator[str]:
async def wrap_done(wrapped_fn: Coroutine, event: asyncio.Event):
try:
await wrapped_fn
except Exception as e:
# TODO: Error Handling
print(f"Caught exception: {e}")
finally:
# これがないとエラーが起きた際にeventが終了しない
event.set()
task = asyncio.create_task(
wrap_done(fn, callback.done),
)
# Response Data
# Match the format of the OpenAI response.
async for token in callback.aiter():
data = json.dumps(
{"choices": [{"delta": {"content": token}}]}
)
yield f"data: {data}\n\n"
# Response Finish
data = json.dumps({"choices": [{"finish_reason": "stop"}]})
yield f"data: {data}\n\n"
await task
エンドポイントを実装する
FastAPIの方法に従って、実装しています。
# app/main.py
from typing import List
from typing import List
from fastapi import FastAPI
from langchain.chat_models import ChatOpenAI
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.callbacks.manager import CallbackManager
from pydantic import BaseModel
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from app.stream import stream
app = FastAPI()
class ChatMessage(BaseModel):
role: str
content: str
def to_langchain_message(self) -> BaseMessage:
if self.role == 'system':
return SystemMessage(content=self.content)
if self.role == 'user':
return HumanMessage(content=self.content)
if self.role == 'assistant':
return AIMessage(content=self.content)
raise ValueError('role is not expected, role:', self.role)
class Chat(BaseModel):
messages: List[ChatMessage]
@app.post("/v1/chat/completions")
async def v1_chat_completions(
chat: Chat,
):
async_callback = AsyncIteratorCallbackHandler()
lc_messages = [m.to_langchain_message() for m in chat.messages]
llm = ChatOpenAI(
streaming=True,
callback_manager=CallbackManager([
async_callback
])
)
return stream(llm.apredict_messages(
messages=lc_messages,
), async_callback)
テストする
アプリケーションを起動します
poetry run uvicorn app.main:app
curlコマンドを送信します。無事受信できてそうです。
$ curl -X POST -H "Accept: text/event-stream" http://0.0.0.0:8000/v1/chat/completions -d '{"message": [{"role": "user", "content": "hello"}]}'
data: {"choices": [{"delta": {"content": "Hi"}}]}
data: {"choices": [{"delta": {"content": " there"}}]}
data: {"choices": [{"delta": {"content": "!"}}]}
data: {"choices": [{"delta": {"content": " How"}}]}
data: {"choices": [{"delta": {"content": " can"}}]}
data: {"choices": [{"delta": {"content": " I"}}]}
data: {"choices": [{"delta": {"content": " assist"}}]}
data: {"choices": [{"delta": {"content": " you"}}]}
data: {"choices": [{"delta": {"content": " today"}}]}
data: {"choices": [{"delta": {"content": "?"}}]}
data: {"choices": [{"finish_reason": "stop"}]}
まとめ
AsyncIteratorCallbackHandlerを使うことで比較的簡単に処理を行うことができました。
この記事が気に入ったらサポートをしてみませんか?