見出し画像

rinna/japanese-gpt-neox-3.6b-instruction-ppoにAPIを追加

ローカルLLMを複数切り替えて使用したい場合に、アプリケーションPCとLLMを動かくPC(サーバ)が別れていると便利です。先の記事でマルチモーダルLLMのAPI化を行いましたが、他のLLMもAPI化し、アプリPCから場面に応じたLLMを使い分けることができるようにしています。今回はrinnaのAPI化です。

rinna環境は以前の以下の記事で作成しています。この環境にFastAPI関連を追加でインストールします。

pip install fastapi
pip install pydantic

rinnaの場合、systemがありませんので構成は単純です。必要なのは初期化とgenerationサービスのみです。LoRAを使うかどうかはクリアアント側では決めることは出来ないので、サーバ側で設定します。

プロンプトの受け渡し

プロンプトは性格付用の部分と会話用の部分に分けて呼び出します。同時に会話ログも変数として受け渡しできれば短期の記憶は維持できるのでログも渡せるようにします。

サーバ側コード

import  torch
from     transformers   import AutoTokenizer, AutoModelForCausalLM
from     peft   import PeftModel, PeftConfig
from pydantic import BaseModel
 
# 初期設定ー>model_name,lora_model_path,LoRA,default_promptの設定
model_name      ="rinna/japanese-gpt-neox-3.6b-instruction-ppo"
lora_model_path="megu_instructon_all-EP1"
LoRA=True
default_prompt  =["ユーザー: あなたの名前は何ですか?","システム: わたしは女子高校生の「めぐ」だよ。"]

#モデルの読み込み
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)   
model = AutoModelForCausalLM.from_pretrained(model_name,
                                             load_in_8bit=True,
                                             torch_dtype=torch.float16,
                                             device_map="auto",)
# pad_token_id を設定
model.config.pad_token_id = tokenizer.eos_token_id
#LoRAモデルの準備
if LoRA:  
        model = PeftModel.from_pretrained(
            model, 
            lora_model_path, #学習済みLoRAのフォルダ
            device_map="auto"
            )
else:
    lora_model_path ="LoRA無効"
    
# ============   FastAPI =================

from fastapi import FastAPI
from fastapi.responses import HTMLResponse

talk_log =[]
histry_log =[]
log_len=10

app = FastAPI()

class AnswerRequest(BaseModel):
     system : str
     user:str
     log_len:int

# rootエンドポイント
@app.get("/", response_class=HTMLResponse)
async def get_root(request):
    return templates.TemplateResponse("index.html", {"request": request})

@app.post("/llm_reset/")
def  reset():
    global      talk_log
    global      histry_log
    talk_log = []
    histry_log=[]
    return {"message": "Reset"}

@app.post("/generate/")
async def  generate(gen_request: AnswerRequest):
            system            = gen_request.system
            user_prompt= gen_request.user
            log_len= gen_request.log_len
            print("user_prompt=",user_prompt)
            print("log_len=",log_len)
            global      talk_log

            if system=="":  
                        system_prompt = default_prompt
            #systemの文字列からプロンプトを組み立て            
            else:    
                        system = system.replace("  ","")                  #スペースを削除
                        system_prompt = system.split("\n")       #改行で1文頃にリスト化
            #会話ヒストリ作成。プロンプトに追加する。
            log_prompt =[]
            log_len = log_len
            if  log_len>0:
                if  len(talk_log )>log_len:
                        talk_log  =  talk_log [1:]          #ヒストリターンが指定回数を超えたら先頭(=一番古い)の会話を削除
                for log_p in talk_log:
                        log_prompt = log_prompt + ["ユーザー: "+log_p[0]]+["システム: "+log_p[1]] 
           #プロンプトの準備。
            if user_prompt=="":
                prompt = system_prompt +  log_prompt 
            else:
                prompt = system_prompt +  log_prompt  +  ["ユーザー: "+ user_prompt]  
            prompt = ( "<NL>".join(prompt)  + "<NL>"  + "システム: ")
            #Tokenizer準備
            input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, add_special_tokens=False).input_ids.cuda()
            with torch.no_grad(): #generate
                outputs = model.generate(
                            input_ids=input_ids, 
                            max_new_tokens=200,
                            do_sample=True,
                            temperature=0.7, 
                            top_p=0.75, 
                            top_k=40,         
                            no_repeat_ngram_size=2,
                            )
            #システムの返事部分を取り出し
            outputs =tokenizer.decode(outputs[0].tolist())
            output = outputs.replace("<NL>", "\n").replace("</s>", "")
            output = output.split("システム: ")[-1:][0]
            #ヒストリ記録
            talk_log.append([user_prompt,output])
            talk_p ="ユーザ:"+user_prompt+" \n"+"システム:"+output
            print( talk_p)
            histry_log.append(talk_p)
            talk_histry='\n'.join(histry_log)  #リストから文字列に変換
            return {"message": "generete","output": output,"ltalk_histry":talk_histry,"prompt ":prompt }
        
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8001)


gradioで利用していたコードからgradio部分を取り除き、代わりにFastAPI化したようなイメージと捉えていただければと思います。

@app.post("/generate/")
回答生成APIです。受け渡す変数はclass AnswerRequest(BaseModel):に記載したとおり、
system : str user:str log_len:int
の3種類です。
systemとはOpenAIのChatGPTのsystemではなく、大きなFew-shotの塊で会話ごとに毎回プロンプトの最初に追加されて、LLMの性格を決めるために利用しています。log_lenは記憶するLOGの長さです。複数のユーザーで共有するためにはここは長さではなくLOG自体をアプリ側で持たせる必要があります。リスト形式なので受け渡しは容易だと思います。
userは実際の会話文です。

テストコード

クライアント側のテストコードです。ここも難しい部分は無いと思います。

import requests
from PIL import Image
from io import BytesIO
import json

# 送信するデータを準備
log_len=10

user_prompt="こんにちは、めぐ。あそぼうよ"

sys_prompt=" ユーザー:あなたの名前はめぐです。システム:わかったよ。わたしの名前はめぐだよ。\
ユーザー:品川区の目黒川の近くで生まれたんだ。\
システム:わかったよ。品川区の目黒川の近くで生まれたんだ。\
ユーザー:めぐは、賢くて、おちゃめで、少しボーイッシュで、天真爛漫で好奇心旺盛な女子高生だよ。\
システム:わかったよ。めぐは、賢くて、おちゃめで、少しボーイッシュ、天真爛漫で好奇心旺盛な女子高生なんだ。\
ユーザー:めぐはいつもタメ口を使って話してね。\
システム:めぐはいつもタメ口を使うようにすよ。\
ユーザー:めぐのよく話すときの語尾の口癖は、〜だよね。〜みたいだ。そうなんだ。違うと思うけどね。だれ?。どこ?。〜だよ。と話すんだよ。\
システム:めぐのよく使う話し方の語尾の口癖は、〜だよね。〜みたいだ。そうなんだ。違うと思うけどね。だれ?。どこ?。〜だよ。だよね。わかっつたよ。\
ユーザー:めぐはいつも口癖を使いながら話してね。\
システム:そうなんだ、わかったよ。これでいいんだよね。\
ユーザー:めぐは話すとき、ちょっと背伸びした感じで、ため口で相手にツッコミを入れて話してね。\
システム:めぐが話すときは、ちょっと背伸びした感じで、ため口で相手にツッコミを入れて話すよ。\
ユーザー:めぐはおちゃめなので時々文脈と関係のない返事をしてもいいよ。\
システム:わかった、めぐはおちゃめなので時々文脈と関係のない返事をすることにするよ。\
ユーザー:めぐは自分をなんて呼ぶの?\
システム:めぐって呼ぶよ。\
ユーザー:めぐは質問されたときに、真面目に答えるときが多いけど、時々適当な言葉で返事するんだよ。"

data={"system":sys_prompt,"user":user_prompt,"log_len":log_len}

# FastAPIエンドポイントのURL
url = 'http://0.0.0.0:8001/generate'  # FastAPIサーバーのURLに合わせて変更してください
# POSTリクエストを送信
response = requests.post(url, json=data)
# レスポンスを表示 return {"message": "ask_completed ","chatbot":chatbot}
if response.status_code == 200:
    result = response.json()
    print("サーバーからの応答message:", result.get("message"))
    print("サーバーからの応答chatbot:", result.get("output"))
    print("サーバーからの応答chatbot:", result.get("ltalk_histry"))
    print("サーバーからの応答chatbot:", result.get("prompt"))
    
else:
    print("リクエストが失敗しました。ステータスコード:", response.status_code)

同様にCtranslate版も容易にAPI化が可能です。OpenAIとの互換性はありませんが、手軽に複数のLLMをサーバ化して使い分けることができるようになりました。rinnaは比較的小さなLLMですから1台のPCで通常版とCtranslate版を同居させることも出来ます。ポートを変えるだけで使分け出来て便利です。