見出し画像

ABCIでOrion14Bを使って10Kマルチターン日本語会話データセットを生成した

そろそろ年度末。
みなさんいかがお過ごしでしょうか。
年度末といえばABCIもそろそろ混み始めるお年頃。
ABCIのA100は960基しかないので奪い合いが続いています。

一方でMetaは年内に34万基のH100を導入するそうですが・・・

まさにB-29と竹槍。

とはいえ、前世代のV100ならまだ空きがあるので残ったポイントを今のうちに使ってしまおうかなと思い、Orion14B-ChatとWikipediaデータセットを使って日本語マルチターン会話データセットを作りました。V100x4マシンを10台使用して半日で1万会話(10K)を生成。

Orion14Bは商用利用可能とのことですが、ライセンスがllama2に比べて格段にややこしいのでよく法務と相談して使うかどうか決めてください。

本当はInt4でやりたかったんだけど、V100だとInt4使えないのでフル精度でぶん回してます。Orion14BはInt4でも精度が1%しか落ちないという触れ込みなのでローカルLLM向きだと思います。Orion14B-LongChatでもやってみようとしたんだけどLongだけあってものすごく推論に時間がかかるので諦めました。

しかしV100もそろそろ退役が近いかなあ。Int4も使えないしBF16も使えないので最新のモデルはAmpere世代以降じゃないとなかなか厳しい感じになってきています。SDXLのLoRAファインチューニングもV100だとできないし。

Orion14Bを使って思ったのは、前回まで使っていたllamaPro8Bは、あんまりJSONで出してくれなくて1/3くらいしかちゃんとしたJSONファイルが得られなかったのに対し、Orion14Bは生成時にJSONかどうかチェックすればちゃんと「使える」データを生成してくれるあたりが便利です。

14Bでも「全然あり」な感じになってきました。
一応生成に使ったソースコード載せておきます

from transformers import pipeline


from transformers import AutoTokenizer, AutoModelForCausalLM,pipeline
from transformers.generation.utils import GenerationConfig
import json
import random
import string
import time
import sys
import torch


from transformers import AutoTokenizer, AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("OrionStarAI/Orion-14B-Chat",trust_remote_code=True,torch_dtype=torch.bfloat16,device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("OrionStarAI/Orion-14B-Chat" ,trust_remote_code=True)
model.generation_config = GenerationConfig.from_pretrained(
    "OrionStarAI/Orion-14B-Chat"
)
model.generation_config.max_new_tokens=5000

def gpt(utterance,theme):
    messages = [{"role": "system", "content": "あなたは役に立つAIです。ユーザの質問、依頼を正確に便利に答えてください。正解がわからない場合に「正解がわかりません」と答えてください。全てJSON形式で返します。"}]
    messages.append({"role": "user", "content": utterance})
    res = model.chat(tokenizer, messages, streaming=False)

    return res

def generate_random_string(length):
    letters = string.ascii_letters
    result_str = ''.join(random.choice(letters) for i in range(length))
    return result_str

import time
import datasets
#data=datasets.load_dataset("izumi-lab/wikipedia-ja-20230720",split="train").shuffle()
data=datasets.load_dataset("hpprc/wikipedia-20240101",split="train").shuffle()

import json
cnt=0
import re
res=""
for row in data:
    if row["title"]==row["text"]:
        continue
    if len(row["text"])<500:
        continue
    try:
            res=gpt(f"{row['title']}について書かれた以下の文章を読んで先生と生徒で会話する会話文を作りなさい。\n\n▪️{row['title']}\n\n"
                    f"{row['text'][:4096]}\n\n" 
                    '上記の文章について日本語での質問文と返答文のセットを作り、```"conversations":[{"生徒":"<質問1>","先生":"<回答1>"},{"生徒":"<質問2>",'
                    '"先生":"<回答2>"},{"生徒":"<質問3>","先生":"<回答3>"},{"生徒":"<質問4>","先生":"<回答4>"}]```のように4つ以上の質問と答えを考え、'
                    "それをJSON形式で返しなさい。ダブルクォーテーションは適切にエスケープしなさい\n",row['title'])
            
            if '{"生徒": "<質問' in res or "<回答" in res:
                continue
            print(res)
            data = json.loads(res)

            #print(data, file=sys.stderr)
            if len(data['conversations'])<1:
                continue
            data['title']=row['title']
            data['body']=row['text']
            print(json.dumps(data, ensure_ascii=False))
            with open("orin14b_log.txt","a") as f:
                f.write(json.dumps(data, ensure_ascii=False)+"\n")

    except Exception as e:
            print('### エラーが発生しました。%s'%res, file=sys.stderr)
            print(e, file=sys.stderr)
            pass

    cnt+=1
    if cnt>100000:
            break
    

ABCI用投入スクリプト

#!/bin/bash

#$-l rt_G.large=1
#$-j y
#$-l h_rt=168:00:00
#$-cwd

export HF_DATASETS_CACHE=/scratch/aca10054zv
source /etc/profile.d/modules.sh
module load python/3.10/3.10.10
module load cuda/12.1/12.1.1
module load cudnn/8.9/8.9.2
nvidia-smi
python3 orionMakeconv.py