Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make accuracy and loss values available for upstream use-cases #662

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions TrainingLoop/TrainingProgress.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ let progressBarLength = 30
/// A progress bar that displays to the console as a model trains, and as validation is performed.
/// It hooks into a TrainingLoop via a callback method.
public class TrainingProgress {
public var accuracies: [Float] // accessible list of accuracies values
public var losses: [Float] // acceessible list of loss values
var statistics: TrainingStatistics?
let metrics: Set<TrainingMetrics>
let liveStatistics: Bool
Expand All @@ -34,6 +36,8 @@ public class TrainingProgress {
/// This has an impact on performance, due to materialization of tensors, and updating values
/// on every batch can reduce training speed by up to 30%.
public init(metrics: Set<TrainingMetrics> = [.accuracy, .loss], liveStatistics: Bool = true) {
self.accuracies = []
self.losses = []
self.metrics = metrics
self.liveStatistics = liveStatistics
if !metrics.isEmpty {
Expand Down Expand Up @@ -68,6 +72,15 @@ public class TrainingProgress {
return result
}

func updateMetrics() {
if metrics.contains(.loss) {
losses.append(statistics!.averageLoss())
}
if metrics.contains(.accuracy) {
accuracies.append(statistics!.accuracy())
}
}

/// The callback used to hook into the TrainingLoop. This is updated once per event.
///
/// - Parameters:
Expand All @@ -92,6 +105,7 @@ public class TrainingProgress {
let metricDescriptionComponent: String
if liveStatistics || (batchCount == (batchIndex + 1)) {
metricDescriptionComponent = metricDescription()
updateMetrics()
} else {
metricDescriptionComponent = ""
}
Expand Down