見出し画像

ABCIでMPT-7Bのファインチューニングを試す

前提知識

MPT-7Bは最近発表された商用利用可能な大規模言語モデルで、LLaMAに匹敵する性能を持っていると言われています。

ABCIは経産省が管轄する日本在住者なら誰でも安価に使えるスーパーコンピュータです。
(ただし登録がいろいろ大変なので法人が前提です/利用料は最低20万円から)

対象読者

行間が読める人。本文が間違っていても自分でソースコードに手を加えて修正できるスキルがある人。ABCIを使えるポジションの人。
僕も人間なのでミスはよくありますし、備忘録とこれからやろうとする人のために書いています。質問は受け付けません(自分でなんとかしてください)。

準備

思ったより大変だったのでメモ
まず、大前提として自宅のA6000x2のマシンでできるかと思ったら、ダメだった(12:57更新。ウソ:A6000x2でちゃんとできました)。

まず、MPTはTransformerなのでRWKVと違い、VRAMをめちゃくちゃ要求します。必要なVRAMの容量は、12*N(Nはパラメータ数)で概算できます。

たとえばGPT-13Bをやりたければ、12*13=156GBが必要ということになります。

12*7=84GBなので、もしかすると自宅のドスパラ製Memeplexサーバ
A6000x2(48GBx2=96GB)でもできるのかもしれない(12:57更新:できた)けど、とりあえず面倒だから確実に学習できるABCIで練習しました。

手順としてはまず、llm-foundryのリポジトリをgit cloneします。 

$ git clone https://github.com/mosaicml/llm-foundry.git
$ cd llm-foundry

ABCIの場合は、ここですぐセットアップできません。
まず、moduleをロードします。
MPT-7Bが動作するモジュールの組み合わせは、python 3.10 / cuda 11.7.1 / cudnn 8.4.1です。はっきりいってこの情報だけでもメモっておきたいのでこのエントリを書いてます

$ module load python/3.10/3.10.10
$ module load cuda/11.7/11.7.1
$ module load cudnn/8.4/8.4.1

ここで注意しなければならないのは、venvを使う場合は、venvを設定した後でモジュールを読み込む必要があることです。間違うとパスの順番の関係でPythonが動かなくなります。

ここでようやくpip install を走らせることができます。

$ pip install -e ".[gpu]" 

ここまでインタラクティブノード上でできるはずですが、エラーが出たら自分でなんとかしてください。
ここを突破できないスキルの人はこの先はもっと難しいと思います。

データセットの変換

まず最初は大人しくサンプルのページにある通りに動くかやってみましょう。ここで正常に動かなかったらセットアップのやり直しです。

僕はファインチューニングが試したかったので最初から用意されているサンプルではダメでした。

ファインチューニングをするには、まずhttps://github.com/mosaicml/llm-foundry/blob/main/scripts/train/yamls/mpt/finetune/7b_dolly_sft.yamlにある設定ファイルを使います。

そのままでも動くかもしれないけど、これだけでは芸がないので、dolly_15kを日本語化したhttps://huggingface.co/datasets/kunishou/databricks-dolly-15k-jaに変更してみます。

# Dataloaders
train_loader:
  name: finetuning
  dataset:
    hf_name: kunishou/databricks-dolly-15k-ja

また、modelも、エラーが出る場合は、損失関数をtorch_crossentropyにするととりあえず動く

# Model
model:
  name: mpt_causal_lm
  init_device: meta
  d_model: 4096
  n_heads: 32
  n_layers: 32
  expansion_ratio: 4
  max_seq_len: ${max_seq_len}
  vocab_size: 50368
  attn_config:
    attn_impl: triton
  loss_fn: torch_crossentropy

あと、最後の行、なぜかネットからデータ読み込むことになってる(雑だ)ので、コメントアウトしておく。完全なyamlは巻末に置いておきます。

 #load_path : oci://my-bucket/my-folder/mpt-7b/checkpoints/some_checkpoint.pt




このまま実行するとエラーが起きるので、

vi /home/自分のABCIユーザー名/.local/lib/python3.10/site-packages/llmfoundry/data/finetuning/tasks.py

で以下のブロックを追加

@dataset_constructor.register('kunishou/databricks-dolly-15k-ja')
def dolly_preprocessing_function(inp: Dict):
    """Format the text string."""
    PROMPT_FORMAT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n'
    try:
        if inp['input'] != '':
            instruction = inp['instruction'] + '\n' + inp['input']
        else:
            instruction = inp['instruction']
        prompt = PROMPT_FORMAT.format(instruction=instruction)
        response = inp['output']
    except Exception as e:
        raise ValueError(
            f'Unable to extract prompt/response from {inp=}') from e
    return {'prompt': prompt, 'response': response}

インストールしたパッケージを弄るのはちょっと乱暴すぎるのでもう少しマシな方法が公式に説明されている。
が、面倒が増えそうなので無視した。ローカルのデータセットを学習するときにはこの方法を使う方が効率が良さそうだがとりあえず動くところを目指す。

$ cd scripts
$ composer train/train.py train/yamls/mpt/finetune/7b_dolly_sft.yaml

さあこれでうまくいけば学習が開始される。
ちなみにデフォルトのyamlだと1ep(1epoch)しか学習しない設定なので適宜ほしいエポック数に変えること。
また、yamlの末尾にセーブするディレクトリなどを設定する。
自分のユーザーディレクトリに保存すると、すぐにクォータがいっぱいになって死ぬので、保存先は必ずスクラッチパッドを使うこと

[epoch=1][batch=213/235]:
	 Train time/batch: 212
	 Train time/sample: 13457
	 Train time/batch_in_epoch: 212
	 Train time/sample_in_epoch: 13457
	 Train time/token: 27559936
	 Train time/token_in_epoch: 27559936
	 Train memory/allocated_mem: 17.6580
	 Train memory/active_mem: 17.6580
	 Train memory/inactive_mem: 2.2777
	 Train memory/reserved_mem: 36.8890
	 Train memory/alloc_retries: 2
	 Train trainer/device_train_microbatch_size: 8
	 Train loss/train/total: 4.3715
	 Train metrics/train/LanguageCrossEntropy: 4.3677
	 Train metrics/train/LanguagePerplexity: 78.8618
	 Train throughput/batches_per_sec: 0.2320
	 Train throughput/samples_per_sec: 14.7099
	 Train throughput/device/batches_per_sec: 0.0290
	 Train throughput/device/samples_per_sec: 1.8387
	 Train throughput/flops_per_sec: 1300668282604287.2500
	 Train throughput/device/flops_per_sec: 162583535325535.9062
	 Train throughput/device/mfu: 0.5211
	 Train time/train: 0.2841
	 Train time/val: 0.0000
	 Train time/total: 0.2841
	 Train lr-DecoupledAdamW/group0: 0.0000
	 Train time/remaining_estimate: 30.9840

さて、学習した。
とりあえず1epoch。こんなので何も変わらないと思うが学習できたということが大事だ。

推論

学習したら推論である。
しかし、MPT-7Bは、そう簡単に推論できない。
まず、MPT-7Bで推論できるようにするために、学習したチェックポイントを変換する必要がある。

$ python3 inference/convert_composer_to_hf.py --composer_path dolly15j/checkpoints/ep0-ba200-rank0.pt --hf_output_path /scratch/自分のABCIID/out --output_precision bf16

これで/scratch/ABCIユーザーID/outに変換された。
いよいよ推論だ。

$ python3 inference/hf_generate.py --name_or_path /scratch/aca10054zv/out --prompts "光の三原色"
Loading HF Config...
Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Loading HF model to device=cuda:0 and dtype=torch.bfloat16...
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.
You are using config.init_device='cpu', but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.
n_params=6658859008

Loading HF tokenizer...

Generate kwargs:
{'max_new_tokens': 100, 'temperature': 1.0, 'top_p': 1.0, 'top_k': 50, 'use_cache': True, 'do_sample': True, 'eos_token_id': 0, 'pad_token_id': 0}

Tokenizing prompts...
NOT using autocast...
Warming up...
Generating responses...
####################################################################################################
光の三原色、キ��はジ、フラシニアンレイロ
バリ�、デリ
バミ・ママ・アッドレはマカ(マは、デ、ラのルリ()年
ビ年
オント、ロメラーラーのカリ:パレホである(セフ�とベスメン」はあ月月
####################################################################################################
bs=1, input_tokens=array([6]), output_tokens=array([89])
total_input_tokens=6, total_output_tokens=89
encode_latency=73.43ms, gen_latency=1981.96ms, decode_latency=41.22ms, total_latency=2096.61ms
latency_per_output_token=23.56ms/tok
output_tok_per_sec=42.45tok/sec

やはり1エポックではなにもできないというか却って悪くなってる気さえする。
まあ本当はresponse形式にしなきゃなんないのかもしれないけど。

とにかく学習のようなものが回せて、推論できた。
あとはデータをどうするかハイパラをどうするか考えるだけだ。

疲れた。朝5時から取り掛かっていろいろハマって昼過ぎになってしまった。

おまけ:完全な設定ファイル(yaml)

個人のディレクトリ的なところだけ隠した完全な設定ファイルを置いておきます。参考まで

max_seq_len: 2048
global_seed: 17

# Run Name
run_name: dolly15j 
# If left blank, will be read from env var $C/baOMPOSER_RUN_NAME

# Model
model:
  name: mpt_causal_lm
  init_device: meta
  d_model: 4096
  n_heads: 32
  n_layers: 32
  expansion_ratio: 4
  max_seq_len: ${max_seq_len}
  vocab_size: 50368
  attn_config:
    attn_impl: triton
  loss_fn: torch_crossentropy

# Tokenizer
tokenizer:
  name: EleutherAI/gpt-neox-20b
  kwargs:
    model_max_length: ${max_seq_len}

# Dataloaders
train_loader:
  name: finetuning
  dataset:
    hf_name: kunishou/databricks-dolly-15k-ja 
    split: train
    max_seq_len: ${max_seq_len}
    allow_pad_trimming: false
    decoder_only_format: true
    shuffle: true
    # Use `python llmfoundry/data/packing.py --yaml-path /path/to/this/yaml/ ...` to profile
    # this run's optimal packing_ratio
    # packing_ratio:
  drop_last: true
  num_workers: 8
  pin_memory: false
  prefetch_factor: 2
  persistent_workers: true
  timeout: 0

# There is no validation split so we skip eval_loader

# Optimization
scheduler:
  name: linear_decay_with_warmup # linear no warmup is HF default which dolly used
  t_warmup: 0ba
  alpha_f: 0

optimizer:
  # mimic HF defaults to replicate dolly
  name: decoupled_adamw
  lr: 1.0e-5
  betas:
  - 0.9
  - 0.999
  eps: 1.0e-8
  weight_decay: 0

algorithms:
  gradient_clipping:
    clipping_type: norm
    clipping_threshold: 1.0


max_duration: 100ep
eval_interval: 1 # this is the only allowed value for no eval
# eval_first: false
# eval_subset_num_batches: -1
global_train_batch_size: 64 # assuming 8 gpus

# System
seed: ${global_seed}
device_eval_batch_size: 8
device_train_microbatch_size: 8
# device_train_microbatch_size: auto
precision: amp_bf16

# FSDP
fsdp_config:
  sharding_strategy: FULL_SHARD
  mixed_precision: PURE
  activation_checkpointing: true
  activation_checkpointing_reentrant: false
  activation_cpu_offload: false
  limit_all_gathers: true
  verbose: false

# Logging
progress_bar: false
log_to_console: true
console_log_interval: 1ba

callbacks:
  speed_monitor:
    window_size: 10
  lr_monitor: {}
  memory_monitor: {}
  runtime_estimator: {}

# loggers:
#   wandb: {}

# Checkpoint to local filesystem or remote object store
  
save_interval: 200ba
# save_num_checkpoints_to_keep: 1  # Important, this cleans up checkpoints saved to DISK



save_folder: /scratch/あなたのABCIユーザー名/{run_name}/checkpoints
# save_folder: s3://my-bucket/my-folder/{run_name}/checkpoints

# Load from remote object store
# REPLACE THE BELOW with you own checkpoint! #load_path : oci://my-bucket/my-folder/mpt-7b/checkpoints/some_checkpoint.pt