TensorFlow Lite入門 / iOSによる画像分類
1. iOSによる画像分類
iOSで「TensorFlow Lite」を使って画像分類を行ます。端末の背面カメラから見えるものをリアルタイムに画像分類し、可能性の高いラベル3つを表示します。
◎バージョン
・Xcode 10.3
・Swift 5
・TensorFlowLiteSwift 1.14.0
2. プロジェクトへのTensorFlow Liteフレームワークの追加
プロジェクトへのTensorFlow Liteフレームワークを追加するには、CocosPodsを使います。「pod init」でProfileを生成し、PodfileにTensorFlow Liteのフレームワークを追加し、最後に「pod install」でフレームワークの追加を実行します。
platform :ios, '12.0'
target 'CaptureClassificationEx' do
use_frameworks!
pod 'TensorFlowLiteSwift'
end
以降、プロジェクトを開く時、ImageClassification.xcworkspaceをダブルクリックします。
3. リソース
プロジェクトには、Image classificationからダウンロードしたTensorFlow Liteモデルとラベルを追加します。
・mobilenet_v1_1.0_224_quant.tflite
・labels_mobilenet_quant_v1_224.txt
4. 推論
「推論」を行っているのは、以下のコードです。
具体的には、interpreter.input()で入力バッファを取得し、interpreter.copy()で入力バッファを指定し、try interpreter.invoke()で推論し、interpreter.output()で出力バッファを取得するのみで簡単です。
難しい(めんどくさい)のは、「カメラの準備」と「入力バッファの生成」と「出力バッファの解析」になります。
//予測
func predict(_ sampleBuffer: CMSampleBuffer) {
//CMSampleBufferをCVPixelBufferに変換
let pixelBuffer = CMSampleBufferGetImageBuffer(sampleBuffer)!
//Pixelフォーマットの確認
let sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer)
assert(sourcePixelFormat == kCVPixelFormatType_32ARGB ||
sourcePixelFormat == kCVPixelFormatType_32BGRA ||
sourcePixelFormat == kCVPixelFormatType_32RGBA)
//画像のクロップとスケーリング
let scaledSize = CGSize(width: INPUT_WIDTH, height: INPUT_HEIGHT)
guard let cropPixelBuffer = pixelBuffer.centerThumbnail(ofSize: scaledSize) else {
return
}
let outputTensor: Tensor
do {
//RGBデータの生成
let inputTensor = try interpreter.input(at: 0)
let rgbData = buffer2rgbData(
cropPixelBuffer,
byteCount: BATCH_SIZE * INPUT_WIDTH * INPUT_HEIGHT * INPUT_CHANNELS,
isModelQuantized: inputTensor.dataType == .uInt8)
//推論の実行
try interpreter.copy(rgbData!, toInputAt: 0)
try interpreter.invoke()
outputTensor = try interpreter.output(at: 0)
} catch let error {
print(error.localizedDescription)
return
}
var results: [Float] = []
switch outputTensor.dataType {
//量子化モデル
case .uInt8:
let quantization = outputTensor.quantizationParameters!
let quantizedResults = [UInt8](outputTensor.data)
results = quantizedResults.map{
quantization.scale * Float(Int($0) - quantization.zeroPoint)}
//浮動少数モデル
case .float32:
results = [Float32](unsafeData: outputTensor.data) ?? []
//その他
default:
return
}
//検出結果の取得
var text: String = "\n"
let zippedResults = zip(labels.indices, results)
let sortedResults = zippedResults.sorted {$0.1 > $1.1}.prefix(3) //上位3件ソート
for result in sortedResults {
let probabillity = Int(result.1*100) //信頼度
let label = labels[result.0] //ID
text += "\(label) : \(probabillity)%\n"
}
//UIの更新
DispatchQueue.main.async {
self.lblText.text = text
}
}
5. ソースコード全体
ソースコード全体は次の通りです。
import UIKit
import AVFoundation
import TensorFlowLite
import Accelerate
//画像分類(カメラ映像)
class ViewController: UIViewController,
AVCaptureVideoDataOutputSampleBufferDelegate {
//UI
@IBOutlet weak var lblText: UILabel!
@IBOutlet weak var drawView: UIView!
var previewLayer: AVCaptureVideoPreviewLayer!
//パラメータ
let BATCH_SIZE = 1 //バッチサイズ
let INPUT_CHANNELS = 3 //入力チャンネル
let INPUT_WIDTH = 224 //入力幅
let INPUT_HEIGHT = 224 //入力高さ
let THREAD_COUNT = 1 //スレッド数
//参照
var interpreter: Interpreter! //インタプリタ
var labels: [String]! //ラベル
//====================
//ライフサイクル
//====================
//ビュー表示時に呼ばれる
override func viewDidAppear(_ animated: Bool) {
do {
//モデルパスの生成
let modelPath = Bundle.main.path(
forResource: "mobilenet_v1_1.0_224_quant",
ofType: "tflite")!
//インタプリタオプションの生成
var options = InterpreterOptions()
options.threadCount = THREAD_COUNT
//インタプリタの生成
interpreter = try Interpreter(modelPath: modelPath, options: options)
try interpreter.allocateTensors()
//ラベルURLの生成
let labelURL = Bundle.main.url(
forResource: "labels_mobilenet_quant_v1_224",
withExtension: "txt")!
//ラベルの読み込み
let contents = try String(contentsOf: labelURL, encoding: .utf8)
labels = contents.components(separatedBy: .newlines)
} catch let error {
print(error.localizedDescription)
}
//カメラキャプチャの開始
startCapture()
}
//====================
//カメラキャプチャ
//====================
//カメラキャプチャの開始
func startCapture() {
//セッションの生成
let captureSession = AVCaptureSession()
captureSession.sessionPreset = AVCaptureSession.Preset.photo //プリセット
let captureDevice: AVCaptureDevice! = self.device(false)
//コンフィギュレーションの指定
do {
try captureDevice.lockForConfiguration()
captureDevice.activeVideoMinFrameDuration = CMTimeMake(value: 1, timescale: 20) //FPS
captureDevice.focusMode = .continuousAutoFocus //フォーカス
captureDevice.exposureMode = .continuousAutoExposure //露出
captureDevice.whiteBalanceMode = .continuousAutoWhiteBalance //ホワイトバランス
captureDevice.unlockForConfiguration()
} catch {
return
}
//入力の生成
guard let input = try? AVCaptureDeviceInput(device: captureDevice) else {return}
guard captureSession.canAddInput(input) else {return}
captureSession.addInput(input)
//出力の生成
let output: AVCaptureVideoDataOutput = AVCaptureVideoDataOutput()
output.setSampleBufferDelegate(self, queue: DispatchQueue(label: "VideoQueue"))
output.videoSettings = [String(kCVPixelBufferPixelFormatTypeKey) : kCMPixelFormat_32BGRA] //画像フォーマット
output.alwaysDiscardsLateVideoFrames = true //出力の遅延フレームの破棄
guard captureSession.canAddOutput(output) else {return}
captureSession.addOutput(output)
//画面の向き
let videoConnection = output.connection(with: AVMediaType.video)
videoConnection!.videoOrientation = .portrait
//プレビューの指定
previewLayer = AVCaptureVideoPreviewLayer(session: captureSession)
previewLayer.videoGravity = AVLayerVideoGravity.resizeAspectFill
previewLayer.frame = self.drawView.frame
self.view.layer.insertSublayer(previewLayer, at: 0)
//カメラキャプチャの開始
captureSession.startRunning()
}
//デバイスの取得
func device(_ frontCamera: Bool) -> AVCaptureDevice! {
//AVCaptureDeviceのリストの取得
let deviceDiscoverySession = AVCaptureDevice.DiscoverySession(
deviceTypes: [AVCaptureDevice.DeviceType.builtInWideAngleCamera],
mediaType: AVMediaType.video,
position: AVCaptureDevice.Position.unspecified)
let devices = deviceDiscoverySession.devices
//指定したポジションを持つAVCaptureDeviceの検索
let position: AVCaptureDevice.Position = frontCamera ? .front : .back
for device in devices {
if device.position == position {
return device
}
}
return nil
}
//カメラキャプチャの取得時に呼ばれる
func captureOutput(_ output: AVCaptureOutput,
didOutput sampleBuffer: CMSampleBuffer,
from connection: AVCaptureConnection) {
//予測
predict(sampleBuffer)
}
//====================
//画像分類(カメラ映像)
//====================
//予測
func predict(_ sampleBuffer: CMSampleBuffer) {
//CMSampleBufferをCVPixelBufferに変換
let pixelBuffer = CMSampleBufferGetImageBuffer(sampleBuffer)!
//Pixelフォーマットの確認
let sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer)
assert(sourcePixelFormat == kCVPixelFormatType_32ARGB ||
sourcePixelFormat == kCVPixelFormatType_32BGRA ||
sourcePixelFormat == kCVPixelFormatType_32RGBA)
//画像のクロップとスケーリング
let scaledSize = CGSize(width: INPUT_WIDTH, height: INPUT_HEIGHT)
guard let cropPixelBuffer = pixelBuffer.centerThumbnail(ofSize: scaledSize) else {
return
}
let outputTensor: Tensor
do {
//RGBデータの生成
let inputTensor = try interpreter.input(at: 0)
let rgbData = buffer2rgbData(
cropPixelBuffer,
byteCount: BATCH_SIZE * INPUT_WIDTH * INPUT_HEIGHT * INPUT_CHANNELS,
isModelQuantized: inputTensor.dataType == .uInt8)
//推論の実行
try interpreter.copy(rgbData!, toInputAt: 0)
try interpreter.invoke()
outputTensor = try interpreter.output(at: 0)
} catch let error {
print(error.localizedDescription)
return
}
var results: [Float] = []
//量子化モデル
if outputTensor.dataType == .uInt8 {
let quantization = outputTensor.quantizationParameters!
let quantizedResults = [UInt8](outputTensor.data)
results = quantizedResults.map{
quantization.scale * Float(Int($0) - quantization.zeroPoint)}
}
//浮動少数モデル
else if outputTensor.dataType == .float32 {
results = [Float32](unsafeData: outputTensor.data) ?? []
}
//検出結果の取得
var text: String = "\n"
let zippedResults = zip(labels.indices, results)
let sortedResults = zippedResults.sorted {$0.1 > $1.1}.prefix(3) //上位3件ソート
for result in sortedResults {
let probabillity = Int(result.1*100) //信頼度
let label = labels[result.0] //ID
text += "\(label) : \(probabillity)%\n"
}
//UIの更新
DispatchQueue.main.async {
self.lblText.text = text
}
}
//PixelBuffer→rgbData
private func buffer2rgbData(_ buffer: CVPixelBuffer,
byteCount: Int, isModelQuantized: Bool) -> Data? {
//PixelBuffer→bufferData
CVPixelBufferLockBaseAddress(buffer, .readOnly)
defer { CVPixelBufferUnlockBaseAddress(buffer, .readOnly) }
guard let mutableRawPointer = CVPixelBufferGetBaseAddress(buffer) else {
return nil
}
let count = CVPixelBufferGetDataSize(buffer)
let bufferData = Data(bytesNoCopy: mutableRawPointer,
count: count, deallocator: .none)
//bufferData→rgbBytes
var rgbBytes = [UInt8](repeating: 0, count: byteCount)
var index = 0
for component in bufferData.enumerated() {
let offset = component.offset
let isAlphaComponent = (offset % 4) == 3
guard !isAlphaComponent else {continue}
rgbBytes[index] = component.element
index += 1
}
//rgbBytes→rgbData
if isModelQuantized {return Data(bytes: rgbBytes)}
return Data(copyingBufferOf: rgbBytes.map{Float($0)/255.0})
}
}
//====================
//拡張
//====================
//CVPixelBufferの拡張
extension CVPixelBuffer {
//画像のトリミングとスケーリング
func centerThumbnail(ofSize size: CGSize ) -> CVPixelBuffer? {
let imageWidth = CVPixelBufferGetWidth(self)
let imageHeight = CVPixelBufferGetHeight(self)
let pixelBufferType = CVPixelBufferGetPixelFormatType(self)
assert(pixelBufferType == kCVPixelFormatType_32BGRA)
let inputImageRowBytes = CVPixelBufferGetBytesPerRow(self)
let imageChannels = 4
let thumbnailSize = min(imageWidth, imageHeight)
CVPixelBufferLockBaseAddress(self, CVPixelBufferLockFlags(rawValue: 0))
var originX = 0
var originY = 0
if imageWidth > imageHeight {
originX = (imageWidth - imageHeight) / 2
}
else {
originY = (imageHeight - imageWidth) / 2
}
//PixelBufferで最大の正方形をみつける
guard let inputBaseAddress = CVPixelBufferGetBaseAddress(self)?.advanced(
by: originY * inputImageRowBytes + originX * imageChannels) else {
return nil
}
//入力画像から画像バッファを取得
var inputVImageBuffer = vImage_Buffer(
data: inputBaseAddress, height: UInt(thumbnailSize), width: UInt(thumbnailSize),
rowBytes: inputImageRowBytes)
let thumbnailRowBytes = Int(size.width) * imageChannels
guard let thumbnailBytes = malloc(Int(size.height) * thumbnailRowBytes) else {
return nil
}
//サムネイル画像にvImageバッファを割り当て
var thumbnailVImageBuffer = vImage_Buffer(data: thumbnailBytes,
height: UInt(size.height), width: UInt(size.width), rowBytes: thumbnailRowBytes)
//入力画像バッファでスケール操作を実行し、サムネイル画像バッファに保存
let scaleError = vImageScale_ARGB8888(&inputVImageBuffer, &thumbnailVImageBuffer, nil, vImage_Flags(0))
CVPixelBufferUnlockBaseAddress(self, CVPixelBufferLockFlags(rawValue: 0))
guard scaleError == kvImageNoError else {
return nil
}
let releaseCallBack: CVPixelBufferReleaseBytesCallback = {mutablePointer, pointer in
if let pointer = pointer {
free(UnsafeMutableRawPointer(mutating: pointer))
}
}
//サムネイルのvImageバッファをCVPixelBufferに変換
var thumbnailPixelBuffer: CVPixelBuffer?
let conversionStatus = CVPixelBufferCreateWithBytes(
nil, Int(size.width), Int(size.height), pixelBufferType, thumbnailBytes,
thumbnailRowBytes, releaseCallBack, nil, nil, &thumbnailPixelBuffer)
guard conversionStatus == kCVReturnSuccess else {
free(thumbnailBytes)
return nil
}
return thumbnailPixelBuffer
}
}
//Dataの拡張
extension Data {
//float配列→byte配列(長さ4倍)
init<T>(copyingBufferOf array: [T]) {
self = array.withUnsafeBufferPointer(Data.init)
}
}
//Arrayの拡張
extension Array {
//byte配列→float配列(長さ1/4倍)
init?(unsafeData: Data) {
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
#if swift(>=5.0)
self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) }
#else
self = unsafeData.withUnsafeBytes {
.init(UnsafeBufferPointer<Element>(
start: $0,
count: unsafeData.count / MemoryLayout<Element>.stride
))
}
#endif // swift(>=5.0)
}
}
この記事が気に入ったらサポートをしてみませんか?