見出し画像

社長(AI)に頼んで今度こそ商用利用可能な日本語マルチターン会話データセットを作ってもらった

前回、けっこう時間をかけて作ったにも関わらず、よくみるとQarasu14BはShareGPTを使っているので商用利用可能かどうかは微妙な結果に終わってしまった。性能は抜群に高いのだが・・・

ところが、最近でてきたTencentのllama2Pro8Bは、かなり高性能にも関わらずわずか8Bでしかもllama2ライセンスなので今度は文句なしに商用利用可能(ただ月間7億ユーザーまで/どんな大成功サイトやねん)。

8Bなので並行して動かせる。とりあえず1GPUは温存するとして7GPUで24時間、社長(AI)をぶん回すと1万会話くらいが生成できた。これは速い。前回は一週間くらいかかったことを考えると驚異的である。vllm使ったのも効果的だったようだ。こんなに速いのかvllm

しかし、さすがに精度はQarasuに及ばない。そこは8B。いい会話データがとれるかどうかはけっこう運次第なのである。運次第ではあるが、この速さなら十分実用に耐える。一週間もあればスクリーニングしても1万会話くらいとれるだろう。

というわけでデータセットを公開した。3000でも試しにファインチューニングするくらいの用途には十分使えるのではないだろうか。

社長(AI)への稟議書は以下の通り

# 稟議書
# 社長、おねげえですだ。今度こそこれで商用利用可能な日本語データセットを作ってくだせえ
#
# 提案者: shi3z

from vllm import LLM, SamplingParams
import json
import random
import string
import time
import sys
import torch
import numpy as np
def torch_fix_seed(seed=42):
    # Python random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)



sampling_params = SamplingParams(temperature=0.8, max_tokens=8000,  repetition_penalty=1.1)
llm = LLM(model="TencentARC/LLaMA-Pro-8B-Instruct", trust_remote_code=True)


#unixtimeを取得
seed=int(time.time())
torch_fix_seed(seed)
print(seed)

def convert_message(message):
    message_text = ""
    if message["content"] is None and message["role"] == "assistant":
        message_text += "<|assistant|>\n"  # final msg
    elif message["role"] == "system":
        message_text += "<|system|>\n" + message["content"].strip() + "\n"
    elif message["role"] == "user":
        message_text += "<|user|>\n" + message["content"].strip() + "\n"
    elif message["role"] == "assistant":
        message_text += "<|assistant|>\n" + message["content"].strip() + "\n"
    else:
        raise ValueError("Invalid role: {}".format(message["role"]))
    # gradio cleaning - it converts stuff to html entities
    # we would need special handling for where we want to keep the html...
    #message_text = html.unescape(message_text)
    # it also converts newlines to <br>, undo this.
    message_text = message_text.replace("<br>", "\n")
    return message_text


def pipe(utterances):
    prompts=[]
    for utterance in utterances:
        messages = [{"role": "system", "content": "あなたは役に立つAIです。ユーザの質問、依頼を正確に答えてください。全てJSON形式で返します。"}]
        messages.append({"role": "user", "content": utterance[0]})
        prompt = llm.llm_engine.tokenizer.apply_chat_template(conversation=messages, seed=seed,add_generation_prompt=True, tokenize=False)

        #print("prompt",prompt)
        prompt+='\n{"conversations":[{"生徒":"%sって何ですか?","先生":%sは、'%(utterance[1],utterance[1])
        prompts.append(prompt)
    outputs=llm.generate(prompts, sampling_params)
    result=[]
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        result.append(generated_text)
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
    #res = pipe(prompt, max_new_tokens=1000, do_sample=False, temperature=0.8, return_full_text=False)
    return result

import time
import datasets
data=datasets.load_dataset("izumi-lab/wikipedia-ja-20230720",split="train").shuffle()
import json
cnt=0
import re



buffer=[]

for r in data:
    if r["title"]==r["text"]:
        continue
    if len(r["text"])<500:
        continue
    #try:
    if len(buffer)>6:
            del buffer[0]
    buffer.append({"title":r['title'],"text":r['text']})
    #if True:
    prompts=[]
    for row in buffer:
        prompts.append( (f"{row['title']}について書かれた以下の文章を読んで先生と生徒で会話する会話文を作りなさい。\n\n▪️{row['title']}\n\n"
                f"{row['text'][:2048]}\n\n" 
                '上記の文章について日本語での質問文と返答文のセットを作り、```"conversations":[{"生徒":"<質問1>","先生":"<回答1>"},{"生徒":"<質問2>",'
                '"先生":"<回答2>"},{"生徒":"<質問3>","先生":"<回答3>"},{"生徒":"<質問4>","先生":"<回答4>"}]```のように4つ以上の質問と答えを考え、'
                "それをJSON形式で返しなさい。ダブルクォーテーションは適切にエスケープしなさい\n",row['title'],row['text'] ))
    result=pipe(prompts)
    for i in range(len(result)):
        try:
            res=result[i]
            res=res.replace("```json","").replace("```","").replace("<|im_end|>","")
            res='{"conversations":[{"生徒":"%sって何ですか?","先生":"'%(prompts[i][1])+res.split("\n")[0]
            print(res)

            if '{"生徒": "<質問' in res or "<回答" in res:
                continue
            data = json.loads(res)

            #print(data, file=sys.stderr)
            if len(data['conversations'])<1:
                continue
            data['title']=prompts[i][1]
            data['body']=prompts[i][2]
            with open("llama2pro8b.txt","a") as f:
                print(json.dumps(data, ensure_ascii=False))
                f.write(json.dumps(data, ensure_ascii=False)+"\n")
        except Exception as e:
            print('### エラーが発生しました。%s'%res, file=sys.stderr)
            print(e, file=sys.stderr)
            pass
                
    """except Exception as e:
            print('### エラーが発生しました。%s'%res, file=sys.stderr)
            print(e, file=sys.stderr)
            pass
    """

    cnt+=1
    if cnt>100000:
            break

ちなみに社長のスペックだとA100 80GBを80GBくらいフルに使って1個のLLMを動かせるので、CUDA_VISIBLE_DEVICESで0から6まで使って同時並行的に推論させている。重複する可能性もあるが、そもそも重複していても会話が異なればデータセットとしては成り立つので無視している。

今回もFreeAIのAIスーパーコンピュータ、継之助社長にお世話になった。