見出し画像

Simple Transformers 入門 (7) - Seq2Seq

Simple Transformers」で「Seq2Seq」を行う方法をまとめました。

1. Seq2Seq

Seq2Seq」(Sequence-to-Sequence)は、入力とターゲットの両方がテキストであるモデルです。 「翻訳」や「要約」のタスクに有用なモデルになります。

「Seq2Seq」には主に3つのモデル種別があります。

・BART(要約)
・Marian(翻訳)
・Encoder-Decoder(汎用)

これらのモデルは、指定されたタスクに限定されないことに注意してください。 タスクは単に出発点として与えられています。

現在、次のルールがEncoder-Decoderに適用されます。

・Decoderはbertでなければならない。
・Encoderは、[bert、roberta、distilbert、camembert、electra]のいずれか。
・EncodrとDecoderは同じ「サイズ」でなければならない。
(例: roberta-base Encoderとbert-base-uncased Decoder)

保存されたEncoder-Decoderのロードには、既知の問題があります。 ロードされたモデルは、保存されたモデルと比較してパフォーマンスが低いようです。

1-1. BARTによるSeq2Seq

「BART」による「Seq2Seq」の最小限のコードは、次のとおりです。「Seq2Seq」は、encoder_decoder_type = "bart"で初期化し、encoder_decoder_nameでモデル名を指定する必要があります。

# ログの設定
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

# 学習データ
train_data = [
    ["one", "1"],
    ["two", "2"],
]
train_df = pd.DataFrame(train_data, columns=["input_text", "target_text"])

# 評価データ
eval_data = [
    ["three", "3"],
    ["four", "4"],
]
eval_df = pd.DataFrame(eval_data, columns=["input_text", "target_text"])

# モデルの作成
model_args = {
    "reprocess_input_data": True,
    "overwrite_output_dir": True,
    "max_seq_length": 10,
    "train_batch_size": 2,
    "num_train_epochs": 10,
    "save_eval_checkpoints": False,
    "save_model_every_epoch": False,
    "evaluate_during_training": True,
    "evaluate_generated_text": True,
    "evaluate_during_training_verbose": True,
    "use_multiprocessing": False,
    "max_length": 15,
    "manual_seed": 4,
}
model = Seq2SeqModel(
    encoder_decoder_type="bart",
    encoder_decoder_name="bart-large",
    args=model_args,
)

# 学習
model.train_model(train_df)

# 評価
results = model.eval_model(eval_df)

# 予想
print(model.predict(["five"]))

# モデルの保存
model1 = Seq2SeqModel(
    encoder_decoder_type="bart",
    encoder_decoder_name="outputs",
    args=model_args,
)
print(model1.predict(["five"]))

◎ 学習データと評価データ
train_model()eval_model()の入力は、input_texttarget_textの2つの列を含むPandasDataFrameである必要があります。

・input_text : 入力テキスト。
・target_text : ターゲットテキスト。

◎ 予測データ
予測データは文字列のリストである必要があります。

1-2. MarianによるSeq2Seq

「Seq2Seq」は、encoder_decoder_type = "marian"で初期化し、encoder_decoder_nameでモデル名を指定する必要があります。

それ以外は、BARTの使用法と同じです。

# ログの設定
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

# モデルの作成
model_args = {
    "reprocess_input_data": True,
    "overwrite_output_dir": True,
    "max_seq_length": 50,
    "train_batch_size": 2,
    "num_train_epochs": 10,
    "save_eval_checkpoints": False,
    "save_model_every_epoch": False,
    "evaluate_generated_text": True,
    "evaluate_during_training_verbose": True,
    "use_multiprocessing": False,
    "max_length": 50,
    "manual_seed": 4,
}
model = Seq2SeqModel(
    encoder_decoder_type="marian",
    encoder_decoder_name="Helsinki-NLP/opus-mt-en-de",
    args=model_args,
)

# 予測
src = [
    "People say nothing is impossible, but I do nothing every day.",
    "My opinions may have changed, but not the fact that I'm right.",
    "He who laughs last didn't get the joke.",
]
predictions = model.predict(src)
for en, de in zip(src, predictions):
    print("-------------")
    print(en)
    print(de)
    print()

1-3. Encoder-DecoderによるSeq2Seq

# ログの設定
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

# 学習データ
train_data = [
    ["one", "1"],
    ["two", "2"],
]
train_df = pd.DataFrame(train_data, columns=["input_text", "target_text"])

# 評価データ
eval_data = [
    ["three", "3"],
    ["four", "4"],
]
eval_df = pd.DataFrame(eval_data, columns=["input_text", "target_text"])

# モデルの作成
model_args = {
    "reprocess_input_data": True,
    "overwrite_output_dir": True,
    "max_seq_length": 10,
    "train_batch_size": 2,
    "num_train_epochs": 10,
    "save_eval_checkpoints": False,
    "save_model_every_epoch": False,
    "evaluate_generated_text": True,
    "evaluate_during_training_verbose": True,
    "use_multiprocessing": False,
    "max_length": 15,
    "manual_seed": 4,
}
encoder_type = "roberta"
model = Seq2SeqModel(
    encoder_type,
    "roberta-base",
    "bert-base-cased",
    args=model_args,
    use_cuda=True,
)

# 学習
model.train_model(train_df)

# 評価
results = model.eval_model(eval_df)

# 予測
print(model.predict(["five"]))


# モデルの作成
model1 = Seq2SeqModel(
    encoder_type,
    encoder_decoder_name="outputs",
    args=model_args,
    use_cuda=True,
)

# 予測
print(model1.predict(["five"]))

2. カスタムメトリックを使用した評価

カスタムメトリック関数(学習中の評価を含む)を使用して、モデルの生成されたテキストを評価できます。 ただし、これは他のモデルでの評価よりも大幅に遅い場合があります。

生成されたテキストを評価するには、evaluate_generated_textTrueに設定する必要があることに注意してください。

import logging
import pandas as pd
from simpletransformers.t5 import T5Model

# ログの設定
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

# 学習データ
train_data = [
    ["convert", "one", "1"],
    ["convert", "two", "2"],
]
train_df = pd.DataFrame(train_data, columns=["prefix", "input_text", "target_text"])

# 評価データ
eval_data = [
    ["convert", "three", "3"],
    ["convert", "four", "4"],
]
eval_df = pd.DataFrame(eval_data, columns=["prefix", "input_text", "target_text"])
eval_df = train_df.copy()

# モデルの作成
model_args = {
    "reprocess_input_data": True,
    "overwrite_output_dir": True,
    "max_seq_length": 10,
    "train_batch_size": 2,
    "num_train_epochs": 200,
    "save_eval_checkpoints": False,
    "save_model_every_epoch": False,
    "evaluate_generated_text": True,
    "evaluate_during_training": True,
    "evaluate_during_training_verbose": True,
}
model = Seq2SeqModel(
    encoder_decoder_type="bart",
    encoder_decoder_name="bart-large",
    args=model_args,
    use_cuda=True,
)

# カスタムメトリック関数
def count_matches(labels, preds):
    print(labels)
    print(preds)
    return sum([1 if label == pred else 0 for label, pred in zip(labels, preds)])

# 学習
model.train_model(train_df, eval_data=eval_df, matches=count_matches)


この記事が気に入ったらサポートをしてみませんか?