TenorFlow Lite 入門 / Androidによる画像セグメンテーション
1. Androidによる画像セグメンテーション
「TensorFlow Lite」を使ってAndroidで画像セグメンテーションを行います。端末の背面カメラに映るものをリアルタイムに画像セグメンテーションを行い、ピクセル毎にクラスの色分けを行います。クラスは次の21種類。
・background
・aeroplane
・bicycle
・bird
・boat
・bottle
・bus
・car
・cat
・chair
・cow
・diningtable
・dog
・horse
・motorbike
・person
・pottedplant
・sheep
・sofa
・train
・tv
2. バージョン
・compileSdkVersion 29
・minSdkVersion 26
・targetSdkVersion 29
・tensorflow-lite:0.1.7
3. 依存関係の追加
「build.gradle(Mudule:app)」に、「CameraX」と「TensorFlow Lite」のプロジェクトの依存関係を追加します。
android {
<<省略>>
aaptOptions {
noCompress "tflite"
}
compileOptions {
sourceCompatibility = '1.8'
targetCompatibility = '1.8'
}
}
dependencies {
<<省略>>
// CameraX
def camerax_version = '1.0.0-alpha06'
implementation "androidx.camera:camera-core:${camerax_version}"
implementation "androidx.camera:camera-camera2:${camerax_version}"
// TensorFlow Lite
implementation('org.tensorflow:tensorflow-lite:0.0.0-nightly') { changing = true }
implementation('org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly') { changing = true }
implementation('org.tensorflow:tensorflow-lite-support:0.0.0-nightly') { changing = true }
}
4. マニフェストファイルの設定
「CAMERA」のパーミッションを追加します。
<uses-permission android:name="android.permission.CAMERA" />
5. アセットの準備
プロジェクトの「app/src/main/assets」に、「Segmentation」のページからダウンロードした「モデル」を追加します。
・deeplabv3_257_mv_gpu.tflite
6. レイアウトの設定
「activity_main.xml」に「ImageView」を追加します。
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<RelativeLayout android:layout_width="match_parent"
android:layout_height="match_parent"
android:background="#FFFFFF">
<TextureView
android:id="@+id/texture_view"
android:layout_width="match_parent"
android:layout_height="match_parent"/>
<ImageView
android:id="@+id/image_view"
android:layout_width="640px"
android:layout_height="640px"
android:layout_margin="16dp"
android:visibility="visible"
app:srcCompat="@mipmap/ic_launcher" />
</RelativeLayout>
</androidx.constraintlayout.widget.ConstraintLayout>
7. UIの作成
画像分類を行うUIを作成します。
以下の処理を行なっています。
・パーミッション
・カメラのプレビューと解析
・ImageSegmentationInterpriterに画像を渡して推論(後ほど説明)
◎ MainActivity.java
package net.npaka.imageclassificationex;
import androidx.appcompat.app.AppCompatActivity;
import androidx.camera.core.CameraX;
import androidx.camera.core.ImageAnalysis;
import androidx.camera.core.ImageAnalysisConfig;
import androidx.camera.core.Preview;
import androidx.camera.core.PreviewConfig;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;
import android.Manifest;
import android.content.pm.PackageManager;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.ImageFormat;
import android.graphics.Matrix;
import android.graphics.Point;
import android.graphics.Rect;
import android.graphics.YuvImage;
import android.media.Image;
import android.os.Bundle;
import android.view.TextureView;
import android.view.ViewGroup;
import android.widget.ImageView;
import android.widget.RelativeLayout;
import android.widget.Toast;
import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.util.concurrent.Executors;
// MainActivity
public class MainActivity extends AppCompatActivity {
// 定数
private final int REQUEST_CODE_PERMISSIONS = 101;
private final String[] REQUIRED_PERMISSIONS = new String[]{
Manifest.permission.CAMERA};
// UI
private TextureView textureView;
private ImageView imageView;
// 推論
private ImageSegmentationInterpriter interpriter;
// 生成時に呼ばれる
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
// UI
this.textureView = findViewById(R.id.texture_view);
this.imageView = findViewById(R.id.image_view);
// 推論
this.interpriter = new ImageSegmentationInterpriter(this);
// パーミッションのチェック
if (allPermissionsGranted()) {
this.textureView.post(() -> startCamera());
} else {
ActivityCompat.requestPermissions(this,
REQUIRED_PERMISSIONS, REQUEST_CODE_PERMISSIONS);
}
}
// パーミッション許可のリクエストの結果の取得
@Override
public void onRequestPermissionsResult(int requestCode,
String[] permissions, int[] grantResults) {
if (requestCode == REQUEST_CODE_PERMISSIONS) {
if (allPermissionsGranted()) {
startCamera();
} else {
Toast.makeText(this, "ユーザーから権限が許可されていません。",
Toast.LENGTH_SHORT).show();
finish();
}
}
}
// 全てのパーミッション許可
private boolean allPermissionsGranted() {
for (String permission : REQUIRED_PERMISSIONS) {
if (ContextCompat.checkSelfPermission(this, permission)
!= PackageManager.PERMISSION_GRANTED) {
return false;
}
}
return true;
}
// カメラの開始
private void startCamera() {
// プレビュー
PreviewConfig previewConfig = new PreviewConfig.Builder().build();
Preview preview = new Preview(previewConfig);
preview.setOnPreviewOutputUpdateListener(
output -> {
// SurfaceTextureの更新
ViewGroup parent = (ViewGroup)this.textureView.getParent();
parent.removeView(this.textureView);
parent.addView(this.textureView, 0);
// SurfaceTextureをTextureViewに指定
this.textureView.setSurfaceTexture(output.getSurfaceTexture());
// レイアウトの調整
Point point = new Point();
getWindowManager().getDefaultDisplay().getSize(point);
int w = point.x;
int h = point.x * 4 / 3;
RelativeLayout.LayoutParams params = new RelativeLayout.LayoutParams(w, h);
params.addRule(RelativeLayout.CENTER_IN_PARENT);
textureView.setLayoutParams(params);
params = new RelativeLayout.LayoutParams(w, w);
params.addRule(RelativeLayout.CENTER_IN_PARENT);
imageView.setLayoutParams(params);
});
// 画像の解析
ImageAnalysisConfig config = new ImageAnalysisConfig.Builder()
.setImageReaderMode(ImageAnalysis.ImageReaderMode.ACQUIRE_LATEST_IMAGE)
.build();
ImageAnalysis imageAnalysis = new ImageAnalysis(config);
imageAnalysis.setAnalyzer(Executors.newSingleThreadExecutor(),
(image, rotationDegrees) -> {
// 推論
Bitmap bitmap = imageToToBitmap(image.getImage(), rotationDegrees);
Bitmap result = this.interpriter.predict(bitmap);
bitmap.recycle();
// 結果の表示
this.imageView.post(() -> {
imageView.setImageBitmap(result);
});
});
// バインド
CameraX.bindToLifecycle(this, preview, imageAnalysis);
}
// ImageProxy → Bitmap
private Bitmap imageToToBitmap(Image image, int rotationDegrees) {
byte[] data = imageToByteArray(image);
Bitmap bitmap = BitmapFactory.decodeByteArray(data, 0, data.length);
if (rotationDegrees == 0) {
return bitmap;
} else {
return rotateBitmap(bitmap, rotationDegrees);
}
}
// Bitmapの回転
private Bitmap rotateBitmap(Bitmap bitmap, int rotationDegrees) {
Matrix mat = new Matrix();
mat.postRotate(rotationDegrees);
return Bitmap.createBitmap(bitmap, 0, 0,
bitmap.getWidth(), bitmap.getHeight(), mat, true);
}
// Image → JPEGのバイト配列
private byte[] imageToByteArray(Image image) {
byte[] data = null;
if (image.getFormat() == ImageFormat.JPEG) {
Image.Plane[] planes = image.getPlanes();
ByteBuffer buffer = planes[0].getBuffer();
data = new byte[buffer.capacity()];
buffer.get(data);
return data;
} else if (image.getFormat() == ImageFormat.YUV_420_888) {
data = NV21toJPEG(YUV_420_888toNV21(image),
image.getWidth(), image.getHeight());
}
return data;
}
// YUV_420_888 → NV21
private byte[] YUV_420_888toNV21(Image image) {
byte[] nv21;
ByteBuffer yBuffer = image.getPlanes()[0].getBuffer();
ByteBuffer uBuffer = image.getPlanes()[1].getBuffer();
ByteBuffer vBuffer = image.getPlanes()[2].getBuffer();
int ySize = yBuffer.remaining();
int uSize = uBuffer.remaining();
int vSize = vBuffer.remaining();
nv21 = new byte[ySize + uSize + vSize];
yBuffer.get(nv21, 0, ySize);
vBuffer.get(nv21, ySize, vSize);
uBuffer.get(nv21, ySize + vSize, uSize);
return nv21;
}
// NV21 → JPEG
private byte[] NV21toJPEG(byte[] nv21, int width, int height) {
ByteArrayOutputStream out = new ByteArrayOutputStream();
YuvImage yuv = new YuvImage(nv21, ImageFormat.NV21, width, height, null);
yuv.compressToJpeg(new Rect(0, 0, width, height), 100, out);
return out.toByteArray();
}
}
8. 画像セグメンテーション
画像を受け取り、ピクセル毎にクラスの色分けを行います。
◎ ImageSegmentationInterpriter.java
package net.npaka.imageclassificationex;
import android.content.Context;
import android.content.res.AssetFileDescriptor;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Rect;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.gpu.GpuDelegate;
import java.io.FileInputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.Random;
// 画像セグメンテーションインタープリタ
public class ImageSegmentationInterpriter {
// パラメータ定数
private final static int BATCH_SIZE = 1; //バッチサイズ
private final static int INPUT_PIXELS = 3; //入力ピクセル
private final static int INPUT_SIZE = 257; // 入力サイズ
private final static int NUM_CLASSES = 21; // クラス数
private final boolean IS_QUANTIZED = false;
private final static float IMAGE_MEAN = 128.0f;
private final static float IMAGE_STD = 128.0f;
// システム
private Context context;
private Interpreter interpreter;
private int[] imageBuffer = new int[INPUT_SIZE * INPUT_SIZE];
private int[] colors = new int[NUM_CLASSES];
// 入力
private Bitmap inBitmap;
private Canvas inCanvas;
private Rect inBitmapSrc = new Rect();
private Rect inBitmapDst = new Rect(0, 0, INPUT_SIZE, INPUT_SIZE);
private ByteBuffer inBuffer;
// 出力
private ByteBuffer outSegmentationMasks;
// コンストラクタ
public ImageSegmentationInterpriter(Context context) {
this.context = context;
// 色の初期化
Random rand = new Random();
this.colors[0] = Color.TRANSPARENT;
for (int i = 1; i < NUM_CLASSES; i++) {
this.colors[i] = Color.rgb(rand.nextInt(255), rand.nextInt(255), rand.nextInt(255));
}
// モデルの読み込み
MappedByteBuffer model = loadModel("deeplabv3_257_mv_gpu.tflite");
// インタプリタの生成
Interpreter.Options options = new Interpreter.Options();
//options.setUseNNAPI(true); //NNAPI
options.addDelegate(new GpuDelegate()); //GPU
options.setNumThreads(1); // スレッド数
this.interpreter = new Interpreter(model, options);
// 入力の初期化
this.inBitmap = Bitmap.createBitmap(
INPUT_SIZE, INPUT_SIZE, Bitmap.Config.ARGB_8888);
this.inCanvas = new Canvas(inBitmap);
int numBytesPerChannel = IS_QUANTIZED ? 1 : 4;
this.inBuffer = ByteBuffer.allocateDirect(
BATCH_SIZE * INPUT_SIZE * INPUT_SIZE * INPUT_PIXELS * numBytesPerChannel);
this.inBuffer.order(ByteOrder.nativeOrder());
this.inBuffer.rewind();
// 出力の初期化
this.outSegmentationMasks = ByteBuffer.allocateDirect(
1 * INPUT_SIZE * INPUT_SIZE * 21 * 4);
this.outSegmentationMasks.order(ByteOrder.nativeOrder());
}
// モデルの読み込み
private MappedByteBuffer loadModel(String modelPath) {
try {
AssetFileDescriptor fd = this.context.getAssets().openFd(modelPath);
FileInputStream in = new FileInputStream(fd.getFileDescriptor());
FileChannel fc = in.getChannel();
return fc.map(FileChannel.MapMode.READ_ONLY, fd.getStartOffset(), fd.getDeclaredLength());
} catch (Exception e) {
e.printStackTrace();
return null;
}
}
// 推論
public Bitmap predict(Bitmap bitmap) {
// 入力画像の生成
int minSize = Math.min(bitmap.getWidth(), bitmap.getHeight());
this.inBitmapSrc.set(
(bitmap.getWidth()-minSize)/2,
(bitmap.getHeight()-minSize)/2, minSize, minSize);
inCanvas.drawBitmap(bitmap, this.inBitmapSrc, this.inBitmapDst, null);
// 推論
bmpToInBuffer(inBitmap);
this.outSegmentationMasks.rewind();
this.interpreter.run(this.inBuffer, this.outSegmentationMasks);
// 結果の取得
return bufferToBitmap(this.outSegmentationMasks);
}
// Bitmap → 入力バッファ
private void bmpToInBuffer(Bitmap bitmap) {
this.inBuffer.rewind();
bitmap.getPixels(this.imageBuffer, 0, bitmap.getWidth(),
0, 0, bitmap.getWidth(), bitmap.getHeight());
int pixel = 0;
for (int i = 0; i < INPUT_SIZE; ++i) {
for (int j = 0; j < INPUT_SIZE; ++j) {
addPixelValue(imageBuffer[pixel++]);
}
}
}
// ピクセル値の追加
private void addPixelValue(int pixelValue) {
if (IS_QUANTIZED) {
inBuffer.put((byte)((pixelValue >> 16) & 0xFF));
inBuffer.put((byte)((pixelValue >> 8) & 0xFF));
inBuffer.put((byte)(pixelValue & 0xFF));
} else {
inBuffer.putFloat((((pixelValue >> 16) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
inBuffer.putFloat((((pixelValue >> 8) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
inBuffer.putFloat(((pixelValue & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
}
}
// ByteBuffer → Bitmap
private Bitmap bufferToBitmap(ByteBuffer segmentationMasks) {
Bitmap maskBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Bitmap.Config.ARGB_8888);
for (int y = 0; y < INPUT_SIZE; y++) {
for (int x = 0; x < INPUT_SIZE; x++) {
float maxVal = 0f;
// 確率の高いクラス
int classIndex = 0;
for (int c = 0; c < NUM_CLASSES; c++) {
float value = segmentationMasks.getFloat((y*INPUT_SIZE*NUM_CLASSES+x*NUM_CLASSES+c)*4);
if (c == 0 || value > maxVal) {
maxVal = value;
classIndex = c;
}
}
// 色の指定
maskBitmap.setPixel(x, y, colors[classIndex]);
}
}
return maskBitmap;
}
}
この記事が気に入ったらサポートをしてみませんか?