見出し画像

MeZO(0次最適化)による大規模言語モデルのファインチューニングを試す

プリンストン大学の自然言語処理研究所が公開したMeZOというアルゴリズムを使うと、本来VRAMに入りきらないサイズの言語モデルでもファインチューニングができるという。

A100 80GBx1だと本来は2.7Bモデルまでしかファインチューニングできないが、MeZOを使うと30Bモデルを学習可能に?

勾配を保存せずに推定値と乱数で生成するから推論時と同じメモリ使用量で最適化できるらしい。えーと、それが本当だとするとこれまで勾配が増えるたびにメモリ量が莫大に必要だったという話がどこかへ飛んでいってしまうのだが・・・。

早速A100 80GBx1マシンで試す。
元のコードはOPT-13Bベースだったのだが、OPT-13Bはそう簡単に手に入らないので同じ規模のモデルであるLlama2-13Bで試してみた。

まず、Transformerが特定のバージョンじゃないと動かないので注意

$ pip install transformers==4.28.1
$ pip install accelerate==0.17.1

Torchは2.1、CUDAは11.8が推奨されているが、僕はCUDA12.1でPyTorchは2.1.0+cu121で動作確認した。

とりあえずLlama-2-13bをSST2でLoRAファインチューニングしてみる。

$ MODEL=meta-llama/Llama-2-13b-hf TASK=SST2 MODE=lora EPS=5e-5 LR=1e-2 CUDA_VISIBLE_DEVICES=0 nohup bash mezo.sh  & 
2023-11-11 19:45:11,711 - INFO - ***** Running training *****
2023-11-11 19:45:11,711 - INFO -   Num examples = 1000
2023-11-11 19:45:11,711 - INFO -   Num Epochs = 318
2023-11-11 19:45:11,711 - INFO -   Instantaneous batch size per device = 16
2023-11-11 19:45:11,711 - INFO -   Total train batch size (w. parallel, distributed & accumulation) = 16
2023-11-11 19:45:11,711 - INFO -   Gradient Accumulation steps = 1
2023-11-11 19:45:11,711 - INFO -   Total optimization steps = 20000
2023-11-11 19:45:11,711 - INFO -   Number of trainable parameters = 2048000
2023-11-11 22:44:43,317 - INFO - There are 0 training samples and 872 validation samples
{'eval_loss': 0.385498046875, 'eval_runtime': 7.835, 'eval_samples_per_second': 63.816, 'eval_steps_per_second': 8.041, 'epoch': 317.46}
{'train_runtime': 10771.6048, 'train_samples_per_second': 29.708, 'train_steps_per_second': 1.857, 'train_loss': 0.5619064697265626, 'epoch': 317.46}
  0%|          | 0/872 [00:00<?, ?it/s]2023-11-11 22:44:43,317 - INFO - ========= Example =========
2023-11-11 22:44:43,317 - INFO - Candidate: [0, 1]
2023-11-11 22:44:43,318 - INFO - Correct candidate: 1
2023-11-11 22:44:43,359 - INFO - === Candidate 0 ===
2023-11-11 22:44:43,359 - INFO - <s>the special effects and many scenes of weightlessness look as good or better than in the original , while the oscar-winning sound and james horner 's rousing score make good use of the hefty audio system . It was terrible
2023-11-11 22:44:43,361 - INFO - Log probabilities of the option tokens: tensor([-7.8711], dtype=torch.float16)
2023-11-11 22:44:43,391 - INFO - === Candidate 1 (without context)===
2023-11-11 22:44:43,391 - INFO - <s>the special effects and many scenes of weightlessness look as good or better than in the original , while the oscar-winning sound and james horner 's rousing score make good use of the hefty audio system . It was great
2023-11-11 22:44:43,391 - INFO - Log probabilities of the option tokens: tensor([-5.7539], dtype=torch.float16)
2023-11-11 22:44:43,391 - INFO - Prediction scores: [-7.87109375, -5.75390625]
2023-11-11 22:44:43,392 - INFO - ========= Example =========
2023-11-11 22:44:43,392 - INFO - Candidate: [0, 1]
2023-11-11 22:44:43,392 - INFO - Correct candidate: 0
2023-11-11 22:44:43,422 - INFO - === Candidate 0 ===
2023-11-11 22:44:43,422 - INFO - <s>well-nigh unendurable ... though the picture strains to become cinematic poetry , it remains depressingly prosaic and dull . It was terrible
2023-11-11 22:44:43,422 - INFO - Log probabilities of the option tokens: tensor([-6.6367], dtype=torch.float16)
2023-11-11 22:44:43,451 - INFO - === Candidate 1 (without context)===
2023-11-11 22:44:43,452 - INFO - <s>well-nigh unendurable ... though the picture strains to become cinematic poetry , it remains depressingly prosaic and dull . It was great
2023-11-11 22:44:43,452 - INFO - Log probabilities of the option tokens: tensor([-8.6484], dtype=torch.float16)
2023-11-11 22:44:43,452 - INFO - Prediction scores: [-6.63671875, -8.6484375]
  0%|          | 2/872 [00:00<00:58, 14.88it/s]2023-11-11 22:44:43,452 - INFO - ========= Example =========
2023-11-11 22:44:43,452 - INFO - Candidate: [0, 1]
2023-11-11 22:44:43,452 - INFO - Correct candidate: 0
2023-11-11 22:44:43,481 - INFO - === Candidate 0 ===
2023-11-11 22:44:43,481 - INFO - <s>i 've always dreamed of attending cannes , but after seeing this film , it 's not that big a deal . It was terrible
2023-11-11 22:44:43,481 - INFO - Log probabilities of the option tokens: tensor([-6.3789], dtype=torch.float16)
2023-11-11 22:44:43,510 - INFO - === Candidate 1 (without context)===
2023-11-11 22:44:43,510 - INFO - <s>i 've always dreamed of attending cannes , but after seeing this film , it 's not that big a deal . It was great
2023-11-11 22:44:43,511 - INFO - Log probabilities of the option tokens: tensor([-6.5586], dtype=torch.float16)
2023-11-11 22:44:43,511 - INFO - Prediction scores: [-6.37890625, -6.55859375]
100%|██████████| 872/872 [00:51<00:00, 17.04it/s]
2023-11-11 22:45:34,494 - INFO - There are 0 training samples and 500 validation samples
  0%|          | 0/500 [00:00<?, ?it/s]2023-11-11 22:45:34,494 - INFO - ========= Example =========
2023-11-11 22:45:34,494 - INFO - Candidate: [0, 1]
2023-11-11 22:45:34,494 - INFO - Correct candidate: 1
2023-11-11 22:45:34,523 - INFO - === Candidate 0 ===
2023-11-11 22:45:34,523 - INFO - <s>i liked it just enough . It was terrible
2023-11-11 22:45:34,523 - INFO - Log probabilities of the option tokens: tensor([-6.2656], dtype=torch.float16)
2023-11-11 22:45:34,551 - INFO - === Candidate 1 (without context)===
2023-11-11 22:45:34,551 - INFO - <s>i liked it just enough . It was great
2023-11-11 22:45:34,552 - INFO - Log probabilities of the option tokens: tensor([-6.0547], dtype=torch.float16)
2023-11-11 22:45:34,552 - INFO - Prediction scores: [-6.265625, -6.0546875]
2023-11-11 22:45:34,552 - INFO - ========= Example =========
2023-11-11 22:45:34,552 - INFO - Candidate: [0, 1]
2023-11-11 22:45:34,552 - INFO - Correct candidate: 1
2023-11-11 22:45:34,580 - INFO - === Candidate 0 ===
2023-11-11 22:45:34,580 - INFO - <s>make you laugh It was terrible
2023-11-11 22:45:34,580 - INFO - Log probabilities of the option tokens: tensor([-6.9883], dtype=torch.float16)
2023-11-11 22:45:34,608 - INFO - === Candidate 1 (without context)===
2023-11-11 22:45:34,608 - INFO - <s>make you laugh It was great
2023-11-11 22:45:34,608 - INFO - Log probabilities of the option tokens: tensor([-5.5273], dtype=torch.float16)
2023-11-11 22:45:34,608 - INFO - Prediction scores: [-6.98828125, -5.52734375]
  0%|          | 2/500 [00:00<00:28, 17.55it/s]2023-11-11 22:45:34,608 - INFO - ========= Example =========
2023-11-11 22:45:34,608 - INFO - Candidate: [0, 1]
2023-11-11 22:45:34,608 - INFO - Correct candidate: 1
2023-11-11 22:45:34,636 - INFO - === Candidate 0 ===
2023-11-11 22:45:34,637 - INFO - <s>is more accurate than anything i have seen in an american film It was terrible
2023-11-11 22:45:34,637 - INFO - Log probabilities of the option tokens: tensor([-6.7930], dtype=torch.float16)
2023-11-11 22:45:34,665 - INFO - === Candidate 1 (without context)===
2023-11-11 22:45:34,665 - INFO - <s>is more accurate than anything i have seen in an american film It was great
2023-11-11 22:45:34,665 - INFO - Log probabilities of the option tokens: tensor([-5.9766], dtype=torch.float16)
2023-11-11 22:45:34,665 - INFO - Prediction scores: [-6.79296875, -5.9765625]
100%|██████████| 500/500 [00:28<00:00, 17.29it/s]
2023-11-11 22:46:03,414 - INFO - ===== Train set 0 =====
2023-11-11 22:46:03,414 - INFO - {'accuracy': 0.7477064220183486, 'dev_accuracy': 0.74}

1エポックしか回してないのでaccは0.74くらい。1エポックにかかった時間は3時間くらいだった。まあタスクが単純なのでこれだとちゃんとできてるのかもともとできるのかよくわからない。ただ、元のタスクと違いすぎるからたぶん学習そのものはできていると思われる。

ちなみに試しにA100一枚で70Bを推論してみると、もの凄く時間がかかった(数十秒から数分)。当たり前なのだが言語モデルの規模が大きくなると計算量も大きくなるため、学習にも推論にも時間がかかる。

ChatGPTを使い始めた時、かなり早く応答が来てびっくりしたのだが、GPT-3.5-TurboからGPT-4になるとめちゃくちゃ遅くなった。これはGPT-4の方が大きめのモデルを使っていたためだと思われる。

ただ、このSST2というタスクはある文章を読ませてネガティブなものかポジティブなものかを0か1で返すというものなのだが、DPOPのようにコンテキストから質問に応答するような、少し難しめのタスクにしたい場合はどうするか。

タスクを増やすには、tasks.pyに自分のやりたいタスクを追加して、templates.pyにテンプレートを追加する。

class DROPDataset(Dataset): #DROPタスク
    metric_name = "f1"
    generation = True

    def __init__(self, subtask=None, **kwargs) -> None:
        self.load_dataset()
        
    def load_dataset(self):
        dataset = load_dataset("drop")
        train_examples = dataset["train"]
        valid_examples = dataset["validation"]

        train_samples = [self.build_sample(example, idx) for idx, example in enumerate(train_examples)]
        valid_samples = [self.build_sample(example, idx) for idx, example in enumerate(valid_examples)]
        self.samples = {"train": train_samples, "valid": valid_samples}
    
    # for generative tasks, candidates are []
    def build_sample(self, example, idx):
        answers = example['answers_spans']['spans']
        assert len(answers) > 0
        return Sample(
            id=idx,
            data={
                "context": example['passage'],
                "question": example['question'],
                "answers": answers
            },
            candidates=None,
            correct_candidate=answers
        )
        
    def get_template(self, template_version=0):
        return {0: DROPTemplate}[template_version]()

templates.pyのテンプレート設定部分

class DROPTemplate(Template):

    def encode(self, sample):
        question = sample.data['question'].strip()
        # title = sample.data['title']
        context = sample.data['context']
        answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one

        return f"Passage: {context}\nQuestion: {question}\nAnswer:"

    def verbalize(self, sample, candidate):
        question = sample.data['question'].strip()
        # title = sample.data['title']
        context = sample.data['context']
        answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one

        return f"Passage: {context}\nQuestion: {question}\nAnswer: {answer}\n"

    
    def encode_sfc(self, sample):
        raise NotImplementedError

    def verbalize_sfc(self, sample, candidate):
        raise NotImplementedError

encodeとverbalizeのところでテンプレートを生成している。
DROPというタスクではContextとQuestionとAnswerを設定しているようだ。

ここを[INST][/INST]とかに変えればいわゆるインストラクションモデルになるということだろう。

48GBとかもっと小さいVRAMのGPUでもそれまで扱えなかった大きさのモデルをファインチューニングできるかもしれない。

ただ、MeZOは追試してる人が海外含めてあまり見当たらないので知見を貯めていきたい。