![見出し画像](https://assets.st-note.com/production/uploads/images/23187988/rectangle_large_type_2_7a66eb7988dfb4752811e6932c60bf50.png?width=800)
TensorFlow Lite Model Maker 入門 / テキスト分類
「TensorFlow Lite Model Maker」で、テキスト分類のモデルを学習する方法をまとめました。
1. テキスト分類
「TensorFlow Lite Model Maker」は、「TensorFlow」のモデルの学習を簡単に行うことができるライブラリです。テキスト分類は、「Averging Word Embedding」と「BERT」をサポートしています。
2. インストール
「Google Colab」に「TensorFlow Lite Model Maker」をインストールするには、以下のコマンドを入力します。
!pip install git+git://github.com/tensorflow/examples.git#egg=tensorflow-examples[model_maker,metadata]
3. パッケージのインポート
パッケージをインポートします。
import numpy as np
import os
import tensorflow as tf
assert tf.__version__.startswith('2')
from tensorflow_examples.lite.model_maker.core.data_util.text_dataloader import TextClassifierDataLoader
from tensorflow_examples.lite.model_maker.core.model_export_format import ModelExportFormat
from tensorflow_examples.lite.model_maker.core.task.model_spec import AverageWordVecModelSpec
from tensorflow_examples.lite.model_maker.core.task.model_spec import BertModelSpec
from tensorflow_examples.lite.model_maker.core.task import text_classifier
4. 入力データの取得
入力データの取得方法は次のとおりです。
data_path = tf.keras.utils.get_file(
fname='aclImdb',
origin='http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz',
untar=True)
以下の場所に入力データがダウンロードされます。
/root/.keras/datasets/aclImdb
このデータセットには、訓練用の25000の映画レビューと、テスト用の25000の映画レビューが含まれています。そして、ポジティブとネガティブという2つのクラスがあります。
aclImdb
|__ train
|______ pos: [1962_10.txt, 2499_10.txt, ...]
|______ neg: [104_3.txt, 109_2.txt, ...]
|______ unsup: [12099_0.txt, 1424_0.txt, ...]
|__ test
|______ pos: [1384_9.txt, 191_9.txt, ...]
|______ neg: [1629_1.txt, 21_1.txt]
5. 学習の実行
学習の実行方法は、次のとおりです。
(1) モデル仕様を生成します。
model_spec = AverageWordVecModelSpec()
(2) TextClassifierDataLoader.from_folder()で入力データの読み込み、入力データを訓練データとテストデータと検証データに分割します。
train_data = TextClassifierDataLoader.from_folder(os.path.join(data_path, 'train'), model_spec=model_spec, class_labels=['pos', 'neg'])
test_data = TextClassifierDataLoader.from_folder(os.path.join(data_path, 'test'), model_spec=model_spec, is_training=False, shuffle=False)
train_data, validation_data = train_data.split(0.9)
(2) text_classifier.create()でモデルを学習します。
model = text_classifier.create(train_data, model_spec=model_spec, validation_data=validation_data)
INFO:tensorflow:Retraining the models...
INFO:tensorflow:Retraining the models...
Epoch 1/2
703/703 [==============================] - 5s 7ms/step - loss: 0.5212 - accuracy: 0.7638 - val_loss: 0.3176 - val_accuracy: 0.8802
Epoch 2/2
703/703 [==============================] - 5s 7ms/step - loss: 0.2859 - accuracy: 0.8883 - val_loss: 0.2711 - val_accuracy: 0.8914
(3) model.evaluate()でモデルを評価します。
loss, accuracy = model.evaluate(test_data)
782/782 [==============================] - 2s 3ms/step - loss: 0.3098 - accuracy: 0.8736
(4) model.export()でモデルをエクスポートします。「with_metadata=True」でメタデータを付加しています。
model.export('movie_review_classifier.tflite', 'text_label.txt', 'vocab.txt')
成功すると、「movie_review_classifier.tflite」と「sample_data text_label.txt」と「vocab.txt」が出力されます。
6. モデルの切り替え
モデルの切り替えは、text_classifier.create()の「model_spec」で行います。
# Averging Word Embedding
model_spec = AverageWordVecModelSpec()
# BERT
model_spec = BertModelSpec()
この記事が気に入ったらサポートをしてみませんか?