見出し画像

HuggingFace Evaluate を試す

HuggingFace Evaluate を試したのでまとめました。

1. HuggingFace Evaluate

「Evaluate」は、機械学習モデルとデータセットを評価するためのライブラリです。1行のコードで、さまざまなドメイン (NLP、Computer Vision、強化学習など) の数十の評価指標を利用できます。

評価モジュールは、次の3つのカテゴリに分類されます。

Metric : モデルの性能評価のために使用。通常、モデルの予測といくつかのground truthラベルが含まれる。
Comparison : 2つのモデルを比較するために使用。予測をground truthラベルと比較し、一致を計算することで実行できる。
Measurement : データセットのプロパティを調査するために使用。

2. インストール

Google Colabでのパッケージのインストール方法は、次のとおりです。

# パッケージのインストール
!pip install evaluate

3. 評価モジュールの一覧表示

list_evaluation_modules() で、評価モジュールを一覧表示できます。

# 評価モジュールの一覧表示
import evaluate
evaluate.list_evaluation_modules(
    module_type="metric",  # ツール種別
    include_community=False,  # コミュニティを含むか
    with_details=True)  # 詳細を含むか
[{'name': 'precision', 'type': 'metric', 'community': False, 'likes': 0},
 {'name': 'code_eval', 'type': 'metric', 'community': False, 'likes': 3},
 {'name': 'roc_auc', 'type': 'metric', 'community': False, 'likes': 0},
    :
 {'name': 'accuracy', 'type': 'metric', 'community': False, 'likes': 6},
 {'name': 'exact_match', 'type': 'metric', 'community': False, 'likes': 1},
 {'name': 'indic_glue', 'type': 'metric', 'community': False, 'likes': 0},
 {'name': 'spearmanr', 'type': 'metric', 'community': False, 'likes': 0},
    :
 {'name': 'nist_mt', 'type': 'metric', 'community': False, 'likes': 0},
 {'name': 'character', 'type': 'metric', 'community': False, 'likes': 0},
 {'name': 'charcut_mt', 'type': 'metric', 'community': False, 'likes': 0}]

4. 評価モジュールの属性

評価モジュールには、次の属性が付属しています。

・description : 評価モジュールの簡単な説明。
・citation : 利用可能な場合の引用のための BibTex 文字列。
・features : 入力形式を定義する Features オブジェクト。
・input_description : 評価モジュールの docstring と同等。
・homepage : 評価モジュールのホームページ。
・license : 評価モジュールのライセンス。
・codebase_urls : 評価モジュールのコードへのリンク。
・reference_urls : 追加の参照URL。

accuracy metricのdescription属性で簡単な説明を確認します。

# 評価モジュール属性の確認
accuracy = evaluate.load("accuracy")
print(accuracy.description)
Accuracy is the proportion of correct predictions among the total number of cases processed. It can be computed with:
Accuracy = (TP + TN) / (TP + TN + FP + FN)
 Where:
TP: True positive
TN: True negative
FP: False positive
FN: False negative

accuracy metricのfeatures属性で入力形式を確認します。

accuracy.features
{'predictions': Value(dtype='int32', id=None),
 'references': Value(dtype='int32', id=None)}

5. 評価モジュールのスコア計算

評価モジュールのスコア計算を行うには、次の2つの方法があります。

・All-in-one : 一度にすべての入力を EvaluationModule.compute() に渡す。
・Incremental : EvaluationModule.add() または EvaluationModule.add_batch() で入力を追加し、最後に EvaluationModule.compute() でスコア計算。

All-in-one でスコア計算を行う手順は、次のとおりです。

accuracy.compute(references=[0,1,0,1], predictions=[1,0,0,1])
{'accuracy': 0.5}

Incremental でスコア計算を行う手順は、次のとおりです。

for ref, pred in zip([0,1,0,1], [1,0,0,1]):
    accuracy.add(references=ref, predictions=pred)
accuracy.compute()
{'accuracy': 0.5}

Incremental のバッチでスコア計算を行う手順は、次のとおりです。

for refs, preds in zip([[0,1],[0,1]], [[1,0],[0,1]]):
    accuracy.add_batch(references=refs, predictions=preds)
accuracy.compute()
{'accuracy': 0.5}

6. 複数の評価モジュールのスコア計算

複数の評価モジュールを組み合せてスコア計算を行うには、evaluate.combine() を使います。

# 評価モジュールを組み合わせる
clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])
clf_metrics.compute(predictions=[0, 1, 0], references=[0, 1, 1])
{'accuracy': 0.6666666666666666,
 'f1': 0.6666666666666666,
 'precision': 1.0,
 'recall': 0.5}

7. Evaluatorのスコア計算

「Evaluator」を使うと、モデル、データセット、メトリクスを用意するだけで、スコア計算が可能になります。上記の例のように、予測値を自分で用意するのではなく、モデルとデータセットによる推論で準備します。

(1) パッケージのインストール。

# パッケージのインストール
!pip install transformers datasets
!pip install evaluate[evaluator]

(2) Evaluatorによるスコア計算

from transformers import pipeline
from datasets import load_dataset
from evaluate import evaluator
import evaluate

# モデルとデータセットとメトリクスの準備
pipe = pipeline("text-classification", model="lvwerra/distilbert-imdb", device=0)
data = load_dataset("imdb", split="test").shuffle().select(range(1000))
metric = evaluate.load("accuracy")

# Evaluatorの準備
eval = evaluator("text-classification")

# スコアの計算
results = eval.compute(
    model_or_pipeline=pipe, 
    data=data, metric=metric,
    label_mapping={"NEGATIVE": 0, "POSITIVE": 1},
)
print(results)
{
'accuracy': 0.937, 
'total_time_in_seconds': 10.774307620999934, 
'samples_per_second': 92.81338858850891, 
'latency_in_seconds': 0.010774307620999936
}



この記事が気に入ったらサポートをしてみませんか?