SageMaker Training JobsでFine-Tuningを行う際にmodel.tar.gzをS3から読み込む
こんにちは、エンジニアのすずきです。
以前の記事で、SageMaker Training JobsによるTabBERTモデルの事前学習を行ったので、今回は事前学習の結果model.tar.gzを元にFine-Tuningを実行するJobを作成しました。
基本的には事前学習と同じようなJobなのですが、以下の部分で工夫が必要だったのでメモとしてまとめました。
tarファイルの展開
環境変数によるローカルとSageMaker間での引数の切替
なお、Fine-Tuningの元コードについては、以下の記事で解説しています。
tarファイルの展開
事前学習のJobによって、S3上にmodel.tar.gzが保存されています。
model.tar.gz内には、モデルpytorch_model.bin、設定ファイルconfig.json、辞書ファイルvocab.nb、トークン→id変換ファイルvocab_token2id.binが入っています。
Fine-Tuningではこれらを読み込む必要があるため、Jobを実行するときにtarファイルを展開するような工夫を行います。
まずは、Jobファイルのinput_modelでmodel.tar.gzのS3パスを指定します。
これでJob実行時にmodel.tar.gzが/opt/ml/input/data/input_model/(model_path)以下に置かれます。
import sagemaker
from sagemaker.estimator import Estimator
session = sagemaker.Session()
role = sagemaker.get_execution_role()
estimator = Estimator(
image_uri=<イメージURL>,
role=role,
instance_type="ml.g4dn.2xlarge",
instance_count=1,
base_job_name="tabformer-opt-fine-tuning",
output_path="s3://<バケット名>/sagemaker/output_data/fine_tuning",
code_location="s3://<バケット名>/sagemaker/output_data/fine_tuning",
sagemaker_session=session,
entry_point="fine-tuning.sh",
dependencies=["tabformer-opt"],
hyperparameters={
"data_root": "/opt/ml/input/data/input_data/",
"data_fname": "summary",
"output_dir": "/opt/ml/model/",
"model_path": "/opt/ml/input/data/input_model/",
}
)
estimator.fit({
"input_data": "s3://<バケット名>/sagemaker/input_data/summary.csv",
"input_model": "s3://<バケット名>/sagemaker/output_data/pre_training/tabformer-opt-2022-12-16-07-00-45-931/output/model.tar.gz"
})
次にFine-Tuningの実行ファイルtabformer_bert_fine_tuning.py上に以下を記載します。
with tarfile.open(name=path.join(args.model_path, f'model.tar.gz'), mode="r:gz") as mytar:
mytar.extractall(path.join(args.model_path, f'model'))
token2id_file = path.join(args.model_path, f"model/vocab_token2id.bin")
vocab_file = path.join(args.model_path, f"model/vocab.nb")
pretrained_model = path.join(args.model_path, f"model/checkpoint-500/pytorch_model.bin")
pretrained_config = path.join(args.model_path, f"model/checkpoint-500/config.json")
tarfile.open()でmodel.tar.gzが読み込まれ、mytar.extractall(path.join(args.model_path, f'model'))で/opt/ml/input/data/input_model/model/以下に中身が展開されます。
これで、token2id_file = path.join(args.model_path, f"model/vocab_token2id.bin")のように展開されたファイルを読み込むことができるようになります。
環境変数によるローカルとSageMaker間での引数の切替
これでS3上のmodel.tar.gzを読み込めるようになったのですが、ローカルでFineTuningを行う際には読み込み先を変えたいケースもあると思います。
そんなときは、os.getenv('SM_MODEL_DIR')でSageMakerの環境変数SM_MODEL_DIR(コンテナ終了時にS3へアップロードされるディレクトリ)を取得し、ローカルとSageMaker(のJob)で読み込み先を切り替えます。
key = os.getenv('SM_MODEL_DIR')
if key :
with tarfile.open(name=path.join(args.model_path, f'model.tar.gz'), mode="r:gz") as mytar:
mytar.extractall(path.join(args.model_path, f'model'))
token2id_file = path.join(args.model_path, f"model/vocab_token2id.bin")
vocab_file = path.join(args.model_path, f"model/vocab.nb")
pretrained_model = path.join(args.model_path, f"model/checkpoint-500/pytorch_model.bin")
pretrained_config = path.join(args.model_path, f"model/checkpoint-500/config.json")
else :
vocab_file = path.join(args.model_path, f"vocab.nb")
token2id_file = path.join(args.model_path, f"vocab_token2id.bin")
pretrained_model = path.join(args.model_path, f"checkpoint-500/pytorch_model.bin")
pretrained_config = path.join(args.model_path, f"checkpoint-500/config.json")
参考資料
この記事が気に入ったらサポートをしてみませんか?