見出し画像

TensorFlow GraphDefモデルをTensorFlow.jsにインポートする

1. TensorFlowのモデルファイル

TensorFlow GraphDefモデル」は、次のいずれかの形式で保存できます。

SavedModel
Frozen Model
Session Bundle
Tensorflow Hub モジュール

2. TensorFlow.js web format

上記ファイルは、「TensorFlow.jsコンバータ」を使うことで、「TensorFlow.js web format」に変換できます。これは、「TensorFlow.js」にロードできる形式です。

「TensorFlow.js web format」は、1つの「model.json」と複数の「重みファイル」を含むフォルダとして出力されます。「model.json」には、トポロジー(レイヤーの説明と接続方法)と重みのマニフェストが含まれています。

3. TensorFlowのモデルファイルをTensorFlow.js web formatに変換

(1) 「TensorFlow.js」のパッケージをインストール。

$ pip install tensorflowjs==1.7.3

(2) コンバータの実行
「TensorFlow GraphDefモデル」を「TensorFlow.js web format」に変換します。

◎ SavedModelの例

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    saved_model jsmodel

◎ FrozenModelの例

tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    frozen_model.pb jsmodel

◎ Tensorflow Hub moduleの例

tensorflowjs_converter \
    --input_format=tf_hub \
    'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' jsmodel

「引数」は次の2つです。

input_path : TensorFlowのモデルファイルのパス
output_path : 出力フォルダのパス

「オプション」は次の5つです。

--input_format : 入力モデルの形式
 ・tf_saved_model : SavedModel
 ・tf_frozen_model : Frozen Model
 ・tf_session_bundle : Session Bundle
 ・tf_hub : TensorFlow Hubモジュール
 ・keras : HDF5ファイル
--output_node_names : 出力ノードの名前。カンマ区切り。
--saved_model_tags : ロードするMetaGraphDefのタグ。カンマ区切り。SavedModelでのみ適用。
--signature_name : TensorFlow Hubモジュール変換にのみ適用。https://www.tensorflow.org/hub/common_signatures/を参照。

詳細なヘルプは以下のコマンドを実行してください。

$ tensorflowjs_converter --help

4. TensorFlow.js web formatをTensorFlow.jsに読み込む

(1) tfjs-converter npmパッケージのインストール。

$ npm install @tensorflow/tfjs

(2) Frozen Modelをインスタンス化し、推論を実行。

import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

const MODEL_URL = 'model_directory/model.json';

const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));

MobileNetデモを参照してください。

loadGraphModel()は追加の LoadOptions パラメータを受け付けており、これを使用して認証情報やカスタムヘッダをリクエストと共に送信することができます。詳細はloadGraphModel()ドキュメントを参照してください。

5. サポートされるオペレーション

現在、TensorFlow.jsはTensorFlow演算の限定されたセットをサポートしています。モデルでサポートされていないオペレーションが使用されている場合、tensorflowjs_converterスクリプトは失敗し、モデルでサポートされていないオペレーションのリストが出力されます。

6. 重みのみのロード

重みのみをロードする場合は、次のコードを使用できます。

import * as tf from '@tensorflow/tfjs';

const weightManifestUrl = "https://example.org/model/weights_manifest.json";

const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
    this.weightManifest, "https://example.org/model");


この記事が気に入ったらサポートをしてみませんか?