TenorFlow Lite 入門 / Androidによる画像分類
「CameraX」を使ってAndroidによるTensorFlow Liteの画像分類のコードを書き直しました。Camera 2 APIよりだいぶシンプルになった(サンプルにしては長いけど)。
1. Androidによる画像分類
「TensorFlow Lite」を使ってAndroidで画像分類を行います。端末の背面カメラに映るものをリアルタイムに画像分類し、可能性の高いラベルを3つ表示します。
2. バージョン
・compileSdkVersion 29
・minSdkVersion 26
・targetSdkVersion 29
・tensorflow-lite:0.1.7
3. 依存関係の追加
「build.gradle(Mudule:app)」に、「CameraX」と「TensorFlow Lite」のプロジェクトの依存関係を追加します。
android {
<<省略>>
aaptOptions {
noCompress "tflite"
}
compileOptions {
sourceCompatibility = '1.8'
targetCompatibility = '1.8'
}
}
dependencies {
<<省略>>
// CameraX
def camerax_version = '1.0.0-alpha06'
implementation "androidx.camera:camera-core:${camerax_version}"
implementation "androidx.camera:camera-camera2:${camerax_version}"
// TensorFlow Lite
implementation('org.tensorflow:tensorflow-lite:0.0.0-nightly') { changing = true }
implementation('org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly') { changing = true }
implementation('org.tensorflow:tensorflow-lite-support:0.0.0-nightly') { changing = true }
}
4. マニフェストファイルの設定
「CAMERA」のパーミッションを追加します。
<uses-permission android:name="android.permission.CAMERA" />
5. アセットの準備
プロジェクトの「app/src/main/assets」に、「Image classification」のページからダウンロードしたモデルとラベルを追加します。
・mobilenet_v1_1.0_224_quant.tflite
・labels_mobilenet_quant_v1_224.txt
6. レイアウトの設定
「activity_main.xml」に「TextureView」と「TextView」を追加します。
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<RelativeLayout android:layout_width="match_parent"
android:layout_height="match_parent"
android:background="#FFFFFF">
<TextureView
android:id="@+id/texture_view"
android:layout_width="match_parent"
android:layout_height="match_parent"/>
<RelativeLayout
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_alignParentBottom="true"
android:background="#590de4"
android:padding="8dp">
<TextView
android:id="@+id/text_view"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_centerInParent="true"
android:text=""
android:textColor="@android:color/white"
android:textSize="18dp" />
</RelativeLayout>
</RelativeLayout>
</androidx.constraintlayout.widget.ConstraintLayout>
7. UIの作成
画像分類を行うUIを作成します。
以下の処理を行なっています。
・パーミッション
・カメラのプレビューと解析
・ImageClassificationInterpriterにBitmapを渡して推論(後ほど説明)
◎ MainActivity.java
package net.npaka.imageclassificationex;
import androidx.appcompat.app.AppCompatActivity;
import androidx.camera.core.CameraX;
import androidx.camera.core.ImageAnalysis;
import androidx.camera.core.ImageAnalysisConfig;
import androidx.camera.core.Preview;
import androidx.camera.core.PreviewConfig;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;
import android.Manifest;
import android.content.pm.PackageManager;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.ImageFormat;
import android.graphics.Matrix;
import android.graphics.Rect;
import android.graphics.YuvImage;
import android.media.Image;
import android.os.Bundle;
import android.view.TextureView;
import android.view.ViewGroup;
import android.widget.TextView;
import android.widget.Toast;
import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executors;
// MainActivity
public class MainActivity extends AppCompatActivity {
// 定数
private final int REQUEST_CODE_PERMISSIONS = 101;
private final String[] REQUIRED_PERMISSIONS = new String[]{
Manifest.permission.CAMERA};
// UI
private TextureView textureView;
private TextView textView;
// 推論
private ImageClassificationInterpriter interpriter;
// 生成時に呼ばれる
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
// UI
this.textureView = findViewById(R.id.texture_view);
this.textView = findViewById(R.id.text_view);
// 推論
this.interpriter = new ImageClassificationInterpriter(this);
// パーミッションのチェック
if (allPermissionsGranted()) {
this.textureView.post(() -> startCamera());
} else {
ActivityCompat.requestPermissions(this,
REQUIRED_PERMISSIONS, REQUEST_CODE_PERMISSIONS);
}
}
// パーミッション許可のリクエストの結果の取得
@Override
public void onRequestPermissionsResult(int requestCode,
String[] permissions, int[] grantResults) {
if (requestCode == REQUEST_CODE_PERMISSIONS) {
if (allPermissionsGranted()) {
startCamera();
} else {
Toast.makeText(this, "ユーザーから権限が許可されていません。",
Toast.LENGTH_SHORT).show();
finish();
}
}
}
// 全てのパーミッション許可
private boolean allPermissionsGranted() {
for (String permission : REQUIRED_PERMISSIONS) {
if (ContextCompat.checkSelfPermission(this, permission)
!= PackageManager.PERMISSION_GRANTED) {
return false;
}
}
return true;
}
// カメラの開始
private void startCamera() {
// プレビュー
PreviewConfig previewConfig = new PreviewConfig.Builder().build();
Preview preview = new Preview(previewConfig);
preview.setOnPreviewOutputUpdateListener(
output -> {
// SurfaceTextureの更新
ViewGroup parent = (ViewGroup)this.textureView.getParent();
parent.removeView(this.textureView);
parent.addView(this.textureView, 0);
// SurfaceTextureをTextureViewに指定
this.textureView.setSurfaceTexture(output.getSurfaceTexture());
// レイアウトの調整
Point point = new Point();
getWindowManager().getDefaultDisplay().getSize(point);
int w = point.x;
int h = point.x * 4 / 3;
RelativeLayout.LayoutParams params = new RelativeLayout.LayoutParams(w, h);
params.addRule(RelativeLayout.CENTER_IN_PARENT);
textureView.setLayoutParams(params);
});
// 画像の解析
ImageAnalysisConfig config = new ImageAnalysisConfig.Builder()
.setImageReaderMode(ImageAnalysis.ImageReaderMode.ACQUIRE_LATEST_IMAGE)
.build();
ImageAnalysis imageAnalysis = new ImageAnalysis(config);
imageAnalysis.setAnalyzer(Executors.newSingleThreadExecutor(),
(image, rotationDegrees) -> {
// 推論
Bitmap bitmap = imageToToBitmap(image.getImage(), rotationDegrees);
List<Map.Entry<String, Float>> results = this.interpriter.predict(bitmap);
bitmap.recycle();
// ラベルに上位3件の表示
this.textView.post(() -> {
String text = "\n";
int count = 0;
for(Map.Entry<String, Float> entry : results) {
text += entry.getKey()+" : "+(int)(entry.getValue()*100)+"%\n";
count++;
if (count >= 3) break;
}
textView.setText(text);
});
});
// バインド
CameraX.bindToLifecycle(this, preview, imageAnalysis);
}
// ImageProxy → Bitmap
private Bitmap imageToToBitmap(Image image, int rotationDegrees) {
byte[] data = imageToByteArray(image);
Bitmap bitmap = BitmapFactory.decodeByteArray(data, 0, data.length);
if (rotationDegrees == 0) {
return bitmap;
} else {
return rotateBitmap(bitmap, rotationDegrees);
}
}
// Bitmapの回転
private Bitmap rotateBitmap(Bitmap bitmap, int rotationDegrees) {
Matrix mat = new Matrix();
mat.postRotate(rotationDegrees);
return Bitmap.createBitmap(bitmap, 0, 0,
bitmap.getWidth(), bitmap.getHeight(), mat, true);
}
// Image → JPEGのバイト配列
private byte[] imageToByteArray(Image image) {
byte[] data = null;
if (image.getFormat() == ImageFormat.JPEG) {
Image.Plane[] planes = image.getPlanes();
ByteBuffer buffer = planes[0].getBuffer();
data = new byte[buffer.capacity()];
buffer.get(data);
return data;
} else if (image.getFormat() == ImageFormat.YUV_420_888) {
data = NV21toJPEG(YUV_420_888toNV21(image),
image.getWidth(), image.getHeight());
}
return data;
}
// YUV_420_888 → NV21
private byte[] YUV_420_888toNV21(Image image) {
byte[] nv21;
ByteBuffer yBuffer = image.getPlanes()[0].getBuffer();
ByteBuffer uBuffer = image.getPlanes()[1].getBuffer();
ByteBuffer vBuffer = image.getPlanes()[2].getBuffer();
int ySize = yBuffer.remaining();
int uSize = uBuffer.remaining();
int vSize = vBuffer.remaining();
nv21 = new byte[ySize + uSize + vSize];
yBuffer.get(nv21, 0, ySize);
vBuffer.get(nv21, ySize, vSize);
uBuffer.get(nv21, ySize + vSize, uSize);
return nv21;
}
// NV21 → JPEG
private byte[] NV21toJPEG(byte[] nv21, int width, int height) {
ByteArrayOutputStream out = new ByteArrayOutputStream();
YuvImage yuv = new YuvImage(nv21, ImageFormat.NV21, width, height, null);
yuv.compressToJpeg(new Rect(0, 0, width, height), 100, out);
return out.toByteArray();
}
}
8. 画像分類
Bitmapを受け取り、各ラベルの確率を返します。
パラメータ定数を変更することで、別の画像分類モデルにも対応できます。
◎ ImageClassificationInterpriter.java
package net.npaka.imageclassificationex;
import android.content.Context;
import android.content.res.AssetFileDescriptor;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Rect;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.gpu.GpuDelegate;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
// 画像分類インタープリタ
public class ImageClassificationInterpriter {
// パラメータ定数
private static final int BATCH_SIZE = 1; //バッチサイズ
private static final int INPUT_PIXELS = 3; //入力ピクセル
private final static int INPUT_SIZE = 224; // 入力サイズ
private boolean IS_QUANTIZED = true; //量子化
private static final float IMAGE_MEAN = 127.5f;
private static final float IMAGE_STD = 127.5f;
// システム
private Context context;
private Interpreter interpreter;
private List<String> labels;
private int[] imageBuffer = new int[INPUT_SIZE * INPUT_SIZE];
// 入力
private Bitmap inBitmap;
private Canvas inCanvas;
private Rect inBitmapSrc = new Rect();
private Rect inBitmapDst = new Rect(0, 0, INPUT_SIZE, INPUT_SIZE);
private ByteBuffer inBuffer;
// 出力
private byte[][] outByteProbs;
private float[][] outFloatProbs;
// コンストラクタ
public ImageClassificationInterpriter(Context context) {
this.context = context;
// モデルの読み込み
MappedByteBuffer model = loadModel("mobilenet_v1_1.0_224_quant.tflite");
// ラベルの読み込み
this.labels = loadLabel("labels_mobilenet_quant_v1_224.txt");
// インタプリタの生成
Interpreter.Options options = new Interpreter.Options();
//options.setUseNNAPI(true); //NNAPI
options.addDelegate(new GpuDelegate()); //GPU
options.setNumThreads(1); // スレッド数
this.interpreter = new Interpreter(model, options);
// 入力の初期化
this.inBitmap = Bitmap.createBitmap(
INPUT_SIZE, INPUT_SIZE, Bitmap.Config.ARGB_8888);
this.inCanvas = new Canvas(inBitmap);
int numBytesPerChannel = IS_QUANTIZED ? 1 : 4;
this.inBuffer = ByteBuffer.allocateDirect(
BATCH_SIZE * INPUT_SIZE * INPUT_SIZE * INPUT_PIXELS * numBytesPerChannel);
inBuffer.order(ByteOrder.nativeOrder());
// 出力の初期化
if (IS_QUANTIZED) {
this.outByteProbs = new byte[1][labels.size()];
} else {
this.outFloatProbs = new float[1][labels.size()];
}
}
// モデルの読み込み
private MappedByteBuffer loadModel(String modelPath) {
try {
AssetFileDescriptor fd = this.context.getAssets().openFd(modelPath);
FileInputStream in = new FileInputStream(fd.getFileDescriptor());
FileChannel fileChannel = in.getChannel();
return fileChannel.map(FileChannel.MapMode.READ_ONLY,
fd.getStartOffset(), fd.getDeclaredLength());
} catch (Exception e) {
e.printStackTrace();
return null;
}
}
// ラベルの読み込み
private List<String> loadLabel(String labelPath) {
try {
List<String> labels = new ArrayList<>();
BufferedReader reader = new BufferedReader(new InputStreamReader(
this.context.getAssets().open(labelPath)));
String line;
while ((line = reader.readLine()) != null) {
labels.add(line);
}
reader.close();
return labels;
} catch (Exception e) {
e.printStackTrace();
}
return null;
}
// 推論
public List<Map.Entry<String, Float>> predict(Bitmap bitmap) {
// 入力画像の生成
int minSize = Math.min(bitmap.getWidth(), bitmap.getHeight());
int dx = (bitmap.getWidth()-minSize)/2;
int dy = (bitmap.getHeight()-minSize)/2;
this.inBitmapSrc.set(dx, dy, dx+minSize, dy+minSize);
inCanvas.drawBitmap(bitmap, this.inBitmapSrc, this.inBitmapDst, null);
// 入力バッファの生成
bmpToInBuffer(inBitmap);
// 推論
if (IS_QUANTIZED) {
this.interpreter.run(this.inBuffer, this.outByteProbs);
} else {
this.interpreter.run(this.inBuffer, this.outFloatProbs);
}
// 結果の取得
Map<String, Float> map = new HashMap();
for (int i = 0; i < this.labels.size(); i++) {
String title = this.labels.size() > i ? this.labels.get(i) : "unknown";
float prob = getProb(i);
map.put(title, prob);
}
// ソート
List<Map.Entry<String, Float>> results = new ArrayList<>(map.entrySet());
Collections.sort(results, (obj1, obj2) -> obj2.getValue().compareTo(obj1.getValue()));
return results;
}
// Bitmap → 入力バッファ
private void bmpToInBuffer(Bitmap bitmap) {
this.inBuffer.rewind();
bitmap.getPixels(this.imageBuffer, 0, bitmap.getWidth(),
0, 0, bitmap.getWidth(), bitmap.getHeight());
int pixel = 0;
for (int i = 0; i < INPUT_SIZE; ++i) {
for (int j = 0; j < INPUT_SIZE; ++j) {
int pixelValue = imageBuffer[pixel++];
if (IS_QUANTIZED) {
inBuffer.put((byte)((pixelValue >> 16) & 0xFF));
inBuffer.put((byte)((pixelValue >> 8) & 0xFF));
inBuffer.put((byte)(pixelValue & 0xFF));
} else {
inBuffer.putFloat((((pixelValue >> 16) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
inBuffer.putFloat((((pixelValue >> 8) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
inBuffer.putFloat(((pixelValue & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
}
}
}
}
//確率の取得
private float getProb(int index) {
if (IS_QUANTIZED) {
return (this.outByteProbs[0][index] & 0xff)/255.0f;
} else {
return this.outFloatProbs[0][index];
}
}
}
この記事が気に入ったらサポートをしてみませんか?