![見出し画像](https://assets.st-note.com/production/uploads/images/23102054/rectangle_large_type_2_1137f516c4dc909b19125688726ec9a9.png?width=800)
TenorFlow Lite 入門 / Androidによる姿勢推定
1. Androidによる姿勢推定
「TensorFlow Lite」を使ってAndroidで姿勢推定を行います。端末の背面カメラに映るものをリアルタイムに姿勢推定し、検出したボディパーツを表示します。
2. バージョン
・compileSdkVersion 29
・minSdkVersion 26
・targetSdkVersion 29
・tensorflow-lite:0.1.7
3. 依存関係の追加
「build.gradle(Mudule:app)」に、「CameraX」と「TensorFlow Lite」のプロジェクトの依存関係を追加します。
android {
<<省略>>
aaptOptions {
noCompress "tflite"
}
compileOptions {
sourceCompatibility = '1.8'
targetCompatibility = '1.8'
}
}
dependencies {
<<省略>>
// CameraX
def camerax_version = '1.0.0-alpha06'
implementation "androidx.camera:camera-core:${camerax_version}"
implementation "androidx.camera:camera-camera2:${camerax_version}"
// TensorFlow Lite
implementation('org.tensorflow:tensorflow-lite:0.0.0-nightly') { changing = true }
implementation('org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly') { changing = true }
implementation('org.tensorflow:tensorflow-lite-support:0.0.0-nightly') { changing = true }
}
4. マニフェストファイルの設定
「CAMERA」のパーミッションを追加します。
<uses-permission android:name="android.permission.CAMERA" />
5. アセットの準備
プロジェクトの「app/src/main/assets」に、「TensorFlow Lite PoseNet Android Demo」のページからダウンロードしたモデルを追加します。
・posenet_mobilenet_v1_100_257x257_multi_kpt_stripped.tflite
6. レイアウトの設定
「activity_main.xml」に「TextureView」と「ImageView」を追加します。
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<RelativeLayout android:layout_width="match_parent"
android:layout_height="match_parent"
android:background="#FFFFFF">
<TextureView
android:id="@+id/texture_view"
android:layout_width="match_parent"
android:layout_height="match_parent"/>
<ImageView
android:id="@+id/image_view"
android:layout_width="640px"
android:layout_height="640px"
android:layout_margin="16dp"
android:visibility="visible"
app:srcCompat="@mipmap/ic_launcher" />
</RelativeLayout>
</androidx.constraintlayout.widget.ConstraintLayout>
7. UIの作成
画像分類を行うUIを作成します。
以下の処理を行なっています。
・パーミッション
・カメラのプレビューと解析
・PoseNetInterpriterにBitmapを渡して推論(後ほど説明)
◎ MainActivity.java
package net.npaka.posenetex;
import androidx.appcompat.app.AppCompatActivity;
import androidx.camera.core.CameraX;
import androidx.camera.core.ImageAnalysis;
import androidx.camera.core.ImageAnalysisConfig;
import androidx.camera.core.Preview;
import androidx.camera.core.PreviewConfig;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;
import android.Manifest;
import android.content.pm.PackageManager;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.ImageFormat;
import android.graphics.Matrix;
import android.graphics.Point;
import android.graphics.Rect;
import android.graphics.YuvImage;
import android.media.Image;
import android.os.Bundle;
import android.view.TextureView;
import android.view.ViewGroup;
import android.widget.ImageView;
import android.widget.RelativeLayout;
import android.widget.Toast;
import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.util.concurrent.Executors;
// MainActivity
public class MainActivity extends AppCompatActivity {
// 定数
private final int REQUEST_CODE_PERMISSIONS = 101;
private final String[] REQUIRED_PERMISSIONS = new String[]{
Manifest.permission.CAMERA};
// UI
private TextureView textureView;
private ImageView imageView;
// 推論
private PoseNetInterpriter interpriter;
// 生成時に呼ばれる
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
// UI
this.textureView = findViewById(R.id.texture_view);
this.imageView = findViewById(R.id.image_view);
// 推論
this.interpriter = new PoseNetInterpriter(this);
// パーミッションのチェック
if (allPermissionsGranted()) {
this.textureView.post(() -> startCamera());
} else {
ActivityCompat.requestPermissions(this,
REQUIRED_PERMISSIONS, REQUEST_CODE_PERMISSIONS);
}
}
// パーミッション許可のリクエストの結果の取得
@Override
public void onRequestPermissionsResult(int requestCode,
String[] permissions, int[] grantResults) {
if (requestCode == REQUEST_CODE_PERMISSIONS) {
if (allPermissionsGranted()) {
startCamera();
} else {
Toast.makeText(this, "ユーザーから権限が許可されていません。",
Toast.LENGTH_SHORT).show();
finish();
}
}
}
// 全てのパーミッション許可
private boolean allPermissionsGranted() {
for (String permission : REQUIRED_PERMISSIONS) {
if (ContextCompat.checkSelfPermission(this, permission)
!= PackageManager.PERMISSION_GRANTED) {
return false;
}
}
return true;
}
// カメラの開始
private void startCamera() {
// プレビュー
PreviewConfig previewConfig = new PreviewConfig.Builder().build();
Preview preview = new Preview(previewConfig);
preview.setOnPreviewOutputUpdateListener(
output -> {
// SurfaceTextureの更新
ViewGroup parent = (ViewGroup)this.textureView.getParent();
parent.removeView(this.textureView);
parent.addView(this.textureView, 0);
// SurfaceTextureをTextureViewに指定
this.textureView.setSurfaceTexture(output.getSurfaceTexture());
// レイアウトの調整
Point point = new Point();
getWindowManager().getDefaultDisplay().getSize(point);
int w = point.x;
int h = point.x * 4 / 3;
RelativeLayout.LayoutParams params = new RelativeLayout.LayoutParams(w, h);
params.addRule(RelativeLayout.CENTER_IN_PARENT);
textureView.setLayoutParams(params);
params = new RelativeLayout.LayoutParams(w, w);
params.addRule(RelativeLayout.CENTER_IN_PARENT);
imageView.setLayoutParams(params);
});
// 画像の解析
ImageAnalysisConfig config = new ImageAnalysisConfig.Builder()
.setImageReaderMode(ImageAnalysis.ImageReaderMode.ACQUIRE_LATEST_IMAGE)
.build();
ImageAnalysis imageAnalysis = new ImageAnalysis(config);
imageAnalysis.setAnalyzer(Executors.newSingleThreadExecutor(),
(image, rotationDegrees) -> {
// 推論
Bitmap bitmap = imageToToBitmap(image.getImage(), rotationDegrees);
Bitmap result = this.interpriter.predict(bitmap);
bitmap.recycle();
// 結果の表示
this.imageView.post(() -> {
imageView.setImageBitmap(result);
});
});
// バインド
CameraX.bindToLifecycle(this, preview, imageAnalysis);
}
// ImageProxy → Bitmap
private Bitmap imageToToBitmap(Image image, int rotationDegrees) {
byte[] data = imageToByteArray(image);
Bitmap bitmap = BitmapFactory.decodeByteArray(data, 0, data.length);
if (rotationDegrees == 0) {
return bitmap;
} else {
return rotateBitmap(bitmap, rotationDegrees);
}
}
// Bitmapの回転
private Bitmap rotateBitmap(Bitmap bitmap, int rotationDegrees) {
Matrix mat = new Matrix();
mat.postRotate(rotationDegrees);
return Bitmap.createBitmap(bitmap, 0, 0,
bitmap.getWidth(), bitmap.getHeight(), mat, true);
}
// Image → JPEGのバイト配列
private byte[] imageToByteArray(Image image) {
byte[] data = null;
if (image.getFormat() == ImageFormat.JPEG) {
Image.Plane[] planes = image.getPlanes();
ByteBuffer buffer = planes[0].getBuffer();
data = new byte[buffer.capacity()];
buffer.get(data);
return data;
} else if (image.getFormat() == ImageFormat.YUV_420_888) {
data = NV21toJPEG(YUV_420_888toNV21(image),
image.getWidth(), image.getHeight());
}
return data;
}
// YUV_420_888 → NV21
private byte[] YUV_420_888toNV21(Image image) {
byte[] nv21;
ByteBuffer yBuffer = image.getPlanes()[0].getBuffer();
ByteBuffer uBuffer = image.getPlanes()[1].getBuffer();
ByteBuffer vBuffer = image.getPlanes()[2].getBuffer();
int ySize = yBuffer.remaining();
int uSize = uBuffer.remaining();
int vSize = vBuffer.remaining();
nv21 = new byte[ySize + uSize + vSize];
yBuffer.get(nv21, 0, ySize);
vBuffer.get(nv21, ySize, vSize);
uBuffer.get(nv21, ySize + vSize, uSize);
return nv21;
}
// NV21 → JPEG
private byte[] NV21toJPEG(byte[] nv21, int width, int height) {
ByteArrayOutputStream out = new ByteArrayOutputStream();
YuvImage yuv = new YuvImage(nv21, ImageFormat.NV21, width, height, null);
yuv.compressToJpeg(new Rect(0, 0, width, height), 100, out);
return out.toByteArray();
}
}
8. 姿勢推定
Bitmapを受け取り、姿勢推定結果を描画したBitmapを返します。
パラメータ定数を変更することで、別の画像分類モデルにも対応できます。
◎ PoseNetInterpriter.java
package net.npaka.posenetex;
import android.content.Context;
import android.content.res.AssetFileDescriptor;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Paint;
import android.graphics.PointF;
import android.graphics.PorterDuff;
import android.graphics.Rect;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.gpu.GpuDelegate;
import java.io.FileInputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
// 姿勢推定インタープリタ
public class PoseNetInterpriter {
// パラメータ定数
private static final int BATCH_SIZE = 1; // バッチサイズ
private static final int INPUT_PIXELS = 3; // 入力ピクセル
private final static int INPUT_SIZE = 257; // 入力サイズ
private boolean IS_QUANTIZED = false; // 量子化
private static final float IMAGE_MEAN = 128.0f;
private static final float IMAGE_STD = 128.0f;
// 描画定数
private static final int[] DRAW_POINT = {0, 1, 2, 3, 4};
private static final int[][] DRAW_LINE = {
{5,6},
{5,7},
{6,8},
{7,9},
{8,10},
{5,11},
{6,12},
{11,12},
{11,13},
{12,14},
{13,15},
{14,16},
};
// システム
private Context context;
private Interpreter interpreter;
private int[] imageBuffer = new int[INPUT_SIZE * INPUT_SIZE];
// 入力
private ByteBuffer inBuffer;
private Bitmap inBitmap;
private Canvas inCanvas;
private Rect inBitmapSrc = new Rect();
private Rect inBitmapDst = new Rect(0, 0, INPUT_SIZE, INPUT_SIZE);
// 出力
private float[][][][] outHeatmaps;
private float[][][][] outOffsets;
private float[][][][] outDisplacementsFwd;
private float[][][][] outDplacementsBwd;
private Bitmap outBitmap;
private Canvas outCanvas;
private Paint outPaint;
// ボディパーツ
enum BodyPart {
NOSE,
LEFT_EYE,
RIGHT_EYE,
LEFT_EAR,
RIGHT_EAR,
LEFT_SHOULDER,
RIGHT_SHOULDER,
LEFT_ELBOW,
RIGHT_ELBOW,
LEFT_WRIST,
RIGHT_WRIST,
LEFT_HIP,
RIGHT_HIP,
LEFT_KNEE,
RIGHT_KNEE,
LEFT_ANKLE,
RIGHT_ANKLE
}
// ポジション
class Position {
int x = 0;
int y = 0;
}
// キーポイント
class KeyPoint {
public BodyPart bodyPart = BodyPart.NOSE;
Position position= new Position();
float score = 0.0f;
}
// パーソン
class Person {
ArrayList<KeyPoint> keyPoints = new ArrayList<>();
float score = 0.0f;
}
// コンストラクタ
public PoseNetInterpriter(Context context) {
this.context = context;
// モデルの読み込み
MappedByteBuffer model = loadModel("posenet_mobilenet_v1_100_257x257_multi_kpt_stripped.tflite");
// インタプリタの生成
Interpreter.Options options = new Interpreter.Options();
//options.setUseNNAPI(true); //NNAPI
options.addDelegate(new GpuDelegate()); //GPU
options.setNumThreads(1); // スレッド数
this.interpreter = new Interpreter(model, options);
// 入力の初期化
this.inBitmap = Bitmap.createBitmap(
INPUT_SIZE, INPUT_SIZE, Bitmap.Config.ARGB_8888);
this.inCanvas = new Canvas(inBitmap);
int numBytesPerChannel = IS_QUANTIZED ? 1 : 4;
this.inBuffer = ByteBuffer.allocateDirect(
BATCH_SIZE * INPUT_SIZE * INPUT_SIZE * INPUT_PIXELS * numBytesPerChannel);
this.inBuffer.order(ByteOrder.nativeOrder());
// 出力の初期化
this.outHeatmaps = new float[1][9][9][17];
this.outOffsets = new float[1][9][9][34];
this.outDisplacementsFwd = new float[1][9][9][32];
this.outDplacementsBwd = new float[1][9][9][32];
this.outBitmap = Bitmap.createBitmap(
INPUT_SIZE, INPUT_SIZE, Bitmap.Config.ARGB_8888);
this.outCanvas = new Canvas(outBitmap);
this.outPaint = new Paint();
}
// モデルの読み込み
private MappedByteBuffer loadModel(String modelPath) {
try {
AssetFileDescriptor fd = this.context.getAssets().openFd(modelPath);
FileInputStream in = new FileInputStream(fd.getFileDescriptor());
FileChannel fileChannel = in.getChannel();
return fileChannel.map(FileChannel.MapMode.READ_ONLY,
fd.getStartOffset(), fd.getDeclaredLength());
} catch (Exception e) {
e.printStackTrace();
return null;
}
}
// 推論
public Bitmap predict(Bitmap bitmap) {
// 入力画像の生成
int minSize = Math.min(bitmap.getWidth(), bitmap.getHeight());
int dx = (bitmap.getWidth()-minSize)/2;
int dy = (bitmap.getHeight()-minSize)/2;
this.inBitmapSrc.set(dx, dy, dx+minSize, dy+minSize);
inCanvas.drawBitmap(bitmap, this.inBitmapSrc, this.inBitmapDst, null);
// 入力バッファの生成
bmpToInBuffer(inBitmap);
// 推論
Object[] inputArray = {inBuffer};
Map<Integer, Object> outputMap = new HashMap<>();
outputMap.put(0, this.outHeatmaps);
outputMap.put(1, this.outOffsets);
outputMap.put(2, this.outDisplacementsFwd);
outputMap.put(3, this.outDplacementsBwd);
this.interpreter.runForMultipleInputsOutputs(inputArray, outputMap);
// 結果の取得
int height = outHeatmaps[0].length;
int width = outHeatmaps[0][0].length;
int numKeypoints = outHeatmaps[0][0][0].length;
// キーポイントが最も正しそうな位置(row, col)を見つける
ArrayList<PointF> keypointPositions = new ArrayList<>();
for (int keypoint = 0; keypoint < numKeypoints; keypoint++) {
float maxVal = outHeatmaps[0][0][0][keypoint];
float maxRow = 0;
float maxCol = 0;
for (int row = 0; row < height; row++) {
for (int col = 0; col < width; col++) {
if (outHeatmaps[0][row][col][keypoint] > maxVal) {
maxVal = outHeatmaps[0][row][col][keypoint];
maxRow = row;
maxCol = col;
}
}
}
keypointPositions.add(new PointF(maxRow, maxCol));
}
// キーポイントのx座標とy座標をオフセットで調整
int[] xCoords = new int[numKeypoints];
int[] yCoords = new int[numKeypoints];
float[] confidenceScores = new float[numKeypoints];
for (int idx = 0; idx < keypointPositions.size(); idx++) {
PointF position = keypointPositions.get(idx);
int positionY = (int)keypointPositions.get(idx).x;
int positionX = (int)keypointPositions.get(idx).y;
yCoords[idx] = (int)(
position.x / (float)(height - 1) * INPUT_SIZE +
this.outOffsets[0][positionY][positionX][idx]);
xCoords[idx] = (int)(
position.y / (float)(width - 1) * INPUT_SIZE +
this.outOffsets[0][positionY][positionX][idx + numKeypoints]);
confidenceScores[idx] = sigmoid(outHeatmaps[0][positionY][positionX][idx]);
}
// Personの生成
Person person = new Person();
ArrayList<KeyPoint> keyPoints = new ArrayList<>();
float totalScore = 0.0f;
for (int idx = 0; idx < numKeypoints; idx++) {
KeyPoint keyPoint = new KeyPoint();
keyPoint.bodyPart = BodyPart.values()[idx];
keyPoint.position.x = xCoords[idx];
keyPoint.position.y = yCoords[idx];
keyPoint.score = confidenceScores[idx];
totalScore += confidenceScores[idx];
keyPoints.add(keyPoint);
}
person.keyPoints = keyPoints;
person.score = totalScore / numKeypoints;
//出力画像の生成
this.outCanvas.drawColor(Color.TRANSPARENT, PorterDuff.Mode.CLEAR);
this.outPaint.setTextSize(12);
this.outPaint.setAntiAlias(true);
this.outPaint.setStyle(Paint.Style.FILL);
this.outPaint.setColor(Color.RED);
this.outPaint.setStrokeWidth(2);
for (int idx : DRAW_POINT) {
this.outCanvas.drawCircle(
keyPoints.get(idx).position.x,
keyPoints.get(idx).position.y,
3, this.outPaint);
}
for (int[] idx : DRAW_LINE) {
this.outCanvas.drawLine(
keyPoints.get(idx[0]).position.x,
keyPoints.get(idx[0]).position.y,
keyPoints.get(idx[1]).position.x,
keyPoints.get(idx[1]).position.y,
this.outPaint);
}
return this.outBitmap;
}
// シグモイド
private float sigmoid(float x) {
return (float)(1.0f / (1.0f + Math.exp(-x)));
}
// Bitmap → 入力バッファ
private void bmpToInBuffer(Bitmap bitmap) {
this.inBuffer.rewind();
bitmap.getPixels(this.imageBuffer, 0, bitmap.getWidth(),
0, 0, bitmap.getWidth(), bitmap.getHeight());
int pixel = 0;
for (int i = 0; i < INPUT_SIZE; ++i) {
for (int j = 0; j < INPUT_SIZE; ++j) {
int pixelValue = imageBuffer[pixel++];
if (IS_QUANTIZED) {
inBuffer.put((byte)((pixelValue >> 16) & 0xFF));
inBuffer.put((byte)((pixelValue >> 8) & 0xFF));
inBuffer.put((byte)(pixelValue & 0xFF));
} else {
inBuffer.putFloat((((pixelValue >> 16) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
inBuffer.putFloat((((pixelValue >> 8) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
inBuffer.putFloat(((pixelValue & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
}
}
}
}
}
この記事が気に入ったらサポートをしてみませんか?