見出し画像

stabilityai/japanese-stablelm-instruct-alpha-7b-v2にAPIを付けてサーバ化する

stabilityai/japanese-stablelm-instruct-alpha-7b-v2をアプリから容易に使えるようにAPI化しました。とても単純なAPIですが、サーバとして可動させれば、会話のみならず、翻訳や他のタスクでも利用できます。ローカルで変数を保持していないので、複数のユーザーで共用できて便利です。

環境

全記事と同じですが、FastAPIのインストールが必要です。

pip install fastapi

サーバ側コード

import torch
from transformers import LlamaTokenizer, AutoModelForCausalLM
from fastapi import FastAPI,Form
from fastapi.responses import HTMLResponse
from pydantic import BaseModel

tokenizer = LlamaTokenizer.from_pretrained(
    "novelai/nerdstash-tokenizer-v1", additional_special_tokens=["▁▁"],
    legacy=False
)
model = AutoModelForCausalLM.from_pretrained(
    "stabilityai/japanese-stablelm-instruct-alpha-7b-v2",
    trust_remote_code=True,
    torch_dtype=torch.float16,
    variant="fp16",
    )
if torch.cuda.is_available():
    model = model.to("cuda")
 #model  = AutoModelForCausalLM.from_pretrained(
#   "stabilityai/japanese-stablelm-instruct-alpha-7b-v2",
#    trust_remote_code=True,
#    load_in_8bit=True,
#    device_map="auto",  
#    variant="int8",
#    )

model.eval()

app = FastAPI()

# rootエンドポイント
@app.get("/", response_class=HTMLResponse)
async def get_root(request):
    return templates.TemplateResponse("index.html", {"request": request})

class AnswerRequest(BaseModel):
     sys_msg : str
     user_query:str
     user:str
     max_token:int
     temperature:float 

@app.post("/generate/")
def  genereate(gen_request: AnswerRequest):
    sys_msg       =gen_request.sys_msg
    user_query =gen_request.user_query
    user                 =gen_request.user
    max_token =gen_request.max_token
    get_temperature=gen_request.temperature

    user_inputs = {
                               "user_query": user_query,
                               "inputs": user,
                                   }
    prompt = build_prompt(sys_msg ,**user_inputs)
    print("prompt =",prompt )
    input_ids = tokenizer.encode(
                                prompt, 
                                add_special_tokens=False, 
                                return_tensors="pt"
                                )
    
    # パッドトークンIDの設定
    pad_token_id = tokenizer.eos_token_id  # パディングトークンIDをeos_token_idに設定

    tokens = model.generate(
                                input_ids.to(device=model.device),
                                max_new_tokens=max_token,
                                temperature=get_temperature,
                                top_p=0.95,
                                do_sample=True,
                                pad_token_id= pad_token_id,
                                )
    all_out = tokenizer.decode(tokens[0][input_ids.shape[1]:], skip_special_tokens=True).strip()
    print("SYS_OUT=", all_out)
    out=all_out.split("###")[0]
    result=200
    return {'message':result, "out":out,"all_out":all_out}

def build_prompt(sys_msg ,user_query, inputs="", sep="\n\n### "):
    p = sys_msg
    roles = ["指示", "応答"]
    msgs = [": \n" + user_query, ": \n"]
    if inputs:
        roles.insert(1, "入力")
        msgs.insert(1, ": \n" + inputs)
    for role, msg in zip(roles, msgs):
        p += sep + role + msg
    return p

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8005)

例によって前半部分はfp16か8bit量子化のどちらかをコメントアウトします。gradio版をFastAPIに置き換えているだけです。genereate関数をPOSTでアクセスする仕様です。コードの内部は以下の記事を参考にしてください。

クライアント側

import requests
import json

# Infer with prompt without any additional input
data = {"sys_msg" : "以下は、応答の指示と、文脈のある入力の組み合わせです。入力の要求を適切に満たす応答を書きなさい。1回の入力に対して何度も応答してはいけません。",
                "user_query":"女子高校生のめぐを演じるんだ。めぐは賢くて、おちゃめで、少しボーイッシュ、天真爛漫で好奇心旺盛な女子高生だよ。品川区の目黒川の近くで生まれたんだ。いつもタメ口で話すし、自分のことをめぐと言うんだ。質問に応答えるときもあれば、聞かれたことに対して適当な言葉や文章で応答をすることもるね。応答の長さは最大でも30字以内だし、応答は1ターンだけで答えるんだ。めぐはおちゃめなので時々文脈と関係のない応答をするよ。応答は、ちょっと背伸びした感じで、ため口で相手にツッコミを入れるんだ。めぐのよく使う口癖は次のとおりで、よく使う語尾は、だよね、みたいだ、そうなんだ、違うと思うけどね、だれ?、どこ?、",
                "user":"美味しいお店教えてくれる?",
                "max_token":50,
                "temperature":0.9
                }

# FastAPIエンドポイントのURL
url = 'http://0.0.0.0:8005/generate/'  # FastAPIサーバーのURLに合わせて変更してください

# POSTリクエストを送信
response = requests.post(url, json=data)

# レスポンスを表示
if response.status_code == 200:
    result = response.json()
    print("サーバーからの応答message:", result.get("message"))
    print("サーバーからの応答all_out:", result.get("all_out"))
    print("サーバーからの応答out:", result.get("out"))
    
else:
    print("リクエストが失敗しました。ステータスコード:", response.status_code)

POSTリクエスト用のdataを作成し、 requests.postでサーバにリクエストしています。上記はAIキャラ作成用のプロンプトの場合です。
以下は、翻訳タスク実行時のプロンプトをdataに埋め込んでいます。

import requests
import json

# Infer with prompt without any additional input
data = {"sys_msg" : "英語のcentenceを翻訳する文脈が含まれています。与えられた指示に従い、英文を日本語に訳しなさい",
"user_query":"以下の英文を正しい日本語に訳しなさい。",
"user":" The image shows a residential area with tall trees on either side of the road. There are buildings on either side of the road, and a park in the distance. The sky is a light blue color with some clouds.\
The image shows a street with trees on either side of the road. There are buildings in the background with windows and balconies. The sky is blue and there are clouds in the background.\
This image is a cityscape with neon lights and tall buildings in the background. There are cars parked on the street and people walking in the foreground. The overall atmosphere is dark and moody, with red and pink tones in the neon lights..",
"max_token":50,
"temperature":0.9
}

# FastAPIエンドポイントのURL
url = 'http://0.0.0.0:8005/generate/'  # FastAPIサーバーのURLに合わせて変更してください

# POSTリクエストを送信
response = requests.post(url, json=data)

# レスポンスを表示
if response.status_code == 200:
    result = response.json()
    print("サーバーからの応答message:", result.get("message"))
    print("サーバーからの応答out:", result.get("out"))
    
else:
    print("リクエストが失敗しました。ステータスコード:", response.status_code)

このように、クライント側のプロンプトを変えるとサーバのタスクが変わります。