CoreML3を使用したオンデバイストレーニングの実践
1. はじめに
「CoreML3」では、iPhone上で更新可能なモデルを訓練できるようになりました。Appleのサイトでは、「MLUpdateTask」を使用して、手書き分類モデルを更新するサンプルが提供されています。
アプリの「Add Sticker」ボタンで、「ユーザーの手書き」と「絵文字」のペアのデータを追加し、モデルの訓練を行います。
訓練が完了すると、更新された手書き分類モデルを使用して、ユーザーからの手書きを絵文字に変換できるようになります。
2. モデル更新のための訓練データの準備
モデル更新のための訓練データを準備します。「ユーザーの手書き」と「絵文字」のペアのデータが訓練データになります。
訓練データの作成手順は次の通りです。
(1)MLFeatureValueでラップしたモデルの入力と出力の準備。
(2)MLFeatureProviderでMLFeatureValueをグループ化。
(3)MLBatchProviderでMLFeatureProviderをグループ化。
var featureProviders = [MLFeatureProvider]()
let inputName = "drawing"
let outputName = "label"
for drawing in trainingDrawings {
// MLFeatureValueでラップしたモデルの入力と出力の準備
let inputValue = drawing.featureValue
let outputValue = MLFeatureValue(string: String(emoji))
// MLFeatureProviderでMLFeatureValueをグループ化
let dataPointFeatures: [String: MLFeatureValue] = [
inputName: inputValue,
outputName: outputValue]
if let provider = try? MLDictionaryFeatureProvider(dictionary: dataPointFeatures) {
featureProviders.append(provider)
}
}
// MLBatchProviderでMLFeatureProviderをグループ化
return MLArrayBatchProvider(array: featureProviders)
◎MLFeatureValue
モデルの入力である「ユーザーの手書き」はMLFeatureValue(cgImage: preparedImage,constraint:)、モデルの出力である「絵文字」はMLFeatureValue(string:)でラップしています。
let imageFeatureValue = try? MLFeatureValue(
cgImage: preparedImage,
constraint: imageConstraint)
return imageFeatureValue!
◎MLFeatureProvider
2つの「MLFeatureValue」を以下のように辞書でまとめ、それを元にMLDictionaryFeatureProvider(dictionary:)でグループ化します。
・キー「drawing」と値「ユーザーの手書き」
・キー「label」と値「絵文字」
そして、MLFeatureProvider配列に追加します。
◎MLArrayBatchProvider
MLFeatureProvider配列を元に、MLArrayBatchProvider(array:)でグループ化します。
3. モデル更新タスクの生成
「MLUpdateTask」を生成します。コンストラクタの引数は次の通りです。
・モデルファイル(ModelName.mlmodelc)のパス
・更新データを持つMLBatchProvider
・モデル設定
・タスクが終了したときに呼び出すクロージャ
今回は、現在使用している手書き分類モデルを更新します。
// Update Taskの生成
guard let updateTask = try? MLUpdateTask(forModelAt: url,
trainingData: trainingData,
configuration: nil,
completionHandler: completionHandler)
else {
print("Could't create an MLUpdateTask.")
return
}
4. モデル更新タスクの実行
モデル更新タスクを実行するには、MLUpdateTaskのresume()を呼び出します。
updateTask.resume()
「Core ML」は、個別のスレッドでモデル訓練し、終了すると完了ハンドラを呼び出します。
5. モデルの保存
完了ハンドラを使用して、訓練されたモデルを保存します。
サンプルでは、最初にモデルを一時的な場所に書き込みます。その後、モデルを永続的な場所に移動します。この時、以前に保存されたモデルを上書きします。
let updatedModel = updateContext.model
let fileManager = FileManager.default
do {
// 更新されたモデルのディレクトリを作成
try fileManager.createDirectory(at: tempUpdatedModelURL,
withIntermediateDirectories: true,
attributes: nil)
// 更新されたモデルを一時ファイル名に保存
try updatedModel.write(to: tempUpdatedModelURL)
// 以前に更新されたモデルをこのモデルに置き換える
_ = try fileManager.replaceItemAt(updatedModelURL,
withItemAt: tempUpdatedModelURL)
print("Updated model saved to:\n\t\(updatedModelURL)")
} catch let error {
print("Could not save updated model to the file system: \(error)")
return
}
6. モデルの読み込み
モデルを読み込むには、UpdatableDrawingClassifier(contentsOf :)を使います。
// モデルの存在チェック
guard FileManager.default.fileExists(atPath: updatedModelURL.path) else {
return
}
// モデルの読み込み
guard let model = try? UpdatableDrawingClassifier(contentsOf: updatedModelURL) else {
return
}
// モデルを予測に利用
updatedDrawingClassifier = model
この記事が気に入ったらサポートをしてみませんか?