見出し画像

CoreML3を使用したオンデバイストレーニングの実践

1. はじめに

「CoreML3」では、iPhone上で更新可能なモデルを訓練できるようになりました。Appleのサイトでは、「MLUpdateTask」を使用して、手書き分類モデルを更新するサンプルが提供されています。

アプリの「Add Sticker」ボタンで、「ユーザーの手書き」と「絵文字」のペアのデータを追加し、モデルの訓練を行います。

画像1

訓練が完了すると、更新された手書き分類モデルを使用して、ユーザーからの手書きを絵文字に変換できるようになります。

画像2

2. モデル更新のための訓練データの準備

モデル更新のための訓練データを準備します。「ユーザーの手書き」と「絵文字」のペアのデータが訓練データになります。
訓練データの作成手順は次の通りです。

(1)MLFeatureValueでラップしたモデルの入力と出力の準備。
(2)MLFeatureProviderでMLFeatureValueをグループ化。
(3)MLBatchProviderでMLFeatureProviderをグループ化。

var featureProviders = [MLFeatureProvider]()

let inputName = "drawing"
let outputName = "label"

for drawing in trainingDrawings {
    // MLFeatureValueでラップしたモデルの入力と出力の準備
    let inputValue = drawing.featureValue
    let outputValue = MLFeatureValue(string: String(emoji))


    // MLFeatureProviderでMLFeatureValueをグループ化
    let dataPointFeatures: [String: MLFeatureValue] = [
        inputName: inputValue,
        outputName: outputValue]
    if let provider = try? MLDictionaryFeatureProvider(dictionary: dataPointFeatures) {
         featureProviders.append(provider)
    }
}

// MLBatchProviderでMLFeatureProviderをグループ化
return MLArrayBatchProvider(array: featureProviders)

◎MLFeatureValue
モデルの入力である「ユーザーの手書き」はMLFeatureValue(cgImage: preparedImage,constraint:)、モデルの出力である「絵文字」はMLFeatureValue(string:)でラップしています。

let imageFeatureValue = try? MLFeatureValue(
    cgImage: preparedImage,
    constraint: imageConstraint)
return imageFeatureValue!

◎MLFeatureProvider
2つの「MLFeatureValue」を以下のように辞書でまとめ、それを元にMLDictionaryFeatureProvider(dictionary:)でグループ化します。

・キー「drawing」と値「ユーザーの手書き」
・キー「label」と値「絵文字」

そして、MLFeatureProvider配列に追加します。

◎MLArrayBatchProvider
MLFeatureProvider配列を元に、MLArrayBatchProvider(array:)でグループ化します。

3. モデル更新タスクの生成

「MLUpdateTask」を生成します。コンストラクタの引数は次の通りです。

・モデルファイル(ModelName.mlmodelc)のパス
・更新データを持つMLBatchProvider
・モデル設定
・タスクが終了したときに呼び出すクロージャ

今回は、現在使用している手書き分類モデルを更新します。

// Update Taskの生成
guard let updateTask = try? MLUpdateTask(forModelAt: url,
        trainingData: trainingData,
        configuration: nil,
        completionHandler: completionHandler)
    else {
        print("Could't create an MLUpdateTask.")
        return
}

4. モデル更新タスクの実行

モデル更新タスクを実行するには、MLUpdateTaskのresume()を呼び出します。

updateTask.resume()

「Core ML」は、個別のスレッドでモデル訓練し、終了すると完了ハンドラを呼び出します。

5. モデルの保存

完了ハンドラを使用して、訓練されたモデルを保存します。

サンプルでは、最初にモデルを一時的な場所に書き込みます。その後、モデルを永続的な場所に移動します。この時、以前に保存されたモデルを上書きします。

let updatedModel = updateContext.model
let fileManager = FileManager.default
do {
    // 更新されたモデルのディレクトリを作成
    try fileManager.createDirectory(at: tempUpdatedModelURL,
        withIntermediateDirectories: true,
        attributes: nil)

    // 更新されたモデルを一時ファイル名に保存
    try updatedModel.write(to: tempUpdatedModelURL)

    // 以前に更新されたモデルをこのモデルに置き換える
    _ = try fileManager.replaceItemAt(updatedModelURL,
        withItemAt: tempUpdatedModelURL)

    print("Updated model saved to:\n\t\(updatedModelURL)")
} catch let error {
    print("Could not save updated model to the file system: \(error)")
    return
}

6. モデルの読み込み

モデルを読み込むには、UpdatableDrawingClassifier(contentsOf :)を使います。

// モデルの存在チェック
guard FileManager.default.fileExists(atPath: updatedModelURL.path) else {
    return
}

// モデルの読み込み
guard let model = try? UpdatableDrawingClassifier(contentsOf: updatedModelURL) else {
    return
}

// モデルを予測に利用
updatedDrawingClassifier = model


この記事が気に入ったらサポートをしてみませんか?