見出し画像

Simple Transformers 入門 (6) - T5

Simple Transformers」で「T5」を行う方法をまとめました。

1. T5 Transformer

T5」(Text-to-Text Transfer Transformer)は「分類」「翻訳」「要約」などの様々な自然言語処理タスクを「Text-to-Text」で解くモデルです。

「Text-to-Text」は、入力を "タスク:問題"、出力を "回答" の形式で、全てのタスクを同じモデルで解きます。学習データを変えるだけで、同じモデルで様々なタスクが解けるのが特徴になります。

「テキストペア分類」の最小限のコードは、次のとおりです。

import logging

import pandas as pd
from simpletransformers.t5 import T5Model

# ログの設定
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

# 学習データ
train_data = [
    ["convert", "one", "1"],
    ["convert", "two", "2"],
]
train_df = pd.DataFrame(train_data, columns=["prefix", "input_text", "target_text"])

# 評価データ
eval_data = [
    ["convert", "three", "3"],
    ["convert", "four", "4"],
]
eval_df = pd.DataFrame(eval_data, columns=["prefix", "input_text", "target_text"])
eval_df = train_df.copy()

# モデルの作成
model_args = {
    "reprocess_input_data": True,
    "overwrite_output_dir": True,
    "max_seq_length": 10,
    "train_batch_size": 2,
    "num_train_epochs": 200,
}
model = T5Model("t5-base", args=model_args)

# 学習
model.train_model(train_df)

# 評価
results = model.eval_model(eval_df)

# 予測
print(model.predict(["convert: four"]))

◎ データ形式
「T5」の「入力」には、次のパターンがあります。

"<prefix>: <input_text> </s>"

「T5」の「ラベル」には、次のパターンがあります。

"<target_sequence> </s>"

◎ 学習データと評価データ
train_model()
eval_model()の入力は、prefixinput_texttarget_textの3つの列を含むPandasDataFrameである必要があります。

・prefix : 実行するタスクを示す文字列(ex: "question", "stsb")。
・input_text : 入力テキスト。prefixは、完全な入力にするために自動的に付加されます(<prefix>: <input_text>)。
・target_text : ターゲットテキスト。

モデル引数でpreprocess_inputsTrueに設定されている場合、prefixとinput_textの間に「</ s>」と「:」(prefixセパレータ)が自動的に追加されます。それ以外の場合、入力DataFrameには「</ s>」と「:」(prefixセパレータ)が含まれている必要があります。

◎ 予測データ
予測データは、prefixと「:」(prefixセパレータ)が含まれている文字列のリストである必要があります。

モデル引数でpreprocess_inputsTrueに設定されている場合、「</ s>」がリスト内の各文字列に自動的に追加されます。それ以外の場合、文字列には「</ s>」を含める必要があります。

2. カスタムメトリックを使用した評価

カスタムメトリック関数(学習中の評価を含む)を使用して、モデルの生成されたテキストを評価できます。 ただし、「T5」出力の生成方法により、他のモデルでの評価よりも大幅に遅くなる可能性があります。

生成されたシーケンスを評価するには、evaluate_generated_textTrueに設定する必要があります。

import logging
import pandas as pd
import sklearn
from simpletransformers.classification import ClassificationModel
from simpletransformers.classification.multi_modal_classification_model import \
    MultiModalClassificationModel
from simpletransformers.experimental.classification import ClassificationModel
from simpletransformers.language_representation import RepresentationModel
from simpletransformers.seq2seq import Seq2SeqModel
from simpletransformers.t5 import T5Model

# ログの設定
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

# 学習データ
train_data = [
    ["convert", "one", "1"],
    ["convert", "two", "2"],
]
train_df = pd.DataFrame(train_data, columns=["prefix", "input_text", "target_text"])

# 評価データ
eval_data = [
    ["convert", "three", "3"],
    ["convert", "four", "4"],
]
eval_df = pd.DataFrame(eval_data, columns=["prefix", "input_text", "target_text"])
eval_df = train_df.copy()

# モデルの作成
model_args = {
    "reprocess_input_data": True,
    "overwrite_output_dir": True,
    "max_seq_length": 10,
    "train_batch_size": 2,
    "num_train_epochs": 200,
    "save_eval_checkpoints": False,
    "save_model_every_epoch": False,
    # "silent": True,
    "evaluate_generated_text": True,
    "evaluate_during_training": True,
    "evaluate_during_training_verbose": True,
}
model = T5Model("t5-base", args=model_args)

# カスタムメトリック関数
def count_matches(labels, preds):
    print(labels)
    print(preds)
    return sum([1 if label == pred else 0 for label, pred in zip(labels, preds)])

# 学習
model.train_model(train_df, eval_data=eval_df, matches=count_matches)

# 評価
print(model.eval_model(eval_df, matches=count_matches))

3. 新しいタスクに関するT5モデルの学習

Question Generation With T5
The Guide to Multi-Tasking with the T5 Transformer


この記事が気に入ったらサポートをしてみませんか?