見出し画像

TensorFlow Lite Model Maker 入門 / テキスト分類

TensorFlow Lite Model Maker」で、テキスト分類のモデルを学習する方法をまとめました。

1. テキスト分類

「TensorFlow Lite Model Maker」は、「TensorFlow」のモデルの学習を簡単に行うことができるライブラリです。テキスト分類は、「Averging Word Embedding」と「BERT」をサポートしています。

2. インストール

「Google Colab」に「TensorFlow Lite Model Maker」をインストールするには、以下のコマンドを入力します。

!pip install git+git://github.com/tensorflow/examples.git#egg=tensorflow-examples[model_maker,metadata]

3. パッケージのインポート

パッケージをインポートします。

import numpy as np
import os

import tensorflow as tf
assert tf.__version__.startswith('2')

from tensorflow_examples.lite.model_maker.core.data_util.text_dataloader import TextClassifierDataLoader
from tensorflow_examples.lite.model_maker.core.model_export_format import ModelExportFormat
from tensorflow_examples.lite.model_maker.core.task.model_spec import AverageWordVecModelSpec
from tensorflow_examples.lite.model_maker.core.task.model_spec import BertModelSpec
from tensorflow_examples.lite.model_maker.core.task import text_classifier

4. 入力データの取得

入力データの取得方法は次のとおりです。

data_path = tf.keras.utils.get_file(
    fname='aclImdb',
    origin='http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz',
    untar=True)

以下の場所に入力データがダウンロードされます。

/root/.keras/datasets/aclImdb

このデータセットには、訓練用の25000の映画レビューと、テスト用の25000の映画レビューが含まれています。そして、ポジティブとネガティブという2つのクラスがあります。

aclImdb
|__ train
    |______ pos: [1962_10.txt, 2499_10.txt, ...]
    |______ neg: [104_3.txt, 109_2.txt, ...]
    |______ unsup: [12099_0.txt, 1424_0.txt, ...]
|__ test
    |______ pos: [1384_9.txt, 191_9.txt, ...]
    |______ neg: [1629_1.txt, 21_1.txt]

5. 学習の実行

学習の実行方法は、次のとおりです。

(1) モデル仕様を生成します。

model_spec = AverageWordVecModelSpec()

(2) TextClassifierDataLoader.from_folder()で入力データの読み込み、入力データを訓練データとテストデータと検証データに分割します。

train_data = TextClassifierDataLoader.from_folder(os.path.join(data_path, 'train'), model_spec=model_spec, class_labels=['pos', 'neg'])
test_data = TextClassifierDataLoader.from_folder(os.path.join(data_path, 'test'), model_spec=model_spec, is_training=False, shuffle=False)
train_data, validation_data = train_data.split(0.9)

(2) text_classifier.create()でモデルを学習します。

model = text_classifier.create(train_data, model_spec=model_spec, validation_data=validation_data)
INFO:tensorflow:Retraining the models...
INFO:tensorflow:Retraining the models...
Epoch 1/2
703/703 [==============================] - 5s 7ms/step - loss: 0.5212 - accuracy: 0.7638 - val_loss: 0.3176 - val_accuracy: 0.8802
Epoch 2/2
703/703 [==============================] - 5s 7ms/step - loss: 0.2859 - accuracy: 0.8883 - val_loss: 0.2711 - val_accuracy: 0.8914

(3) model.evaluate()でモデルを評価します。

loss, accuracy = model.evaluate(test_data)
782/782 [==============================] - 2s 3ms/step - loss: 0.3098 - accuracy: 0.8736

(4) model.export()でモデルをエクスポートします。「with_metadata=True」でメタデータを付加しています。

model.export('movie_review_classifier.tflite', 'text_label.txt', 'vocab.txt')

成功すると、「movie_review_classifier.tflite」と「sample_data text_label.txt」と「vocab.txt」が出力されます。

6. モデルの切り替え

モデルの切り替えは、text_classifier.create()の「model_spec」で行います。

# Averging Word Embedding
model_spec = AverageWordVecModelSpec()

# BERT
model_spec = BertModelSpec()


この記事が気に入ったらサポートをしてみませんか?