見出し画像

Huggingface Transformers 入門 (16) - 言語モデルの学習スクリプト

以下の記事を参考に書いてます。

Language model examples - huggingface/transformers

前回

1. 言語モデルの学習

テキストデータセットでの「言語モデル」のファインチューニング(または0からの学習)を行います。モデル毎に以下の損失で学習します。

・CLM(Causal Language Modeling): GPT、GPT-2
・MLM(Masked Language Modeling)  : ALBERT、BERT、DistilBERT、RoBERTa
・PLM(Permutation Language Modeling): XLNet

これらの目的の違いについての詳細は、モデルの概要を参照してください。 

【注意】 古いスクリプト「run_language_modeling.py」はここで入手できます。

次の例は、ハブでホストされているデータセット、または学習と検証用の独自のテキストファイルで実行されます。以下に両方の例を示します。

2. GPT-2/GPTとCLM

次の例では、WikiText-2でGPT-2をファインチューニングします。 生のWikiText-2を使用しています。 ここでの損失は、CLMです。

python run_clm.py \
    --model_name_or_path gpt2 \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --do_train \
    --do_eval \
    --output_dir /tmp/test-clm  

単一のK80 GPUで学習するのに約30分、評価を実行するのに約1分かかります。データセットでファインチューニングすると、~20の複雑さのスコアに達します。

独自の学習ファイルと検証ファイルで実行するには、次のコマンドを使用します。

python run_clm.py \
    --model_name_or_path gpt2 \
    --train_file path_to_train_file \
    --validation_file path_to_validation_file \
    --do_train \
    --do_eval \
    --output_dir /tmp/test-clm

3. RoBERTa/BERT/DistilBERTとMLM

次の例では、WikiText-2でRoBERTaをファインチューニングします。ここでも、生のWikiText-2を使用しています。BERT/RoBERTaには双方向メカニズムがあるため、損失は異なります。したがって、事前学習中に使用されたのと同じ損失、つまりMLMを使用しています。

RoBERTaの論文によれば、静的マスキングではなく動的マスキングを使用しています。したがって、モデルの収束がわずかに遅くなる可能性があります(過剰適合にはより多くのエポックが必要です)。

python run_mlm.py \
    --model_name_or_path roberta-base \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --do_train \
    --do_eval \
    --output_dir /tmp/test-mlm

独自の学習ファイルと検証ファイルで実行するには、次のコマンドを使用します。

python run_mlm.py \
    --model_name_or_path roberta-base \
    --train_file path_to_train_file \
    --validation_file path_to_validation_file \
    --do_train \
    --do_eval \
    --output_dir /tmp/test-mlm

データセットが1行に1つのサンプルで構成されている場合は、--line_by_lineを使用できます(そうでない場合、スクリプトはすべてのテキストを連結してから、同じ長さのブロックに分割します)。  

【注意】 TPUでは、--pad_to_max_length--line_by_lineと組み合わせて使用して、すべてのバッチの長さが同じであることを確認する必要があります。

4. Whole Word Masking

BERTの作成者は、2019年5月に「Whole Word Masking」を使用してBERTの新しいバージョンをリリースしました。ランダムに選択されたトークン(単語の一部である可能性があります)をマスクする代わりに、ランダムに選択された単語をマスクします(その単語に対応するすべてのトークンをマスクします)。 この論文では、この手法を中国語向けに改良しました。

単語全体のマスキングを使用してモデルをファインチューニングするには、次のスクリプトを使用します。

python run_mlm_wwm.py \
    --model_name_or_path roberta-base \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --do_train \
    --do_eval \
    --output_dir /tmp/test-mlm-wwm

中国語モデルの場合、文字レベルでトークン化されるため、参照ファイル(ltpライブラリが必要)を生成する必要があります。

Q:なぜ参照ファイルなのですか? 

A:次のような中国語の文があるとします。我喜欢你元のChinese-BERTは、['我', '喜', '欢', '你'](文字レベル)としてトークン化します。 しかし、喜欢は一言です。 全単語マスキングプロキシのためには、['我', '喜', '##欢', '你']のような結果が必要なので、BERT元のトークンのどの位置に##を追加すべきかをモデルに伝えるための参照ファイルが必要です。
Q:なぜLTPなのですか?

A:最もよく知られている中国のWWM BERTは、HITによるChinese-BERT-wwmです。CLUE(Chinese GLUE)のような多くのChinesTaskでうまく機能します。彼らはLTPを使用しているので、モデルをファインチューニングしたい場合はLTPが必要です。

現在、LTPはTransformers==3.2.0でのみうまく機能します。 したがって、requirements.txtには追加しません。参照ファイルを作成するrun_chinese_ref.pyスクリプトを実行するには、このバージョンのTransformersを使用して別の環境を作成する必要があります。スクリプトはexamples/contribにあります。 適切な環境になったら、以下を実行します。

export TRAIN_FILE=/path/to/dataset/wiki.train.raw
export LTP_RESOURCE=/path/to/ltp/tokenizer
export BERT_RESOURCE=/path/to/bert/tokenizer
export SAVE_PATH=/path/to/data/ref.txt

python examples/contrib/run_chinese_ref.py \
    --file_name=path_to_train_or_eval_file \
    --ltp=path_to_ltp_tokenizer \
    --bert=path_to_bert_tokenizer \
    --save_path=path_to_reference_file

次に、次のようにスクリプトを実行できます。

python run_mlm_wwm.py \
    --model_name_or_path roberta-base \
    --train_file path_to_train_file \
    --validation_file path_to_validation_file \
    --train_ref_file path_to_train_chinese_ref_file \
    --validation_ref_file path_to_validation_chinese_ref_file \
    --do_train \
    --do_eval \
    --output_dir /tmp/test-mlm-wwm
【注意】 TPUでは、--pad_to_max_lengthを使用して、すべてのバッチの長さが同じであることを確認する必要があります。

5. XLNetとPLM

XLNetは、PLMという別の学習目標を使用します。これは、入力シーケンスの因数分解順序のすべての順列にわたって予想される確率を最大化することにより、双方向コンテキストを学習する自己回帰法です。

--plm_probabilityを使用して、マスクされたトークンのスパンの長さと、PLMの周囲のコンテキストの長さの比率を定義します。

--max_span_lengthを使用して、PLMに使用されるマスクされたトークンのスパンの長さを制限することもできます。

wikitext-2でXLNetをファインチューニングする方法は次のとおりです。

python run_plm.py \
    --model_name_or_path=xlnet-base-cased \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --do_train \
    --do_eval \
    --output_dir /tmp/test-plm

独自の学習および検証ファイルで微調整するには、次のコマンドを実行します。

python run_plm.py \
    --model_name_or_path=xlnet-base-cased \
    --train_file path_to_train_file \
    --validation_file path_to_validation_file \
    --do_train \
    --do_eval \
    --output_dir /tmp/test-plm

データセットが1行に1つのサンプルで構成されている場合は、--line_by_lineを使用できます(そうでない場合、スクリプトはすべてのテキストを連結してから、同じ長さのブロックに分割します)。

【注意】 TPUでは、--pad_to_max_lengthを--line_by_lineフラグと組み合わせて使用して、すべてのバッチの長さが同じであることを確認する必要があります。

【おまけ】 run_clm.py のパラメータ

◎ モデルパラメータ

・model_name_or_path: モデルのチェックポイント(モデルを最初から学習しない場合)
・model_type: モデルの種別(モデルを最初から学習する場合)
・config_name: コンフィグ名(model_nameと同じでない場合)
・tokenizer_name: トークナイザー名(model_nameと同じでない場合)
・cache_dir: キャッシュフォルダ
・use_fast_tokenizer: Fastトークナイザーを使用するかどうか
・model_revision: 使用するモデルの特定のバージョン
・use_auth_token: 「transformers-cli login」の実行時に生成されたトークンを使用するかどうか

◎ 学習パラメータ

・dataset_name: データセット名
・dataset_config_name: データセットのコンフィグ名
・train_file: 学習データ(テキストファイル)
・validation_file: 検証データ(テキストファイル)
・overwrite_cache: キャッシュの上書き
・validation_split_percentage: 学習データから使われる検証データの割合(検証データがない場合)
・max_seq_length: トークン化後の最大合計入力シーケンス長
・preprocessing_num_workers: 前処理に使用するプロセス数
・block_size: トークン化後のオプションの入力シーケンス長
・max_train_samples: 学習データの最大数
・max_val_samples: 検証データの最大数

【おまけ】 run_mlm.py のパラメータ

◎ モデルパラメータ

・model_name_or_path: モデルのチェックポイント(モデルを最初から学習しない場合)
・model_type: モデルの種別(モデルを最初から学習する場合)
・config_name: コンフィグ名(model_nameと同じでない場合)
・tokenizer_name: トークナイザー名(model_nameと同じでない場合)
・cache_dir: キャッシュフォルダ
・use_fast_tokenizer: Fastトークナイザーを使用するかどうか
・model_revision: 使用するモデルの特定のバージョン
・use_auth_token: 「transformers-cli login」の実行時に生成されたトークンを使用するかどうか

◎ 学習パラメータ

・dataset_name: データセット名
・dataset_config_name: データセットのコンフィグ名
・train_file: 学習データ(テキストファイル)
・validation_file: 検証データ(テキストファイル)
・overwrite_cache: キャッシュの上書き
・validation_split_percentage: 学習データから使われる検証データの割合(検証データがない場合)
・max_seq_length: トークン化後の最大合計入力シーケンス長
・preprocessing_num_workers: 前処理に使用するプロセス数
・mlm_probability: MLMの損失に対するトークンとマスクの比率
・line_by_line: データセット内のテキスト行を個別シーケンスとして処理するかどうか
・pad_to_max_length: 全てのサンプルをmax_seq_lengthにパディングするかどうか
・max_train_samples: 学習データの最大数
・max_val_samples: 検証データの最大数

次回



この記事が気に入ったらサポートをしてみませんか?