見出し画像

Huggingface Transformers 入門 (25) - 日本語の要約の学習

「Huggingface Transformers」による日本語の要約の学習手順をまとめました。

・Huggingface Transformers 4.4.2
・Huggingface Datasets 1.2.1

前回

1. 日本語T5事前学習済みモデル

モデルは、「日本語T5事前学習済みモデル」が公開されたので、ありがたく使わせてもらいます。

2. 要約データセットの準備

はじめに、「要約データセット」を準備します。

「ThreeLineSummaryDataset」の「train.csv」にLivedoorニュースのIDのリストがあるので、それを使います。

Livedoorニュースの「3行要約」と「本文」をスクレイピングで取得します。bs4でスクレイピングするコードは次のとおりです。サーバーに負荷をかけないように、10秒に1回だけ通信するようにしています。

from urllib.request import urlopen
from bs4 import BeautifulSoup
from bs4.element import NavigableString
from pprint import pprint
import time


# コンテンツの取得
def get_content(id):
   time.sleep(10)
   URL = 'https://news.livedoor.com/article/detail/'+id+'/'
   print(URL)
   try:
       with urlopen(URL) as res:
           # 本文
           output1 = ''
           html = res.read().decode('euc_jp', 'ignore')
           soup = BeautifulSoup(html, 'html.parser')
           lineList = soup.select('.articleBody p')
           for line in lineList:
               if len(line.contents) > 0 and type(line.contents[0]) == NavigableString:
                   output1 += line.contents[0].strip()
           if output1 == '': # 記事がない
               return
           output1 += '\n'

           # 3行要約
           output0 = ''
           summaryList = soup.select('.summaryList li')
           for summary in summaryList:
               output0 += summary.contents[0].strip()+'\t'
           if output0 == '': # 記事がない
               return

           # 出力
           print(output0+output1)
           with open('output.tsv', mode='a') as f:
               f.writelines(output0+output1)
   except Exception:
       print('Exception')

# IDリストの生成の取得
idList = []
with open('train.csv', mode='r') as f:
   lines = f.readlines()
   for line in lines:
       id = line.strip().split(',')[3].split('.')[0]
       idList.append(id)

# コンテンツの取得
for i in range(0, 10): # 取得したい記事のINDEXをここで指定
   print('index:', i)
   get_content(idList[i])

「output.tsv」に以下のようなtsvのデータ形式で取得できます。<tab>はタブ(\t)になります。

要約1 <tab> 要約2 <tab> 要約3 <tab> 本文
   :​

今回は練習のため1000件ほど取得しました。

3. 要約データセットの書式の変換

「output.tsv」を「Numbers」(またはExcelやテキストエディタ)で以下の書式に変換し、「train.csv」(900件)と「dev.csv」(100件)に分割します。

text,summary
本文,要約1
本文,要約1
本文,要約1
    :

​4. 日本語の要約の学習

(1) データの永続化。

# データの永続化
from google.colab import drive 
drive.mount('/content/drive')
!mkdir -p '/content/drive/My Drive/work/'
%cd '/content/drive/My Drive/work/'

(2) ソースからの「Huggingface Transformers」のインストール。

# ソースからのHuggingface Transformersのインストール
!git clone https://github.com/huggingface/transformers -b v4.4.2
!pip install -e transformers

(3) メニュー「ランタイム → ランタイムを再起動」で「Google Colab」を再起動し、作業フォルダに戻る。

# メニュー「ランタイム → ランタイムを再起動」で「Google Colab」を再起動

# 作業フォルダに戻る
%cd '/content/drive/My Drive/work/'

(4) 「Huggingface Datasets」のインストール。

# Huggingface Datasetsのインストール
!pip install datasets==1.2.1

(5) 「rouge_score」と「sentencepiece」のインストール。

!pip install rouge_score
!pip install sentencepiece

(6) ファインチューニングの実行。

%%time

# ファインチューニングの実行
!python ./transformers/examples/seq2seq/run_summarization.py \
    --model_name_or_path=sonoisa/t5-base-japanese \
    --do_train \
    --do_eval \
    --train_file=train2.csv \
    --validation_file=dev2.csv \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --save_steps=5000 \
    --save_total_limit=3 \
    --output_dir=output/ \
    --predict_with_generate \
    --use_fast_tokenizer=False
CPU times: user 3.87 s, sys: 542 ms, total: 4.41 s
Wall time: 8min 4s

5. 日本語の要約の推論

日本語の要約の推論のコードは、次のとおりです。

import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM

# モデルとトークナイザーの準備
model = AutoModelForSeq2SeqLM.from_pretrained('output/')    
tokenizer = AutoTokenizer.from_pretrained('sonoisa/t5-base-japanese') 

# テキスト
text = r"""
地球の人里離れた山奥に住む尻尾の生えた少年・孫悟空はある日、西の都からやって来た少女ブルマと出会う。そこで、7つ集めると神龍が現れ、どんな願いでも一つだけ叶えてくれるというドラゴンボールの存在を、さらに育ての親である孫悟飯の形見として大切に持っていた球がその1つ「四星球」であることを知り、ブルマと共に残りのドラゴンボールを探す旅に出る。人さらいのウーロンや盗賊のヤムチャなどを巻き込んだボール探しの末、世界征服を企むピラフ一味にボールを奪われ神龍を呼び出されるが、ウーロンがとっさに言い放った下らない願いを叶えてもらうことで一味の野望を阻止する。
"""

# テキストをテンソルに変換
inputs = tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True)

# 推論
model.eval()
with torch.no_grad():
    summary_ids = model.generate(inputs) #, max_length=512, min_length=5, length_penalty=5., num_beams=2)
    summary = tokenizer.decode(summary_ids[0])
    print(summary)
<pad><extra_id_0>ドラゴンボールを探す旅に出た孫悟空は、7つ集めると神龍が現れるというドラゴンボールを探す旅に出る。</s>

次回



この記事が気に入ったらサポートをしてみませんか?