見出し画像

TensorFlow Lite Model Maker 入門 / 画像分類

TensorFlow Lite Model Maker」で、画像分類のモデルの学習する方法をまとめました。

1. 画像分類

「TensorFlow Lite Model Maker」は、「TensorFlow」のモデルの学習を簡単に行うことができるライブラリです。画像分類は、「EfficientNet-Lite」「MobileNetV2」「ResNet50」をサポートしています。

2. インストール

「Google Colab」に「TensorFlow Lite Model Maker」をインストールするには、以下のコマンドを入力します。

!pip install git+git://github.com/tensorflow/examples.git#egg=tensorflow-examples[model_maker,metadata]

3. パッケージのインポート

パッケージをインポートします。

import numpy as np

import tensorflow as tf
assert tf.__version__.startswith('2')

from tensorflow_examples.lite.model_maker.core.data_util.image_dataloader import ImageClassifierDataLoader
from tensorflow_examples.lite.model_maker.core.task import image_classifier
from tensorflow_examples.lite.model_maker.core.task.model_spec import mobilenet_v2_spec
from tensorflow_examples.lite.model_maker.core.task.model_spec import ImageModelSpec

import matplotlib.pyplot as plt

4. 入力データの取得

入力データの取得方法は次のとおりです。

image_path = tf.keras.utils.get_file(
     'flower_photos',
     'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
     untar=True)

以下の場所に入力データがダウンロードされます。

/root/.keras/datasets/flower_photos

このデータセットには、5つのクラスに属する3670枚の花の画像が含まれています。

flower_photos
|__ daisy
    |______ 100080576_f52e8ee070_n.jpg
    |______ 14167534527_781ceb1b7a_n.jpg
    |______ ...
|__ dandelion
    |______ 10043234166_e6dd915111_n.jpg
    |______ 1426682852_e62169221f_m.jpg
    |______ ...
|__ roses
    |______ 102501987_3cdb8e5394_n.jpg
    |______ 14982802401_a3dfb22afb.jpg
    |______ ...
|__ sunflowers
    |______ 12471791574_bb1be83df4.jpg
    |______ 15122112402_cafa41934f.jpg
    |______ ...
|__ tulips
    |______ 13976522214_ccec508fe7.jpg
    |______ 14487943607_651e8062a1_m.jpg
    |______ ...

5. 学習の実行

学習の実行方法は、次のとおりです。

(1) ImageClassifierDataLoader.from_folder()で入力データの読み込み、入力データを訓練データ(0.8)とテストデータ(0.1)と検証データ(0.1)に分割します。

data = ImageClassifierDataLoader.from_folder(image_path)
train_data, rest_data = data.split(0.8)
validation_data, test_data = rest_data.split(0.5)

(2) image_classifier.create()でモデルを学習します。

model = image_classifier.create(train_data, validation_data=validation_data)
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
hub_keras_layer_v1v2 (HubKer (None, 1280)              3413024
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0
_________________________________________________________________
dense (Dense)                (None, 5)                 6405
=================================================================
Total params: 3,419,429
Trainable params: 6,405
Non-trainable params: 3,413,024
_________________________________________________________________
None
INFO:tensorflow:Retraining the models...
INFO:tensorflow:Retraining the models...
Epoch 1/5
103/103 [==============================] - 16s 153ms/step - loss: 0.8684 - accuracy: 0.7715
Epoch 2/5
103/103 [==============================] - 16s 153ms/step - loss: 0.6587 - accuracy: 0.8932
Epoch 3/5
103/103 [==============================] - 16s 152ms/step - loss: 0.6267 - accuracy: 0.9072
Epoch 4/5
103/103 [==============================] - 16s 153ms/step - loss: 0.6029 - accuracy: 0.9284
Epoch 5/5
103/103 [==============================] - 16s 153ms/step - loss: 0.5916 - accuracy: 0.9311

(3) model.evaluate()でモデルを評価します。

loss, accuracy = model.evaluate(test_data)
12/12 [==============================] - 2s 131ms/step - loss: 0.6191 - accuracy: 0.9074

(4) model.export()でモデルをエクスポートします。「with_metadata=True」でメタデータを付加しています。

model.export('image_classifier.tflite', 'image_labels.txt', with_metadata=True)

成功すると、「image_classifier.tflite」と「image_labels.txt」が出力されます。

6. モデルの切り替え

モデルの切り替えは、image_classifier.create()の「model_spec」で行います。

# EfficinetNet-Lite2.
model = image_classifier.create(data, model_spec=efficientnet_lite2_spec)

# MobileNetV2
model = image_classifier.create(train_data, model_spec=mobilenet_v2_spec


# ResNet 50.
model = image_classifier.create(data, model_spec=resnet_50_spec)


この記事が気に入ったらサポートをしてみませんか?