見出し画像

Rinnaの日本語言語モデルを試してみる

今回は、古いかもしれませんが、Rinnnaの日本語言語モデルを試してみます。


今回紹介するコードは、日本語の文章をトークン化して、文章内の特定の語をマスク化して、そのマスク化した部分を列挙するコードとなります。

各タスクごとにシンプルなコードなので、各ステップごとにわかりやすいなと思いましたので紹介します。

下記は、上記GitHubにあるコードをGoogle Colab用に少しだけ変えています。

!git clone https://github.com/rinnakk/japanese-pretrained-models.git
!pip install -r /content/japanese-pretrained-models/requirements.txt
!pip install sentencepiece
!pip install transformers
import torch
from transformers import T5Tokenizer, RobertaForMaskedLM

# load tokenizer
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-roberta-base")
tokenizer.do_lower_case = True  # due to some bug of tokenizer config loading

# load model
model = RobertaForMaskedLM.from_pretrained("rinna/japanese-roberta-base")
model = model.eval()

# original text
text = "4年に1度オリンピックは開かれる。"

# prepend [CLS]
text = "[CLS]" + text

# tokenize
tokens = tokenizer.tokenize(text)
print(tokens)  # output: ['[CLS]', '▁4', '年に', '1', '度', 'オリンピック', 'は', '開かれる', '。']']

# mask a token
masked_idx = 5
tokens[masked_idx] = tokenizer.mask_token
print(tokens)  # output: ['[CLS]', '▁4', '年に', '1', '度', '[MASK]', 'は', '開かれる', '。']

# convert to ids
token_ids = tokenizer.convert_tokens_to_ids(tokens)
print(token_ids)  # output: [4, 1602, 44, 24, 368, 6, 11, 21583, 8]

# convert to tensor
token_tensor = torch.LongTensor([token_ids])

# provide position ids explicitly
position_ids = list(range(0, token_tensor.size(1)))
print(position_ids)  # output: [0, 1, 2, 3, 4, 5, 6, 7, 8]
position_id_tensor = torch.LongTensor([position_ids])

# get the top 10 predictions of the masked token
with torch.no_grad():
    outputs = model(input_ids=token_tensor, position_ids=position_id_tensor)
    predictions = outputs[0][0, masked_idx].topk(10)

for i, index_t in enumerate(predictions.indices):
    index = index_t.item()
    token = tokenizer.convert_ids_to_tokens([index])[0]
    print(i, token)


上記コードでは、masked_idxでオリンピックに相当する部分をマスク化して、日本語の言語モデルでマスク部分を推測させて、最後に10個候補を挙げています。

0 総会
1 サミット
2 ワールドカップ
3 フェスティバル
4 大会
5 オリンピック
6 全国大会
7 党大会
8 イベント
9 世界選手権


他の例で試してみますと、textの部分をtext="世界は君が思うよりも美しい"と書いて実行しますと、下記のように、思うの部分がマスク化されて、マスク化されたところの候補が10個列挙されます。

['[CLS]', '▁世界', 'は', '君', 'が', '思う', 'よりも', '美しい', '。']
['[CLS]', '▁世界', 'は', '君', 'が', '[MASK]', 'よりも', '美しい', '。']
[4, 13618, 11, 1783, 12, 6, 632, 5797, 8]
[0, 1, 2, 3, 4, 5, 6, 7, 8]
0 思う
1 望む
2 見た
3 考える
4 知っている
5 思った
6 考えた
7 言う
8 描く
9 みる


いいなと思ったら応援しよう!