![見出し画像](https://assets.st-note.com/production/uploads/images/36013386/rectangle_large_type_2_ccc21d0eb16552ad18653d62936d24db.jpg?width=800)
Simple Transformers 入門 (5) - テキスト生成
「Simple Transformers」で「テキスト生成」を行う方法をまとめました。
1. テキスト生成
「テキスト生成」は、与えられたテキストに続くテキストを生成するタスクです
サポートモデルは、次のとおりです。
・CTRL
・GPT-2
・OpenAI-GPT
・Transformer-XL
・XLM
・XLNet
「テキスト生成」の最小限のコードは、次のとおりです。
import logging
from simpletransformers.language_generation import LanguageGenerationModel
# ログの設定
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)
# モデルの作成
model = LanguageGenerationModel("gpt2", "gpt2")
# テキスト生成
model.generate("Let's give a minimal start to the model like")
2. テキスト生成のデータセット
主な「テキスト生成」のデータセットは、次のとおりです。
・Writing Scientific Paper Abstracts with GPT-2
3-1. LanguageGenerationModel
「LanguageGenerationModel」は、「テキスト生成」で使用するクラスです。
◎ コンストラクタ
コンストラクタの書式は、次のとおりです。
LanguageGenerationModel (self, model_type, model_name, args=None, use_cuda=True, cuda_device=-1, **kwargs)
パラメータは、次のとおりです。
・model_type : (required) str - モデル種別。
・model_name : (required) str - Huggingface Transformersの事前学習済みモデル名、またはモデルファイルを含むディレクトリへのパス。
・args : (optional) dict - オプション引数。
・use_cuda : (optional) bool - CUDAを使用するかどうか。
・cuda_device : (optional) - CUDAデバイス。
・**kwargs (optional) : プロキシ引数。
◎ クラス属性
クラス属性は、次のとおりです。
・tokenizer : トークナイザー。
・model : Huggingface Transformersの事前学習済みモデル名、またはモデルファイルを含むディレクトリへのパス。
・device : デバイス。
・cuda_device : (optional) - CUDAデバイス。
◎ generate()
テキスト生成を行います。
generate(self, prompt=None, args=None, verbose=True)
パラメータは、次のとおりです。
・prompt : (optional) - プロンプトテキスト。
・args : (optional) - オプション引数。
・verbose : (optional) - 詳細出力。
戻り値は、次のとおりです。
・generated_sequences : 生成したテキスト。
3-2. LanguageGenerationModelの追加オプション引数
デフォルト値は、次のとおりです。
"do_sample": True,
"prompt": "",
"max_length": 20,
"stop_token": None,
"temperature": 1.0,
"repetition_penalty": 1.0,
"k": 0,
"p": 0.9,
"padding_text": "",
"xlm_language": "",
"num_return_sequences": 1,
"config_name": None,
"tokenizer_name": None,
パラメータは、次のとおりです。
・do_sample: bool - Falseに設定すると、greedyデコードが使用される。それ以外の場合は、サンプリングが使用される。
・prompt: str - プロンプトテキスト。
・ max_length: int - 生成するテキストの長さ。
・stop_token: str - テキスト生成が停止するトークン。
・temperature: float - 1.0の温度がデフォルト。 これを下げると、サンプリングがよりgreedyになる。
・repetition_penalty: float - 主にCTRLモデルに役立つ。その場合は、1.2を使用。
・k: int - top-kサンプリングのk値。
・p: float - top-p(nucleus)サンプリングのp値。
・padding_text: str - Transfo-XLおよびXLNetのパディングテキスト。
・xlm_language: str - XLMモデルで使用する場合のオプションの言語。
・num_return_sequences: int - 生成するサンプルの数。
・config: dict - ここで指定されたキー値は、モデル構成で使用されるデフォルト値をオーバーライド。
4. 日本語データセットでのテキスト生成
日本語データセットで学習させた「言語モデル」を読み込んで、テキスト生成を実行します。
from simpletransformers.language_generation import LanguageGenerationModel
# モデルの作成
model = LanguageGenerationModel("gpt2", "outputs/")
# テキスト生成
model.generate("織田信長は、")
['織田信長は、こ永禄年)1月、これは5月、信長は5日、家康は日に・長久手の戦いしている']
簡単な学習(2800データを10エポック)のわりには、文章になってる。
5. 参考
この記事が気に入ったらサポートをしてみませんか?