見出し画像

AWS_RDSからSnowflakeへのデータ投入を自動化する #461

RDS (MySQL)に保存しているデータを一部抽出してSnowflakeに保存する処理を自動化する方法についてまとめます。

Step FunctionsとLambda、S3を使用して段階的に処理しています。

Step Fuctionsの定義

以下のように定義できます(ちょこちょこマスキングを入れています)。

  1. Slackに開始メッセージを通知

  2. LambdaでRDSからデータ抽出してS3へtsv形式で保存

  3. 2の処理が成功したかチェック(失敗していたらSlackにエラー通知)

  4. LambdaでS3のファイル(のデータ)をSnowflakeへcsv形式で保存

  5. 4の処理が成功したかチェック(失敗していたらSlackにエラー通知)

  6. Slackに終了メッセージを通知

{
  "Comment": "Step Function to export data to S3, notify via Slack, and then import to DB if export_file_name is not None",
  "StartAt": "StartSlackNotification",
  "States": {
    "StartSlackNotification": {
      "Type": "Task",
      "Resource": "arn:aws:lambda:XXXXXXXXXXXXXX",
      "Parameters": {
        "status": "normal",
        "message": "開始メッセージ"
      },
      "Next": "ExportToS3"
    },
    "ExportToS3": {
      "Type": "Task",
      "Resource": "arn:aws:lambda:YYYYYYYYYYYYY",
      "Next": "CheckExportResult"
    },
    "CheckExportResult": {
      "Type": "Choice",
      "Choices": [
        {
          "Variable": "$.statusCode",
          "NumericEquals": 200,
          "Next": "ImportToSnowflakeFromS3"
        }
      ],
      "Default": "ErrorSlackNotification"
    },
    "ImportToSnowflakeFromS3": {
      "Type": "Task",
      "Resource": "arn:aws:lambda:ZZZZZZZZZZZZZ",
      "Parameters": {
        "import_file_name.$": "$.body.export_file_name",
        "target_database_name": "TARGETDB"
      },
      "ResultPath": "$.importResult",
      "Next": "CheckImportResult"
    },
    "CheckImportResult": {
      "Type": "Choice",
      "Choices": [
        {
          "Variable": "$.importResult.statusCode",
          "NumericEquals": 200,
          "Next": "EndSlackNotification"
        }
      ],
      "Default": "ErrorSlackNotification"
    },
    "EndSlackNotification": {
      "Type": "Task",
      "Resource": "arn:aws:lambda:XXXXXXXXXXXXXX",
      "Parameters": {
        "status": "normal",
        "message": "終了メッセージ"
      },
      "End": true
    },
    "ErrorSlackNotification": {
      "Type": "Task",
      "Resource": "arn:aws:lambda:XXXXXXXXXXXXXX",
      "Parameters": {
        "status": "error",
        "message": "エラーメッセージ"
      },
      "End": true
    }
  }
}

以下で、Lambdaを活用している2と4の処理について触れます。

LambdaでRDSからデータ抽出してS3へ保存

ここでは大まかに以下の手順になっています。

  1. Lambdaのレイヤーに「pymysql」をインストール

  2. pymysqlにRDSの接続情報を渡し、コネクションを作る

  3. 保存先の情報を作る

    1. S3クライアントのインスタンスを作成

    2. 保存ファイル名やパスを設定

  4. RDS (MySQL) へSQLを叩いて結果をフェッチ

  5. 結果をtsv形式のファイルへ書き込み

    1. csvにしていないのは、結果の文字列にカンマを含む場合があるため

  6. 作成したtsvファイルをS3に保存

    1. 保存前の処理として、seek()メソッドに0を指定してストリーム内のポインタを先頭に移動

    2. これはs3.put_object()メソッドが正しいデータを読み込むように準備するため

    3. s3.put_object()は、ストリーム内の先頭からデータを読み込む

  7. 無事にS3へ保存されたら、ステータスコードやファイル名などの情報をStep Functionsへ返す

import logging
import pymysql  # 外部ライブラリのためLambdaのレイヤーでインストール
import boto3
import os
import datetime
from io import StringIO

logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Lambda環境変数から設定を読み込み
user_name = os.environ.get('USER_NAME')
password = os.environ.get('PASSWORD')
host = os.environ.get('HOST')
db_name = os.environ.get('DB_NAME')
s3_bucket = os.environ.get('S3_BUCKET')
save_file_path = os.environ.get('SAVE_FILE_PATH', 'sample_directory')

sql_query = """
    select
        table_a.name as sample_name,
        table_b.data as sample_data
    from table_a
    inner join table_b
    on table_a.data_id = table_b.id
"""

try:
    # RDS MySQL への接続
    conn = pymysql.connect(host=host, user=user_name, passwd=password, db=db_name, connect_timeout=5)
    logger.info("SUCCESS: Connection to RDS for MySQL instance succeeded")
except pymysql.MySQLError as e:
    logger.error("ERROR: Unexpected error: Could not connect to MySQL instance.")
    logger.error(e)
    sys.exit(1)

# S3 クライアントのインスタンスを作成
s3 = boto3.client('s3')

def lambda_handler(event, context):
    save_file_name = event.get('save_file_name') or datetime.datetime.now().strftime("%Y%m%d") + '.tsv'
    full_save_file_name = os.path.join(save_file_path, save_file_name)
    logger.info(f"save file: {s3_bucket}/{full_save_file_name}")

    try:
        with conn.cursor() as cursor:
            cursor.execute(sql_query)
            result = cursor.fetchall()

            if not result:
                logger.info(f"ERROR: No data found to upload.")
                return {
                    "statusCode": 500,
                    "body": {
                        "count_export_record": 0,
                        "export_file_name": "None",
                        "full_export_file_name": "None"
                    }
                }

            # 結果を TSV 形式で S3 に保存
            tsv_buffer = StringIO()
            for row in result:
                tsv_buffer.write('\t'.join(str(value) for value in row) + '\n')
            tsv_buffer.seek(0)  # 次のs3.put_object()メソッドは、ストリーム内の先頭からデータを読み込むため、seek()でストリーム内のポインタを先頭に移動する
            response = s3.put_object(Bucket=s3_bucket, Key=full_save_file_name, Body=tsv_buffer.getvalue())
            if response['ResponseMetadata']['HTTPStatusCode'] == 200:
                logger.info(f"SUCCESS: Uploaded data to S3 bucket {s3_bucket}/{full_save_file_name}")
                return {
                    "statusCode": 200,
                    "body": {
                        "count_export_record": len(result),
                        "export_file_name": save_file_name,
                        "full_export_file_name": full_save_file_name
                    }
                }
            else:
                logger.info(f"ERROR: Uploaded data to S3 bucket {s3_bucket}/{full_save_file_name}")
                return {
                    "statusCode": 500,
                    "body": {
                        "count_export_record": 0,
                        "export_file_name": "None",
                        "full_export_file_name": "None"
                    }
                }

    except Exception as e:
        logger.error("ERROR: Unexpected error occurred.")
        logger.error(e)
        return {
            "count_export_record": 0,
            "export_file_name": "None",
            "full_export_file_name": "None"
        }

続いてS3のファイルからSnowflakeへデータを保存する処理です。

LambdaでS3のファイルのデータをSnowflakeへ保存

処理の解説に入る前に、外部のデータストレージに格納されたデータへSnowflakeがアクセスするための重要な仕組みである、名前付きステージ (Named Stage)について触れておきます。

名前付きステージ (Named Stage)

特定の外部ストレージ(S3など)の場所とクレデンシャルを設定します。これにより、Snowflakeはステージ経由で外部データにアクセスできます。ステージの作成方法は公式ドキュメントを参照してください。

つまりSnowflakeからS3のパスを指定する際は、名前付きステージのパスとS3ディレクトリのパスを繋げます。また、名前付きステージのパスを指定する時は頭に「@」をつけます。

@名前付きステージのパス/S3ディレクトリのパス


Lambdaの手順

では具体的な手順を見ていきます。

  1. Lambdaのレイヤーに「snowflake.connector」をインストール

  2. S3ディレクトリへのパス、Snowflakeへの接続情報などを環境変数から読み込み

  3. Step Functionsから先ほどRDSから保存したファイル名を取得(lambda_handlerの引数event)

  4. Snowflakeの対象テーブルを指定し、更新処理(delete - insert)

    1. snowflake.connectorで接続を確立

    2. トランザクションを開始

    3. 削除前のレコード数を取得してログ出力したうえでdelete

    4. 削除後のレコード数を取得(削除確認であり0件になる想定)

    5. S3のデータを対象テーブルにインポート(ステージ経由で)

    6. インポート後のレコード数を取得してログ出力

    7. コミット

  5. Step Functionsにステータスコードや更新情報(レコード数)を返す

[環境変数の一部]

# 「@」マークを付けることで、Snowflakeの名前付きステージ(Named Stages)に対する参照を表す
# 以下ではwork_stageがステージで、sample_directoryはS3のディレクトリ(ステージを経由してディレクトリを指定している)
S3_BUCKET_PATH = @targetdb.work.work_stage/sample_directory
import snowflake.connector  # 外部ライブラリのためLambdaのレイヤーでインストール
import logging
import os

# ロガーの設定
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# 環境変数からの読み込み
SCHEMA = os.getenv('SNOWFLAKE_SCHEMA')
S3_BUCKET_PATH = os.environ.get('S3_BUCKET_PATH')  # Snowflakeのステージ経由のS3バケットパス

# Snowflake の接続パラメータの初期設定
SNOWFLAKE_PARAMS = {
    'user': os.getenv('SNOWFLAKE_USER'),
    'password': os.getenv('SNOWFLAKE_PASSWORD'),
    'account': os.getenv('SNOWFLAKE_ACCOUNT'),
    "warehouse": os.getenv('SNOWFLAKE_WAREHOUSE'),
    "database": os.getenv('SNOWFLAKE_DATABASE'),  # 後で更新するケースがある
    "schema": SCHEMA,
    "role": os.getenv('SNOWFLAKE_ROLE'),
}

def lambda_handler(event, context):
    # import_file_name パラメータの取得
    import_file_name = event.get('import_file_name', '')

    # フル S3 パスの構築
    full_s3_path = f'{S3_BUCKET_PATH}/{import_file_name}'
    logger.info(f'import s3 file name: {full_s3_path}')

    # Snowflake の接続パラメータを更新
    SNOWFLAKE_PARAMS['database'] = event.get('target_database_name', os.getenv('SNOWFLAKE_DATABASE'))
    logger.info(f"database: {SNOWFLAKE_PARAMS['database']}")

    # 対象テーブルを指定
    target_table_name = os.getenv('TARGET_TABLE')

    try:
        # Snowflake に接続
        ctx = snowflake.connector.connect(**SNOWFLAKE_PARAMS)

        # トランザクションを開始
        ctx.cursor().execute('BEGIN TRANSACTION;')


        # 削除前のレコード数を取得
        count_before_import = ctx.cursor().execute(f'SELECT COUNT(*) FROM {target_table_name};').fetchone()
        logger.info(f'Count before import: {count_before_import[0]}')

        # テーブルからデータを削除
        ctx.cursor().execute(f'DELETE FROM {target_table_name};')

        # 削除後のレコード数を取得
        count_after_delete = ctx.cursor().execute(f'SELECT COUNT(*) FROM {target_table_name};').fetchone()
        logger.info(f'Count after delete: {count_after_delete[0]}')

        # S3 からデータをインポート
        # FORCE = TRUE は、強制的にファイルを読み込むモード
        # Snowflake では同一ファイルを読み込む場合、デフォルトでスキップする。
        # Delete 後、ファイルでテーブルを更新するため、このモード設定が必要。
        copy_sql = f'''
        COPY INTO {target_table_name} FROM '{full_s3_path}'
        FILE_FORMAT = (TYPE = 'CSV' FIELD_DELIMITER = '\\t')
        FORCE = TRUE
        '''
        ctx.cursor().execute(copy_sql)

        # インポート後のレコード数を取得
        count_after_import = ctx.cursor().execute(f'SELECT COUNT(*) FROM {target_table_name}').fetchone()
        logger.info(f'Count after import: {count_after_import[0]}')

        # トランザクションをコミット
        ctx.cursor().execute('COMMIT;')

    except Exception as e:
        # エラーが発生した場合はロールバック
        ctx.cursor().execute('ROLLBACK;')
        logger.error(f'Error: {e}')
        return {
            'statusCode': 500,
            'body': {'error': str(e)}
        }
    finally:
        ctx.close()

    return {
        'statusCode': 200,
        'body': {
            'count_before_import': count_before_import[0],
            'count_after_delete': count_after_delete[0],
            'count_after_import': count_after_import[0]
        }
    }


あとはこれをEvent Bridgeなりで定期実行するように仕込めば、完全自動化の完成です。

ここまでお読みいただきありがとうございました!

参考


この記事が気に入ったらサポートをしてみませんか?