見出し画像

SetFit による Sentence TransformersのFew-Shotファインチューニングを試す

「SetFit」による「Sentence Transformers」のFew-Shotファインチューニングを試したので、まとめました。

1. SetFit

SetFit」は、「Sentence Transformers」をFew-Shotファインチューニングするためのフレームワークです。ラベル付き学習データをほとんど使用せずに、高い精度を実現します。

例えば「Customer Reviews (CR) sentiment dataset」では、1クラス8個の学習データセットのみで、3000個の完全な学習データセットで「RoBERTa Large」をファインチューニングしたものと同等の精度になります。

特徴は、次のとおりです。

◎ プロンプトは必要なし
少数のラベル付き学習データから埋め込みを直接生成するため、プロンプトは不要です。

◎ 学習が高速
T0やGPT-3 のような大規模モデルなしに、高い精度を実現します。結果、学習が高速です。

◎ 多言語サポート
Hub上の任意の「Sentence Transformers」を使用できるため、多言語チェックポイントを使用することで、多言語サポートが可能になります。

2. 英語でSetFitを試す

「Google Colab」で試します。

(1) SetFitのインストール。

# パッケージのインストールインストール
!pip install setfit

(2) データセットの読み込み
テキスト分類用のネガ/ポジのデータセットになります。

from datasets import load_dataset

# データセットの読み込み読み込み
dataset = load_dataset("SetFit/SentEval-CR")

◎ SetFit/SentEval-CR
英語のでネガ・ポジのデータセットになります。

(3) 学習データセットとテストデータセットの準備。
学習データは16個のみとします。

# 学習データセットセットとテストデータセットの準備
train_ds = dataset["train"].shuffle(seed=42).select(range(8 * 2))
test_ds = dataset["test"]

# 確認
print("train count:", len(train_ds))
print("test count:", len(test_ds))
for t in train_ds:
    print(t)
train count: 16
test count: 753

{'text': '* slick-looking design and improved interface', 'label': 1, 'label_text': 'positive'}
{'text': "the day finally arrived when i was sure i 'd leave sprint .", 'label': 0, 'label_text': 'negative'}
    :

(4) SetFitモデルの読み込みと、トレーナーの準備。

from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer

# SetFitモデルの読み込み読み込み
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")

# トレーナーの生成
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_ds, # 学習データセット
    eval_dataset=test_ds, # テストデータセット
    loss_class=CosineSimilarityLoss, # 損失
    batch_size=16, # バッチサイズ
    num_epochs=20, # エポック
)

◎ sentence-transformers/paraphrase-mpnet-base-v2
英語のSentence Transformersのモデルになります。

(5) 学習の実行。
4分ほどかかりました。

%%time

# 学習
trainer.train()
CPU times: user 2min 26s, sys: 1min 19s, total: 3min 46s
Wall time: 3min 47s

(6) 評価の実行。
3秒ほどかかりました。

%%time

# 評価
metrics = trainer.evaluate()
print(metrics)
{'accuracy': 0.7875166002656042}
CPU times: user 1.72 s, sys: 124 ms, total: 1.84 s
Wall time: 2.48 s

精度は 0.78 と学習できてそうです。

3. 日本語でSetFitを試す

「Google Colab」で試します。

(1) SetFitのインストール。

# パッケージのインストールインストール
!pip install setfit

(2) データセットの読み込み
テキスト分類用のネガ/ニュートラル/ポジのデータセットになります。

from datasets import load_dataset

# データセットの読み込み読み込み
dataset = load_dataset("tyqiangz/multilingual-sentiments", "japanese")

tyqiangz/multilingual-sentiments
マルチリンガルのネガ・ニュートラル・ポジのデータセットになります。

(3) 学習データセットとテストデータセットの準備。
学習データは16個のみとします。filter()でネガ・ニュートラル・ポジからニュートラルを消しています。

# 学習データセットセット(クラスごとに8個)とテストデータセットの生成
train_ds = dataset["train"].shuffle(seed=42).filter(lambda data: data["label"] != 1).select(range(8 * 2))
test_ds = dataset["test"].filter(lambda data: data["label"] != 1).select(range(753))

# 確認
print("train count:", len(train_ds))
print("test count:", len(test_ds))
for t in train_ds:
    print(t)
train count: 16
test count: 753

{'text': '4つ購入したが、2つは使用できず。接触が悪いのかわからないが、2度と購入しません', 'source': 'amazon_reviews_multi', 'label': 2}
{'text': '以前購入したものが使えなくなり購入。使いやすい、音も問題ない。良い買い物でした。', 'source': 'amazon_reviews_multi', 'label': 0}
    :

(4) SetFitモデルの読み込みと、トレーナーの準備。

from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer

# SetFitモデルの読み込み読み込み
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")

# トレーナーの生成
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_ds, # 学習データセット
    eval_dataset=test_ds, # テストデータセット
    loss_class=CosineSimilarityLoss, # 損失
    batch_size=16, # バッチサイズ
    num_epochs=20, # エポック
)

◎ sentence-transformers/paraphrase-multilingual-mpnet-base-v2
マルチリンガルのSentence Transformersのモデルになります。

(5) 学習の実行。
10分ほどかかりました。

%%time

# 学習
trainer.train()
CPU times: user 4min 47s, sys: 5min 5s, total: 9min 53s
Wall time: 9min 59s

(6) 評価の実行。
6秒ほどかかりました。

%%time

# 評価
metrics = trainer.evaluate()
print(metrics)
{'accuracy': 0.8047808764940239}
CPU times: user 2.86 s, sys: 145 ms, total: 3 s
Wall time: 5.68 s

精度は 0.80 と学習できてそうです。

(7) 推論の実行。

text = "購入して良かった。"
print(text, ":", model.predict([text])[0])

text = "ひどい商品!"
print(text, ":", model.predict([text])[0])
購入して良かった。 : 0
ひどい商品! : 2



この記事が気に入ったらサポートをしてみませんか?