見出し画像

gtp-2-simple (2) - gtp-2-simpleのノートブック

以下のノートブックが面白かったので、ざっくり翻訳しました。

Train a GPT-2 Text-Generating Model w/ GPU For Free

1. はじめに

「Google Colab」で「gpt-2-simple」を使用して、ファインチューニングテキスト生成を行います。詳細については、リポジトリブログも参照してください。

2. パッケージのインポート

%tensorflow_version 1.x
!pip install -q gpt-2-simple
import gpt_2_simple as gpt2
from datetime import datetime
from google.colab import files

3. Google ColabのGPU

「Google Colab」のGPUには、「T4」と「K80」があります。「T4」は「K80」より若干高速で、多くのメモリを持ち、より大きな「GPT-2モデル」を訓練することができるます。

以下のセルで、GPUを確認できます。

!nvidia-smi
【メモ】今はP100がColabで一番良いGPUな模様。

4. GPT-2モデルのダウンロード

はじめに、「GPT-2モデル」をダウンロードする必要があります。
「GPT-2モデル」は、次の4種類が提供されています。

・smallモデル : 124Mパラメータモデル(ディスク 500MB)。
・mediumモデル : 335パラメータモデル(ディスク 1.5GB)。
・largeモデル : 774Mパラメータモデル (ディスク 3GB)。
 Colabでファインチューニング不可。テキスト生成可能。
・extra largeモデル : 1558Mパラメータモデル。
 Colabでファインチューニング不可。「T4」でのみテキスト生成可能。

モデルが大きいほど知識は増えますが、ファインチューニングとテキスト生成に時間がかかります。

以下のセルの「model_name」でモデルを指定できます。

gpt2.download_gpt2(model_name="124M")

5. Googleドライブのマウント

「Google Colab」で「訓練テキスト」「モデル」を取得するには、Googleドライブにマウントする必要があります。

以下のセルで、Googleドライブにマウントします。

gpt2.mount_gdrive()

6. 訓練テキストのアップロード

「Google Colab」のサイドバーの「ファイル」を介してchar-rnnで提供しているtinyshakespeareデータセット(1MBのテキストファイル)をアップロードします。

画像1

セル内のファイル名(file_name = "<xxx>")を変更して、セルを実行します。

file_name = "input.txt"

訓練テキストが10MB以上の場合は、はじめにGoogleドライブにアップロードしてから、「Google Colab」にコピーすることをお勧めします。

gpt2.copy_file_from_gdrive(file_name)

7. ファインチューニング

以下のセルは、ファインチューニングを開始し、指定したステップ数の訓練を実行します。無期限に実行するには、「steps = -1」を指定します。

sess = gpt2.start_tf_sess()

gpt2.finetune(sess,
    dataset=file_name, # ファイル名
    model_name='124M', # モデル名 (124M, 355M, 744M)
    steps=1000, # ステップ数
    restore_from='fresh', # リストア場所 (fresh:ベースから, latest:チェックポイントから)
    run_name='run1', # チェックポイント内のサブフォルダ名
    print_every=10, # 訓練の進捗状況を何ステップ毎に表示するか
    sample_every=200, # サンプルを何ステップ毎に行うか
    save_every=500, # チェックポイントを何ステップ毎に保存するか
    learning_rate=1e-4, # 学習率
    overwrite=False # 上書き
    )

デフォルトで「/checkpoint/run1」にチェックポイントが保存されます。チェックポイントは、500ステップ毎(変更可能)に、保存されます。

【注意】このセルを再実行する場合は、「ランタイム->ランタイムの再起動」で再起動してください。インポートを再実行する必要がありますが、ファイルを再コピーする必要はありません。

8. チェックポイントのダウンロード

ファインチューニングが完了したら、チェックポイントをローカルにダウンロードします。Googleドライブにコピーしてから、ローカルにダウンロードすることを、強くお勧めします。

gpt2.copy_checkpoint_to_gdrive(run_name='run1')

9. チェックポイントの読み込み

以下のセルで、チェックポイントがGoogleドライブから「Google Colab」にコピーされます。

gpt2.copy_checkpoint_from_gdrive(run_name='run1')

次のセルで、チェックポイントからモデルを読み込むことができます。

sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess, run_name='run1')
【注意】このセルを再実行する場合は、「ランタイム->ランタイムの再起動」で再起動してください。インポートを再実行する必要がありますが、ファイルを再コピーする必要はありません。

10. テキスト生成

チェックポイントからモデルをロードした後、テキスト生成を行います。generate()は、モデルから単一テキストを生成します。

gpt2.generate(sess, run_name='run1')
・length : 生成するトークンの数(デフォルトは1024)。
・temperature : 温度(デフォルト0.7、推奨0.7〜1.0)。
 温度が高いほど、ランダムな補完になる。
・prefix : 特定テキストの後にに続くテキストを生成。
・nsamples : 生成するテキストの数。
・batch_size :  複数のテキストを並行して生成する並列数。
 nsamplesで割り切れる必要がある。
 「Google Colab」では、最大20。
・top_k : 各ステップでkの単語を考慮して補完を行う。
 デフォルトは0(制限なし)。
 40が一般的に良い値。
・top_p : 生成テキストを累積確率に制限。
 top_p = 0.9のデータセットで良い結果を得る。
・truncate : 指定したシーケンスで入力テキストを切り捨てる。
 たとえば、truncate = '<|endoftext|>'の場合、返されるテキストには最初の<|endoftext|>の前のすべてが含まれる。
・include_prefix : truncateおよびinclude_prefix = Falseを使用する場合、指定されたprefixは返されるテキストに含まれない。

生成したテキストを別の場所に渡す必要がある場合は、次のセルを実行できます。

text = gpt2.generate(sess, return_as_list=True)[0]

11. 複数のテキスト生成

次のセルでは、複数のテキストを生成できます。さらに多くのテキストを生成するために、セルを何度でも再実行できます。

gen_file = 'gpt2_gentext_{:%Y%m%d_%H%M%S}.txt'.format(datetime.utcnow())

gpt2.generate_to_file(sess,
    destination_path=gen_file,
    length=500,
    temperature=0.7,
    nsamples=100,
    batch_size=20
    )

12. 生成テキストのダウンロード

ダウンロードするファイルを取得するために2回実行する必要がある場合があります。

# 生成したテキストのダウンロード
files.download(gen_file)

13. 774モデルおよび1558モデルからのテキスト生成

774Mモデルまたは1558Mモデルからテキスト生成する場合は、model_nameをgpt2.load_gpt2()およびgpt2.generate()に渡します。

model_name = "774M"

gpt2.download_gpt2(model_name=model_name)
sess = gpt2.start_tf_sess()

gpt2.load_gpt2(sess, model_name=model_name)
gpt2.generate(sess,
    model_name=model_name,
    prefix="The secret of life is",
    length=100,
    temperature=0.7,
    top_p=0.9,
    nsamples=5,
    batch_size=5
    )

【おまけ】 GPU Sync Fail

「GPU Sync Fail」などのノートブックにエラー時は、以下のコマンドで強制終了および再起動してください。

!kill -9 -1



この記事が気に入ったら、サポートをしてみませんか?
気軽にクリエイターの支援と、記事のオススメができます!
2

こちらでもピックアップされています

機械学習入門
機械学習入門
  • 77本

機械学習関連のノートをまとめました

コメントを投稿するには、 ログイン または 会員登録 をする必要があります。