BERTモデルを使ってことわざテストを解いてみた

この記事では、Google Colaboratory 上で BERT モデルを利用します。transformers ライブラリを用いて、穴埋めの推論を行うまでの流れを説明しています。

はじめに

この記事では、BERT モデルを使ってことわざテストを解いてみます。

下はその結果の一部です。

================================================================================
entered_text : [MASK]で鯛を釣る
predicted_text : 釣りで鯛を釣る
================================================================================
entered_text : 果報は[MASK]待て
predicted_text : 果報はしばらく待て
================================================================================
entered_text:井の中の蛙[MASK]を知らず
predicted_text:井の中の蛙であることを知らず
================================================================================

本記事は、主に以下のような方を対象としています。

Google Colaboratoryの基本的な操作がわかる方
機械学習のプログラムを作成したことがある方
自然言語処理について学んでみたい方

BERT モデルを使うと、自然言語についての推論を行うことができます。

今回は実際に事前学習済みのモデルを Google Colaboratory 上で使用します。

BERT とは

BERT とは 2018 年に Google が発表したニューラル言語モデルです。
自然言語処理について人間を超える精度を叩き出したことで有名です。

BERT では RNN を用いたモデルと同様に文章をトークンに分割したものを入力として受けて、それぞれのトークンに対応するベクトルを出力します。
BERT 以前の ELMo などの RNN を用いたモデルではそれぞれのモデル内の層で、文章の前方(または後方から)順々に処理が行われていました。
そのため、処理を経るにつれて先に処理が行われたトークンの情報が失われてしまう問題点がありました。

そこで  BERT では事前学習として「masked language model」というものを使っています。Masked language model では文章の一部を隠して、その隠した部分に入るトークンを予測するという方法で学習します。

このような技術的側面から、BERT は文章の一部が穴で隠された穴埋め問題を解くことが得意です。また BERT モデル自体も穴埋め問題で紹介されることが多いです。

今回この記事では、まず公式のサンプルを動かしてから
ことわざで穴埋めテストを解いてみたいと思います!

サンプルプログラムの実行

まずは BERT を実際に使ってみましょう。
今回は東北大学乾研究室より公開された訓練済み日本語 BERT モデルを使用します。
また以下のコードは「BERT」による自然言語処理入門
を参考にさせていただきました。

以下で利用している Transformers は huggingface 社が提供しているオープンソースのライブラリです。
今回使用する東北大学の BERT モデルも transformers から利用可能です。​

import numpy as np


import torch
from transformers import BertJapaneseTokenizer, BertForMaskedLM

最初は「明日の天気は [MASK] でしょう。」という文章の
"[MASK]" の部分を予測してみましょう。

まずはモデルとトークナイザを読み込みます。

model_name = 'cl-tohoku/bert-base-japanese-whole-word-masking'
tokenizer = BertJapaneseTokenizer.from_pretrained(model_name)
bert_mlm = BertForMaskedLM.from_pretrained(model_name)
bert_mlm = bert_mlm.cuda()

文章を BERT のようなニューラルネットワークに入力する際、
最初に文を適当な単位に分割します。

例えば「私は大学生だ」という文章があるとき、
["私","は","大学生", "だ"]
といったように分割します。

この "私" や "は" といった分割された単位の事をトークンと呼び、
トークンに分割する作業のことをトークン化と呼びます。
[MASK] もトークンの一つとなります。

では実際にトークナイザを用いてトークン化してみましょう。

text = '明日の天気は[MASK]でしょう。'
tokens = tokenizer.tokenize(text)
print(tokens)

以下のように、無事トークンごとに分けることができました。

['明日', 'の', '天気', 'は', '[MASK]', 'でしょ', 'う', '。']

トークン化したあとは、それぞれのトークンに対応する id に変換してから BERT に入力します。
トークン化と id に変換するところまでを tokenizer.encode() が行ってくれます。

またその後の input_ids = input_ids.cuda() によって
符号化された文章を GPU に配置しています。

bert_mlm には BertForMaskedLM が入っています。
id に変換された文章 input_ids を入力することで
特殊トークン [MASK] に入るトークンを予測してくれます。

# 文章を符号化し(GPUに配置する)
input_ids = tokenizer.encode(text, return_tensors='pt')
input_ids = input_ids.cuda()
# BERTに入力し、分類スコアを得る
# 系列長を揃える必要がないので、単にinput_idsのみを入力する
with torch.no_grad():
   output = bert_mlm (input_ids=input_ids)
   scores = output.logits
scores

無事に語彙に含まれる各トークンの分類スコアを得ることができました。

tensor([[[ -6.4771,   5.8697,  -1.0711,  ...,  -6.2113,  -4.9229,  -8.3482],
        [ -6.5532,   6.0818,  -6.6133,  ...,  -8.1145,  -5.0092,  -5.5071],
        [-10.6194,   5.3389,  -1.7737,  ...,  -9.5513,  -3.4223,  -4.0840],
        ...,
        [ -9.5266,   8.1490,  -1.9271,  ...,  -7.3054,  -7.8035, -10.3056],
        [ -4.9954,   8.2409,   1.7941,  ...,  -5.3433,  -3.5264,  -6.1000],
        [ -7.2727,   6.2676,  -2.0865,  ...,  -6.3508,  -7.1688, -10.7873]]])

上で得られたスコアが高いトークンほど予測の確度が高いことを意味しているので、
スコアが高いトークンで [MASK] を穴埋めすれば自然な文章になることが期待できます。
今回はスコアが最も良いトークンで穴埋めしてみましょう。

# ID列で'[MASK]'(IDは4)の位置を調べる
mask_position = input_ids[0].tolist().index(4)

# スコアが最も良いトークンのIDを取り出し、トークンに変換する
id_best = scores[0, mask_position].argmax(-1).item()
token_best = tokenizer.convert_ids_to_tokens(id_best)
token_best = token_best.replace('##', '')

# [MASK]を上で求めたトークンで置き換える
text = text.replace('[MASK]', token_best)
print(text)

無事に「明日の天気は [MASK] でしょう。」という文章の
"[MASK]" の部分を予測することができました。
自然な文章になっていますね!

明日の天気は晴れでしょう。

また以下の関数のように、スコア上位のトークンで置き換えてみるのも
とても面白いです。

def predict_mask_topk(text, tokenizer, bert_mlm, num_topk):
 """
 文章中の最初の[MASK]をスコアの上位のトークンに置き換える
 上位何位まで使うかは、num_topkで指定
 出力は穴埋めされた文章のリストと、置き換えられたトークンのスコアのリスト
 """
 
 # 文章を符号化し、BERTで分類スコアを得る
 input_ids = tokenizer.encode(text, return_tensors='pt')
 input_ids = input_ids.cuda()
 with torch.no_grad():
   output = bert_mlm(input_ids=input_ids)
 scores = output.logits
 
 # スコアが上位のトークンとスコアを求める
 mask_position = input_ids[0].tolist().index(4)
 topk = scores[0, mask_position].topk(num_topk)
 ids_topk = topk.indices
 tokens_topk = tokenizer.convert_ids_to_tokens(ids_topk)
 scores_topk = topk.values.cpu().numpy()
 
 # 文章中の[MASK]を上で求めたトークンで置き換える
 text_topk = []
 for token in tokens_topk:
   token = token.replace('##', '')
   text_topk. append(text.replace('[MASK]',token,1))
 return text_topk, scores_topk
 
text = '東京都は31日、都内で新たに2909人が新型コロナウイルスに感染していることを確認したと[MASK]しました。'
text_topk, _ = predict_mask_topk(text,tokenizer,bert_mlm,10)
print(*text_topk, sep='\n')
東京都は31日、都内で新たに2909人が新型コロナウイルスに感染していることを確認したと発表しました。
東京都は31日、都内で新たに2909人が新型コロナウイルスに感染していることを確認したと公表しました。
東京都は31日、都内で新たに2909人が新型コロナウイルスに感染していることを確認したと報道しました。
東京都は31日、都内で新たに2909人が新型コロナウイルスに感染していることを確認したと報告しました。
東京都は31日、都内で新たに2909人が新型コロナウイルスに感染していることを確認したと判断しました。
東京都は31日、都内で新たに2909人が新型コロナウイルスに感染していることを確認したと通知しました。
東京都は31日、都内で新たに2909人が新型コロナウイルスに感染していることを確認したと説明しました。
東京都は31日、都内で新たに2909人が新型コロナウイルスに感染していることを確認したと確認しました。
東京都は31日、都内で新たに2909人が新型コロナウイルスに感染していることを確認したと認定しました。
東京都は31日、都内で新たに2909人が新型コロナウイルスに感染していることを確認したと回答しました。

上の例のように、
[MASK] で一部を隠された文章で [MASK] に入る言葉を予測する関数を作ることができました。

ことわざの穴埋めテストを解いてみる

それでは本題のことわざ穴埋めテストを試してみましょう。
今回ことわざテストは小学生向けのことわざプリントの「穴埋め問題 50問版」を使用させていただきました。

KOTOWAZA_TESTS = [
                '[MASK]隠して尻隠さず',
                '雨降って[MASK]固まる',
                '案ずるより産むが[MASK]',
                '石の上にも[MASK]年',
                '石橋を[MASK]渡る',
                '急がば[MASK]',
                '一寸の虫にも五分の[MASK]',
                '犬も歩けば[MASK]にあたる',
                '井の中の蛙[MASK]を知らず',
                '魚心あれば[MASK]心',
                '馬の[MASK]に念仏',
                '[MASK]で鯛を釣る',
                 '鬼に[MASK]',
                 '鬼の居ぬ間に[MASK]',
                 '鬼の目にも[MASK]',
                 '溺れるものは[MASK]をも掴む',
                 '火中の[MASK]を拾う',
                 '[MASK]の川流れ',
                 '果報は[MASK]待て',
                 '聞くは一時の恥[MASK]は一生の恥',
]

関数 predict_mask_topk を用いて、以下のコードで予測してみましょう。

for test in KOTOWAZA_TESTS:
 print(f"entered_text : {test}")
 text_topk, _ = predict_mask_topk(test,tokenizer,bert_mlm,1)
 print(f"predicted_text : {text_topk[0]}")
 print('='*80)

結果

entered_text : [MASK]隠して尻隠さず
predicted_text : 胸隠して尻隠さず
================================================================================
entered_text : 雨降って[MASK]固まる
predicted_text : 雨降ってから固まる
================================================================================
entered_text : 案ずるより産むが[MASK]
predicted_text : 案ずるより産むが。
================================================================================
entered_text : 石の上にも[MASK]年
predicted_text : 石の上にもある年
================================================================================
entered_text : 石橋を[MASK]渡る
predicted_text : 石橋を渡り渡る
================================================================================
entered_text : 急がば[MASK]
predicted_text : 急がば、
================================================================================
entered_text : 一寸の虫にも五分の[MASK]
predicted_text : 一寸の虫にも五分の一
================================================================================
entered_text : 犬も歩けば[MASK]にあたる
predicted_text : 犬も歩けば幸せにあたる
================================================================================
entered_text : 井の中の蛙[MASK]を知らず
predicted_text : 井の中の蛙井を知らず
================================================================================
entered_text : 魚心あれば[MASK]心
predicted_text : 魚心あれば魚心
================================================================================
entered_text : 馬の[MASK]に念仏
predicted_text : 馬の上に念仏
================================================================================
entered_text : [MASK]で鯛を釣る
predicted_text : 釣りで鯛を釣る
================================================================================
entered_text : 鬼に[MASK]
predicted_text : 鬼になる
================================================================================
entered_text : 鬼の居ぬ間に[MASK]
predicted_text : 鬼の居ぬ間に...
================================================================================
entered_text : 鬼の目にも[MASK]
predicted_text : 鬼の目にも見える
================================================================================
entered_text : 溺れるものは[MASK]をも掴む
predicted_text : 溺れるものは何をも掴む
================================================================================
entered_text : 火中の[MASK]を拾う
predicted_text : 火中の物を拾う
================================================================================
entered_text : [MASK]の川流れ
predicted_text : 川の川流れ
================================================================================
entered_text : 果報は[MASK]待て
predicted_text : 果報はしばらく待て
================================================================================
entered_text : 聞くは一時の恥[MASK]は一生の恥
predicted_text : 聞くは一時の恥聞くは一生の恥
================================================================================

考察

とても面白い予測結果が得られました。
ことわざは通常の文章とは異なる部分も多く、一つも正解することはできませんでした。

一方、完全に正解していなくても大体の意味はあっているかなというケースもあり、改めて BERT モデルの凄さを感じました。

おわりに

今回の記事では Google Colaboratory における BERT 環境の構築を行いました。また 構築した BERT 環境を用いて実際にことわざテストを解かせてみました。

BERT も今回用いた東北大のモデルの他にも複数のモデルが存在するので、
次回は複数のモデルで同じタスクを解いて比較してみたいと思います。

この記事が参考になった、という方は是非スキをお願いいたします。
今後の記事の方向性の参考にいたします。

参考

https://arxiv.org/abs/1810.04805
https://github.com/stockmarkteam/bert-book
https://xtech.nikkei.com/atcl/nxt/news/18/11068/
https://huggingface.co/rinna/japanese-roberta-base

この記事が気に入ったらサポートをしてみませんか?