見出し画像

テーブルデータ用ニューラルネットワークは勾配ブースティング木にどこまで迫れるのか?(本編)

はじめまして、三菱UFJフィナンシャル・グループ(以下MUFG)の戦略子会社であるJapan Digital Design(以下JDD)でMUFG AI Studio(以下M-AIS)に所属する蕭喬仁です。2022年の12月からデータサイエンティストとして入社し、普段はMUFGに向けたPoCやR&Dに従事しています。

今回はR&Dとして取り組んでいた「テーブルデータ用ニューラルネットワーク(以下テーブルデータ用NN)の検証」について内容の共有をしたいと思います。

TL;DR

様々な実験をしているうちに長文になってしまったので、まずはまとめから。こちらに合わせて検証結果の表を見れば本投稿の趣旨がだいたい理解できると思います。
今回は複数のモデルを並列で何回も学習させる必要があったのですが、JDDの分析環境であるDatabricksのJobs機能を活用することで簡単に並列計算が実現できました。

特徴量エンジニアリングを実施しない状態で勾配ブーステインング木(LightGBM)とテーブルデータ用NN(TabNet、FTTransformer、GATE)、時系列NN(1D CNN、GRU、Transformer)の学習を行い、精度や学習時間、ハイパラ探索の効果等を検証

2019年から2023年にかけて様々なテーブルデータ用NNが提案されているが、精度やその安定性、学習時間を鑑みると勾配ブースティング木より使いやすいものは未だ無い

TabNetはテーブルデータ用NNの中では古参のモデルで新しいモデルよりも精度が劣るが、学習時間が短いためアンサンブルの候補としては有力

時系列性のあるテーブルデータに対してはテーブルデータ用NNよりも時系列NNの方が高性能だったが、有効な特徴量を自動で見つけるのは難しく、特徴量エンジニアリングをする方が精度向上を期待できる

取り組みの背景

MUFG向けのPoCではテーブルデータを用いたAI開発を行うことが多いのですが、利用するモデルは大抵勾配ブーステイング木で、テーブルデータ用NNを利用することはほとんどありませんでした。

2019年に発表されたTabNetを皮切りにこれまで数多くのテーブルデータ用NNが提案されてきましたが、勾配ブーステイング木よりも学習時間が長く計算資源としてGPUが必要なことが実務利用の際のネックとなっていました。
というのも、学習にかかる時間が短いほど「学習→評価→エラー分析→改善策の検討」というサイクルを多く回せるので最終的な精度も向上しやすいからです。

テーブルデータ用NNの学習時間が多少長くても勾配ブースティング木以上の精度が期待できれば良いのですが、実際は精度面でも劣るというのが業界での通説です。
私のこれまでの観測範囲においても、テーブルデータ用NNは論文中で勾配ブースティング木以上の性能を謳っているものの、実務やデータ分析コンペティションでその活躍を耳にすることは稀だったと思います。
また、Why do tree-based models still outperform deep learning on tabular data? という論文では回帰と分類の両方のタスクを含んだ45個のデータセットに対して勾配ブースティング木とテーブルデータ用NNの精度を比較しているのですが、ハイパラ探索の有無に関わらず勾配ブースティング木がテーブルデータ用NNよりも高い性能を出すことを実験的に明らかにしています。

このような背景もあり、私がテーブルデータを扱う業務に従事する際はData Centricなアプローチとして、「モデルは勾配ブースティング木に固定し、学習用データの準備やデータ特性の理解、特徴量エンジニアリングといったステップにできるだけ時間を割く」ことで精度改善を目指すようにしています。
時折、精度が伸び悩んだ時にNN系のモデルを試したくなることもありますが、実装や学習のコストが高い割に精度が期待できないことを踏まえると、結局別の方策に落ち着いてしまいます。

しかし、このような業務のやり方を続けているとテーブルデータ用NN系のモデルを触る機会が発生しないため、知見がいつまでも蓄積されず、本当に使いたくなった時に足取りが重たくなり十分な検証ができない可能性があります。

そこで今回は普段の業務では触りづらいテーブルデータ用NNに対する知見を蓄積すべく、kaggleのコンペティションを題材に予測精度や学習時間、ハイパラ探索による精度上昇幅などを検証しました。

検証の設定

検証に利用したコンペティション

今回はkaggleのAmerican Express - Default Prediction というコンペティション(以下Amexコンペ)を題材にしました。

このコンペではクレジットカード利用者の最大13ヶ月分の匿名特徴量を用いて利用者が120日以内にデフォルトするかの予測が課題となっていました。匿名特徴量は全部で188個存在し、177個が数値特徴量で11個がカテゴリ特徴量として与えられていました。

Amexコンペを題材にした理由は、時系列性を持ったテーブルデータを業務でよく扱うというのと、学習データとして45万人分のデータが存在しNNを学習するのにも十分な量がありそうだったからです。

検証したモデル

  • TabNet

  • FTTransformer

  • GATE

  • 1D CNN

  • GRU

  • Transformers

  • LightGBM

今回の検証では上記のモデルによる検証を実施しました。まずはテーブルデータ用NNの検証対象について説明します。

TabNetはTabNet: Attentive Interpretable Tabular Learning という論文で2019年に提案されたモデルです。TabNetはテーブルデータ用NNの中では古参のモデルですが、教師なし事前学習できる数少ないモデルなので今回の検証対象としました。モデルの仕組みについてはこちらのスライドがわかりやすいです。
TabNetにはカテゴリ特徴量に対するembeddingアルゴリズムが備わっているので、カテゴリ特徴量であると明言されていた11個の特徴量についてはそちらを利用するようにしました。
今回の検証では公式レポジトリの実装を利用しました。2019年に公開されたレポジトリにも関わらず、とても使いやすい実装となっていました。

FTTransformerはRevisiting Deep Learning Models for Tabular Dataという論文でNeurIPS 2021で発表されたモデルになります。TabNetが発表された2019年から2021年の間にも様々なテーブルデータ用NNが提案されているのですが、上述したWhy do tree-based models still outperform deep learning on tabular data?の論文中でFTTransformerが最良の結果を残していたため、こちらのみを検証対象としています。モデルの仕組みについてはこちらの記事がわかりやすいです。
FTTransformerにもカテゴリ特徴量に対するembeddingアルゴリズムが備わっているので11個のカテゴリ特徴量はそれを利用するようにしています。
実装には公式レポジトリを利用しました。こちらのレポジトリもかなり使いやすく整えられています。

GATEはGATE: Gated Additive Tree Ensemble for Tabular Classification and Regressionという論文で2022年に提案されたモデルです。詳細は論文に譲りますが、GRUに着想を得たGated Feature Learning Unitというモジュールによって特徴量の選択と非線形変換を行い、決定木を模したNNのスタッキングにより予測を得る手法となっています。論文中でLightGBMと同等の精度かつFTTransformer以上の精度を達成したと謳っていた最近のモデルなので検証対象としました。
こちらのモデルに対してもどれがカテゴリ特徴量であるかは明示的に教えて学習を実施しました。
実装には論文の実験で利用されていたpytorch_tabularを利用しました。基本的に使いやすいレポジトリなのですが、pytorch_lightningをラップしていたり、独自のプログレスバーが実装されていて少し癖のある印象でした。

2022年に提案されたテーブルデータ用NNの中には、小規模データに対して高精度かつ高速な学習を謳っているTabPFNやMetaから提案された解釈性の高さを謳うSPAMというモデルもありました。
しかし、TabPFNに関しては特徴量数100未満かつ行数1024行未満のデータを対象としていたため今回のデータセットサイズでは公式レポジトリの実装では動かすことができませんでした。
また、SPAMにも公式レポジトリが存在したのですが、非商用ライセンスでの公開であったため今回は検証対象から外すことにいたしました。

次に時系列NNについてです。今回のデータは顧客ごとに最大13ヶ月分の特徴量が利用できたので、それらを時系列方向に並べることで時系列NNによるモデリングが可能でした。

そこで1D CNNとGRU、Transofrmerといった基本的なアーキテクチャを何層か重ねて全結合層を2つ重ねたシンプルな構造を検証対象とし、どれくらいの精度を出せるかの検証を行いました。ただし、何層のブロックを重ねるか、中間層の次元をいくつにするかなどは自明では無いため、最終的な構造はハイパラ探索の要領で決定しました。
また、カテゴリ特徴量を入れる際はembeddingをした上でモデルに投入するようにしました。

最後に、勾配ブースティング木のモデルとしては実務でも利用する機会の多いLightGBMを対象としました。CatboostやXGBoostなどのモデルも検証対象にしても良かったのですが、あくまでテーブルデータ用NNの検証をしたかったため今回は省くことにしました。

モデリングに使用する特徴量

今回は最新1ヶ月分の特徴量を利用した場合と13ヶ月分の特徴量を利用した場合の2パターンを検証しました。

最新1ヶ月分の特徴量による検証も実施した理由は、13ヶ月分の特徴量を全て使った場合テーブルデータ用NNの学習時間が極端に遅くなることや、精度が逆に悪化したケースがあったからです。

特徴量エンジニアリングの方向性としては、匿名特徴量に対する平均や最大、最小といった集約演算やモデルが予測したデフォルト確率に集約演算をかけるメタ特徴量を作成するといったアプローチが上位解法では採用されていました。

ただ、今回はあえてそのような特徴量エンジニアリングを一切行わない前提で検証を実施しました。テーブルデータ用NNでよく謳われている特徴量の非線形変換の効果を検証したかったからです。
NNによる特徴量の自動抽出に関しては、画像や自然言語、音声などの分野で人間が設計した特徴量よりもNNが大量のデータから自発的に獲得した特徴量を用いることで様々なタスクで目覚ましい成果をあげています。果たしてテーブルデータに対しても同様な効果が期待できるのでしょうか?

データセットの前処理には、欠損値補完(今回は全て-2で補完)や特徴量が13ヶ月分に満たない顧客に対してのパディングを共通で実施しました。
テーブルデータ用NNのモデルには追加の前処理として数値データに対するGaussian Quantile Transformationによって値の正規化を実施しました。ただし、この正規化は1D CNNとGRU、Transformersに対しては逆効果だったため、これらのモデルに対しては実施しない場合の精度で比較を行いました。

比較項目

  • デフォルトパラメータにおけるCVとPublic LB、Private LBの値、学習時間

  • 探索によって得られたベストパラメータにおけるCVとPublic LB、Private LB、学習時間

  • ハイパラ探索にかかった時間およびその効果

CVとPublic LB、Private LBはいずれもコンペで用いられた指標である「ジニ係数と予測値上位4%のRecallの平均値」で評価しました。CVは5FoldのOOFで計算し、Public LBとPrivate LBは5Foldモデルの平均値を提出し評価を確認しました。各モデルの精度の安定性も見たかったため、10通りの乱数シードを用いて実験を行い結果を収集しました。

モデルの学習にはOOFの精度を用いたearly_stopping (patience = 2)を適用し、5Fold全てを学習するにかかった時間を学習時間としています。また、NNの最適化手法は共通してAdamとCosienLearningLRを用いました。

ハイパラ探索にはOptunaを利用しました。最適化する評価値は探索時間短縮のため学習データを8:2に分けたsingle Foldの値を用いました。各モデルの探索範囲などは後日リリースする(補足編)にて紹介します。

また、基本的な特徴量エンジニアリングを行った場合の精度のベンチマークとしてAMEX LightGBM Quickstartという公開ノートブックの値を転載しています。こちらのノートブックでは平均や最小、最大、最新の値といった集約演算を行いLightGBMで予測を行っています。

使用した計算リソース

計算リソースにはAWSのEC2をバックエンドとしたDatabricks Jobsを利用しました。こちらの機能を利用することで自由なインスタンス構成で様々な実験を並列で動かすことができました。
最新1ヶ月分の特徴量のみを使う検証ではg4dn.2xlarge(8core, 32GB, 1GPU)を利用し、全期間の特徴量を使う場合はCPUメモリの関係でg4dn.4xlarge(16core, 64GB, 1GPU)のインスタンスを利用しました。学習時間を載せる時にg4dn.4xlargeを使ったものには*マークをつけています。
また、LightGBMの学習にはCPUのみを利用し、その他のNNにはGPUを利用してモデルの学習を行いました。

検証結果

各モデルの評価指標及び学習時間の一覧

今回の検証結果を上記の表にまとめました。各評価指標の上位3位までの数値は太字にしています。**がついている数値についてはハイパラ探中最も良かった評価値と75%タイルの評価値の差を記載しているものです。

空白となっているセルの補足

今回ベンチマークとしたノートブックではLightGBMのパラメータを手動で調整していたので、ハイパラ探索後という扱いをし、デフォルトパラメータの値部分は空白としました。

GRUや1D CNN、Transformesらは上述したようにデフォルトの構造やハイパラが定まっているわけでは無いので、構造をハイパラ探索で決定した後の値を載せています。

GATEについてはハイパラ探索中に学習がハングする事象に遭遇したため、結果を得ることができませんでした。
この事象はモデルの予測結果を得る際にも時折発生していたのですが、pytorch_tabularがデフォルトで利用しているプログレスバーを使用しないようにすることで、予測時のハングは回避することが可能でした。一方、学習時のコードから取り除くことが困難だったため、今回の検証では断念しました。

TabNetの事前学習有りのハイパラ探索時間が空欄になっているのは、事前学習無しで探索したベストパラメータを共有しているからです。

最後に、FTTransformerとGATEに13ヶ月分の特徴量を入れた場合についてですが、学習に数日かかることが判明したため検証から外しました。

デフォルトパラメータを用いた場合の結果について

CVとPublic LBにおける評価ではLightGBMの精度が最も高く、Private LBにおいては数値のブレが大きいもののFTTransformerを用いた場合が最良の結果を示しました。

GATEは論文中でFTTransformerよりも高精度と謳っていたのですが、残念ながらAmexコンペにおいては逆の結果となりました。

TabNetは古参のモデルということもありテーブルデータ用NNの中でも最も精度が悪かったです。
意外だったのは13ヶ月分の特徴量を利用するよりも最新1ヶ月分の特徴量を利用した方が高い精度となった点です。TabNetなどのテーブルデータ用NNでは特徴量のマスクを学習することで特徴量選択をしますが、この結果を見るに特徴量が多くなると特徴量マスクが上手く働かなくなるようです。(ちゃんと機能するならば少なくとも最新1ヶ月分の精度と同等の値になるはずです。)
事前学習の効果は最新1ヶ月分の特徴量を利用した場合は殆ど確認できませんでしたが、13ヶ月分の特徴量を利用した場合 には効果がありました。数十万規模のデータに対する事前学習の効果はほとんどないが、データがノイジーな場合は効果が見込めるということでしょう。

学習時間という観点だと、最も高速だったのがLightGBMの1ヶ月分特徴量を用いる場合で、最も遅かったのがFTTransformerの180分、次いでLightGBMの13ヶ月分特徴量を用いる場合とGATEという結果となりました。

LightGBMはデフォルトでは全特徴量を用いて決定木を作成する仕様なので特徴量数に応じて学習時間が伸びてしまったのは当然の結果と言えます。

TabNetは精度が少し劣るものの、13ヶ月分の2444個の特徴量を入れても事前学習無しであれば計算時間が1時間もかからないのは実務者としては嬉しいポイントです。FTTransformerやGATEはこの規模の特徴量を現実的な時間で扱えず、1ヶ月分の188個の特徴量に対しても1回の学習に数時間単位必要なので、実務的にはTabNetよりも有用性が劣りそうです。

kaggleの上位解法で時折TabNetがアンサンブルに組み入れられているのはこのような理由からなのかもしれません。

ベストパラメータを用いた場合の結果について

CVとPublic LB、Private LBの全てにおいて最も精度が高かったのはLIghtGBMで特徴量エンジニアリングを行ったものでした。
CVでは0.794とNN系のモデルで最も精度の高かったGRUに対して+0.006ポイントの大差をつけています。
この結果から推測するに数十万規模程度のデータでは、NNの非線形変換に特徴量の自動抽出はあまり期待できず、人手で特徴量設計を行った方が効果が期待できることがわかります。

一方、特徴量エンジニアリングを行わなかったモデルの中ではLightGBMと時系列NNは殆ど性能がなかったため、モデリング性能自体は同等と言えそうです。

NNの中では時系列系のモデルが安定して高精度を達成していました。今回は系列長が13だったので学習時間も数十分と短かく、テーブルデータ用NNよりもはるかに使い勝手が良い印象です。

テーブルデータ用NNの中で最も精度が高かったのはFTTransformerですが、ハイパラ探索による精度向上は確認できなかった上にCVとPublic LB、Pribate LBともに精度が不安定です。FTTransformerのハイパラ探索は18回のグリッドサーチのみでしたが合計60時間以上かかっており、最終的な学習時間も1回7時間とっているので、総合的に考えると使い勝手はかなり悪いと言えます。

TabNetについてはハイパラ探索により精度が多少向上しましたが、結果としては最下位のままでしたが、ハイパラ探索時間は10時間前後とNNにしてはかなり短かったです。

GATEの結果が得られなかったのが惜しいですが、テーブルデータ用NNは元々の表現力が高いためハイパラ探索を実施してもあまり精度は伸びないと思っておいた方が良さそうです。

感想

様々なNNをAmexコンペに適用してみて、その精度や使い勝手を検証しましたが、現状テーブルデータに対してLightGBMを超えるようなモデルはまだ登場していないと言えそうです。無論、今回は1つのデータでしか検証をしていないため、より正確な結論を得るためには様々なテーブルデータで検証する必要はありますが。

テーブルデータに対してNNを使うのは筋が悪いことは検証前からある程度わかっていましたが、自分で手を動かしてみたことでLightGBMがいかに使い勝手の良いモデルかを身をもって実感できました。

予測精度やその安定性、学習時間およびハイパラ探索に必要な時間などを考えると、自分がテーブルデータのPoCでこれまで採用していた「予測アルゴリズムをLightGBMに固定し、データ整備や特徴量エンジニアリングにエネルギーを使う」というData Centricなアプローチはあながち間違っていなかったことが確認できて良かったです。


以上、「テーブルデータ用ニューラルネットワークは勾配ブースティング木にどこまで迫れるのか?(本編)」についてモデル間の精度比較に重点をおいて内容の共有を行いました。

「本編」に入りきらならかった個々のモデルの詳細な検証結果については「補足編」に紹介する予定です。

最後までご覧いただきありがとうございました!


Japan Digital Design株式会社では、一緒に働いてくださる仲間を募集中です。カジュアル面談も実施しておりますので下記リンク先からお気軽にお問合せください。


この記事に関するお問い合わせはこちら

M-AIS
Kyojin Syo