見出し画像

SageMakerで学習したXGBoostモデルのFeature Importanceを取得するAirflowオペレーター

電通デジタルでデータサイエンティストとして働いている長島です。

本記事ではSageMakerで学習したXGBoostモデルのFeature Importance取得をAirflowで自動化する方法を紹介します。

SageMakerにはXGBoostをはじめとする組み込みモデルが多数用意されており、容易に学習・推論を行うことができます[1]。 これらに対応したAirflowオペレーターもあるため、機械学習フローをまとめて管理するのに非常に便利です[2]。 一方、モデル依存の情報を抜き出すような機能は備えていないため、XGBoostの強みの一つである、Feature Importanceの取得はSageMakerの機能として備わっていません。

今回は、SageMakerXGBoostImportanceOperatorという名前で、タイトルのような機能を持ったオペレーターを作成していきます。 これをトレーニングやチューニングオペレーターの後ろに繋げることで、学習が済んだモデルでどのfeatureが重要度が高いのか、自動的に出力することができます。

DAGイメージ

トレーニングとデプロイを行うDAGに、importance出力用のオペレータを追加すると以下のようになります。

train_xgb_op >> {deploy_xgb_op, plot_importance_op}

画像1

全体の流れ

基本的な流れは、ローカルやノートブックで実行するときと変わりません。

  • 1. モデルファイル(model.tar.gz)をS3からダウンロード&解凍
    2. pickleでロード
    3. モデルにfeature名を紐づける
    4. importanceをファイル出力
    5. 出力結果をS3にアップロード
  • という流れになります。
3.モデルにfeature名を紐づける

これは、モデルファイルに学習時のfeature名が保存されていないために別途用意が必要です。 出力後に照らし合わせても良いのですが、便利な xgboost.plot_importance を使いたいので、先に紐づけることにします。

事前準備

以下のライブラリをAirflow環境に事前にインストールしておきます。

  • ・pickle
    ・xgboost==0.90
    →最新版は1.0.2ですが、SageMakerのxgboostは0.90が最新です。(公開時点)
    ・japanize_matplotlib
    →これは必須ではないですが、feature_namesに日本語が入るときに簡単に文字化けが解消できるのでおすすめです。

コンストラクタ

オペレーターに渡す情報は以下4つです。

  • ・Airflowに登録したS3コネクションID
    ・S3バケット名
    ・モデルファイルのS3キー
    トレーニングの際、output_pathで指定した場所に、 model.tar.gz ファイルが出力されます。バケット名よりも後ろの部分になります。
    ・feature名

S3へのアクセスは、 airflow.hooks.S3_hook で提供されています[3]

    def __init__(self,
                 s3_conn_id: str,
                 s3_bucket: str,
                 model_key: str,
                 feature_names: list,
                 **kwargs) -> None:
        """
        Args:
            s3_conn_id: Airflow connection ID for S3
            s3_bucket: S3 bucket name
            model_key: key for trained XGBoost model
            feature_names: list of feature names
        Returns:
            S3 key where importance plot will be uploaded
        """

        self.s3_hook = S3Hook(aws_conn_id=s3_conn_id)
        self.s3_bucket = s3_bucket
        self.model_key = model_key
        self.feature_names = feature_names

        super(SageMakerXGBoostImportanceOperator, self).__init__(**kwargs)

モデルファイルのダウンロード・解凍

model.tar.gz ファイルには、唯一 xgboost-model というファイルが含まれています。 これをS3からダウンロードし、tarfileで解凍します。

        ## S3にファイルがあるか確認
        if not self.s3_hook.check_for_key(f's3://{self.s3_bucket}/{self.model_key}'):
            raise AirflowException(f'{self.model_key} is not created')

        ## 作業用に一時ディレクトリを作成
        tempdir = tempfile.TemporaryDirectory()

        ## ダウンロード&解凍
        local_path = path.join(tempdir.name, 'model.tar.gz')
        try:
            self.conn.download_file(
                Bucket=self.s3_bucket,
                Key=s3_key,
                Filename=local_path
            )
            with tarfile.open(local_path, 'r') as tar:
                mem_name = tar.getnames()[0]
                tar.extract(mem_name, tempdir.name)
        except Exception as e:
            raise e        

モデルのロード

xgboost-model はpickleで保存されたモデルファイルなので、 pickle.load を使うことで、メモリ上にモデルを復元することができます。 前述したとおり、このモデルファイルはfeature名を含まないので、別途用意したfeature名を紐づけています。

        ## モデルのロード
        model_path = path.join(tempdir.name, mem_name)
        model = pkl.load(open(model_path, 'rb'))
        model.feature_names = self.feature_names        

特徴量出力

通常の学習済みXGBoostモデルの扱いと同じです。 今回は、全featureのimportance一覧と、importanceの高いfeatureをplotしたものを出力します。

        ## importanceのファイル出力
        score_dict = model.get_score(importance_type='weight')
        csv_name = path.join(tempdir.name, 'importance.csv')
        with open(csv_name, 'wt') as f:
            f.write('\n'.join([name + ',' + str(imp) for name, imp in enumerate(score_dict)]))                

        ## plot
        fig = plt.figure(figsize=(10, 20))
        fig.subplots_adjust(left=0.2)
        ax = fig.add_subplot(1, 1, 1)
        xgboost.plot_importance(model,
                                ax=ax,
                                importance_type='weight',
                                show_values=False,
                                max_num_features=20)

        plt_name = path.join(tempdir.name, 'importance.png')
        plt.savefig(plt_name, format='png')

plotは以下のような形で出力されます。(feature名はマスクしてあります)

画像2

結果のアップロード

出力した2つのファイルを保存します。 ダウンロード同様、S3_hookを使って、モデルファイルがあった場所と同じ階層に置くことにします。

        ## 結果のアップロード
        try:
            csv_key = self.model_key.replace('model.tar.gz', 'importance.csv')
            self.conn.upload_file(
                Bucket=self.s3_bucket,
                Key=csv_key,
                Filename=csv_name
            )

            plt_key = self.model_key.replace('model.tar.gz', 'importance.png')
            self.conn.upload_file(
                Bucket=self.s3_bucket,
                Key=plt_key,
                Filename=plt_name
            )

        except Exception as e:
            raise e

全コード

from airflow.models import BaseOperator
from airflow.hooks.S3_hook import S3Hook
import xgboost
import pickle as pkl
import matplotlib.pyplot as plt
import tarfile
import tempfile
from os import path
import japanize_matplotlib ## importするだけでOK

class SageMakerXGBoostImportanceOperator(BaseOperator):

    def __init__(self,
                 s3_conn_id: str,
                 s3_bucket: str,
                 model_key: str,
                 feature_names: list,
                 **kwargs) -> None:
        """
        Args:
            s3_conn_id: Airflow connection ID for S3
            s3_bucket: S3 bucket name
            model_key: key for trained XGBoost model
            feature_names: list of feature names
        Returns:
            S3 key where importance plot will be uploaded
        """

        self.model_key = model_key
        self.feature_names = feature_names
        self.s3_hook = S3Hook(aws_conn_id=s3_conn_id)
        self.s3_bucket = s3_bucket
        self.conn = self.s3_hook.get_conn()

        super(SageMakerXGBoostImportanceOperator, self).__init__(**kwargs)

    def execute(self, context) -> str:
        ## S3にファイルがあるか確認
        if not self.s3_hook.check_for_key(f's3://{self.s3_bucket}/{self.model_key}'):
            raise AirflowException(f'{self.model_key} is not created')

        ## 作業用に一時ディレクトリを作成
        tempdir = tempfile.TemporaryDirectory()

        ## ダウンロード&解凍
        local_path = path.join(tempdir.name, 'model.tar.gz')
        try:
            self.conn.download_file(
                Bucket=self.s3_bucket,
                Key=self.model_key,
                Filename=local_path
            )
            with tarfile.open(local_path, 'r') as tar:
                mem_name = tar.getnames()[0]
                tar.extract(mem_name, tempdir.name)
        except Exception as e:
            raise e        

        ## モデルのロード
        model_path = path.join(tempdir.name, mem_name)
        model = pkl.load(open(model_path, 'rb'))
        model.feature_names = self.feature_names

        ## importanceのファイル出力
        score_dict = model.get_score(importance_type='weight')
        csv_name = path.join(tempdir.name, 'importance.csv')
        with open(csv_name, 'wt') as f:
            f.write('\n'.join([name + ',' + str(imp) for name, imp in enumerate(score_dict)]))                

        ## plot
        fig = plt.figure(figsize=(10, 20))
        fig.subplots_adjust(left=0.2)
        ax = fig.add_subplot(1, 1, 1)
        xgboost.plot_importance(model,
                                ax=ax,
                                importance_type='weight',
                                show_values=False,
                                max_num_features=20)

        plt_name = path.join(tempdir.name, 'importance.png')
        plt.savefig(plt_name, format='png')

        ## 結果のアップロード
        try:
            csv_key = self.model_key.replace('model.tar.gz', 'importance.csv')
            self.conn.upload_file(
                Bucket=self.s3_bucket,
                Key=csv_key,
                Filename=csv_name
            )

            plt_key = self.model_key.replace('model.tar.gz', 'importance.png')
            self.conn.upload_file(
                Bucket=self.s3_bucket,
                Key=plt_key,
                Filename=plt_name
            )

        except Exception as e:
            raise e
        return plt_key

まとめ

本記事では、SageMakerで学習したXGBoostモデルのFeature Importanceを取得・描画するAirflowオペレーターを作成してみました。 今回はXGBoostを取り上げましたが、他の組み込みモデルでもpickleやMXNet checkpoint形式でS3に保存されます。 学習済みのモデルから情報を抜き出したい場合は同じような手順で解決できますので、独自の分析を行いたい場合には試してみるとよいかもしれませんね。

参考

[1]https://docs.aws.amazon.com/ja_jp/sagemaker/latest/dg/algos.html
[2]https://aws.amazon.com/jp/blogs/news/build-end-to-end-machine-learning-workflows-with-amazon-sagemaker-and-apache-airflow/
[3]https://airflow.readthedocs.io/en/stable/_modules/airflow/hooks/S3_hook.html