見出し画像

Simple Transformers 入門 (5) - テキスト生成

Simple Transformers」で「テキスト生成」を行う方法をまとめました。

1. テキスト生成

テキスト生成」は、与えられたテキストに続くテキストを生成するタスクです

サポートモデルは、次のとおりです。

・CTRL
・GPT-2
・OpenAI-GPT
・Transformer-XL
・XLM
・XLNet

「テキスト生成」の最小限のコードは、次のとおりです。

import logging
from simpletransformers.language_generation import LanguageGenerationModel

# ログの設定
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

# モデルの作成
model = LanguageGenerationModel("gpt2", "gpt2")

# テキスト生成
model.generate("Let's give a minimal start to the model like")

2. テキスト生成のデータセット

主な「テキスト生成」のデータセットは、次のとおりです。

Writing Scientific Paper Abstracts with GPT-2

3-1. LanguageGenerationModel

「LanguageGenerationModel」は、「テキスト生成」で使用するクラスです。

◎ コンストラクタ
コンストラクタの書式は、次のとおりです。

LanguageGenerationModel (self, model_type, model_name, args=None, use_cuda=True, cuda_device=-1, **kwargs)

パラメータは、次のとおりです。

・model_type : (required) str - モデル種別。
・model_name : (required) str - Huggingface Transformersの事前学習済みモデル名、またはモデルファイルを含むディレクトリへのパス。
・args : (optional) dict - オプション引数。
・use_cuda : (optional) bool - CUDAを使用するかどうか。
・cuda_device : (optional) - CUDAデバイス。
・**kwargs (optional) : プロキシ引数。

◎ クラス属性
クラス属性は、次のとおりです。

・tokenizer : トークナイザー。
・model : Huggingface Transformersの事前学習済みモデル名、またはモデルファイルを含むディレクトリへのパス。
・device : デバイス。
・cuda_device : (optional) - CUDAデバイス。

◎ generate()
テキスト生成を行います。

generate(self, prompt=None, args=None, verbose=True)

パラメータは、次のとおりです。

・prompt : (optional) - プロンプトテキスト。
・args : (optional) - オプション引数。
・verbose : (optional) - 詳細出力。

戻り値は、次のとおりです。

・generated_sequences : 生成したテキスト。

3-2. LanguageGenerationModelの追加オプション引数

デフォルト値は、次のとおりです。

"do_sample": True,
"prompt": "",
"max_length": 20,
"stop_token": None,
"temperature": 1.0,
"repetition_penalty": 1.0,
"k": 0,
"p": 0.9,
"padding_text": "",
"xlm_language": "",
"num_return_sequences": 1,
"config_name": None,
"tokenizer_name": None,

パラメータは、次のとおりです。

・do_sample: bool - Falseに設定すると、greedyデコードが使用される。それ以外の場合は、サンプリングが使用される。
・prompt: str - プロンプトテキスト。
・ max_length: int - 生成するテキストの長さ。
・stop_token: str - テキスト生成が停止するトークン。
・temperature: float - 1.0の温度がデフォルト。 これを下げると、サンプリングがよりgreedyになる。
・repetition_penalty: float - 主にCTRLモデルに役立つ。その場合は、1.2を使用。
・k: int - top-kサンプリングのk値。
・p: float - top-p(nucleus)サンプリングのp値。
padding_text: str - Transfo-XLおよびXLNetのパディングテキスト。
xlm_language: str - XLMモデルで使用する場合のオプションの言語。
num_return_sequences: int - 生成するサンプルの数。
・config: dict - ここで指定されたキー値は、モデル構成で使用されるデフォルト値をオーバーライド。

4. 日本語データセットでのテキスト生成

日本語データセットで学習させた「言語モデル」を読み込んで、テキスト生成を実行します。

from simpletransformers.language_generation import LanguageGenerationModel

# モデルの作成
model = LanguageGenerationModel("gpt2", "outputs/")

# テキスト生成
model.generate("織田信長は、")
['織田信長は、こ永禄年)1月、これは5月、信長は5日、家康は日に・長久手の戦いしている']

簡単な学習(2800データを10エポック)のわりには、文章になってる。

5. 参考



この記事が気に入ったらサポートをしてみませんか?