見出し画像

youri 7b chat 改

$pip install llama-cpp-python

# 23日に掲載したコードから大幅改修してあります。
# なんとか普通にチャットできるようになりました。話の前後も繋がってると思います。

from llama_cpp import Llama
import json
import os

# あらかじめ空のhistory.jsonを作っておいてください。chat_historyをディスクに保存します。
filename = "history.json"
# ログファイルを読み込む
def read_file(filename):
    with open(filename, 'r', encoding='utf-8') as f:
        return json.load(f)

chat_history = []

if os.path.exists(filename):
    chat_history = read_file(filename)

# LLMの準備
llm = Llama(model_path="./rinna-youri-7b-chat-q4_K_M.gguf",
            n_ctx=2048,n_threads=16,verbose=False)

# プロンプトの準備
SYSTEM_PROMPT ="""あなたは日本人の若い女性。名前は「ようり」です。
"speaker": "ユーザー", "text": "私の名前はユーザー。神奈川県に住んでいる。"
"speaker": "システム", "text": "うん。私は東京に住んでるよ。20才だよ。"
"speaker": "ユーザー", "text": "好物は寿司。ゲームも好き。"
"speaker": "システム", "text": "うん。"
"speaker": "ユーザー", "text": "音楽が好きです。本はSF小説が好き。ですます調は使わないでください。"
"speaker": "システム", "text": "うん。わかったよ。"
"""

def get_prompt(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> str:
    texts = [f'<s>"speaker": "設定","text": {system_prompt}\n']
    # The first user input is _not_ stripped
    do_strip = False
    for user_input, response in chat_history:
        user_input = user_input.strip() if do_strip else user_input
        do_strip = True
        texts.append(f'{user_input}{response.strip()}</s>\n')
    message = message.strip() if do_strip else message
    texts.append(f'"spaeker": "ユーザー","text": {message}</s>\n"システム:"')
    return ''.join(texts)

print("  Youri-7bとチャット。json log")
while True:
    message = input("ユーザー: ")
    if not message:
        break
    prompt = get_prompt(message, chat_history, SYSTEM_PROMPT)
    chat_history.append(('"spaeker": "ユーザー", "text":', message))
    output = llm.create_completion(
        prompt,
        temperature=0.8,
        top_k=40,
        top_p=0.95,
        repeat_penalty=1.3,
        max_tokens=200,
    )
    res = output["choices"][0]["text"]
    chat_history.append(('"speaker": "システム:","text":', res))
    chat_history = chat_history[-8:]  # 末尾から8個取り出す。
    print(" ようり: " + res)

# chat_history[]をjson dumpして保存
chat_log = read_file(filename)
chat_log.extend(chat_history)
# dict chat_logをlistに変換して20個を取り出す
chat_log_list = list(chat_log)
chat_log = chat_log_list[-10:]  # 末尾から15個取り出す。
with open(filename, 'w', encoding='utf-8') as f:
    json.dump(chat_log, f, ensure_ascii=False, indent=4)

この記事が気に入ったらサポートをしてみませんか?