From 2fa1c7ce4e313ddf7c8abf4a51221ac7e386461b Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Thu, 13 Jan 2022 10:45:51 +0100 Subject: [PATCH 01/12] add step! --- src/optimise/Optimise.jl | 2 +- src/optimise/train.jl | 29 +++++++++++++++++++++++++++-- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 010cbfc9bb..b269b1b0dd 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -3,7 +3,7 @@ module Optimise using LinearAlgebra import ArrayInterface -export train!, update!, +export train!, step!, update!, Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, OADAM, AdaBelief, InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser, diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 5690c9ea28..8083ecfc27 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -80,6 +80,32 @@ end batchmemaybe(x) = tuple(x) batchmemaybe(x::Tuple) = x +""" + step!(loss, params, opt) + +`step!` uses a `loss` function (with no inputs) to improve the [Model parameters](@ref) (`params`) +based on a pluggable [Optimisers](@ref) (`opt`). It represents a single step in +the training loop `train!`. While there is a default implementation for +optimisers which are based on the `update!` function and only require gradient +information, this `step!` has to be overloaded for more general optimisers. + +While the loss function of `train!` still accepts data as input, the loss function +of `step!` accepts no input. `train!` cycles through the data in a loop +roughly like this + +```julia +for d in data + step!(ps, opt) do + loss(d) + end +``` + +""" +function step!(loss, params, opt) + gs = gradient(loss, params) + update!(opt, params, gs) +end + """ train!(loss, params, data, opt; cb) @@ -106,10 +132,9 @@ function train!(loss, ps, data, opt; cb = () -> ()) cb = runall(cb) @progress for d in data try - gs = gradient(ps) do + step!(ps, opt) do loss(batchmemaybe(d)...) end - update!(opt, ps, gs) cb() catch ex if ex isa StopException From c8f147bb36438d9091c9edf7615da9cb2dfc8f63 Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Thu, 13 Jan 2022 10:48:43 +0100 Subject: [PATCH 02/12] add NEWS.md --- NEWS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NEWS.md b/NEWS.md index 372563b563..c2e0a56969 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,6 +3,8 @@ ## v0.12.9 * Fixed incorrect output and added GPU compatibility for [AlphaDropout](https://github.com/FluxML/Flux.jl/pull/1781). * Add trilinear [Upsample layer](https://github.com/FluxML/Flux.jl/pull/1792). +* Add `step!` as a single training step of `train!` to allow for more exotic +optimisers (#666) ## v0.12.8 * Optimized inference and gradient calculation of OneHotMatrix[pr](https://github.com/FluxML/Flux.jl/pull/1756) From 1f34fd790c584321e7a844cb3ab3789882b227ed Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Sun, 16 Jan 2022 10:05:58 +0100 Subject: [PATCH 03/12] use withgradient instead --- src/optimise/train.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 8083ecfc27..fa6c6bee1b 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,5 +1,5 @@ using Juno -import Zygote: Params, gradient +import Zygote: Params, withgradient """ update!(x, x̄) @@ -102,8 +102,9 @@ for d in data """ function step!(loss, params, opt) - gs = gradient(loss, params) + val, gs = withgradient(loss, params) update!(opt, params, gs) + return val, gs end """ From 0765010c7e18d7bf85f737aec04d33c01d3acd69 Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Tue, 18 Jan 2022 15:52:38 +0100 Subject: [PATCH 04/12] apply suggested changes to docstring Co-authored-by: Kyle Daruwalla --- src/optimise/train.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index fa6c6bee1b..288ff0c261 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -85,9 +85,12 @@ batchmemaybe(x::Tuple) = x `step!` uses a `loss` function (with no inputs) to improve the [Model parameters](@ref) (`params`) based on a pluggable [Optimisers](@ref) (`opt`). It represents a single step in -the training loop `train!`. While there is a default implementation for -optimisers which are based on the `update!` function and only require gradient -information, this `step!` has to be overloaded for more general optimisers. +the training loop `train!`. + +The default implementation for `step!` is takes the gradient of `loss` +and calls `Flux.Optimise.update!` to adjust the parameters, but you can overload +`step!` for specific types of `opt`. This can be useful if your optimization routine +has does not follow the standard gradient descent procedure (e.g. gradient-free optimizers). While the loss function of `train!` still accepts data as input, the loss function of `step!` accepts no input. `train!` cycles through the data in a loop From 5b4fc17dcb6b62e06647020cd81a798da7722a50 Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Tue, 18 Jan 2022 15:53:48 +0100 Subject: [PATCH 05/12] apply suggested changes to docstring Co-authored-by: Kyle Daruwalla --- src/optimise/train.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 288ff0c261..28be823a5d 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -92,17 +92,16 @@ and calls `Flux.Optimise.update!` to adjust the parameters, but you can overload `step!` for specific types of `opt`. This can be useful if your optimization routine has does not follow the standard gradient descent procedure (e.g. gradient-free optimizers). -While the loss function of `train!` still accepts data as input, the loss function -of `step!` accepts no input. `train!` cycles through the data in a loop -roughly like this - +Unlike `train!`, the loss function of `step!` accepts no input. +Instead, `train!` cycles through the data in a loop and calls `step!`: ```julia for d in data step!(ps, opt) do loss(d) end +end ``` - +If you are writing [Custom Training loops](@ref), then you should follow this pattern. """ function step!(loss, params, opt) val, gs = withgradient(loss, params) From 12a8284073898485f697f5adc99a22fab8660a95 Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Tue, 18 Jan 2022 16:08:02 +0100 Subject: [PATCH 06/12] rename step! to optimstep! --- src/optimise/train.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 28be823a5d..468ccfb145 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -81,29 +81,29 @@ batchmemaybe(x) = tuple(x) batchmemaybe(x::Tuple) = x """ - step!(loss, params, opt) + optimstep!(loss, params, opt) -`step!` uses a `loss` function (with no inputs) to improve the [Model parameters](@ref) (`params`) +`optimstep!` uses a `loss` function (with no inputs) to improve the [Model parameters](@ref) (`params`) based on a pluggable [Optimisers](@ref) (`opt`). It represents a single step in the training loop `train!`. -The default implementation for `step!` is takes the gradient of `loss` +The default implementation for `optimstep!` is takes the gradient of `loss` and calls `Flux.Optimise.update!` to adjust the parameters, but you can overload -`step!` for specific types of `opt`. This can be useful if your optimization routine +`optimstep!` for specific types of `opt`. This can be useful if your optimization routine has does not follow the standard gradient descent procedure (e.g. gradient-free optimizers). -Unlike `train!`, the loss function of `step!` accepts no input. -Instead, `train!` cycles through the data in a loop and calls `step!`: +Unlike `train!`, the loss function of `optimstep!` accepts no input. +Instead, `train!` cycles through the data in a loop and calls `optimstep!`: ```julia for d in data - step!(ps, opt) do + optimstep!(ps, opt) do loss(d) end end ``` If you are writing [Custom Training loops](@ref), then you should follow this pattern. """ -function step!(loss, params, opt) +function optimstep!(loss, params, opt) val, gs = withgradient(loss, params) update!(opt, params, gs) return val, gs @@ -135,7 +135,7 @@ function train!(loss, ps, data, opt; cb = () -> ()) cb = runall(cb) @progress for d in data try - step!(ps, opt) do + optimstep!(ps, opt) do loss(batchmemaybe(d)...) end cb() From 68a05908dcefb7f0eedc583d15640b3221fd0f0e Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Tue, 18 Jan 2022 16:29:35 +0100 Subject: [PATCH 07/12] add optimstep! to docs --- docs/src/training/training.md | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/docs/src/training/training.md b/docs/src/training/training.md index 845a22d8a6..ccf825fefc 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -17,15 +17,27 @@ for d in datapoints # `d` should produce a collection of arguments # to the loss function - # Calculate the gradients of the parameters - # with respect to the loss function - grads = Flux.gradient(parameters) do + # Update the parameters based on the chosen + # optimiser (opt) + loss, grads = optimstep!(params, opt) do loss(d...) end +end +``` + +`optimstep!` is the optimiser implementation and thus dispatches depending on +the optimizer type. As an example, the default `optimstep!` for optimiser who +use the gradient to update the parameters (e.g. gradient descent, momentum, ADAM, etc.) looks like this +```julia +function optimstep!(loss, params, opt) + # Calculate the gradients of the parameters + # with respect to the loss function + val, grads = Flux.withgradient(loss, parameters) # Update the parameters based on the chosen # optimiser (opt) Flux.Optimise.update!(opt, parameters, grads) + return val, grads end ``` From c321e4f711ec7fd0e2c4525d7a04ec3783ef786f Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Tue, 18 Jan 2022 16:42:13 +0100 Subject: [PATCH 08/12] update optimisers section --- docs/src/training/optimisers.md | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index 7f3ad6bf37..b4580b90e2 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -65,7 +65,35 @@ AdaBelief ## Optimiser Interface -Flux's optimisers are built around a `struct` that holds all the optimiser parameters along with a definition of how to apply the update rule associated with it. We do this via the `apply!` function which takes the optimiser as the first argument followed by the parameter and its corresponding gradient. +Flux's optimisers are built around a `struct` that holds all the optimiser +parameters along with a definition (`optimstep!`) of how to apply the update +rule associated with it. The default implementation of `optimstep!` +looks like this + +```julia +function optimstep!(loss, params, opt) + # Calculate the gradients of the parameters + # with respect to the loss function + val, grads = Flux.withgradient(loss, parameters) + # Update the parameters based on the chosen + # optimiser (opt) + Flux.Optimise.update!(opt, parameters, grads) + return val, grads +end +``` + +and therefore assumes that its update rule only requires the optimisers internal +state `opt`, the `parameters` themselves and the gradients `grads`. For +optimisers which do not fit this pattern, you want to overload `optimstep!` +itself. + +In the following subsection we define a simple Momentum optimiser which fits the +`update!` pattern and therefore does not have to override `optimstep!` itself. + +### Gradient Based Optimiser + +To obtain an `update!` method applicable to your custom optimiser type, we +need to overload the `apply!` function. Flux internally calls on this function via the update! function. It shares the API with apply! but ensures that multiple parameters are handled gracefully. It takes the optimiser as the first argument followed by the parameter and its corresponding gradient. In this manner Flux also allows one to create custom optimisers to be used seamlessly. Let's work this with a simple example. @@ -99,8 +127,6 @@ w = w - v The `apply!` defines the update rules for an optimiser `opt`, given the parameters and gradients. It returns the updated gradients. Here, every parameter `x` is retrieved from the running state `v` and subsequently updates the state of the optimiser. -Flux internally calls on this function via the `update!` function. It shares the API with `apply!` but ensures that multiple parameters are handled gracefully. - ## Composing Optimisers Flux defines a special kind of optimiser simply called `Optimiser` which takes in arbitrary optimisers as input. Its behaviour is similar to the usual optimisers, but differs in that it acts by calling the optimisers listed in it sequentially. Each optimiser produces a modified gradient From bb7b5c543e4821e88c9861dd0825619831cfee9e Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Tue, 18 Jan 2022 16:43:39 +0100 Subject: [PATCH 09/12] apostrophes --- docs/src/training/optimisers.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index b4580b90e2..9284fa8f69 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -93,7 +93,7 @@ In the following subsection we define a simple Momentum optimiser which fits the ### Gradient Based Optimiser To obtain an `update!` method applicable to your custom optimiser type, we -need to overload the `apply!` function. Flux internally calls on this function via the update! function. It shares the API with apply! but ensures that multiple parameters are handled gracefully. It takes the optimiser as the first argument followed by the parameter and its corresponding gradient. +need to overload the `apply!` function. Flux internally calls on this function via the `update!` function. It shares the API with `apply!` but ensures that multiple parameters are handled gracefully. It takes the optimiser as the first argument followed by the parameter and its corresponding gradient. In this manner Flux also allows one to create custom optimisers to be used seamlessly. Let's work this with a simple example. From cc31acbcba479c4ce92c27666b049c523d2290bb Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Tue, 18 Jan 2022 16:45:17 +0100 Subject: [PATCH 10/12] bracket at the end of the sentence --- docs/src/training/optimisers.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index 9284fa8f69..b4b28b2f11 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -66,8 +66,8 @@ AdaBelief ## Optimiser Interface Flux's optimisers are built around a `struct` that holds all the optimiser -parameters along with a definition (`optimstep!`) of how to apply the update -rule associated with it. The default implementation of `optimstep!` +parameters along with a definition of how to apply the update +rule associated with it (`optimstep!`). The default implementation of `optimstep!` looks like this ```julia From 5220fe966247dd3cc1e595efcf7ba52dce1edf25 Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Tue, 24 May 2022 12:34:04 +0200 Subject: [PATCH 11/12] Update NEWS.md Co-authored-by: Brian Chen --- NEWS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index c2e0a56969..de99959913 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,7 +3,7 @@ ## v0.12.9 * Fixed incorrect output and added GPU compatibility for [AlphaDropout](https://github.com/FluxML/Flux.jl/pull/1781). * Add trilinear [Upsample layer](https://github.com/FluxML/Flux.jl/pull/1792). -* Add `step!` as a single training step of `train!` to allow for more exotic +* Add `optimstep!` as a single training step of `train!` to allow for more exotic optimisers (#666) ## v0.12.8 From fa689939c639f37e8289e2838ccaa2d5706ff98f Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Tue, 24 May 2022 12:34:27 +0200 Subject: [PATCH 12/12] Update src/optimise/Optimise.jl Co-authored-by: Brian Chen --- src/optimise/Optimise.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index b269b1b0dd..901c2ed1ee 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -3,7 +3,7 @@ module Optimise using LinearAlgebra import ArrayInterface -export train!, step!, update!, +export train!, optimstep!, update!, Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, OADAM, AdaBelief, InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser,