見出し画像

Google Colab での JP Language Model Evaluation Harness による日本語LLMの評価手順

「Google Colab」での「JP Language Model Evaluation Harness」による日本語LLMの評価手順をまとめました。


1. JP Language Model Evaluation Harness

「JP Language Model Evaluation Harness」は、Stability AI が作成されてる、日本語 LLM の評価コードです。

2. 評価データセット

「JCommonsenseQA」と「JNLI」と「MARC-ja」と「JSQuAD」の4つのデータセットで評価しています

2-1. JCommonsenseQA

「JCommonsenseQA」は 常識的な推論能力を必要とする多肢選択式質問応答タスクのデータセットです。「CommonsenseQA」の日本語版になります。

2-2. JNLI

「JNLI」は、前提文が仮説文に対して持つ推論関係を認識するタスクのデータセットです。「entailment」(含意)「contradiction」(矛盾)「neutral」(中立)を分類します。「NLI」の日本語版になります。

2-3. MARC-ja

「MARC-ja」はテキスト分類タスクのデータセットです。「positive」(ポジティブ)「negative」(ネガティブ)を分類します。「MARC」の日本語版になります。

2-4. JSQuAD

「JSQuAD」は、文脈を読んで質問に答える質問応答タスクのデータセットです。
「SQuAD」の日本語版になります。

3. Colabでの実行

Colabでの実行手順は、次のとおりです。

(1) 「JP Language Model Evaluation Harness」のパッケージのインストール。

# パッケージのインストール
!git clone -b jp-stable https://github.com/Stability-AI/lm-evaluation-harness.git
%cd lm-evaluation-harness
!pip install -e ".[ja]"

(2) モデルのパッケージのインストール。 
今回は、「meta-llama/Llama-2-7b-hf」の評価を実行するコードを使います。

# モデルのパッケージのインストール
!pip install transformers sentencepiece accelerate bitsandbytes

(3) HuggingFaceのログイン。

# HuggingFaceのログイン
!huggingface-cli login

    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|
    
    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Token: 
Add token as git credential? (Y/n) n
Token is valid (permission: read).
Your token has been saved to /root/.cache/huggingface/token
Login successful

(4) 環境変数の準備。
MODEL_ARGSでモデルID「meta-llama/Llama-2-7b-hf」を指定し、TASKはLlama (Llama-1) と同じ設定にしました。

# 環境変数の準備
import os
os.environ["MODEL_ARGS"] = "pretrained=meta-llama/Llama-2-7b-hf,load_in_8bit=True"
os.environ["TASK"] = "jsquad-1.1-0.3,jcommonsenseqa-1.1-0.3,jnli-1.1-0.3,marc_ja-1.1-0.3"

(5) 評価の実行。
「--num_fewshot」(Few-Shotの数) もLlama (Llama-1) と同じ設定にしました。

# 評価の実行
!python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "2,3,3,3" --device "cuda" --output_path "result.json" --batch_size 2

「meta-llama/Llama-2-7b-hf」の評価はT4で12時間近くかかりました。

・JCommonsenseQA : 3時間
・JNLI : 1時間
・MARC-ja : 1時間30分
・JSQuAD : 6時間

{
  "results": {
    "jsquad-1.1-0.3": {
      "exact_match": 64.16028815848716,
      "f1": 77.62068005748563
    },
    "jcommonsenseqa-1.1-0.3": {
      "acc": 0.515638963360143,
      "acc_stderr": 0.01494639864919038,
      "acc_norm": 0.2993744414655943,
      "acc_norm_stderr": 0.013697125864334922
    },
    "jnli-1.1-0.3": {
      "acc": 0.29745275267050125,
      "acc_stderr": 0.00926777987291474,
      "acc_norm": 0.30156121610517667,
      "acc_norm_stderr": 0.009304239098715018
    },
    "marc_ja-1.1-0.3": {
      "acc": 0.8572691899540149,
      "acc_stderr": 0.00465240999785967,
      "acc_norm": 0.8572691899540149,
      "acc_norm_stderr": 0.00465240999785967
    }
  },
  "versions": {
    "jsquad-1.1-0.3": 1.1,
    "jcommonsenseqa-1.1-0.3": 1.1,
    "jnli-1.1-0.3": 1.1,
    "marc_ja-1.1-0.3": 1.1
  },
  "config": {
    "model": "hf-causal",
    "model_args": "pretrained=meta-llama/Llama-2-7b-hf,load_in_8bit=True",
    "num_fewshot": [
      2,
      3,
      3,
      3
    ],
    "batch_size": 2,
    "device": "cuda",
    "no_cache": false,
    "limit": null,
    "bootstrap_iters": 100000,
    "description_dict": {}
  }
}
hf-causal (pretrained=meta-llama/Llama-2-7b-hf,load_in_8bit=True), limit: None, provide_description: False, num_fewshot: 2,3,3,3, batch_size: 2
|         Task         |Version|  Metric   | Value |   |Stderr|
|----------------------|------:|-----------|------:|---|-----:|
|jsquad-1.1-0.3        |    1.1|exact_match|64.1603|   |      |
|                      |       |f1         |77.6207|   |      |
|jcommonsenseqa-1.1-0.3|    1.1|acc        | 0.5156|±  |0.0149|
|                      |       |acc_norm   | 0.2994|±  |0.0137|
|jnli-1.1-0.3          |    1.1|acc        | 0.2975|±  |0.0093|
|                      |       |acc_norm   | 0.3016|±  |0.0093|
|marc_ja-1.1-0.3       |    1.1|acc        | 0.8573|±  |0.0047|
|                      |       |acc_norm   | 0.8573|±  |0.0047|

accは100を掛けた値が、READMEの結果に表示されている値になります。
meta-llama/Llama-2-7b-hf」のAverateは57.79でした。

(64.1603 + 51.5638963360143 + 29.745275267050125 + 85.72691899540149) / 4 = 57.79

参考



この記事が気に入ったらサポートをしてみませんか?