見出し画像

TenorFlow Lite 入門 / Androidによる姿勢推定

1. Androidによる姿勢推定

「TensorFlow Lite」を使ってAndroidで姿勢推定を行います。端末の背面カメラに映るものをリアルタイムに姿勢推定し、検出したボディパーツを表示します。

画像1

2. バージョン

・compileSdkVersion 29
・minSdkVersion 26
・targetSdkVersion 29
・tensorflow-lite:0.1.7

3. 依存関係の追加

「build.gradle(Mudule:app)」に、「CameraX」と「TensorFlow Lite」のプロジェクトの依存関係を追加します。

android {
    <<省略>>

    aaptOptions {
        noCompress "tflite"
    }
    compileOptions {
        sourceCompatibility = '1.8'
        targetCompatibility = '1.8'
    }
}
dependencies {
    <<省略>>

    // CameraX
    def camerax_version = '1.0.0-alpha06'
    implementation "androidx.camera:camera-core:${camerax_version}"
    implementation "androidx.camera:camera-camera2:${camerax_version}"

    // TensorFlow Lite
    implementation('org.tensorflow:tensorflow-lite:0.0.0-nightly') { changing = true }
    implementation('org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly') { changing = true }
    implementation('org.tensorflow:tensorflow-lite-support:0.0.0-nightly') { changing = true }
}

4. マニフェストファイルの設定

「CAMERA」のパーミッションを追加します。

<uses-permission android:name="android.permission.CAMERA" />

5. アセットの準備

プロジェクトの「app/src/main/assets」に、「TensorFlow Lite PoseNet Android Demo」のページからダウンロードしたモデルを追加します。

・posenet_mobilenet_v1_100_257x257_multi_kpt_stripped.tflite

6. レイアウトの設定

「activity_main.xml」に「TextureView」と「ImageView」を追加します。

<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:app="http://schemas.android.com/apk/res-auto"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    tools:context=".MainActivity">

    <RelativeLayout android:layout_width="match_parent"
        android:layout_height="match_parent"
        android:background="#FFFFFF">
        <TextureView
            android:id="@+id/texture_view"
            android:layout_width="match_parent"
            android:layout_height="match_parent"/>

        <ImageView
            android:id="@+id/image_view"
            android:layout_width="640px"
            android:layout_height="640px"
            android:layout_margin="16dp"
            android:visibility="visible"
            app:srcCompat="@mipmap/ic_launcher" />
    </RelativeLayout>
</androidx.constraintlayout.widget.ConstraintLayout>

7. UIの作成

画像分類を行うUIを作成します。
以下の処理を行なっています。

・パーミッション
・カメラのプレビューと解析
・PoseNetInterpriterにBitmapを渡して推論(後ほど説明)

◎ MainActivity.java

package net.npaka.posenetex;

import androidx.appcompat.app.AppCompatActivity;
import androidx.camera.core.CameraX;
import androidx.camera.core.ImageAnalysis;
import androidx.camera.core.ImageAnalysisConfig;
import androidx.camera.core.Preview;
import androidx.camera.core.PreviewConfig;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;

import android.Manifest;
import android.content.pm.PackageManager;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.ImageFormat;
import android.graphics.Matrix;
import android.graphics.Point;
import android.graphics.Rect;
import android.graphics.YuvImage;
import android.media.Image;
import android.os.Bundle;
import android.view.TextureView;
import android.view.ViewGroup;
import android.widget.ImageView;
import android.widget.RelativeLayout;
import android.widget.Toast;

import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.util.concurrent.Executors;

// MainActivity
public class MainActivity extends AppCompatActivity {
    // 定数
    private final int REQUEST_CODE_PERMISSIONS = 101;
    private final String[] REQUIRED_PERMISSIONS = new String[]{
        Manifest.permission.CAMERA};

    // UI
    private TextureView textureView;
    private ImageView imageView;

    // 推論
    private PoseNetInterpriter interpriter;

    // 生成時に呼ばれる
    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        // UI
        this.textureView = findViewById(R.id.texture_view);
        this.imageView = findViewById(R.id.image_view);

        // 推論
        this.interpriter = new PoseNetInterpriter(this);

        // パーミッションのチェック
        if (allPermissionsGranted()) {
            this.textureView.post(() -> startCamera());
        } else {
            ActivityCompat.requestPermissions(this,
                REQUIRED_PERMISSIONS, REQUEST_CODE_PERMISSIONS);
        }
    }

    // パーミッション許可のリクエストの結果の取得
    @Override
    public void onRequestPermissionsResult(int requestCode,
        String[] permissions, int[] grantResults) {
        if (requestCode == REQUEST_CODE_PERMISSIONS) {
            if (allPermissionsGranted()) {
                startCamera();
            } else {
                Toast.makeText(this, "ユーザーから権限が許可されていません。",
                    Toast.LENGTH_SHORT).show();
                finish();
            }
        }
    }

    // 全てのパーミッション許可
    private boolean allPermissionsGranted() {
        for (String permission : REQUIRED_PERMISSIONS) {
            if (ContextCompat.checkSelfPermission(this, permission)
                != PackageManager.PERMISSION_GRANTED) {
                return false;
            }
        }
        return true;
    }

    // カメラの開始
    private void startCamera() {
        // プレビュー
        PreviewConfig previewConfig = new PreviewConfig.Builder().build();
        Preview preview = new Preview(previewConfig);
        preview.setOnPreviewOutputUpdateListener(
            output -> {
                // SurfaceTextureの更新
                ViewGroup parent = (ViewGroup)this.textureView.getParent();
                parent.removeView(this.textureView);
                parent.addView(this.textureView, 0);

                // SurfaceTextureをTextureViewに指定
                this.textureView.setSurfaceTexture(output.getSurfaceTexture());

                // レイアウトの調整
                Point point = new Point();
                getWindowManager().getDefaultDisplay().getSize(point);
                int w = point.x;
                int h = point.x * 4 / 3;
                RelativeLayout.LayoutParams params = new RelativeLayout.LayoutParams(w, h);
                params.addRule(RelativeLayout.CENTER_IN_PARENT);
                textureView.setLayoutParams(params);
                params = new RelativeLayout.LayoutParams(w, w);
                params.addRule(RelativeLayout.CENTER_IN_PARENT);
                imageView.setLayoutParams(params);
            });

        // 画像の解析
        ImageAnalysisConfig config = new ImageAnalysisConfig.Builder()
            .setImageReaderMode(ImageAnalysis.ImageReaderMode.ACQUIRE_LATEST_IMAGE)
            .build();
        ImageAnalysis imageAnalysis = new ImageAnalysis(config);
        imageAnalysis.setAnalyzer(Executors.newSingleThreadExecutor(),
            (image, rotationDegrees) -> {
                // 推論
                Bitmap bitmap = imageToToBitmap(image.getImage(), rotationDegrees);
                Bitmap result = this.interpriter.predict(bitmap);
                bitmap.recycle();

                // 結果の表示
                this.imageView.post(() -> {
                    imageView.setImageBitmap(result);
                });
            });

        // バインド
        CameraX.bindToLifecycle(this, preview, imageAnalysis);
    }

    // ImageProxy → Bitmap
    private Bitmap imageToToBitmap(Image image, int rotationDegrees) {
        byte[] data = imageToByteArray(image);
        Bitmap bitmap = BitmapFactory.decodeByteArray(data, 0, data.length);
        if (rotationDegrees == 0) {
            return bitmap;
        } else {
            return rotateBitmap(bitmap, rotationDegrees);
        }
    }

    // Bitmapの回転
    private Bitmap rotateBitmap(Bitmap bitmap, int rotationDegrees) {
        Matrix mat = new Matrix();
        mat.postRotate(rotationDegrees);
        return Bitmap.createBitmap(bitmap, 0, 0,
            bitmap.getWidth(), bitmap.getHeight(), mat, true);
    }

    // Image → JPEGのバイト配列
    private byte[] imageToByteArray(Image image) {
        byte[] data = null;
        if (image.getFormat() == ImageFormat.JPEG) {
            Image.Plane[] planes = image.getPlanes();
            ByteBuffer buffer = planes[0].getBuffer();
            data = new byte[buffer.capacity()];
            buffer.get(data);
            return data;
        } else if (image.getFormat() == ImageFormat.YUV_420_888) {
            data = NV21toJPEG(YUV_420_888toNV21(image),
                    image.getWidth(), image.getHeight());
        }
        return data;
    }

    // YUV_420_888 → NV21
    private byte[] YUV_420_888toNV21(Image image) {
        byte[] nv21;
        ByteBuffer yBuffer = image.getPlanes()[0].getBuffer();
        ByteBuffer uBuffer = image.getPlanes()[1].getBuffer();
        ByteBuffer vBuffer = image.getPlanes()[2].getBuffer();
        int ySize = yBuffer.remaining();
        int uSize = uBuffer.remaining();
        int vSize = vBuffer.remaining();
        nv21 = new byte[ySize + uSize + vSize];
        yBuffer.get(nv21, 0, ySize);
        vBuffer.get(nv21, ySize, vSize);
        uBuffer.get(nv21, ySize + vSize, uSize);
        return nv21;
    }

    // NV21 → JPEG
    private byte[] NV21toJPEG(byte[] nv21, int width, int height) {
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        YuvImage yuv = new YuvImage(nv21, ImageFormat.NV21, width, height, null);
        yuv.compressToJpeg(new Rect(0, 0, width, height), 100, out);
        return out.toByteArray();
    }
}

8. 姿勢推定

Bitmapを受け取り、姿勢推定結果を描画したBitmapを返します。
パラメータ定数を変更することで、別の画像分類モデルにも対応できます。

◎ PoseNetInterpriter.java

package net.npaka.posenetex;

import android.content.Context;
import android.content.res.AssetFileDescriptor;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Paint;
import android.graphics.PointF;
import android.graphics.PorterDuff;
import android.graphics.Rect;

import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.gpu.GpuDelegate;

import java.io.FileInputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

// 姿勢推定インタープリタ
public class PoseNetInterpriter {
    // パラメータ定数
    private static final int BATCH_SIZE = 1; // バッチサイズ
    private static final int INPUT_PIXELS = 3; // 入力ピクセル
    private final static int INPUT_SIZE = 257; // 入力サイズ
    private boolean IS_QUANTIZED = false; // 量子化
    private static final float IMAGE_MEAN = 128.0f;
    private static final float IMAGE_STD = 128.0f;

    // 描画定数
    private static final int[] DRAW_POINT = {0, 1, 2, 3, 4};
    private static final int[][] DRAW_LINE = {
        {5,6},
        {5,7},
        {6,8},
        {7,9},
        {8,10},
        {5,11},
        {6,12},
        {11,12},
        {11,13},
        {12,14},
        {13,15},
        {14,16},
    };

    // システム
    private Context context;
   private Interpreter interpreter;
    private int[] imageBuffer = new int[INPUT_SIZE * INPUT_SIZE];

    // 入力
    private ByteBuffer inBuffer;
    private Bitmap inBitmap;
    private Canvas inCanvas;
    private Rect inBitmapSrc = new Rect();
    private Rect inBitmapDst = new Rect(0, 0, INPUT_SIZE, INPUT_SIZE);

    // 出力
    private float[][][][] outHeatmaps;
    private float[][][][] outOffsets;
    private float[][][][] outDisplacementsFwd;
    private float[][][][] outDplacementsBwd;
    private Bitmap outBitmap;
    private Canvas outCanvas;
    private Paint outPaint;

    // ボディパーツ
    enum BodyPart {
        NOSE,
        LEFT_EYE,
        RIGHT_EYE,
        LEFT_EAR,
        RIGHT_EAR,
        LEFT_SHOULDER,
        RIGHT_SHOULDER,
        LEFT_ELBOW,
        RIGHT_ELBOW,
        LEFT_WRIST,
        RIGHT_WRIST,
        LEFT_HIP,
        RIGHT_HIP,
        LEFT_KNEE,
        RIGHT_KNEE,
        LEFT_ANKLE,
        RIGHT_ANKLE
    }

    // ポジション
    class Position {
       int x = 0;
       int y = 0;
    }

    // キーポイント
    class KeyPoint {
        public BodyPart bodyPart = BodyPart.NOSE;
        Position position= new Position();
        float score = 0.0f;
    }

    // パーソン
    class Person {
        ArrayList<KeyPoint> keyPoints = new ArrayList<>();
        float score = 0.0f;
    }

    // コンストラクタ
    public PoseNetInterpriter(Context context) {
        this.context = context;

        // モデルの読み込み
        MappedByteBuffer model = loadModel("posenet_mobilenet_v1_100_257x257_multi_kpt_stripped.tflite");

        // インタプリタの生成
        Interpreter.Options options = new Interpreter.Options();
        //options.setUseNNAPI(true); //NNAPI
        options.addDelegate(new GpuDelegate()); //GPU
        options.setNumThreads(1); // スレッド数
        this.interpreter = new Interpreter(model, options);

        // 入力の初期化
        this.inBitmap = Bitmap.createBitmap(
            INPUT_SIZE, INPUT_SIZE, Bitmap.Config.ARGB_8888);
        this.inCanvas = new Canvas(inBitmap);
        int numBytesPerChannel = IS_QUANTIZED ? 1 : 4;
        this.inBuffer = ByteBuffer.allocateDirect(
            BATCH_SIZE * INPUT_SIZE * INPUT_SIZE * INPUT_PIXELS * numBytesPerChannel);
        this.inBuffer.order(ByteOrder.nativeOrder());

        // 出力の初期化
        this.outHeatmaps = new float[1][9][9][17];
        this.outOffsets = new float[1][9][9][34];
        this.outDisplacementsFwd = new float[1][9][9][32];
        this.outDplacementsBwd = new float[1][9][9][32];
        this.outBitmap = Bitmap.createBitmap(
            INPUT_SIZE, INPUT_SIZE, Bitmap.Config.ARGB_8888);
        this.outCanvas = new Canvas(outBitmap);
        this.outPaint = new Paint();
    }

    // モデルの読み込み
    private MappedByteBuffer loadModel(String modelPath) {
        try {
            AssetFileDescriptor fd = this.context.getAssets().openFd(modelPath);
            FileInputStream in = new FileInputStream(fd.getFileDescriptor());
            FileChannel fileChannel = in.getChannel();
            return fileChannel.map(FileChannel.MapMode.READ_ONLY,
                fd.getStartOffset(), fd.getDeclaredLength());
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    // 推論
    public Bitmap predict(Bitmap bitmap) {
        // 入力画像の生成
        int minSize = Math.min(bitmap.getWidth(), bitmap.getHeight());
        int dx = (bitmap.getWidth()-minSize)/2;
        int dy = (bitmap.getHeight()-minSize)/2;
        this.inBitmapSrc.set(dx, dy, dx+minSize, dy+minSize);
        inCanvas.drawBitmap(bitmap, this.inBitmapSrc, this.inBitmapDst, null);

        // 入力バッファの生成
        bmpToInBuffer(inBitmap);

        // 推論
        Object[] inputArray = {inBuffer};
        Map<Integer, Object> outputMap = new HashMap<>();
        outputMap.put(0, this.outHeatmaps);
        outputMap.put(1, this.outOffsets);
        outputMap.put(2, this.outDisplacementsFwd);
        outputMap.put(3, this.outDplacementsBwd);
        this.interpreter.runForMultipleInputsOutputs(inputArray, outputMap);

        // 結果の取得
        int height = outHeatmaps[0].length;
        int width = outHeatmaps[0][0].length;
        int numKeypoints = outHeatmaps[0][0][0].length;

        // キーポイントが最も正しそうな位置(row, col)を見つける
        ArrayList<PointF> keypointPositions = new ArrayList<>();
        for (int keypoint = 0; keypoint < numKeypoints; keypoint++) {
            float maxVal = outHeatmaps[0][0][0][keypoint];
            float maxRow = 0;
            float maxCol = 0;
            for (int row = 0; row < height; row++) {
                for (int col = 0; col < width; col++) {
                    if (outHeatmaps[0][row][col][keypoint] > maxVal) {
                        maxVal = outHeatmaps[0][row][col][keypoint];
                        maxRow = row;
                        maxCol = col;
                    }
                }
            }
            keypointPositions.add(new PointF(maxRow, maxCol));
        }

        // キーポイントのx座標とy座標をオフセットで調整
        int[] xCoords = new int[numKeypoints];
        int[] yCoords = new int[numKeypoints];
        float[] confidenceScores = new float[numKeypoints];
        for (int idx = 0; idx < keypointPositions.size(); idx++) {
            PointF position = keypointPositions.get(idx);
            int positionY = (int)keypointPositions.get(idx).x;
            int positionX = (int)keypointPositions.get(idx).y;
            yCoords[idx] = (int)(
                position.x / (float)(height - 1) * INPUT_SIZE +
                this.outOffsets[0][positionY][positionX][idx]);
            xCoords[idx] = (int)(
                position.y / (float)(width - 1) * INPUT_SIZE +
                this.outOffsets[0][positionY][positionX][idx + numKeypoints]);
            confidenceScores[idx] = sigmoid(outHeatmaps[0][positionY][positionX][idx]);
        }

         // Personの生成
        Person person = new Person();
        ArrayList<KeyPoint> keyPoints = new ArrayList<>();
        float totalScore = 0.0f;
        for (int idx = 0; idx < numKeypoints; idx++) {
            KeyPoint keyPoint = new KeyPoint();
            keyPoint.bodyPart = BodyPart.values()[idx];
            keyPoint.position.x = xCoords[idx];
            keyPoint.position.y = yCoords[idx];
            keyPoint.score = confidenceScores[idx];
            totalScore += confidenceScores[idx];
            keyPoints.add(keyPoint);
        }
        person.keyPoints = keyPoints;
        person.score = totalScore / numKeypoints;

        //出力画像の生成
        this.outCanvas.drawColor(Color.TRANSPARENT, PorterDuff.Mode.CLEAR);
        this.outPaint.setTextSize(12);
        this.outPaint.setAntiAlias(true);
        this.outPaint.setStyle(Paint.Style.FILL);
        this.outPaint.setColor(Color.RED);
        this.outPaint.setStrokeWidth(2);
        for (int idx : DRAW_POINT) {
            this.outCanvas.drawCircle(
                 keyPoints.get(idx).position.x,
                 keyPoints.get(idx).position.y,
                 3, this.outPaint);
        }
        for (int[] idx : DRAW_LINE) {
            this.outCanvas.drawLine(
                keyPoints.get(idx[0]).position.x,
                keyPoints.get(idx[0]).position.y,
                keyPoints.get(idx[1]).position.x,
                keyPoints.get(idx[1]).position.y,
                this.outPaint);
        }
        return this.outBitmap;
    }

    // シグモイド
    private float sigmoid(float x) {
        return (float)(1.0f / (1.0f + Math.exp(-x)));
    }

    // Bitmap → 入力バッファ
    private void bmpToInBuffer(Bitmap bitmap) {
        this.inBuffer.rewind();
        bitmap.getPixels(this.imageBuffer, 0, bitmap.getWidth(),
            0, 0, bitmap.getWidth(), bitmap.getHeight());
        int pixel = 0;
        for (int i = 0; i < INPUT_SIZE; ++i) {
            for (int j = 0; j < INPUT_SIZE; ++j) {
                int pixelValue = imageBuffer[pixel++];
                if (IS_QUANTIZED) {
                    inBuffer.put((byte)((pixelValue >> 16) & 0xFF));
                    inBuffer.put((byte)((pixelValue >> 8) & 0xFF));
                    inBuffer.put((byte)(pixelValue & 0xFF));
                } else {
                    inBuffer.putFloat((((pixelValue >> 16) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
                    inBuffer.putFloat((((pixelValue >> 8) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
                    inBuffer.putFloat(((pixelValue & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
                }
            }
        }
    }
}


この記事が気に入ったらサポートをしてみませんか?