diff --git a/NEWS.md b/NEWS.md index 372563b563..de99959913 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,6 +3,8 @@ ## v0.12.9 * Fixed incorrect output and added GPU compatibility for [AlphaDropout](https://github.com/FluxML/Flux.jl/pull/1781). * Add trilinear [Upsample layer](https://github.com/FluxML/Flux.jl/pull/1792). +* Add `optimstep!` as a single training step of `train!` to allow for more exotic +optimisers (#666) ## v0.12.8 * Optimized inference and gradient calculation of OneHotMatrix[pr](https://github.com/FluxML/Flux.jl/pull/1756) diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index 7f3ad6bf37..b4b28b2f11 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -65,7 +65,35 @@ AdaBelief ## Optimiser Interface -Flux's optimisers are built around a `struct` that holds all the optimiser parameters along with a definition of how to apply the update rule associated with it. We do this via the `apply!` function which takes the optimiser as the first argument followed by the parameter and its corresponding gradient. +Flux's optimisers are built around a `struct` that holds all the optimiser +parameters along with a definition of how to apply the update +rule associated with it (`optimstep!`). The default implementation of `optimstep!` +looks like this + +```julia +function optimstep!(loss, params, opt) + # Calculate the gradients of the parameters + # with respect to the loss function + val, grads = Flux.withgradient(loss, parameters) + # Update the parameters based on the chosen + # optimiser (opt) + Flux.Optimise.update!(opt, parameters, grads) + return val, grads +end +``` + +and therefore assumes that its update rule only requires the optimisers internal +state `opt`, the `parameters` themselves and the gradients `grads`. For +optimisers which do not fit this pattern, you want to overload `optimstep!` +itself. + +In the following subsection we define a simple Momentum optimiser which fits the +`update!` pattern and therefore does not have to override `optimstep!` itself. + +### Gradient Based Optimiser + +To obtain an `update!` method applicable to your custom optimiser type, we +need to overload the `apply!` function. Flux internally calls on this function via the `update!` function. It shares the API with `apply!` but ensures that multiple parameters are handled gracefully. It takes the optimiser as the first argument followed by the parameter and its corresponding gradient. In this manner Flux also allows one to create custom optimisers to be used seamlessly. Let's work this with a simple example. @@ -99,8 +127,6 @@ w = w - v The `apply!` defines the update rules for an optimiser `opt`, given the parameters and gradients. It returns the updated gradients. Here, every parameter `x` is retrieved from the running state `v` and subsequently updates the state of the optimiser. -Flux internally calls on this function via the `update!` function. It shares the API with `apply!` but ensures that multiple parameters are handled gracefully. - ## Composing Optimisers Flux defines a special kind of optimiser simply called `Optimiser` which takes in arbitrary optimisers as input. Its behaviour is similar to the usual optimisers, but differs in that it acts by calling the optimisers listed in it sequentially. Each optimiser produces a modified gradient diff --git a/docs/src/training/training.md b/docs/src/training/training.md index 845a22d8a6..ccf825fefc 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -17,15 +17,27 @@ for d in datapoints # `d` should produce a collection of arguments # to the loss function - # Calculate the gradients of the parameters - # with respect to the loss function - grads = Flux.gradient(parameters) do + # Update the parameters based on the chosen + # optimiser (opt) + loss, grads = optimstep!(params, opt) do loss(d...) end +end +``` + +`optimstep!` is the optimiser implementation and thus dispatches depending on +the optimizer type. As an example, the default `optimstep!` for optimiser who +use the gradient to update the parameters (e.g. gradient descent, momentum, ADAM, etc.) looks like this +```julia +function optimstep!(loss, params, opt) + # Calculate the gradients of the parameters + # with respect to the loss function + val, grads = Flux.withgradient(loss, parameters) # Update the parameters based on the chosen # optimiser (opt) Flux.Optimise.update!(opt, parameters, grads) + return val, grads end ``` diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 010cbfc9bb..901c2ed1ee 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -3,7 +3,7 @@ module Optimise using LinearAlgebra import ArrayInterface -export train!, update!, +export train!, optimstep!, update!, Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, OADAM, AdaBelief, InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser, diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 5690c9ea28..468ccfb145 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,5 +1,5 @@ using Juno -import Zygote: Params, gradient +import Zygote: Params, withgradient """ update!(x, x̄) @@ -80,6 +80,35 @@ end batchmemaybe(x) = tuple(x) batchmemaybe(x::Tuple) = x +""" + optimstep!(loss, params, opt) + +`optimstep!` uses a `loss` function (with no inputs) to improve the [Model parameters](@ref) (`params`) +based on a pluggable [Optimisers](@ref) (`opt`). It represents a single step in +the training loop `train!`. + +The default implementation for `optimstep!` is takes the gradient of `loss` +and calls `Flux.Optimise.update!` to adjust the parameters, but you can overload +`optimstep!` for specific types of `opt`. This can be useful if your optimization routine +has does not follow the standard gradient descent procedure (e.g. gradient-free optimizers). + +Unlike `train!`, the loss function of `optimstep!` accepts no input. +Instead, `train!` cycles through the data in a loop and calls `optimstep!`: +```julia +for d in data + optimstep!(ps, opt) do + loss(d) + end +end +``` +If you are writing [Custom Training loops](@ref), then you should follow this pattern. +""" +function optimstep!(loss, params, opt) + val, gs = withgradient(loss, params) + update!(opt, params, gs) + return val, gs +end + """ train!(loss, params, data, opt; cb) @@ -106,10 +135,9 @@ function train!(loss, ps, data, opt; cb = () -> ()) cb = runall(cb) @progress for d in data try - gs = gradient(ps) do + optimstep!(ps, opt) do loss(batchmemaybe(d)...) end - update!(opt, ps, gs) cb() catch ex if ex isa StopException