Skip to content

Commit

Permalink
Improve model recompilation logic in CoreMLBackend
Browse files Browse the repository at this point in the history
Changed the `model` property in `CoreMLBackend` from a constant to a variable to allow reassignment when recompiling the model.

- Updated the `safelyPredict` function to handle prediction failures more gracefully:
  - Reorganized the logic to include a loop that attempts compilation and prediction with both cached and recompilation strategies.
  - Introduced a new private method `compileAndPredict` to encapsulate the model compilation and prediction logic, improving code readability and maintainability.

- Enhanced the `KataGoModel` class by modifying the `compileBundleMLModel` and `compileMLModel` methods to accept a `mustCompile` parameter, allowing conditional recompilation of the model based on input flags.

- This change addresses issues where the model fails to produce valid predictions by ensuring a fresh compilation under specific circumstances, improving overall reliability in predicting with CoreML models.
  • Loading branch information
ChinChangYang committed Aug 29, 2024
1 parent 6182cc8 commit 71129f5
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 18 deletions.
41 changes: 29 additions & 12 deletions cpp/neuralnet/coremlbackend.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public class CoreMLBackend {
return "KataGoModel\(xLen)x\(yLen)fp\(precision)\(encoder)"
}

let model: KataGoModel
var model: KataGoModel
let xLen: Int
let yLen: Int
public let version: Int32
Expand Down Expand Up @@ -117,8 +117,8 @@ public class CoreMLBackend {

let inputBatch = KataGoModelInputBatch(inputArray: inputArray)
let options = MLPredictionOptions()
let outputBatch = safelyPredict(from: inputBatch, options: options)

let outputBatch = safelyPredict(from: inputBatch, options: options)!
assert(outputBatch.count == batchSize)

outputBatch.outputArray.enumerated().forEach { index, output in
Expand Down Expand Up @@ -152,17 +152,34 @@ public class CoreMLBackend {
}

func safelyPredict(from inputBatch: KataGoModelInputBatch,
options: MLPredictionOptions) -> KataGoModelOutputBatch {
if let firstTry = try? model.prediction(from: inputBatch, options: options) {
return firstTry
} else if let secondTry = try? model.prediction(from: inputBatch, options: options) {
return secondTry
} else {
let mlmodel = KataGoModel.compileBundleMLModel(modelName: modelName, computeUnits: .cpuOnly)!
let model = KataGoModel(model: mlmodel)
let cpuTry = try! model.prediction(from: inputBatch, options: options)
return cpuTry
options: MLPredictionOptions) -> KataGoModelOutputBatch? {
if let prediction = try? model.prediction(from: inputBatch, options: options) {
return prediction
}

let computeUnits = model.model.configuration.computeUnits

for mustCompile in [false, true] {
if let prediction = compileAndPredict(with: computeUnits, from: inputBatch, options: options, mustCompile: mustCompile) {
return prediction
}
}

return nil
}

private func compileAndPredict(with computeUnits: MLComputeUnits,
from inputBatch: KataGoModelInputBatch,
options: MLPredictionOptions,
mustCompile: Bool) -> KataGoModelOutputBatch? {
if let mlmodel = KataGoModel.compileBundleMLModel(modelName: modelName, computeUnits: computeUnits, mustCompile: mustCompile) {
model = KataGoModel(model: mlmodel)
if let outputBatch = try? model.prediction(from: inputBatch, options: options) {
return outputBatch
}
}

return nil
}
}

Expand Down
13 changes: 7 additions & 6 deletions cpp/neuralnet/coremlmodel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class KataGoModel {
return bundleModelURL
}

class func compileBundleMLModel(modelName: String, computeUnits: MLComputeUnits) -> MLModel? {
class func compileBundleMLModel(modelName: String, computeUnits: MLComputeUnits, mustCompile: Bool = false) -> MLModel? {
var mlmodel: MLModel?

do {
Expand All @@ -114,7 +114,8 @@ class KataGoModel {
// Compile MLModel
mlmodel = try compileMLModel(modelName: modelName,
modelURL: bundleModelURL,
computeUnits: computeUnits)
computeUnits: computeUnits,
mustCompile: mustCompile)
} catch {
printError("An error occurred: \(error)")
}
Expand Down Expand Up @@ -247,14 +248,14 @@ class KataGoModel {
return savedDigestURL
}

class func compileMLModel(modelName: String, modelURL: URL, computeUnits: MLComputeUnits) throws -> MLModel {
class func compileMLModel(modelName: String, modelURL: URL, computeUnits: MLComputeUnits, mustCompile: Bool) throws -> MLModel {
let permanentURL = try getMLModelCPermanentURL(modelName: modelName)
let savedDigestURL = try getSavedDigestURL(modelName: modelName)
let digest = try getDigest(modelURL: modelURL)

let shouldCompileModel = checkShouldCompileModel(permanentURL: permanentURL,
savedDigestURL: savedDigestURL,
digest: digest)
let shouldCompileModel = mustCompile || checkShouldCompileModel(permanentURL: permanentURL,
savedDigestURL: savedDigestURL,
digest: digest)

if shouldCompileModel {
try compileAndSaveModel(permanentURL: permanentURL,
Expand Down

0 comments on commit 71129f5

Please sign in to comment.