見出し画像

iOSアプリ開発 入門 (11) - CoreMLによる物体検出モデルの推論

「CreateML」による物体検出モデルの推論手順をまとめました。

・iOS 14.7 ※ iOS 14.5/14.6ではエラーが発生

前回

1. 物体検出

「物体検出」は、画像の中からあらかじめ定義された物体の位置とラベルを検出するタスクです。

「CoreML」で物体検出の推論を行うサンプルが、以下のサイトで提供されています。

2. Info.plist

Info.plist」に以下の項目を設定します。

・Privacy - Camera Usage Description : カメラの用途の説明。

3. コード

コードは、次のとおりです。

(1) 「ViewController」を以下のように編集。
「ビデオキャプチャー」と「プレビュー」を実装しています。StoryboardのViewControllerのビューとpreviewViewを関連付けてください。

import UIKit
import AVFoundation
import Vision

// ViewController
class ViewController: UIViewController, AVCaptureVideoDataOutputSampleBufferDelegate {
    // サブクラスで利用
    var bufferSize: CGSize = .zero // バッファサイズ
    var rootLayer: CALayer! = nil // ルートレイヤー
   
    //参照
    @IBOutlet weak private var previewView: UIView!
   
    // ビデオキャプチャー
    private let session = AVCaptureSession()
    private var previewLayer: AVCaptureVideoPreviewLayer! = nil
    private let videoDataOutput = AVCaptureVideoDataOutput()
    private let videoDataOutputQueue = DispatchQueue(label: "VideoDataOutput",
        qos: .userInitiated, attributes: [], autoreleaseFrequency: .workItem)
   
   
//====================
// ライフサイクル
//====================
    // ビューロード時に呼ばれる
    override func viewDidLoad() {
        super.viewDidLoad()
        
        // ビデオキャプチャーのセットアップ
        setupAVCapture()
    }
   
   
//====================
// ビデオキャプチャー
//====================
    // ビデオキャプチャーのセットアップ
    func setupAVCapture() {
        // ビデオ入力の生成
        var deviceInput: AVCaptureDeviceInput!
        let videoDevice = AVCaptureDevice.DiscoverySession(deviceTypes: [.builtInWideAngleCamera], mediaType: .video, position: .back).devices.first
        do {
            deviceInput = try AVCaptureDeviceInput(device: videoDevice!)
        } catch {
            print("ビデオ入力の生成に失敗しました。")
            return
        }
       
        // ビデオキャプチャーの設定開始
        session.beginConfiguration()
       
        // 画像解像度の設定 (モデルの画像サイズに近いものを指定)
        session.sessionPreset = .vga640x480
       
        // ビデオ入力の追加
        guard session.canAddInput(deviceInput) else {
            print("ビデオ入力の追加に失敗しました。")
            session.commitConfiguration()
            return
        }
        session.addInput(deviceInput)
       
        // ビデオ出力の追加
        if session.canAddOutput(videoDataOutput) {
            session.addOutput(videoDataOutput)
            videoDataOutput.alwaysDiscardsLateVideoFrames = true
            videoDataOutput.videoSettings = [kCVPixelBufferPixelFormatTypeKey as String: Int(kCVPixelFormatType_420YpCbCr8BiPlanarFullRange)]
            videoDataOutput.setSampleBufferDelegate(self, queue: videoDataOutputQueue)
        } else {
            print("ビデオ出力の追加に失敗しました。")
            session.commitConfiguration()
            return
        }
       
        // ビデオキャプチャーのコネクションの設定
        let captureConnection = videoDataOutput.connection(with: .video)
        captureConnection?.isEnabled = true
       
        // バッファサイズの取得
        do {
            try  videoDevice!.lockForConfiguration()
            let dimensions = CMVideoFormatDescriptionGetDimensions((videoDevice?.activeFormat.formatDescription)!)
            bufferSize.width = CGFloat(dimensions.width)
            bufferSize.height = CGFloat(dimensions.height)
            videoDevice!.unlockForConfiguration()
        } catch {
            print("ERROR1>>", error)
        }
       
        // セッションの設定終了
        session.commitConfiguration()
       
        // プレビューレイヤーの追加
        previewLayer = AVCaptureVideoPreviewLayer(session: session)
        previewLayer.videoGravity = AVLayerVideoGravity.resizeAspectFill
        rootLayer = previewView.layer
        previewLayer.frame = rootLayer.bounds
        rootLayer.addSublayer(previewLayer)
    }
   
    // ビデオキャプチャーの開始
    func startCaptureSession() {
        session.startRunning()
    }
   
    // ビデオキャプチャーの破棄
    func teardownAVCapture() {
        previewLayer.removeFromSuperlayer()
        previewLayer = nil
    }
   
   
//====================
// AVCaptureVideoDataOutputSampleBufferDelegate
//====================
    // ビデオフレーム更新時に呼ばれる
    func captureOutput(_ output: AVCaptureOutput,
        didOutput sampleBuffer: CMSampleBuffer, from connection: AVCaptureConnection) {
        // サブクラスで実装
    }
   
    // ビデオフレームの破棄時に呼ばれる
    func captureOutput(_ captureOutput: AVCaptureOutput,
        didDrop didDropSampleBuffer: CMSampleBuffer, from connection: AVCaptureConnection) {
        // サブクラスで実装
    }
   
   
//====================
// ユーティリティ
//====================
    // デバイス向きに応じたExifOrientationの取得
    public func exifOrientationFromDeviceOrientation() -> CGImagePropertyOrientation {
        let curDeviceOrientation = UIDevice.current.orientation
        let exifOrientation: CGImagePropertyOrientation
        switch curDeviceOrientation {
        case UIDeviceOrientation.portraitUpsideDown:
            exifOrientation = .left
        case UIDeviceOrientation.landscapeLeft:
            exifOrientation = .upMirrored
        case UIDeviceOrientation.landscapeRight:
            exifOrientation = .down
        case UIDeviceOrientation.portrait:
            exifOrientation = .up
        default:
            exifOrientation = .up
        }
        return exifOrientation
    }
}

(2) 「ViewController」を継承した「VisionObjectRecognitionViewController」を追加。
captureOutput()をオーバーライドして、物体検出の機能を追加しています。

import UIKit
import AVFoundation
import Vision

// VisionObjectRecognitionViewController
class VisionObjectRecognitionViewController: ViewController {
    // オーバーレイ
    private var detectionOverlay: CALayer! = nil
   
    // リクエスト
    private var requests = [VNRequest]()
   
   
//====================
// Vision
//====================
    // Visionのセットアップ
    @discardableResult
    func setupVision() -> NSError? {
        let error: NSError! = nil
        
        // モデルの読み込み
        guard let modelURL = Bundle.main.url(forResource: "ObjectDetector", withExtension: "mlmodelc") else {
            return NSError(domain: "VisionObjectRecognitionViewController", code: -1, userInfo: [NSLocalizedDescriptionKey: "モデルの読み込みに失敗しました。"])
        }
        do {
            let visionModel = try VNCoreMLModel(for: MLModel(contentsOf: modelURL))
            
            // リクエストの生成
            let objectRecognition = VNCoreMLRequest(model: visionModel, completionHandler: {(request, error) in
                // リクエスト結果の処理
                DispatchQueue.main.async(execute: {
                    if let results = request.results {
                        self.drawVisionRequestResults(results)
                    }
                })
            })
            self.requests = [objectRecognition]
        } catch {
            print("モデルの読み込みに失敗しました。")
        }
        return error
    }
   
    // リクエスト結果の表示
    func drawVisionRequestResults(_ results: [Any]) {
        // トランザクションの開始
        CATransaction.begin()
        CATransaction.setValue(kCFBooleanTrue, forKey: kCATransactionDisableActions)
       
        // サブレイヤーの削除
        detectionOverlay.sublayers = nil
       
        // サブレイヤーの追加
        for observation in results where observation is VNRecognizedObjectObservation {
            guard let objectObservation = observation as? VNRecognizedObjectObservation else {
                continue
            }
            let topLabelObservation = objectObservation.labels[0] // ラベル
            let objectBounds = VNImageRectForNormalizedRect( // 領域
                objectObservation.boundingBox, Int(bufferSize.width), Int(bufferSize.height))
            let shapeLayer = self.createRoundedRectLayerWithBounds(objectBounds) // シェイプレイヤー
            let textLayer = self.createTextSubLayerInBounds(objectBounds, // テキストレイヤー
                identifier: topLabelObservation.identifier, confidence: topLabelObservation.confidence)
            shapeLayer.addSublayer(textLayer)
            detectionOverlay.addSublayer(shapeLayer)
        }
        self.updateLayerGeometry()
       
        // トランザクションの終了
        CATransaction.commit()
    }
   
   
//====================
// ビデオキャプチャー
//====================
    // ビデオキャプチャーのセットアップ
    override func setupAVCapture() {
        super.setupAVCapture()
        setupLayers()
        updateLayerGeometry()
        setupVision()
       
        // ビデオキャプチャーの開始
        startCaptureSession()
    }
   
    // オーバーレイの生成
    func setupLayers() {
        detectionOverlay = CALayer()
        detectionOverlay.name = "DetectionOverlay"
        detectionOverlay.bounds = CGRect(x: 0.0, y: 0.0, width: bufferSize.width, height: bufferSize.height)
        detectionOverlay.position = CGPoint(x: rootLayer.bounds.midX, y: rootLayer.bounds.midY)
        rootLayer.addSublayer(detectionOverlay)
    }
   
    // オーバーレイの配置の更新
    func updateLayerGeometry() {
        let bounds = rootLayer.bounds
        var scale: CGFloat
        let xScale: CGFloat = bounds.size.width / bufferSize.height
        let yScale: CGFloat = bounds.size.height / bufferSize.width
        scale = fmax(xScale, yScale)
        if scale.isInfinite {
            scale = 1.0
        }
       
        // トランザクションの開始
        CATransaction.begin()
        CATransaction.setValue(kCFBooleanTrue, forKey: kCATransactionDisableActions)
       
        // オーバーレイの配置の更新
        detectionOverlay.setAffineTransform(CGAffineTransform(
            rotationAngle: CGFloat(.pi / 2.0)).scaledBy(x: scale, y: -scale))
        detectionOverlay.position = CGPoint(x: bounds.midX, y: bounds.midY)
        
        // トランザクションのコミット
        CATransaction.commit()
    }
   
    // シェイプレイヤーの生成
    func createRoundedRectLayerWithBounds(_ bounds: CGRect) -> CALayer {
        let shapeLayer = CALayer()
        shapeLayer.bounds = bounds
        shapeLayer.position = CGPoint(x: bounds.midX, y: bounds.midY)
        shapeLayer.name = "Found Object"
        shapeLayer.backgroundColor = CGColor(colorSpace: CGColorSpaceCreateDeviceRGB(),
            components: [1.0, 1.0, 0.2, 0.4])
        shapeLayer.cornerRadius = 7
        return shapeLayer
    }
   
    // テキストレイヤーの生成
    func createTextSubLayerInBounds(_ bounds: CGRect, identifier: String,
        confidence: VNConfidence) -> CATextLayer {
        let textLayer = CATextLayer()
        textLayer.name = "Object Label"
        let formattedString = NSMutableAttributedString(
            string: String(format: "\(identifier)\nConfidence:  %.2f", confidence))
        let largeFont = UIFont(name: "Helvetica", size: 24.0)!
        formattedString.addAttributes([NSAttributedString.Key.font: largeFont],
            range: NSRange(location: 0, length: identifier.count))
        textLayer.string = formattedString
        textLayer.bounds = CGRect(x: 0, y: 0, width: bounds.size.height - 10, height: bounds.size.width - 10)
        textLayer.position = CGPoint(x: bounds.midX, y: bounds.midY)
        textLayer.shadowOpacity = 0.7
        textLayer.shadowOffset = CGSize(width: 2, height: 2)
        textLayer.foregroundColor = CGColor(colorSpace: CGColorSpaceCreateDeviceRGB(),
            components: [0.0, 0.0, 0.0, 1.0])
        textLayer.contentsScale = 2.0 // Retina
        textLayer.setAffineTransform(CGAffineTransform( // 回転
            rotationAngle: CGFloat(.pi / 2.0)).scaledBy(x: 1.0, y: -1.0))
        return textLayer
    }
   
   
//====================
// AVCaptureVideoDataOutputSampleBufferDelegate
//====================
    // ビデオフレーム更新時に呼ばれる
    override func captureOutput(_ output: AVCaptureOutput,
        didOutput sampleBuffer: CMSampleBuffer, from connection: AVCaptureConnection) {
        // ピクセルバッファの取得
        guard let pixelBuffer = CMSampleBufferGetImageBuffer(sampleBuffer) else {
            return
        }
       
        // ExifOrientationの取得
        let exifOrientation = exifOrientationFromDeviceOrientation()
       
        // リクエストの実行
        let imageRequestHandler = VNImageRequestHandler(
            cvPixelBuffer: pixelBuffer, orientation: exifOrientation, options: [:])
        do {
            try imageRequestHandler.perform(self.requests)
        } catch {
            print("ERROR2>>>", error)
        }
    }
}

4. モデルの変更

CreateMLで学習した物体検出モデルは、プロジェクトのTargetに追加し、「VisionObjectRecognitionViewController」のモデル名を変更するだけで、物体検出の推論に利用できます。

// モデルの読み込み
guard let modelURL = Bundle.main.url(forResource: "ObjectDetector", withExtension: "mlmodelc") else {
    return NSError(domain: "VisionObjectRecognitionViewController", code: -1, userInfo: [NSLocalizedDescriptionKey: "モデルの読み込みに失敗しました。"])
}

次回


この記事が気に入ったらサポートをしてみませんか?