見出し画像

Android ✖︎ 画像認識


🤖 TensorFlow・TensorFlow Lite

📲 Androidでの機械学習

🐋 詳細

  • TensorFlow Lite用のMLモデル:
    Kaggleから訓練済みMLモデルをダウンロードする


  • アフリカ

  • アジア

  • ユーロッパ

  • 北米

  • 南アメリカ

  • オセアニア


  • モデルをPJに追加する

src/main/assetsに利用したいMLモデルを追加する
  • ライブラリ(CameraX・TensorFlow)を build.gradle.kts に追加する

// build.gradle.kts

val cameraXVersion = "1.3.2"

// CameraX
implementation("androidx.camera:camera-core:$cameraXVersion")
implementation("androidx.camera:camera-camera2:$cameraXVersion")
implementation("androidx.camera:camera-lifecycle:$cameraXVersion")
implementation("androidx.camera:camera-video:$cameraXVersion")
implementation("androidx.camera:camera-view:$cameraXVersion")
implementation("androidx.camera:camera-extensions:$cameraXVersion")

...

// TensorFlow
implementation("org.tensorflow:tensorflow-lite-task-vision:0.4.0")
implementation("org.tensorflow:tensorflow-lite-gpu-delegate-plugin:0.4.0")
implementation("org.tensorflow:tensorflow-lite-gpu:2.9.0")
  • Manifestにカメラ権限を追加する

// AndroidManifest.xml

<uses-feature
    android:name="android.hardware.camera"
    android:required="false" />
<uses-permission android:name="android.permission.CAMERA" />
  • 追加したMLモデルを使う

    • 認識結果のモデルを定義する

// Classification.kt

data class Classification(
    val name: String,
    val score: Float
)
  • 認識ロジック

// Classifier.kt 

interface Classifier {
    fun classify(bitmap: Bitmap, rotation: Int): List<Classification>
}
// ImageClassifier.kt

class ImageClassifier(
    private val context: Context,
    private val threshold: Float = 0.5f,
    private val maxResults: Int = 3
): Classifier {

    private var classifier: ImageClassifier? = null

    private fun setupClassifier() {
        val baseOptions = BaseOptions
		        .builder()
            .setNumThreads(2)
            .build()
        val options = ImageClassifier.ImageClassifierOptions
		        .builder()
            .setBaseOptions(baseOptions)
            .setMaxResults(maxResults)
            .setScoreThreshold(threshold)
            .build()
            
        // モデルを使う
        try {
            classifier = ImageClassifier.createFromFileAndOptions(
                context,
                "landmark_asia.tflite",
                options
            )
        } catch (e: IllegalStateException) {
            e.printStackTrace()
        }
    }

    override fun classify(bitmap: Bitmap, rotation: Int): List<Classification> {
        if(classifier == null) {
            setupClassifier()
        }

        val imageProcessor = ImageProcessor.Builder().build()
        val tensorImage = imageProcessor.process(TensorImage.fromBitmap(bitmap))

        val imageProcessingOptions = ImageProcessingOptions
		        .builder()
            .setOrientation(getOrientationFromRotation(rotation))
            .build()

        val results = classifier?.classify(tensorImage, imageProcessingOptions)

        return results?.flatMap { classification ->
            classification.categories.map { category ->
                Classification(
                    name = category.displayName,
                    score = category.score
                )
            }
        }?.distinctBy { it.name } ?: emptyList()
    }

		// インプットの角度処理
    private fun getOrientationFromRotation(rotation: Int): ImageProcessingOptions.Orientation {
        return when(rotation) {
            Surface.ROTATION_270 -> ImageProcessingOptions.Orientation.BOTTOM_RIGHT
            Surface.ROTATION_90 -> ImageProcessingOptions.Orientation.TOP_LEFT
            Surface.ROTATION_180 -> ImageProcessingOptions.Orientation.RIGHT_BOTTOM
            else -> ImageProcessingOptions.Orientation.RIGHT_TOP
        }
    }
}
  • CameraXの分析ロジック

// ImageAnalyzer.kt

class ImageAnalyzer(
    private val classifier: Classifier,
    private val onResults: (List<Classification>) -> Unit
): ImageAnalysis.Analyzer {

		// 認識結果が変動しすぎないように、counterを追加する
    private var frameSkipCounter = 0

    override fun analyze(image: ImageProxy) {
        if(frameSkipCounter % 60 == 0) {
            val rotationDegrees = image.imageInfo.rotationDegrees
            // インプット画像のフォマット
            val bitmap = image
                .toBitmap()
                .centerCrop(
                    desiredWidth = 321,
                    desiredHeight = 321
                )

            val results = classifier.classify(bitmap, rotationDegrees)
            onResults(results)
        }
        frameSkipCounter++

        image.close()
    }
}
// BitmapExtension.kt

fun Bitmap.centerCrop(desiredWidth: Int, desiredHeight: Int): Bitmap {
    val xStart = (width - desiredWidth) / 2
    val yStart = (height - desiredHeight) / 2

    if(xStart < 0 || yStart < 0 || desiredWidth > width || desiredHeight > height) {
        throw IllegalArgumentException("Invalid arguments for center cropping")
    }

    return Bitmap.createBitmap(this, xStart, yStart, desiredWidth, desiredHeight)
}
  • UI画面

// MainActivity.kt

class MainActivity : ComponentActivity() {
    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        
        if(!hasCameraPermission()) {
            ActivityCompat.requestPermissions(
                this, arrayOf(Manifest.permission.CAMERA), 0
            )
        }
        
        setContent {
            LandmarkRecognitionTensorflowTheme {
			          // 認識結果リスト
                var classifications by remember {
                    mutableStateOf(emptyList<Classification>())
                }
                
                // CameraXのアナライザ
                val analyzer = remember {
                    ImageAnalyzer(
                        classifier = ImageClassifier(
                            context = applicationContext
                        ),
                        onResults = { resultList ->
                            classifications = resultList
                        }
                    )
                }
                
                // CameraXのコントローラーを設置
                val controller = remember {
                    LifecycleCameraController(applicationContext).apply {
                        setEnabledUseCases(CameraController.IMAGE_ANALYSIS)
                        setImageAnalysisAnalyzer(
                            ContextCompat.getMainExecutor(applicationContext),
                            analyzer
                        )
                    }
                }
                
                Box(
                    modifier = Modifier
                        .fillMaxSize()
                ) {
                    CameraScreen(controller, Modifier.fillMaxSize())

                    Column(
                        modifier = Modifier
                            .fillMaxWidth()
                            .align(Alignment.TopCenter)
                    ) {
                        classifications.forEach { classification ->
                            val percentage = "%.1f".format(classification.score * 100)

														// テキストで認識結果を表示
                            Text(
                                text = "${classification.name}\n($percentage%)",
                                modifier = Modifier
                                    .fillMaxWidth()
                                    .background(MaterialTheme.colorScheme.primaryContainer)
                                    .padding(8.dp),
                                textAlign = TextAlign.Center,
                                fontSize = 20.sp,
                                color = MaterialTheme.colorScheme.primary
                            )
                        }
                    }
                }
            }
        }
    }

		// 権限チェック
    private fun hasCameraPermission() = ContextCompat.checkSelfPermission(
        this, Manifest.permission.CAMERA
    ) == PackageManager.PERMISSION_GRANTED
}
// CameraScreen.kt

@Composable
fun CameraScreen(
    controller: LifecycleCameraController,
    modifier: Modifier = Modifier
) {
    val lifecycleOwner = LocalLifecycleOwner.current
    
    AndroidView(
        factory = { context ->
		        // CameraXのPreviewViewを使う
            PreviewView(context).apply {
                this.controller = controller
                // ライフサイクルにbind
                controller.bindToLifecycle(lifecycleOwner)
            }
        },
        modifier = modifier
    )
}
  • 実機で確認する

タージ・マハル
アンコール・ワット

💭 その他

  • 認識結果の精度を上げたい場合は、自分で訓練したモデルを利用することをおすすめ

  • 世界名勝データセット


この記事が気に入ったらサポートをしてみませんか?