TensorFlow GraphDefモデルをTensorFlow.jsにインポートする
1. TensorFlowのモデルファイル
「TensorFlow GraphDefモデル」は、次のいずれかの形式で保存できます。
・SavedModel
・Frozen Model
・Session Bundle
・Tensorflow Hub モジュール
2. TensorFlow.js web format
上記ファイルは、「TensorFlow.jsコンバータ」を使うことで、「TensorFlow.js web format」に変換できます。これは、「TensorFlow.js」にロードできる形式です。
「TensorFlow.js web format」は、1つの「model.json」と複数の「重みファイル」を含むフォルダとして出力されます。「model.json」には、トポロジー(レイヤーの説明と接続方法)と重みのマニフェストが含まれています。
3. TensorFlowのモデルファイルをTensorFlow.js web formatに変換
(1) 「TensorFlow.js」のパッケージをインストール。
$ pip install tensorflowjs==1.7.3
(2) コンバータの実行
「TensorFlow GraphDefモデル」を「TensorFlow.js web format」に変換します。
◎ SavedModelの例
tensorflowjs_converter \
--input_format=tf_saved_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
--saved_model_tags=serve \
saved_model jsmodel
◎ FrozenModelの例
tensorflowjs_converter \
--input_format=tf_frozen_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
frozen_model.pb jsmodel
◎ Tensorflow Hub moduleの例
tensorflowjs_converter \
--input_format=tf_hub \
'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' jsmodel
「引数」は次の2つです。
input_path : TensorFlowのモデルファイルのパス
output_path : 出力フォルダのパス
「オプション」は次の5つです。
--input_format : 入力モデルの形式
・tf_saved_model : SavedModel
・tf_frozen_model : Frozen Model
・tf_session_bundle : Session Bundle
・tf_hub : TensorFlow Hubモジュール
・keras : HDF5ファイル
--output_node_names : 出力ノードの名前。カンマ区切り。
--saved_model_tags : ロードするMetaGraphDefのタグ。カンマ区切り。SavedModelでのみ適用。
--signature_name : TensorFlow Hubモジュール変換にのみ適用。https://www.tensorflow.org/hub/common_signatures/を参照。
詳細なヘルプは以下のコマンドを実行してください。
$ tensorflowjs_converter --help
4. TensorFlow.js web formatをTensorFlow.jsに読み込む
(1) tfjs-converter npmパッケージのインストール。
$ npm install @tensorflow/tfjs
(2) Frozen Modelをインスタンス化し、推論を実行。
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';
const MODEL_URL = 'model_directory/model.json';
const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));
MobileNetデモを参照してください。
loadGraphModel()は追加の LoadOptions パラメータを受け付けており、これを使用して認証情報やカスタムヘッダをリクエストと共に送信することができます。詳細はloadGraphModel()ドキュメントを参照してください。
5. サポートされるオペレーション
現在、TensorFlow.jsはTensorFlow演算の限定されたセットをサポートしています。モデルでサポートされていないオペレーションが使用されている場合、tensorflowjs_converterスクリプトは失敗し、モデルでサポートされていないオペレーションのリストが出力されます。
6. 重みのみのロード
重みのみをロードする場合は、次のコードを使用できます。
import * as tf from '@tensorflow/tfjs';
const weightManifestUrl = "https://example.org/model/weights_manifest.json";
const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
this.weightManifest, "https://example.org/model");
この記事が気に入ったらサポートをしてみませんか?