見出し画像

BERTベースモデルのFine-TuningにTrainerクラスを利用する

こんにちは、エンジニアのすずきです。

以前、IBM論文の参考コードでTabBERTモデルの事前学習を行い、Fine-Tuningについては自作コードを実装しました。

自作コードで一応Fine-Tuningをできるようになったのですが、F1スコアなどのメトリクスを計算するだけでも面倒さを感じていました。

事前学習のときと同様にTransformersのTrainerクラスを使えればメトリクスも簡単に出せるのに...といろいろ調べてみたところ、下流タスク用のヘッドをボディ(事前学習済モデル)に付加したものについても、普通にTrainerクラスが使えることがわかりました。

そんなわけで、今回は以前の自作コードをTrainerクラスを使ってリファクタリングしました。

また、今回のリファクタリングのついでに、WandBの導入やオーバーサンプリング処理の追加も行ったので、最後の方におまけで書いています。

実装

ポイントとなる部分だけ書きます。

モデルの作成

Trainerクラスでモデルを扱うためには以下がポイントとなります。

  • PreTrainedModelを継承する

  • 出力をlossとlogitsのタプルで返す

公式Docsに記載されているように、TrainerクラスではPreTrainedModelで動作するように最適化されるようです。

Trainer is optimized to work with the PreTrainedModel provided by the library. You can still use your own models defined as torch.nn.Module as long as they work the same way as the 🤗 Transformers models.

Hugging Face

nn.Moduleでも大丈夫とのことでしたが、実際にこちらを継承したらエラーがでました。

また、自作コードではforwardでlogits(予測結果)のみを返すようにしていたのですが、lossとlogitsのタプルで返すようにしました。

lossを返すためにinitで損失関数loss_fnを指定し、推論時にもモデルを使用することを想定して、損失関数loss_fnを指定しない場合はloss=Noneを返すようにしました。

なお、下流タスク(分類)のヘッドとして、事前学習済モデルにLSTM層とLinear層を付加しています。

from transformers import BertConfig, BertModel, PreTrainedModel
import torch.nn as nn
import torch

class VisitorReactionModel(PreTrainedModel):
    def __init__(self,
                 config='./output_pretraining/action_history/checkpoint-500/config.json',
                 num_categories=2,
                 loss_fn=None,
                 pretrained_model='./output_pretraining/action_history/checkpoint-500/pytorch_model.bin'):
        super().__init__(config=config, num_categories=num_categories, loss_fn=loss_fn, pretrained_model=pretrained_model)
        self.model = BertModel.from_pretrained(pretrained_model, config=config)
        self.lstm = nn.LSTM(self.config.hidden_size, self.config.hidden_size, batch_first=True)
        self.regressor = nn.Linear(self.config.hidden_size, num_categories)
        self.loss_fn = loss_fn

    def forward(self, 
                input_ids=None, 
                attention_mask=None, 
                token_type_ids=None,
                output_attentions=False,
                output_hidden_states=False,
                labels=None):
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )
        out, _ = self.lstm(outputs[0], None)
        sequence_output = out[:, -1, :]
        logits = self.regressor(sequence_output)
        
        loss=None
        if labels is not None and self.loss_fn is not None:
            loss = self.loss_fn(logits, torch.max(labels, 1)[1])
        
        # ModelOutputだとlossがスカラーじゃないというエラーが出るためTupleで返す
        return loss, logits

compute_metricsの作成

モデルをTrainerクラスに適用できるようになったのですが、これだけだとprecision, recall, f1といったメトリクスを導出することができません。

そんなときに使用するのがcompute_metricsとなります。
※Transformersの3系バージョンだと使用できなかったので、4.26.0へ事前にバージョンアップしています。

引数の型はEvalPrediction、戻り値の型はOptional[Dict[str, float]]となります。

def compute_metrics(res: EvalPrediction):
    logits = res.predictions.argmax(axis=1)
    labels = res.label_ids.argmax(axis=1)
    precision = precision_score(labels, logits, average='macro')
    recall = recall_score(labels, logits, average='macro')
    f1 = f1_score(labels, logits, average='macro')
        
    return {
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

あとは、Trainerの引数にmodelとcompute_metricsを与えれば、メトリクスが計算されます。
分類タスクでFine-Tuningを行うため、損失関数にはCrossEntropyを用いています。

    loss_fn = CrossEntropyLoss()
    model = VisitorReactionModel(config=config, pretrained_model=pretrained_model, loss_fn=loss_fn)
        
    training_args = TrainingArguments(
        output_dir=args.output_dir,  # output directory
        num_train_epochs=args.num_train_epochs,  # total number of training epochs
        per_device_train_batch_size=args.num_train_batch_size,
        per_device_eval_batch_size=args.num_eval_batch_size,
        save_steps=args.save_steps,
        do_train=True,
        do_eval=True,
        evaluation_strategy="epoch", # epochかsteps(デフォルト500)ごとに評価
        overwrite_output_dir=True,
        save_total_limit=1,
        report_to="wandb"
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics
    )

その他

今回のリファクタリングでいくつか改善も行ったので、おまけで記載します。

WandBの導入

Trainerの利用でWandBの導入が楽になったので、リファクタリングにあわせて実装しました。

Fine-Tuningのコード内に、WandBのログインと初期化のコードを追加します。
このとき、dotenvを利用して、ローカル学習の際は.envから、SageMaker Training Jobsの際はEstimatorsのhyperparametersからAPIキーを読み込むようにします。

load_dotenv()
WANDB_API_KEY = os.getenv('SM_HP_WANDB_API_KEY')

wandb.login(key=WANDB_API_KEY) # Pass your W&B API key here
wandb.init(project="tabformer-opt") # Add your W&B project name 
estimator = Estimator(
    image_uri="",
    role=role,
    instance_type="ml.g4dn.2xlarge",
    instance_count=1,
    base_job_name="tabformer-opt-fine-tuning",
    output_path="",
    code_location="",
    sagemaker_session=session,
    entry_point="fine-tuning.sh",
    dependencies=["tabformer-opt"],
    hyperparameters={
        "data_root": "/opt/ml/input/data/input_data/",
        "data_fname": "",
        "output_dir": "/opt/ml/model/",
        "model_path": "/opt/ml/input/data/input_model/",
        "wandb_api_key": <APIキー>
    }
)

あとは、TrainingArgumentsにreport_to="wandb"を追加するだけで、学習結果が記録されるようになります。

オーバーサンプリング

今回、ポジティブラベルが10 %以下の不均衡データを使用しており、そのままモデルで学習を行ってもprecisionやrecallが低い結果となってしまいます。

そのため、少数派のポジティブラベルデータをオーバーサンプリングで増やすようにしました。
この際、単純なデータ複製で過学習を起こさないために、少数派のデータからランダムでデータを選択し、そのデータからランダムで選択された近傍点を用いて、両者の合成データを作成する、SMOTEという手法を用いました。

SMOTE処理を以下の関数にまとめ、データ前処理のコードに加えました。

    def overSampling(data):
        sm = SMOTE(random_state=42)
        X = data.drop(columns='reaction', axis=1)
        y = data['reaction']
        X_sample, Y_sample = sm.fit_resample(X, y)

        over_sampling = pd.DataFrame()
        over_sampling = X_sample
        over_sampling['reaction'] = Y_sample

        return over_sampling

参考資料

採用情報

バックエンドが得意な方を募集中です。

AWSやバックエンドの経験があれば、インフラ設計やパフォーマンスチューニングなどなんでもお任せします。

もしご興味があれば、採用情報ページの画面左下のボタンからチャット(かWeb通話)でお声がけいただけると幸いです。

OPTEMOというサービスです

最近、YOUTRUSTにも登録しました。
カジュアル面談に興味がある方はぜひ…!


この記事が気に入ったらサポートをしてみませんか?