![見出し画像](https://assets.st-note.com/production/uploads/images/86709237/rectangle_large_type_2_0060e369acefa63722bde5b05082b190.png?width=800)
HuggingFace Accelerate の概要
「Accelerate」の概要についてまとめました。
・Accelerate v0.12.0
1. Accelerate
「Accelerate」は、PyTorchの CPU / GPU / TPU 対応を共通コードで書けるようにするためのパッケージです。
次に例を示します。
import torch
import torch.nn.functional as F
from datasets import load_dataset
+ from accelerate import Accelerator
+ accelerator = Accelerator()
- device = 'cpu'
+ device = accelerator.device
model = torch.nn.Transformer().to(device)
optimizer = torch.optim.Adam(model.parameters())
dataset = load_dataset('my_dataset')
data = torch.utils.data.DataLoader(dataset, shuffle=True)
+ model, optimizer, data = accelerator.prepare(model, optimizer, data)
model.train()
for epoch in range(10):
for source, targets in data:
source = source.to(device)
targets = targets.to(device)
optimizer.zero_grad()
output = model(source)
loss = F.cross_entropy(output, targets)
- loss.backward()
+ accelerator.backward(loss)
optimizer.step()
学習スクリプトに5 行追加するだけで、任意の単一・分散ノード設定 (single CPU / single GPU / multi-GPU / TPU) で実行できるようになります。デバイス配置も処理するため、学習ループをさらに単純化できます。
import torch
import torch.nn.functional as F
from datasets import load_dataset
+ from accelerate import Accelerator
- device = 'cpu'
+ accelerator = Accelerator()
- model = torch.nn.Transformer().to(device)
+ model = torch.nn.Transformer()
optimizer = torch.optim.Adam(model.parameters())
dataset = load_dataset('my_dataset')
data = torch.utils.data.DataLoader(dataset, shuffle=True)
+ model, optimizer, data = accelerator.prepare(model, optimizer, data)
model.train()
for epoch in range(10):
for source, targets in data:
- source = source.to(device)
- targets = targets.to(device)
optimizer.zero_grad()
output = model(source)
loss = F.cross_entropy(output, targets)
- loss.backward()
+ accelerator.backward(loss)
optimizer.step()
2. 起動スクリプト
「Accelerate」には、スクリプトの実行前に学習環境をすばやく設定およびテストできるCLIツールも提供します。
$ accelerate config
In which compute environment are you running? ([0] This machine, [1] AWS (Amazon SageMaker)): 0
Which type of machine are you using? ([0] No distributed training, [1] multi-CPU, [2] multi-GPU, [3] TPU [4] MPS): 0
Do you want to run your training on CPU only (even if a GPU is available)? [yes/NO]:NO
Do you want to use DeepSpeed? [yes/NO]: NO
Do you wish to use FP16 or BF16 (mixed precision)? [NO/fp16/bf16]: NO
そして聞かれた質問に回答します。これにより、実行時に利用する設定ファイルが生成されます。
$ accelerate launch my_script.py --args_to_my_script
3. MPI による multi CPU 実行
MPI による multi CPU 実行を開始する別の方法を次に示します。Open MPI のインストール方法については、このページで学習できます。 Intel MPI または MVAPICH も使用できます。クラスターで MPI をセットアップしたら、以下を実行します。
$ mpirun -np 2 python examples/nlp_example.py
4. DeepSpeed による学習開始
「Accelerate」は、「DeepSpeed」によるsingle/ multi GPU での学習をサポートします。これを利用するために、コードを変更する必要はありません。Accelerateの設定で設定できます。
Python スクリプトから DeepSpeed関連の引数をファインチューニングしたい場合は、DeepSpeedPlugin を利用します。
from accelerator import Accelerator, DeepSpeedPlugin
# deepspeed needs to know your gradient accumulation steps before hand, so don't forget to pass it
# Remember you still need to do gradient accumulation by yourself, just like you would have done without deepspeed
deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=2)
accelerator = Accelerator(fp16=True, deepspeed_plugin=deepspeed_plugin)
# How to save your 🤗 Transformer?
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(save_dir, save_function=accelerator.save, state_dict=accelerator.get_state_dict(model))
5. ノートブックでの学習開始
「Accelerate」の分散学習の起動は、ノートブックでも使用できます。notebook_launcher()が提供されており、これはTPUバックエンドを備えた Colab または Kaggleノートブックで役立ちます。
training_functionで学習ループを定義し、最後のセルに以下を追加します。
from accelerate import notebook_launcher
notebook_launcher(training_function)
使用例は、このノートブックにあります。
6. Accelerateを使用すべき場合
学習ループの完全な制御を放棄することなく、分散環境で学習スクリプトを手軽に実行したい場合は、「Accelerate」を使用すべきです。これはPyTorchの高レベルフレームワークではなくラッパーであるため、新しいライブラリを学習する必要はありません。実際、「Accelerate」の 全APIはAcceleratorクラスにあります。
7. Accelerateを使用すべきでない場合
自分で学習ループを作成したくない場合は、「Accelerate」を使用すべきではありません。PyTorchに高レベルのライブラリがたくさんあります。
この記事が気に入ったらサポートをしてみませんか?