見出し画像

【AI】BERTの応用モデルでクレジットカードの不正利用検知をおこなう② ~環境構築・事前学習~

はじめに

こんにちは、エンジニアのすずきです。

2022年8月からAI関連の仕事をしており、BERTという自然言語処理モデルについて勉強しています。

最近は、多変量の時系列表データの学習に使用する、TabBERT(Hierarchical Tabular BERT)というBERTの応用モデルに関する論文を読みました(紹介記事も書きました)。

この論文、なんとありがたいことに、著者が事前学習のコードとデータをGitHubにあげています。

環境構築やモデル学習の経験を積むことで理解をさらに深められると思い、こちらのコードをいろいろ触ってみることにしました。
とりあえず第一歩として、AWSで仮想環境をたてて事前学習を実行するところまでをやってみました。

以下について知りたい方に、特に役立つ記事になると思います。

  • GPUインスタンスの立て方(AMIとインスタンス選び、vCPUの制限解除方法)

  • Anacondaによる環境構築(condaのエラー解消、仮想環境の選択)

  • 事前学習まわりのコードの中身

※仮想環境を選択したのは、ローカルマシンのGPUがNVIDIA製ではなく、CUDAとcuDNNが動かせなかったためです。setup.ymlで環境構築を行おうとしたら、ResolvePackageNotFoundエラーでいきなり失敗しました...

(base) [9:10:58] → conda env create -f setup.yml
                                     
Collecting package metadata (repodata.json): | / done
Solving environment: failed

ResolvePackageNotFound:
  - pytorch==1.7.1=py3.8_cuda11.0.221_cudnn8.0.5_0
name: tabformer
channels:
  - anaconda
  - pytorch
  - huggingface
  - conda-forge
dependencies:
  - python>=3.8
  - pip>=21.0
  - pytorch=1.7.1=py3.8_cuda11.0.221_cudnn8.0.5_0
  - torchvision
  - pandas
  - scikit-learn
  - transformers
  - numpy
  - libgcc
  - pip:
      - transformers==3.2.0

環境構築

AWSでGPUインスタンスをたてて、その中でさらにAnacondaで仮想環境をつくるまでの流れとなります。

GPUインスタンスの作成

AWSで用意されていたDeep Learning用のAMIと一番安価なGPUインスタンスを使用します。

  • AMI(Amazon Machine Image): amazon/Deep Learning AMI GPU PyTorch 1.12.0 (Amazon Linux 2) 20220803

  • インスタンスのタイプ: g4dn.2xlarge(xlargeだとメモリが枯渇したため) *説明ではg4ad.xlargeとなっています。すみません。

  • EBSボリューム: 300 GB(45 GBだとストレージが枯渇したためOSError: [Errno 28] No space left on device)

インスタンス作成

インスタンスを起動しようとしたら失敗しました。

インスタンス起動失敗

g4adインスタンスを作成するのに必要なvCPUが足りないとのことです。

制限解除

右上の「制限緩和のリクエスト」から以下のリクエストを送ったところ、4時間くらいで制限が解除され、無事にインスタンスを起動することができました。

制限緩和リクエスト

Anacondaで仮想環境構築

SSHでEC2インスタンスに接続します。
※パブリックIPはコンソール画面のEC2 > インスタンスからパブリックIPv4アドレスをコピーしてください。

(base) [10:55:26] → ssh -i ~/.ssh/MyKeyPair.pem ec2-user@<パブリックIP>

インスタンスにコードを落とします。

[ec2-user@ip-172-31-21-45 ~]$ git clone https://github.com/IBM/TabFormer.git

setup.ymlのあるディレクトリに移動して、conda env create ~を実行します。

[ec2-user@ip-172-31-21-45 ~]$ cd TabFormer
[ec2-user@ip-172-31-21-45 TabFormer]$ conda config --set channel_priority flexible
[ec2-user@ip-172-31-21-45 TabFormer]$ conda env create -f setup.yml

※conda config --set channel_priority flexibleをしないと、以下の文言が出て構築に失敗します。

Collecting package metadata (repodata.json): done
Solving environment: /
Found conflicts! Looking for incompatible packages.

conda init後に、作成した仮想環境(tabformer)をアクティベートします(アクティベート前に再度リモートログインが必要)。

[ec2-user@ip-172-31-21-45 TabFormer]$ conda init bash
[ec2-user@ip-172-31-21-45 TabFormer]$ conda activate tabformer

最後に、特定バージョンのcudatoolkitとpytorchがコンフリクトするというAnacondaのバグがあったので、以下を実行します。

[ec2-user@ip-172-31-21-45 TabFormer]$ pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio===0.7.2 -f https://download.pytorch.org/whl/torch_stable.html

※以上をインストールしないと、python実行時に以下のエラーが出ます。

OSError: /home/ka37/anaconda3/envs/fail/lib/python3.8/site-packages/torch/lib/../../../../libcublas.so.11: symbol free_gemm_select version libcublasLt.so.11 not defined in file libcublasLt.so.11 with link time reference

事前学習

論文では2種類のデータが用意されているのですが、今回はクレジットカードのトランザクションデータを使用し、学習を行います。

まずは、ローカルからデータをコピーします。
ディレクトリごとコピーするために、-rオプションをつける必要があります。
*READMEに記載されていた./data/card/ではなく、./data/prsa/にデータをコピーします(多分誤記)。

(base) [14:31:15] → scp -r -i ~/.ssh/MyKeyPair.pem ~/Programs/TabFormer/data/credit_card ec2-user@18.180.227.218:/home/ec2-user/TabFormer/data/

学習を実行します。

[ec2-user@ip-172-31-21-45 TabFormer]$ python main.py --do_train --mlm --field_ce --lm_type bert --field_hs 64 --data_type card --output_dir ./output_card/

6時間でストレージ(300 GB)がいっぱいになってしまいましたが、とりあえず動きました。
以下が学習が止まった時点でのcheckpoint(35000, 35500)の中身です。

checkpointの中身

作成された学習済モデルpytorch_model.binやconfig.json optimizer.ptをFine-Tuningの際に使用します。
以下がconfig.jsonとpytorch_model.binの読み込み例です。

from transformers import BertConfig, BertModel

config = BertConfig.from_pretrained(
    "./output_card/checkpoint-35000/config.json"
)

model = BertModel.from_pretrained("./output_card/checkpoint-35000/pytorch_model.bin", config=config)

print(config)
print(model)
BertConfig {
  "architectures": [
    "TabFormerHierarchicalLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "field_hidden_size": 64,
  "flatten": false,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "ncols": 12,
  "nhead": 8,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_layers": 1,
  "pad_token_id": 0,
  "total_flos": 41577468710400000,
  "type_vocab_size": 2,
  "vocab_size": 143492
}
BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(143492, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
    ...

コード解説

コード全文の説明は難しいので、要点のみ解説します。
以下がmain.pyです。

Dataset

TransactionDatasetクラスはPytorchのDatasetクラスを継承しており、データの前処理やトークンとidの対応を辞書に登録するようなメソッドが登録されています。

  • encode_data():生データを前処理データに変換

  • init_vocab():トークンとidの対応を登録(辞書作成)

  • user_level_data():ユーザーごとのデータとラベルをすべての連結してリストで出力

    • unique_users: Userカラムのリスト(24001行までなら[0 1])

    • user_data: Userごとのデータフレーム

    • trans_data: 1Userのラベル以外のデータを全部まとめたリスト[[…], […], […]…] special tokenはま special tokenはまだ入っていない

    • trans_labels: 1Userのラベルを全部まとめたリスト

  • format_trans():user_level_dataで出力された連結データを入力して、 11フィールドずつのレコードに分割。その後にトークンをidに変換、[SEP]トークンを付加して出力

  • prepare_samples():format_transで出力されたデータとラベルを10データ(window)ずつ連結して出力

    • user_idx: trans_dataのインデックス

    • user_row: 1Userの生データリスト

    • user_row_ids: 1Userのid変換後のリスト [SEP]付与済

    • ids: strideで5こずつずらしながらseq_lenでuser_row_idsを10連結ずつに区切ったリスト[…]

    • self.data: 10連結リストidsをまとめたリスト[[…], […]…] 全User含む

  • __getitem__:prepare_samplesで作成されたデータからindexに対応したものを取得。flatten(平坦化)フラグがFalseであれば10行(window)ずつreshape[10, 12]。最後のカラムはformat_transで付与した[SEP]

自然言語(例えば英語)の場合であれば、単語をBertTokenizerに入力すれば対応idを出力してくれますが、TabBERT用のTokenizerは存在しないため、辞書を作成するような処理が必要となります。
トークンから辞書を作成するメソッドset_id、辞書からidを取得するメソッドget_idはVocabularyクラスから呼び出して使用しています。

Model

TabFormerBertLMからモデルを読み込みます。

DataCollator

DataCollatorForLanguageModelingを継承したTransDataCollatorForLanguageModelingを作成し、トークンのマスク化やinput_idsとlabels(MASKの正解)の出力を行います。

  • mask_tokens():MLMのためにトークンをマスク化

  • __call__:Datasetの__getitem__で取得したデータをバッチサイズ分まとめてinput_ids([8, 10, 12])とlabelsとしてTensorで出力。バッチサイズは固定。バッチ内のUserは混ざっている

Trainer

Dataset、Model、DataCollatorをそれぞれ引数にとり、事前学習をやってくれます。
DataLoaderがラップされており、DataCollatorの__call__の処理はこの中で実行してくれます。

さいごに

Fine-Tuning以降のコードはなさそうだったので、現在調査しながら実装しています。
FIne-Tuning編、分類タスク編の記事も10月くらいに書く予定です。

また、ジェイタマズではエンジニアを募集しています。
会社やサービスに興味がある!という方がいらっしゃいましたら、ぜひ気軽にカジュアル面談しましょう!


この記事が気に入ったらサポートをしてみませんか?