This repository has been archived by the owner on Apr 13, 2023. It is now read-only.
forked from tensorflow/swift-models
-
Notifications
You must be signed in to change notification settings - Fork 1
/
TrainingLoop.swift
442 lines (350 loc) · 13.4 KB
/
TrainingLoop.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import ModelSupport
import TensorFlow
// Workaround https://bugs.swift.org/browse/TF-1122 that prevents us from registering a
// loss function inside our TrainingLoop struct
public final class LossFunctionWrapper<Output: Differentiable, Target> {
public typealias F = @differentiable(Output, @noDerivative Target) -> Tensor<Float>
public var f: F
init(_ f: @escaping F) { self.f = f }
}
/// Types whose elements represent a training loop.
///
/// - Note: This protocol is mainly there to give us an easy type for a generic `TrainingLoop`
/// and unless you need to rewrite your own training loop entirely, you should use `TrainingLoop`.
public protocol TrainingLoopProtocol {
// Associatedtypes
/// The type of the sequence of epochs for the training data.
associatedtype Training
where
Training: Sequence, Training.Element: Collection,
Training.Element.Element == LabeledData<Opt.Model.Input, Target>
/// The type of the collection of batches for the validation data.
associatedtype Validation
where
Validation: Collection,
Validation.Element == LabeledData<Opt.Model.Input, Target>
/// The type of the target of our model.
associatedtype Target
/// The type of the optimizer used.
associatedtype Opt: Optimizer where Opt.Model: Module
// Typealiases
/// The type of the model.
typealias Model = Opt.Model
/// The type of the input of the model.
typealias Input = Opt.Model.Input
/// The type of the output of the model.
typealias Output = Opt.Model.Output
/// The type of a batch.
typealias Batch = LabeledData<Input, Target>
// In a wrapper for now because of TF-1122.
/// The type of the loss function.
typealias LossFunction = LossFunctionWrapper<Output, Target>
// Data
/// The training epochs.
var training: Training { get }
/// The validation batches.
var validation: Validation { get }
// Optimizer and loss function
/// The optimizer.
var optimizer: Opt { get set }
/// The loss function.
var lossFunction: LossFunction { get set }
/// The metrics on which training is measured.
var metrics: [TrainingMetrics] { get set }
// Callbacks
/// The callbacks used to customize the training loop.
var callbacks: [TrainingLoopCallback<Self>] { get set }
// Temporary data
// MARK: - Step-level data
/// The last input fed to the model.
var lastStepInput: Input? { get set }
/// The last target.
var lastStepTarget: Target? { get set }
/// The last predictions of the model.
var lastStepOutput: Output? { get set }
/// The last gradients computed.
var lastStepGradient: Model.TangentVector? { get set }
/// The last loss.
var lastStepLoss: Tensor<Float>? { get set }
/// The number of batches in the current collection of batches.
var batchCount: Int? { get set }
/// The index of the current batch.
var batchIndex: Int? { get set }
// MARK: - Epoch-level data
/// The number of epochs we are currently fitting for.
var epochCount: Int? { get set }
/// The index of the current epoch.
var epochIndex: Int? { get set }
// MARK: - Others
/// The log for last statistics
var lastStatsLog: [(name: String, value: Float)]? { get set }
}
/// The events that occur during a call to `fit` in the `TrainingLoop`
///
/// - Note: The method is called `fit` and not `train` because it trains the model and validates it.
/// Each epoch is composed of a *training* phase and a *validation* phase.
public enum TrainingLoopEvent {
/// The start of a fit.
case fitStart
/// The end of a fit.
case fitEnd
/// The start of one epoch (training + validation).
case epochStart
/// The start of one epoch (training + validation).
case epochEnd
/// The start of a training phase.
case trainingStart
/// The end of a training phase.
case trainingEnd
/// The start of a validation phase.
case validationStart
/// The end of a validation phase.
case validationEnd
/// The start of a training or inference step on a batch.
case batchStart
/// The end of a training or inference step on a batch.
case batchEnd
/// At the start of the optimizer update, just after the differentiable step.
case updateStart
/// Just after the model prediction at inference, before computing the loss.
case inferencePredictionEnd
}
/// Callbacks that can inject custom behavior in a training loop.
public typealias TrainingLoopCallback<L: TrainingLoopProtocol> = (
_ loop: inout L, _ event: TrainingLoopEvent
) throws -> Void
/// A generic training loop.
///
/// - Parameter `Training`: the type of the sequence of epochs for training data.
/// - Parameter `Validation`: the type of the collection of batches for validation.
/// - Parameter `Target`: the type of the target.
/// - Parameter `Opt`: the type of the optimizer used.
public struct TrainingLoop<
Training: Sequence, Validation: Collection, Target, Opt: Optimizer
>: TrainingLoopProtocol
where
Training.Element: Collection, Training.Element.Element == LabeledData<Opt.Model.Input, Target>,
Validation.Element == LabeledData<Opt.Model.Input, Target>, Opt.Model: Module
{
// Typealiases
/// The type of the model.
public typealias Model = Opt.Model
/// The type of the input of the model.
public typealias Input = Opt.Model.Input
/// The type of the output of the model.
public typealias Output = Opt.Model.Output
/// The type of a batch.
public typealias Batch = LabeledData<Input, Target>
// In a wrapper for now because of TF-1122.
/// The type of the loss function.
public typealias LossFunction = LossFunctionWrapper<Output, Target>
// Data
/// The training epochs.
public let training: Training
/// The validation batches.
public let validation: Validation
// Optimizer and loss function
/// The optimizer.
public var optimizer: Opt
/// The loss function
public var lossFunction: LossFunction
/// The metrics
public var metrics: [TrainingMetrics]
/// Callbacks
/// The callbacks used to customize the training loop.
public var callbacks: [TrainingLoopCallback<Self>]
// MARK: - Default callback objects
/// The callback that records the training statistics.
public var statisticsRecorder: StatisticsRecorder? = nil
/// The callback that prints the training progress.
public var progressPrinter: ProgressPrinter? = nil
/// Temporary data
// MARK: - Step-level data
/// The last input fed to the model.
public var lastStepInput: Input? = nil
/// The last target.
public var lastStepTarget: Target? = nil
/// The last predictions of the model.
public var lastStepOutput: Output? = nil
/// The last gradients computed.
public var lastStepGradient: Model.TangentVector? = nil
/// The last loss.
public var lastStepLoss: Tensor<Float>? = nil
/// The number of batches in the current collection of batches.
public var batchCount: Int? = nil
/// The index of the current batch.
public var batchIndex: Int? = nil
// MARK: - Epoch-level data
/// The number of epochs we are currently fitting for.
public var epochCount: Int? = nil
/// The index of the current epoch.
public var epochIndex: Int? = nil
// MARK: - Others
/// The log for last statistics
public var lastStatsLog: [(name: String, value: Float)]? = nil
/// Creates an instance from `training` and `validation` data, a `model`, an `optimizer` and a
/// `lossFunction`.
///
/// Parameter callbacks: Callbacks that the `TrainingLoop` will use in every call to fit.
public init(
training: Training, validation: Validation, optimizer: Opt,
lossFunction: @escaping LossFunction.F,
metrics: [TrainingMetrics] = [],
callbacks: [TrainingLoopCallback<Self>] = [],
includeDefaultCallbacks: Bool = true
) {
self.training = training
self.validation = validation
self.optimizer = optimizer
self.lossFunction = LossFunction(lossFunction)
self.metrics = metrics
if includeDefaultCallbacks {
let statisticsRecorder = StatisticsRecorder(metrics: [.loss] + metrics)
let progressPrinter = ProgressPrinter()
self.statisticsRecorder = statisticsRecorder
self.progressPrinter = progressPrinter
self.callbacks = [
statisticsRecorder.record,
progressPrinter.printProgress,
] + callbacks
} else {
self.callbacks = callbacks
}
}
}
extension TrainingLoop {
/// The default differentiable step.
public mutating func differentiableStep(model: Model) throws {
guard let data = lastStepInput else { return }
guard let target = lastStepTarget else { return }
(lastStepLoss, lastStepGradient) = valueWithGradient(at: model) {
(model: Model) -> Tensor<Float> in
let predictions = model(data)
lastStepOutput = predictions
return lossFunction.f(predictions, target)
}
}
/// The step used for inference.
public mutating func inferenceStep(model: Model) throws {
guard let data = lastStepInput else { return }
lastStepOutput = model(data)
guard let target = lastStepTarget else { return }
try handleEvent(.inferencePredictionEnd)
lastStepLoss = lossFunction.f(lastStepOutput!, target)
}
/// The step used for training.
public mutating func trainingStep(
model: inout Model, differentiableStep: (Model, inout Self) throws -> Void
) throws {
try differentiableStep(model, &self)
try handleEvent(.updateStart)
optimizer.update(&model, along: lastStepGradient!)
}
}
/// Control flow of the training loop.
///
/// - Note: Each of the "end" event is called after its corresponding "cancel" action for cleanup.
public enum TrainingLoopAction: Error {
/// Abort actions in the current training/inference step and goes to the next batch.
case cancelBatch
/// Abort actions in the current training phase and goes to the validation phase.
case cancelTraining
/// Abort actions in the current validation phase and goes to the next epoch.
case cancelValidation
/// Abort actions in the current epoch and goes to the next epoch.
case cancelEpoch
/// Abort actions in the current fit and ends fitting.
case cancelFit
}
extension TrainingLoop {
/// Call `event` on all callbacks.
mutating private func handleEvent(_ event: TrainingLoopEvent) throws {
for callback in callbacks {
try callback(&self, event)
}
}
}
extension TrainingLoop {
/// Performs `step` on each of `batches`.
mutating private func multipleSteps<Batches: Collection>(
on batches: Batches, step: (inout Self) throws -> Void
) throws where Batches.Element == Batch {
batchCount = batches.count
for (i, batch) in batches.enumerated() {
batchIndex = i
(lastStepInput, lastStepTarget) = (batch.data, batch.label)
do {
try handleEvent(.batchStart)
try step(&self)
} catch TrainingLoopAction.cancelBatch {}
try handleEvent(.batchEnd)
LazyTensorBarrier()
}
}
}
extension TrainingLoop {
/// Fit the model for `epochs` using `callbacks` to customize the default training loop.
///
/// - Parameters:
/// - inferenceStep: The step used during the validation phase of each epoch. The default value
/// uses the `inferenceStep` method of `TrainingLoop`.
/// - trainingStep: The step used during the training phase of each epoch. The default value
/// uses the `trainingStep` method of `TrainingLoop`.
public mutating func fit(
_ model: inout Model, epochs: Int, callbacks: [TrainingLoopCallback<Self>] = [],
on device: Device = Device.default,
differentiableStep: (Model, inout Self) throws -> Void = {
try $1.differentiableStep(model: $0)
}
) throws {
let callbacksCount = self.callbacks.count
self.callbacks += callbacks
defer { self.callbacks = Array(self.callbacks.prefix(callbacksCount)) }
epochCount = epochs
model.move(to: device)
optimizer = Opt(copying: optimizer, to: device)
do {
try handleEvent(.fitStart)
LazyTensorBarrier()
for (i, batches) in training.prefix(epochs).enumerated() {
epochIndex = i
do {
try handleEvent(.epochStart)
// Training phase
do {
Context.local.learningPhase = .training
try handleEvent(.trainingStart)
try multipleSteps(
on: batches,
step: {
try $0.trainingStep(model: &model, differentiableStep: differentiableStep)
})
} catch TrainingLoopAction.cancelTraining {}
try handleEvent(.trainingEnd)
// Validation phase
do {
Context.local.learningPhase = .inference
try handleEvent(.validationStart)
try multipleSteps(on: validation, step: { try $0.inferenceStep(model: model) })
} catch TrainingLoopAction.cancelValidation {}
try handleEvent(.validationEnd)
} catch TrainingLoopAction.cancelEpoch {}
try handleEvent(.epochEnd)
}
} catch TrainingLoopAction.cancelFit {}
try handleEvent(.fitEnd)
}
}