見出し画像

TensorFlow Lite wrapper code generator

1. TensorFlow Lite wrapper code generator

TensorFlow Lite wrapper code generator」は、メタデータを付加した「TensorFlow Liteモデル」から、プラットフォーム固有のラッパーコードを生成するツールです。

メタデータは、「TensorFlow Lite Model Maker」「AI Hub」で生成したモデルには付加されています。ラッパーコードにより、「ByteBuffer」と直接対話する必要がなくなります。開発者は代わりに「Bitmap」や「Rect」などのオブジェクトを使用して「TensorFlow Liteモデル」を操作できます。

【注意】「TensorFlow Lite wrapper code generator」は現在はベータ版です。Androidのみをサポートしています。

2. ラッパーコードの生成

ラッパーコードを生成するには、「tflite-support」が必要です。「Google Colab」で「tflite-support」をインストールするコマンドは、次のとおりです。

!pip install tflite-support

「TensorFlow Model Maker」で生成した「画像分類のモデル」のラッパーコードを生成するコマンドは、次のとおりです。

!tflite_codegen --model=./image_classifier.tflite \
   --package_name=org.tensorflow.lite.classify \
   --model_class_name=ImageClassificationModel \
   --destination=./classify_wrapper

結果は、「destination」で指定した「./classify_wrapper」フォルダに出力されます。

結果をzip圧縮して、「Android Studioプロジェクト」にダウンロードして利用します。

# zip圧縮
!zip -r classify_wrapper.zip classify_wrapper/

# ダウンロード
from google.colab import files
files.download('classify_wrapper.zip')

3. ラッパーコードの使用

(1) ダウンロードしたzip圧縮したラッパーコードを解凍。
(2) 「File → New → Import Module」で解凍したラッパーコードを選択。
モジュール名は「:classify_wrapper」になります。

画像1

(3) build.gradle (Module: app)に以下を追加。

android {
    <<省略>>

    aaptOptions {
        noCompress "tflite"
    }
}
dependencies {
    <<省略>>

    implementation project(":classify_wrapper")
}​

(4) モデルを使用します。

// 推論の準備
try {
    this.imageClassifier = new ImageClassificationModel(this);
} catch (Exception e){
    e.printStackTrace();
}
// 推論の実行
ImageClassificationModel.Inputs inputs = imageClassifier.createInputs();
inputs.loadImage(bitmap);
ImageClassificationModel.Outputs outputs = imageClassifier.run(inputs);
Map<String, Float> map = outputs.getProbability();

4. モデルの推論の高速化

ラッパーコードは、デリゲートとスレッド数によって高速化する方法を提供します。モデルオブジェクトを初期化するときに設定できます。

NNAPIデリゲートと最大3つのスレッドを使用する設定は、次のとおりです。

try {
    this.imageClassifier = new ImageClassificationModel(this, Model.Device.NNAPI, 3);
} catch (Exception e) {
    e.printStackTrace();
}

5. 全ソースコード

全ソースコードは次のとおりです。

package net.npaka.imageclassificationex;

import androidx.appcompat.app.AppCompatActivity;
import androidx.camera.core.CameraX;
import androidx.camera.core.ImageAnalysis;
import androidx.camera.core.ImageAnalysisConfig;
import androidx.camera.core.Preview;
import androidx.camera.core.PreviewConfig;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;

import android.Manifest;
import android.content.pm.PackageManager;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.ImageFormat;
import android.graphics.Matrix;
import android.graphics.Point;
import android.graphics.Rect;
import android.graphics.YuvImage;
import android.media.Image;
import android.os.Bundle;
import android.view.TextureView;
import android.view.ViewGroup;
import android.widget.RelativeLayout;
import android.widget.TextView;
import android.widget.Toast;

import org.tensorflow.lite.classify.ImageClassificationModel;

import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executors;

// MainActivity
public class MainActivity extends AppCompatActivity {
    // 定数
    private final int REQUEST_CODE_PERMISSIONS = 101;
    private final String[] REQUIRED_PERMISSIONS = new String[]{
        Manifest.permission.CAMERA};

    // UI
    private TextureView textureView;
    private TextView textView;

    // 推論
    private ImageClassificationModel imageClassifier;

    // 生成時に呼ばれる
    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        // UI
        this.textureView = findViewById(R.id.texture_view);
        this.textView = findViewById(R.id.text_view);

        // 推論の準備
        try {
            this.imageClassifier = new ImageClassificationModel(this);
        } catch (Exception e){
            e.printStackTrace();
        }

        // パーミッションのチェック
        if (allPermissionsGranted()) {
            this.textureView.post(() -> startCamera());
        } else {
            ActivityCompat.requestPermissions(this,
                REQUIRED_PERMISSIONS, REQUEST_CODE_PERMISSIONS);
        }
    }

    // パーミッション許可のリクエストの結果の取得
    @Override
    public void onRequestPermissionsResult(int requestCode,
        String[] permissions, int[] grantResults) {
        if (requestCode == REQUEST_CODE_PERMISSIONS) {
            if (allPermissionsGranted()) {
                startCamera();
            } else {
                Toast.makeText(this, "ユーザーから権限が許可されていません。",
                    Toast.LENGTH_SHORT).show();
                finish();
            }
        }
    }

    // 全てのパーミッション許可
    private boolean allPermissionsGranted() {
        for (String permission : REQUIRED_PERMISSIONS) {
            if (ContextCompat.checkSelfPermission(this, permission)
                != PackageManager.PERMISSION_GRANTED) {
                return false;
            }
        }
        return true;
    }

    // カメラの開始
    private void startCamera() {
        // プレビュー
        PreviewConfig previewConfig = new PreviewConfig.Builder().build();
        Preview preview = new Preview(previewConfig);
        preview.setOnPreviewOutputUpdateListener(
            output -> {
                // SurfaceTextureの更新
                ViewGroup parent = (ViewGroup)this.textureView.getParent();
                parent.removeView(this.textureView);
                parent.addView(this.textureView, 0);

                // SurfaceTextureをTextureViewに指定
                this.textureView.setSurfaceTexture(output.getSurfaceTexture());

                // レイアウトの調整
                Point point = new Point();
                getWindowManager().getDefaultDisplay().getSize(point);
                int w = point.x;
                int h = point.x * 4 / 3;
                RelativeLayout.LayoutParams params = new RelativeLayout.LayoutParams(w, h);
                params.addRule(RelativeLayout.CENTER_IN_PARENT);
                textureView.setLayoutParams(params);
            });

        // 画像の解析
        ImageAnalysisConfig config = new ImageAnalysisConfig.Builder()
            .setImageReaderMode(ImageAnalysis.ImageReaderMode.ACQUIRE_LATEST_IMAGE)
            .build();
        ImageAnalysis imageAnalysis = new ImageAnalysis(config);
        imageAnalysis.setAnalyzer(Executors.newSingleThreadExecutor(),
            (image, rotationDegrees) -> {
                // Bitmapの生成
                Bitmap bitmap = imageToToBitmap(image.getImage(), rotationDegrees);

                // 推論の実行
                ImageClassificationModel.Inputs inputs = imageClassifier.createInputs();
                inputs.loadImage(bitmap);
                ImageClassificationModel.Outputs outputs = imageClassifier.run(inputs);
                Map<String, Float> map = outputs.getProbability();

                // ソート
                List<Map.Entry<String, Float>> results = new ArrayList<>(map.entrySet());
                Collections.sort(results, (obj1, obj2) -> obj2.getValue().compareTo(obj1.getValue()));

                // Bitmapの解放
                bitmap.recycle();

                // ラベルに上位3件の表示
                this.textView.post(() -> {
                    String text = "\n";
                    int count = 0;
                    for(Map.Entry<String, Float> entry : results) {
                        text += entry.getKey()+" : "+(int)(entry.getValue()*100)+"%\n";
                        count++;
                        if (count >= 3) break;
                    }
                    textView.setText(text);
                });
            });

        // バインド
        CameraX.bindToLifecycle(this, preview, imageAnalysis);
    }

    // ImageProxy → Bitmap
    private Bitmap imageToToBitmap(Image image, int rotationDegrees) {
        byte[] data = imageToByteArray(image);
        Bitmap bitmap = BitmapFactory.decodeByteArray(data, 0, data.length);
        if (rotationDegrees == 0) {
            return bitmap;
        } else {
            return rotateBitmap(bitmap, rotationDegrees);
        }
    }

    // Bitmapの回転
    private Bitmap rotateBitmap(Bitmap bitmap, int rotationDegrees) {
        Matrix mat = new Matrix();
        mat.postRotate(rotationDegrees);
        return Bitmap.createBitmap(bitmap, 0, 0,
            bitmap.getWidth(), bitmap.getHeight(), mat, true);
    } 

    // Image → JPEGのバイト配列
    private byte[] imageToByteArray(Image image) {
        byte[] data = null;
        if (image.getFormat() == ImageFormat.JPEG) {
            Image.Plane[] planes = image.getPlanes();
            ByteBuffer buffer = planes[0].getBuffer();
            data = new byte[buffer.capacity()];
            buffer.get(data);
            return data;
        } else if (image.getFormat() == ImageFormat.YUV_420_888) {
            data = NV21toJPEG(YUV_420_888toNV21(image),
                    image.getWidth(), image.getHeight());
        }
        return data;
    }

    // YUV_420_888 → NV21
    private byte[] YUV_420_888toNV21(Image image) {
        byte[] nv21;
        ByteBuffer yBuffer = image.getPlanes()[0].getBuffer();
        ByteBuffer uBuffer = image.getPlanes()[1].getBuffer();
        ByteBuffer vBuffer = image.getPlanes()[2].getBuffer();
        int ySize = yBuffer.remaining();
        int uSize = uBuffer.remaining();
        int vSize = vBuffer.remaining();
        nv21 = new byte[ySize + uSize + vSize];
        yBuffer.get(nv21, 0, ySize);
        vBuffer.get(nv21, ySize, vSize);
        uBuffer.get(nv21, ySize + vSize, uSize);
        return nv21;
    }

    // NV21 → JPEG
    private byte[] NV21toJPEG(byte[] nv21, int width, int height) {
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        YuvImage yuv = new YuvImage(nv21, ImageFormat.NV21, width, height, null);
        yuv.compressToJpeg(new Rect(0, 0, width, height), 100, out);
        return out.toByteArray();
    }
}

マニフェストファイルに「CAMERA」のパーミッションを追加する必要もあります。

<uses-permission android:name="android.permission.CAMERA" />


この記事が気に入ったらサポートをしてみませんか?