xgenでJGLUEを試す

日本語ベンチマークとしてJGLUE(JP Language Model Evaluation Harness)が提案されている。

以下のリポジトリで簡単に試せそうだったので試してみる。


xgenをqloraさせたので、他の日本語モデルと比較してみる

  • xgen-7b-8k-baseをdolly-jaで学習 (https://huggingface.co/Salesforce/xgen-7b-8k-base)

  • JGLUEでxgenを使う際にはtokenizer.from_pretrainedにtrust_code=Trueにする必要があるので注意(fork元に投げたPRがマージされてたので日本語版にも取り込まれるかも)

  • 学習時eos tokenとpadding tokenに不具合があった状態で学習したもので評価した。tokenizerは現在は修正済。ちゃんと学習し直すともう少しスコアが上がるかも

install

git clone -b jp-stable https://github.com/Stability-AI/lm-evaluation-harness.git
cd lm-evaluation-harness
pip install -e ".[ja]"

start

repositoryのmain関数を動かすだけ

引数


実行例

args=[
   "pretrained=Salesforce/xgen-7b-8k-base",
   "peft=./models/xgen-7b-8k-base/qlora",
   "load_in_8bit=True",
   "device_map_option=auto",
   "dtype=float16",
   'trust_remote_code=True'
]

MODEL_ARGS=','.join(args)
TASK="jcommonsenseqa-1.1-0.3,jnli-1.1-0.3,marc_ja-1.1-0.3,jsquad-1.1-0.3,xlsum_ja"

!python main.py \
--model hf-causal-experimental \
--model_args $MODEL_ARGS \
--tasks $TASK \
--num_fewshot "2,3,3,3,1" \
--device "cuda" \
--output_path "./result.json"

(すべてのタスクをまとめて実行するとcolab A100でもめちゃくちゃ時間かかるので注意。自分は別々にやりました)

結果

xgenをdolly_jaで学習させたものと、githubのリンクにあるrinna-instruction-ppoを比較する。

xgen-dolly_ja-qlora

JCommonsenseQA: 55.32
JNLI: 53.04
MARC-ja: 86.52
JSQuAD: 59.55

av: 63.61


rinna-instruction-ppo


JCommonsenseQA: 41.38
JNLI: 54.03
MARC-ja: 89.71
JSQuAD: 53.42

av: 59.63

質疑応答に関するJCommonsenseQAとJSQuADではxgenが高い値となっている。
文章のネガポジ判定のMARC-jaと文章間類似度であるJNLに関してはxgenがわずかに低くなっている。

所感

簡単にJGLUEが試せて良かった。現状強くて日本語喋れる大きなモデル(wizard-vicuna-13Bなど)とJGLUEでxgenやrinnaと比較してみたい。(日本語LLMに対する評価としてJGLUEのタスクだけでは足りてないという意見もあるので、JGLUEのスコアだけではわからないので注意)

一応スコアではxgenがrinnaより高い値となった。rinnaが3.6Bでxgenは7Bなのでモデルサイズ的には有利な気がする。ただcalm 7Bがrinnaよりもスコアは低いので日本語を喋れる7Bモデルとしてはxgenはなかなか使えるかも。xgenは長いcontextで学習させたことが特徴なので、それが質疑応答の結果に繋がった?

xgenが質疑応答強いみたいなので、更にきれいなデータで学習させると割といい感じになるのでは...?

raw result

xgenのresult


|         Task         |Version| Metric |Value |   |Stderr|
|----------------------|------:|--------|-----:|---|-----:|
|jcommonsenseqa-1.1-0.3|    1.1|acc     |0.5532|±  |0.0149|
|                      |       |acc_norm|0.5130|±  |0.0149|
|    Task    |Version| Metric |Value |   |Stderr|
|------------|------:|--------|-----:|---|-----:|
|jnli-1.1-0.3|    1.1|acc     |0.5304|±  |0.0101|
|            |       |acc_norm|0.4684|±  |0.0101|
|     Task      |Version| Metric |Value |   |Stderr| 5h
|---------------|------:|--------|-----:|---|-----:|
|marc_ja-1.1-0.3|    1.1|acc     |0.8652|±  |0.0045|
|               |       |acc_norm|0.8652|±  |0.0045|
|     Task     |Version|  Metric   | Value |   |Stderr| 6.5h
|--------------|------:|-----------|------:|---|------|
|jsquad-1.1-0.3|    1.1|exact_match|59.5452|   |      |
|              |       |f1         |72.2966|   |      |

rinnaのresult
https://github.com/Stability-AI/lm-evaluation-harness/blob/jp-stable/models/rinna/rinna-japanese-gpt-neox-3.6b-instruction-ppo/result.json


この記事が気に入ったらサポートをしてみませんか?