見出し画像

【AI】BERTの応用モデルでクレジットカードの不正利用検知をおこなう③ ~Fine-Tuning~

はじめに

こんにちは、エンジニアのすずきです。

最近、多変量の時系列表データの学習に使用する、TabBERT(Hierarchical Tabular BERT)というBERTの応用モデルに関する論文を読み、付属コードで事前学習まで行いました。
ただ、付属コードだと事前学習までしか行えなかったため、さらなる理解のために、Fine-Tuningと分類タスクについては自分で実装してみることにしました。

前回までの記事については、以下をご覧ください。

コード解説

以前の記事から繰り返しとなりますが、事前学習では大量の教師なしデータから各レコードの特性を双方向学習することでトランザクション間の関係を把握しました。
Fine-Tuningでは、少量の教師ありデータから事前学習済モデルのパラメータを微調整することで、個々のタスクへの最適化を行います。

以下がFine-Tuningのコード全文となります。

Model

教師ありデータによる調整のため、BERT層に予測のための層(Prediction Layer)を追加したのが今回作成したCommonModelです。

まず、事前学習で得られた学習済モデルpytorch_model.binとconfig.jsonを読み込みます。
論文にならってBERT層の後にLSTM層を追加し、その先にPrediction LayerとしてLinear層を追加します。
今回は不正検知で2クラス(0, 1)の出力が必要となるので、Linear層も2クラスで出力するように設定します(不正じゃない:[1, 0]、不正:[0, 1])。

__init__で、これらの層を定義を行い、forwardで層の連結および予測値logitsを出力します。
BERTのようなTransformerモデルでは、通常[CLS]トークンの最後の隠れ状態をLinear層に通して予測値を出力するのですが、今回のデータには[CLS]トークンを付与しておらず、代わりに付与した[SEP]トークンの隠れ状態をLinear層に通していますsequence_output = out[:, -1, :]。

loss_fnでは、PytorchのMSELossクラスを利用してRMSEを損失関数とした損失計算を行います。

class CommonModel(nn.Module):
    def __init__(self,
                 pretrained_config='./output_pretraining/action_history/checkpoint-500/config.json',
                 pretrained_model='./output_pretraining/action_history/checkpoint-500/pytorch_model.bin'):
        super(CommonModel, self).__init__()
        self.config = BertConfig.from_pretrained(pretrained_config)
        self.model = BertModel.from_pretrained(pretrained_model, config=self.config)
        self.lstm = nn.LSTM(self.config.hidden_size, self.config.hidden_size, batch_first=True)
        self.regressor = nn.Linear(self.config.hidden_size, 2) # 2クラスで

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None):
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )
        out, _ = self.lstm(outputs[0], None)
        sequence_output = out[:, -1, :]
        logits = self.regressor(sequence_output)

        return logits

    def loss_fn(self, logits, label):
        loss = torch.sqrt(nn.MSELoss(reduction='mean')(logits, label))
        return loss
「Kaggleで学んだBERTをfine-tuningする際のTips②〜精度改善編〜」より引用

Dataset

TabBERTでは専用のTokenizerがないため、事前学習ではTransactionDatasetで辞書データの作成とtoken→id変換を行いました。
Fine-Tuningでもこれらを再利用したいので、TransactionDatasetを継承したDataset(FineTuningDataset)を新しく作成します。

Datasetは元々のデータをすべて持ちますが、__getitem__では、指定したindexの入力データ(平坦化で10連結)と連結データに対応する正解ラベル(Window label)をペアでTensorで返します。
CommonModelで予測ラベルを2クラスで出力するため、こちらでも正解ラベルを2クラス(one-hot)で出力するようにします。

init_vocabとformat_transでは、事前学習のときに作成保存した辞書データを読み込んでtoken→idに変換する処理を行います。

class FineTuningDataset(TransactionDataset):

    # 平坦化のためLabelもWindowごとにまとめる
    def __getitem__(self, index):
        one_hot_window_label = F.one_hot(torch.tensor(self.window_label[index]), num_classes=2)
        return_data = (torch.tensor(self.data[index], dtype=torch.long), one_hot_window_label.tolist())

        return return_data

    # pre-trainingでset_idが完了している
    def init_vocab(self):
        column_names = list(self.trans_table.columns)
        self.vocab.set_field_keys(column_names)

    # pre-trainingで保存した辞書でtoken2idをおこなう
    def format_trans(self, trans_lst, column_names):
        with open('./output_pretraining/credit_card/vocab_token2id.bin', 'rb') as p:
            vocab_dic = pickle.load(p)

        trans_lst = list(divide_chunks(trans_lst, len(self.vocab.field_keys) - 2))  # 2 to ignore isFraud and SPECIAL
        user_vocab_ids = []

        sep_id = self.vocab.get_id(self.vocab.sep_token, special_token=True)

        for trans in trans_lst:
            vocab_ids = []
            for jdx, field in enumerate(trans):
                vocab_id, _ = vocab_dic[column_names[jdx]][field]
                vocab_ids.append(vocab_id)

            # TODO : need to handle ncols when sep is not added
            if self.mlm:  # and self.flatten:  # only add [SEP] for BERT + flatten scenario
                vocab_ids.append(sep_id)

            user_vocab_ids.append(vocab_ids)

        return user_vocab_ids

DataCollator

DataCollatorForLanguageModelingを継承したFineTuningDataCollatorForLanguageModelingを作成し、input_idsとlabels(正解ラベル)を辞書型で返します。
このようにすることでDataLoaderからもデータを辞書型で取り出すことができるようになります。

class FineTuningDataCollatorForLanguageModeling(DataCollatorForLanguageModeling):

    def __call__(
            self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:

        input_ids = []
        labels = []

        for example in examples:
            input_ids.append(example[0])
            labels.append(example[1])

        batch_input_ids = self._tensorize_batch(input_ids)
        batch_labels = torch.tensor(labels)
        
        return {"input_ids": batch_input_ids, "label": batch_labels}

DataLoader

DataLoaderは、Datasetのインスタンスを渡すことで、ミニバッチ化した後のデータを返します。
Pytorchで用意されているクラスを使えばよく、実装する必要はありません。

    train_loader = DataLoader(
                        train_dataset,
                        collate_fn=data_collator,
                        batch_size=BS,
                        pin_memory=True, 
                        shuffle=True, 
                        drop_last=True, 
                        num_workers=0)

optimizer

論文のようにBERT層のパラメータのみをフリーズするため、まずはモデルのすべてのパラメータをフリーズした後、LSTM層とLinear層のみフリーズから解放します。
最適化アルゴリズムとしてはAdamWを選択しています。

また、エポックごとに学習率を調整するためのschedulerも用意します。

    # set models
    model = CommonModel()
    model.to(device)

    # freeze parameters in all network
    for name, param in model.named_parameters():
        param.requires_grad = False

    # activate parameters in only lstm network
    for name, param in model.lstm.named_parameters():
        param.requires_grad = True

    # activate parameters in only linear network
    for name, param in model.regressor.named_parameters():
        param.requires_grad = True

    # set optimizer
    optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    max_train_steps = N_EPOCHS * len(train_loader)
    warmup_steps = int(max_train_steps * WARM_UP_RATIO)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=max_train_steps
    )

勾配計算とパラメータ最適化

1エポック内でミニバッチのサイズ分(32)のループをまわします。
ループ内では、まずmodel.train()でネットワークを学習モードにします。

計算された損失lossに対してloss.backward()すると、requires_grad = Trueとなっているtorch.Tensorについて誤差逆伝播で勾配計算が行われます。
この状態でoptimizer.step()を実行すると、学習率に応じてパラメータの重みの更新が行われます。
最後に、計算された勾配結果をoptimizer.zero_grad()で0にリセットします。

    for epoch in range(N_EPOCHS):
        for d in train_loader:
            all_step += 1
            model.train()

            logits = model(
                d["input_ids"].to(device),
                attention_mask=None,
                token_type_ids=None
            )
            loss = model.loss_fn(logits, d["label"].float().to(device))
            loss = loss / ACCUMULATE

            train_iter_loss += loss.item()
            loss.backward()

            if all_step % ACCUMULATE == 0:
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()

                valid_loss = validation_loop(valid_loader, model)
                if valid_best_loss > valid_loss:  
                    valid_best_loss = valid_loss
                train_iter_loss = 0
            bar.update(1)

検証データについても同様に損失を計算しています。
こちらではmodel.eval()でネットワークを推論モードに切り替えています。

def validation_loop(valid_loader, model):
    model.eval()
    preds = []
    true = []

    for d in valid_loader:
        with torch.no_grad():
            logits = model(
                d["input_ids"].to(device),
                attention_mask=None,
                token_type_ids=None
            )
        preds.append(logits)
        true.append(d["label"].float().to(device))
        
    y_pred = torch.hstack(preds).cpu().numpy() # tensor連結してndarrayに変換
    y_true = torch.hstack(true).cpu().numpy()
    
    return mean_squared_error(y_true, y_pred, squared=False)

パラメータ保存

パラメータ最適化のループが終わったら、最後にパラメータを保存します。
このとき、state_dict()を使うことでネットワーク構造や各レイヤの引数といったムダな情報を取り除き、必要な情報のみを保存することができます。

torch.save(model.state_dict(), args.output_model_dir)

さいごに

分類タスク編の記事は11月に書く予定です。

また、ジェイタマズではエンジニアを募集しています。
会社やサービスに興味がある!という方がいらっしゃいましたら、ぜひ気軽にカジュアル面談しましょう!

この記事が気に入ったらサポートをしてみませんか?