見出し画像

TensorFlow Liteのカスタムオペレーション

1. はじめに

「TensorFlow Lite」は、TensorFlowのオペレーションをすべてサポートしているわけではありません。モデルにサポートされていないオペレーションが含まれている時は、ユーザー実装のC/C++コードを「カスタムオペレーション」として使用できます。
(情報源)

2. カスタムオペレーションの使用手順

カスタムオペレーションの使用手順は次の通りです。

(1)グラフ内で正しい名前のオペレーションを参照していることを確認。
(2)「カスタムオペレーション」を「TensorFlow Lite」に登録。
ランタイムがグラフ内のオペレーションをカスタムオペレーションにマップできるようになります。
(3)オペレーションの正確さとパフォーマンスをプロファイリング。
カスタムオペレーションのみをテストする場合は、カスタムオペレーションのみで、「benchmark_model」を使用してモデルを作成することをお勧めします。

3. Sinのカスタムオペレーションの作成

「TensorFlow Lite」にはないオペレーションをサポートする例を見てみましょう。Sinオペレーションを使用して、関数「y = sin(x + offset)」の非常に単純なモデルを構築していると仮定します。ここで、offsetは訓練可能です。

◎TensorFlowモデルの生成
TensorFlowモデルを訓練するコードは、次のようになります。

offset = tf.get_variable("offset", [1,], tf.float32)
x = tf.placeholder(tf.float32, shape=(None,))
y = tf.sin(x + offset)
y_ = tf.placeholder(tf.float32, shape=(None,))
loss = tf.reduce_sum(tf.square(y - y_))
optimizer = tf.train.GradientDescentOptimizer(0.001)
train = optimizer.minimize(loss)

◎TensorFlow Liteモデルへの変換と実行
「--allow_custom_ops」引数を指定して「TensorFlow Lite Optimizing Converter」を使用して、このモデルをTensorflow Lite形式に変換し、デフォルトのインタープリタで実行すると、次のエラーメッセージが表示されます。

Didn't find custom op for name 'Sin'
Registration failed.

◎TensorFlow Liteランタイムのカスタムオペレーションの定義
「TensorFlow Lite」で「カスタムオペレーション」を使用するために必要なことは、「Prepare」「Eval」の2つの関数を定義し、「TfLiteRegistration」を構築することのみです。コードは次のようになります。

TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) {
  using namespace tflite;
  TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);

  const TfLiteTensor* input = GetInput(context, node, 0);
  TfLiteTensor* output = GetOutput(context, node, 0);

  int num_dims = NumDimensions(input);

  TfLiteIntArray* output_size = TfLiteIntArrayCreate(num_dims);
  for (int i=0; i<num_dims; ++i) {
    output_size->data[i] = input->dims->data[i];
  }

  return context->ResizeTensor(context, output, output_size);
}

TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
  using namespace tflite;
  const TfLiteTensor* input = GetInput(context, node,0);
  TfLiteTensor* output = GetOutput(context, node,0);

  float* input_data = input->data.f;
  float* output_data = output->data.f;

  size_t count = 1;
  int num_dims = NumDimensions(input);
  for (int i = 0; i < num_dims; ++i) {
    count *= input->dims->data[i];
  }

  for (size_t i=0; i<count; ++i) {
    output_data[i] = sin(input_data[i]);
  }
  return kTfLiteOk;
}

TfLiteRegistration* Register_SIN() {
  static TfLiteRegistration r = {nullptr, nullptr, SinPrepare, SinEval};
  return &r;
}

OpResolverを初期化するときに、「カスタムオペレーション」を追加します。これにより、「TensorFlow Lite」に「カスタムオペレーション」が追加され、使用できるようになります。

「TfLiteRegistration」の最後の2つの引数は、「カスタムオペレーション」に対して定義したSinPrepare()とSinEval()に対応します。オペレーションで使用される変数を初期化し、スペースを解放するために2つの関数を使用した場合、Init()とFree()は、TfLiteRegistrationの最初の2つの引数に追加されます。この例ではnullptrに設定されています。

tflite::ops::builtin::BuiltinOpResolver builtins;
builtins.AddCustom("Sin", Register_SIN());

Javaでカスタムオペレーションを作成する場合、現在、独自のカスタムJNIレイヤーを構築し、このJNIコードで独自のAARをコンパイルする必要があります。

同様に、これらのオペレーションをPythonで使用できるようにする場合は、Pythonラッパーコードに登録を配置できます。

単一のオペレーションではなく、一連の操作をサポートするために、上記と同様の手順で対応できます。必要な数のAddCustomオペレーションを追加するだけです。さらに、「BuiltinOpResolver」では、「AddBuiltin」を使用して組み込みの実装をオーバーライドすることもできます。


この記事が気に入ったらサポートをしてみませんか?