diff --git a/TrainingLoop/TrainingProgress.swift b/TrainingLoop/TrainingProgress.swift index 208760a5ec5..352d5b15c00 100644 --- a/TrainingLoop/TrainingProgress.swift +++ b/TrainingLoop/TrainingProgress.swift @@ -19,6 +19,8 @@ let progressBarLength = 30 /// A progress bar that displays to the console as a model trains, and as validation is performed. /// It hooks into a TrainingLoop via a callback method. public class TrainingProgress { + public var accuracies: [Float] // accessible list of accuracies values + public var losses: [Float] // acceessible list of loss values var statistics: TrainingStatistics? let metrics: Set let liveStatistics: Bool @@ -34,6 +36,8 @@ public class TrainingProgress { /// This has an impact on performance, due to materialization of tensors, and updating values /// on every batch can reduce training speed by up to 30%. public init(metrics: Set = [.accuracy, .loss], liveStatistics: Bool = true) { + self.accuracies = [] + self.losses = [] self.metrics = metrics self.liveStatistics = liveStatistics if !metrics.isEmpty { @@ -68,6 +72,15 @@ public class TrainingProgress { return result } + func updateMetrics() { + if metrics.contains(.loss) { + losses.append(statistics!.averageLoss()) + } + if metrics.contains(.accuracy) { + accuracies.append(statistics!.accuracy()) + } + } + /// The callback used to hook into the TrainingLoop. This is updated once per event. /// /// - Parameters: @@ -92,6 +105,7 @@ public class TrainingProgress { let metricDescriptionComponent: String if liveStatistics || (batchCount == (batchIndex + 1)) { metricDescriptionComponent = metricDescription() + updateMetrics() } else { metricDescriptionComponent = "" }