機械学習のデータドリフト検知を自動化する方法
電通デジタルでデータサイエンティストを務める吉田です。
本記事では、機械学習においてモデル学習時点でのデータと推論時点でのデータが経時的に乖離を起こしていく、いわゆるデータドリフトの検知を自動化するために構築したワークフローについてご紹介いたします。
データドリフトによる機械学習モデルの劣化とは
機械学習モデルを実運用していく際に課題になる事象の1つとして、データドリフトの問題があります。
一般的に、機械学習ではいくつかの特徴量Xに対する目的変数Yとの隠れた関係を定式化します。XとYの関係は時間が経つにつれて変化していくことがしばしばあり、これに伴って一度作成したモデルの推論精度も低下していきます。
簡単な例として、あるWebサービスにおいてサイト上の行動ログを元にユーザーごとにコンバージョンの発生を予測する機械学習モデルを作成したとします。このモデルは、「平均的に10分以上閲覧しているユーザーはコンバージョン率が高い」などの関係性を探索的に学習し、10分以上サイトを閲覧しているユーザーのコンバージョン率を高く出力するような挙動を取ることになります。
一方、サイトを運営していくうちに、サイト上のコンテンツが増えてきてユーザー全体の平均的な閲覧時間が引き上がるといったようなことがあると、モデルが学習した「平均的に10分以上閲覧しているユーザーはコンバージョン率が高い」という関係性が実状と合わなくなってくる可能性があります。
このように、推論時点でのデータが学習時点のものと許容範囲を超えて変化することをデータドリフトといい、モデルの推論精度の低下などの影響を及ぼし、機械学習モデルの劣化を引き起こします。
データ分布の変化を検出する
データドリフトによる機械学習モデルの劣化への対処法はいくつかありますが、比較的シンプルな方法としてデータ分布の変化を監視し検出する方法があります。
各特徴量のデータ分布について、モデル学習時とその後の推論時での乖離度合いを示す距離関数を定めて定期的に監視します。
Amazon SageMaker Model MonitorやAzureMLのDataDriftDetectorオブジェクト、TensorFlow Data Validationなど、様々な機械学習プラットフォームでデータドリフトを検出するモジュールが実装されています。
最近ではGoogle Cloud Day: Digital ’21にて、データドリフト検知も含め特徴量の統合的な管理が可能なVertex Feature Storeが発表されています。
今回、既存の社内のデータパイプラインへの適応や実装コストなどを考慮して、同様の機能を持ったデータモニタリングワークフローを独自に実装しました。
数値列とカテゴリ列で分布の扱いは異なるため、分布間の距離関数の考え方も異なります。
数値列についてはシンプルで、比較する2つの数値列に対して、確率分布間の距離関数として一般的に用いられるWasserstein距離やKolmogorov-Smirnov距離などを計算することで分布の乖離度合いを測ることができます。
カテゴリ列については、各カテゴリ列について1つ1つ個別値の存在割合に対してユークリッド距離やカイ二乗統計量を計算することで、1つ1つの個別値の発生分布の乖離度合いを測ることができます。
今回の実装では簡便性や汎用性の観点から、数値列に対してはKolmogorov-Smirnov距離を、カテゴリ列に対してはカイ二乗統計量を適用しています。
Kolmogorov-Smirnov距離は2つの数値列の経験累積分布関数の差の最大値をとります。これを検定統計量として2つの確率分布に差があるかを検定するKolmogorov-Smirnov検定が一般に知られています。
カイ二乗統計量は通常カイ二乗検定に用いられますが、2つのカテゴリ列の対応する個別値に対して、それぞれの存在割合にカイ二乗統計量を適用することで2つのカテゴリ列の個別値の存在割合についての適合度検定を行うことができます。
処理の実装
電通デジタルではデータ管理にはBigQuery、機械学習パイプライン、データ処理バッチ構築にはAirflow(Cloud Composer)を使うケースが多いです。
そのため、AirflowでBigQuery上の指定したデータを抽出し、分布間距離を計算する仕様で実装します。
ここでは学習時のデータをBaseline、それと比較する推論時のデータをTargetとして、下記のようにBigQueryOperatorによって指定のクエリでデータを抽出してテーブルを作成、BigQueryToCloudStorageOperatorによってそれぞれCloud Storageへ転送します。このとき、日時やパーティションの指定で最新分のデータをスライスして動的に抽出できるようにクエリを書いておきます。
PythonOperatorで、Cloud Storageに抽出したデータから各特徴量についてデータ分布の乖離を検証する処理を実行し、一定以上の乖離があった場合にSlackにアラートを送信することでデータドリフトを検知できるようにします。
PythonOperatorのcallbackする処理がメインの部分になります。
BigQueryからCloud Storageに出力したファイルをデータフレームに読み込みます。
from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
import pandas as pd
gcs_hook = GoogleCloudStorageHook(google_cloud_storage_conn_id='CONN_ID')
obj_baseline_data = gcs_hook.download(bucket_name, baseline_data_file, 'baseline.csv')
obj_target_data = gcs_hook.download(bucket_name, target_data_file, 'target.csv')
df_baseline = pd.read_csv('baseline.csv', encoding='utf-8', sep=',')
df_target = pd.read_csv('target.csv', encoding='utf-8', sep=',')
読み込んだBaselineデータフレームから1列ずつ読んでいき、数値型列なら対応するTargetデータの列とのKolmogorov-Smirnov距離を計算、オブジェクト型列(カテゴリ列)ならそれぞれの個別値の存在割合を算出し対応するTargetデータの列とのカイ二乗統計量を計算し、結果用のデータフレームに格納します。
Kolmogorov-Smirnov距離とカイ二乗統計量はScipyのks_2sampとchisquareを使うと簡単に計算でき、これらはそれぞれKolmogorov-Smirnov検定と適合度に関するカイ二乗検定を実行するので、算出されるp値も含めて結果用のデータフレームに格納します。
from scipy.stats import ks_2samp, chisquare
df_result = pd.DataFrame(columns=['col_name', 'col_type', 'statistic', 'pvalue'])
for col in df_baseline:
col_type = df_baseline[col].dtype
# numerical variables
if col_type in ['int', 'float']:
ks = ks_2samp(df_baseline[col], df_target[col])
df_result = df_result.append({'col_name': col, 'col_type': 'numerical', 'statistic': ks.statistic, 'pvalue': ks.pvalue}, ignore_index=True)
# categorical variables
elif col_type == 'object':
df_baseline_vc = df_baseline[col].value_counts(normalize=True)
df_stream_vc = df_stream[col].value_counts(normalize=True)
vc = pd.merge(df_stream_vc, df_baseline_vc, how='left', left_index=True, right_index=True)
vc = vc.fillna(0) * 100
chsq = chisquare(vc[f'{col}_y'], vc[f'{col}_x'])
df_result = df_result.append({'col_name': col, 'col_type': 'categorical', 'statistic': chsq.statistic, 'pvalue': chsq.pvalue}, ignore_index=True)
計算結果から「p値が設定した閾値を下回る」 = 「BaselineとTargetで分布に有意な差が認められる」ような列を抽出することでデータドリフトを起こしている特徴列を検知します。p値の閾値の設定は任意ですが、ここでは通常の統計的仮説検定で用いられる0.05を設定しています。
SlackWebhookHookを用いてデータドリフトを検知した場合に列名と併せてSlackに通知されるようにします。
from airflow.contrib.hooks.slack_webhook_hook import SlackWebhookHook
pval_th = 0.05
df_drift = df_result[df_result['pvalue'] < pval_th]
if df_drift.empty:
pass
else:
SlackWebhookHook(
webhook_token='SLACK_WEBHOOK_URL',
attachments=[{
'title':'Data Drift Alert!',
'text': df_drift.to_string(header=False, index=False)
}],
channel='SLACK_CHANNEL'
).execute()
こうすることで、ドリフトした特徴列を検知した際に下記のような通知が指定のチャンネルに送信されます。
また、データ分布間距離の算出結果をBigQueryテーブルに保存し、データポータルを用いて時系列に可視化することで、ベースラインとの乖離の推移を見ることができます。
このようなモニタリングワークフローでドリフトを検知することで、モデルの再学習などの対応をクイックに取ることが可能になります。
機械学習モデルの劣化の原因はこのような入力データのドリフトだけではない場合もありますが、日々取得されるデータの性質を適切に監視できることは、より品質の高い機械学習運用に繋がります。
データドリフト検知の手法について関心のある方は各機械学習プラットフォームのドキュメントや参考文献にある論文などもご参照ください。
参考文献
・Amazon SageMaker Model Monitor
・AzureML DataDriftDetecter
・TensorFlow Data Validation
・Automating Large-Scale Data Quality Verification - Amazon Research
・Differential Data Quality Verification
on Partitioned Data - Amazon Research
・Methodologies for Data Quality Assessment and Improvement - Carlo Batini, Cinzia Cappiello, Chiara Francalanci, Andrea Maurino