Rinnaの日本語言語モデルを試してみる
今回は、古いかもしれませんが、Rinnnaの日本語言語モデルを試してみます。
今回紹介するコードは、日本語の文章をトークン化して、文章内の特定の語をマスク化して、そのマスク化した部分を列挙するコードとなります。
各タスクごとにシンプルなコードなので、各ステップごとにわかりやすいなと思いましたので紹介します。
下記は、上記GitHubにあるコードをGoogle Colab用に少しだけ変えています。
!git clone https://github.com/rinnakk/japanese-pretrained-models.git
!pip install -r /content/japanese-pretrained-models/requirements.txt
!pip install sentencepiece
!pip install transformers
import torch
from transformers import T5Tokenizer, RobertaForMaskedLM
# load tokenizer
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-roberta-base")
tokenizer.do_lower_case = True # due to some bug of tokenizer config loading
# load model
model = RobertaForMaskedLM.from_pretrained("rinna/japanese-roberta-base")
model = model.eval()
# original text
text = "4年に1度オリンピックは開かれる。"
# prepend [CLS]
text = "[CLS]" + text
# tokenize
tokens = tokenizer.tokenize(text)
print(tokens) # output: ['[CLS]', '▁4', '年に', '1', '度', 'オリンピック', 'は', '開かれる', '。']']
# mask a token
masked_idx = 5
tokens[masked_idx] = tokenizer.mask_token
print(tokens) # output: ['[CLS]', '▁4', '年に', '1', '度', '[MASK]', 'は', '開かれる', '。']
# convert to ids
token_ids = tokenizer.convert_tokens_to_ids(tokens)
print(token_ids) # output: [4, 1602, 44, 24, 368, 6, 11, 21583, 8]
# convert to tensor
token_tensor = torch.LongTensor([token_ids])
# provide position ids explicitly
position_ids = list(range(0, token_tensor.size(1)))
print(position_ids) # output: [0, 1, 2, 3, 4, 5, 6, 7, 8]
position_id_tensor = torch.LongTensor([position_ids])
# get the top 10 predictions of the masked token
with torch.no_grad():
outputs = model(input_ids=token_tensor, position_ids=position_id_tensor)
predictions = outputs[0][0, masked_idx].topk(10)
for i, index_t in enumerate(predictions.indices):
index = index_t.item()
token = tokenizer.convert_ids_to_tokens([index])[0]
print(i, token)
上記コードでは、masked_idxでオリンピックに相当する部分をマスク化して、日本語の言語モデルでマスク部分を推測させて、最後に10個候補を挙げています。
0 総会
1 サミット
2 ワールドカップ
3 フェスティバル
4 大会
5 オリンピック
6 全国大会
7 党大会
8 イベント
9 世界選手権
他の例で試してみますと、textの部分をtext="世界は君が思うよりも美しい"と書いて実行しますと、下記のように、思うの部分がマスク化されて、マスク化されたところの候補が10個列挙されます。
['[CLS]', '▁世界', 'は', '君', 'が', '思う', 'よりも', '美しい', '。']
['[CLS]', '▁世界', 'は', '君', 'が', '[MASK]', 'よりも', '美しい', '。']
[4, 13618, 11, 1783, 12, 6, 632, 5797, 8]
[0, 1, 2, 3, 4, 5, 6, 7, 8]
0 思う
1 望む
2 見た
3 考える
4 知っている
5 思った
6 考えた
7 言う
8 描く
9 みる
この記事が気に入ったらサポートをしてみませんか?