diff --git a/README.md b/README.md index 5d0742af20..69b3620b9b 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable. -Works best with [Julia 1.8](https://julialang.org/downloads/) or later. Here's a very short example to try it out: +Works best with [Julia 1.9](https://julialang.org/downloads/) or later. Here's a very short example to try it out: ```julia using Flux, Plots data = [([x], 2x-x^3) for x in -2:0.1f0:2] diff --git a/docs/src/gpu.md b/docs/src/gpu.md index a2acdc32ac..8b94f5af94 100644 --- a/docs/src/gpu.md +++ b/docs/src/gpu.md @@ -1,11 +1,22 @@ # GPU Support -NVIDIA GPU support should work out of the box on systems with CUDA and CUDNN installed. For more details see the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) readme. +Starting with v0.14, Flux doesn't force a specific GPU backend and the corresponding package dependencies on the users. +Thanks to the [package extension mechanism]( +https://pkgdocs.julialang.org/v1/creating-packages/#Conditional-loading-of-code-in-packages-(Extensions)) introduced in julia v1.9, Flux conditionally load GPU specific code once a GPU package is made available (e.g. through `using CUDA`). + +NVIDIA GPU support requires the packages `CUDA.jl` and `cuDNN.jl` to be installed in the environment. In the julia REPL, type `] add CUDA, cuDNN` to install them. For more details see the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) readme. AMD GPU support is available since Julia 1.9 on systems with ROCm and MIOpen installed. For more details refer to the [AMDGPU.jl](https://github.com/JuliaGPU/AMDGPU.jl) repository. Metal GPU acceleration is available on Apple Silicon hardware. For more details refer to the [Metal.jl](https://github.com/JuliaGPU/Metal.jl) repository. Metal support in Flux is experimental and many features are not yet available. +In order to trigger GPU support in Flux, you need to call `using CUDA`, `using AMDGPU` or `using Metal` +in your code. Notice that for CUDA, explicitely loading also `cuDNN` is not required, but the package has to be installed in the environment. + + +!!! compat "Flux ≤ 0.13" + Old versions of Flux automatically installed CUDA.jl to provide GPU support. Starting from Flux v0.14, CUDA.jl is not a dependency anymore and has to be installed manually. + ## Checking GPU Availability By default, Flux will run the checks on your system to see if it can support GPU functionality. You can check if Flux identified a valid GPU setup by typing the following: diff --git a/docs/src/index.md b/docs/src/index.md index 48364f74a8..3c17ff4c9d 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -8,7 +8,8 @@ Flux is a library for machine learning. It comes "batteries-included" with many ### Installation -Download [Julia 1.9](https://julialang.org/downloads/) or later, preferably the current stable release. You can add Flux using Julia's package manager, by typing `] add Flux` in the Julia prompt. This will automatically install several other packages, including [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) for Nvidia GPU support. +Download [Julia 1.9](https://julialang.org/downloads/) or later, preferably the current stable release. You can add Flux using Julia's package manager, by typing `] add Flux` in the Julia prompt. +For Nvidia GPU support, you will also need to install the `CUDA` and the `cuDNN` packages. For AMD GPU support, install the `AMDGPU` package. For acceleration on Apple Silicon, install the `Metal` package. ### Learning Flux diff --git a/docs/src/models/advanced.md b/docs/src/models/advanced.md index 0fe05414b4..eada6b9ef8 100644 --- a/docs/src/models/advanced.md +++ b/docs/src/models/advanced.md @@ -69,9 +69,9 @@ However, doing this requires the `struct` to have a corresponding constructor th When it is desired to not include all the model parameters (for e.g. transfer learning), we can simply not pass in those layers into our call to `params`. -!!! compat "Flux ≤ 0.13" +!!! compat "Flux ≤ 0.14" The mechanism described here is for Flux's old "implicit" training style. - When upgrading for Flux 0.14, it should be replaced by [`freeze!`](@ref Flux.freeze!) and `thaw!`. + When upgrading for Flux 0.15, it should be replaced by [`freeze!`](@ref Flux.freeze!) and `thaw!`. Consider a simple multi-layer perceptron model where we want to avoid optimising the first two `Dense` layers. We can obtain this using the slicing features `Chain` provides: diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index d83d7f5050..31db9cd204 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -29,7 +29,7 @@ Perhaps `Scale` isn't quite fully connected, but it may be thought of as `Dense( !!! compat "Flux ≤ 0.12" Old versions of Flux accepted only `Dense(in, out, act)` and not `Dense(in => out, act)`. - This notation makes a `Pair` object. If you get an error like `MethodError: no method matching Dense(::Pair{Int64,Int64})`, this means that you should upgrade to Flux 0.13. + This notation makes a `Pair` object. If you get an error like `MethodError: no method matching Dense(::Pair{Int64,Int64})`, this means that you should upgrade to newer Flux versions. ## Convolution Models diff --git a/docs/src/models/quickstart.md b/docs/src/models/quickstart.md index 3e4939bc2a..dfef1f0c04 100644 --- a/docs/src/models/quickstart.md +++ b/docs/src/models/quickstart.md @@ -5,8 +5,8 @@ If you have used neural networks before, then this simple example might be helpf If you haven't, then you might prefer the [Fitting a Straight Line](overview.md) page. ```julia -# With Julia 1.7+, this will prompt if neccessary to install everything, including CUDA: -using Flux, Statistics, ProgressMeter +# This will prompt if neccessary to install everything, including CUDA: +using Flux, CUDA, Statistics, ProgressMeter # Generate some data for the XOR problem: vectors of length 2, as columns of a matrix: noisy = rand(Float32, 2, 1000) # 2×1000 Matrix{Float32} @@ -102,7 +102,7 @@ for epoch in 1:1_000 end ``` -!!! compat "Implicit-style training, Flux ≤ 0.13" +!!! compat "Implicit-style training, Flux ≤ 0.14" Until recently Flux's training worked a bit differently. Any code which looks like ``` @@ -113,5 +113,5 @@ end train!((x,y) -> loss(model, x, y), Flux.params(model), loader, opt) ``` (with `Flux.params`) is in the old "implicit" style. - This still works on Flux 0.13, but will be removed from Flux 0.14. + This still works on Flux 0.14, but will be removed from Flux 0.15. See the [training section](@ref man-training) for more details. diff --git a/docs/src/training/callbacks.md b/docs/src/training/callbacks.md index 6e9840ad1d..148aa02128 100644 --- a/docs/src/training/callbacks.md +++ b/docs/src/training/callbacks.md @@ -2,8 +2,6 @@ ```@docs Flux.throttle -Flux.stop -Flux.skip ``` ## Patience Helpers @@ -26,7 +24,7 @@ end es = early_stopping(loss, 2; init_score = 9) # this will stop at the 6th (4 decreasing + 2 increasing calls) epoch -@epochs 10 begin +for epoch in 1:10 es() && break end ``` @@ -43,7 +41,7 @@ end es = early_stopping(acc, 3; delta = (best_score, score) -> score - best_score) # this will iterate until the 10th epoch -@epochs 10 begin +for epoch in 1:10 es() && break end ``` @@ -60,12 +58,12 @@ Both `predicate` in `patience` and `f` in `early_stopping` / `plateau` can accep trigger = patience((a; b) -> a > b, 3) # this will iterate until the 10th epoch -@epochs 10 begin +for epoch in 1:10 trigger(1; b = 2) && break end # this will stop at the 3rd epoch -@epochs 10 begin +for epoch in 1:10 trigger(3; b = 2) && break end ``` diff --git a/docs/src/training/reference.md b/docs/src/training/reference.md index 0d71040267..dc7ecf1314 100644 --- a/docs/src/training/reference.md +++ b/docs/src/training/reference.md @@ -10,7 +10,7 @@ Because of this: * Flux defines its own version of `setup` which checks this assumption. (Using instead `Optimisers.setup` will also work, they return the same thing.) -The new implementation of rules such as Adam in the Optimisers is quite different from the old one in `Flux.Optimise`. In Flux 0.13, `Flux.Adam()` returns the old one, with supertype `Flux.Optimise.AbstractOptimiser`, but `setup` will silently translate it to its new counterpart. +The new implementation of rules such as Adam in the Optimisers is quite different from the old one in `Flux.Optimise`. In Flux 0.14, `Flux.Adam()` returns the old one, with supertype `Flux.Optimise.AbstractOptimiser`, but `setup` will silently translate it to its new counterpart. The available rules are listed the [optimisation rules](@ref man-optimisers) page here; see the [Optimisers documentation](https://fluxml.ai/Optimisers.jl/dev/) for details on how the new rules work. @@ -37,11 +37,11 @@ Optimisers.freeze! Optimisers.thaw! ``` -## Implicit style (Flux ≤ 0.13) +## Implicit style (Flux ≤ 0.14) Flux used to handle gradients, training, and optimisation rules quite differently. The new style described above is called "explicit" by Zygote, and the old style "implicit". -Flux 0.13 is the transitional version which supports both; Flux 0.14 will remove the old. +Flux 0.13 and 0.14 are the transitional version which supports both; Flux 0.15 will remove the old. !!! compat "How to upgrade" The blue-green boxes in the [training section](@ref man-training) describe @@ -62,26 +62,6 @@ Flux.Optimise.update!(opt::Flux.Optimise.AbstractOptimiser, xs::AbstractArray, g Flux.Optimise.train!(loss, ps::Flux.Params, data, opt::Flux.Optimise.AbstractOptimiser; cb) ``` -Note that, by default, `train!` only loops over the data once (a single "epoch"). -A convenient way to run multiple epochs from the REPL is provided by `@epochs`. - -```julia -julia> using Flux: @epochs - -julia> @epochs 2 println("hello") -[ Info: Epoch 1 -hello -[ Info: Epoch 2 -hello - -julia> @epochs 2 Flux.train!(...) -# Train for two epochs -``` - -```@docs -Flux.@epochs -``` - ## Callbacks Implicit `train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example: @@ -98,14 +78,9 @@ A more typical callback might look like this: test_x, test_y = # ... create single batch of test data ... evalcb() = @show(loss(test_x, test_y)) throttled_cb = throttle(evalcb, 5) -Flux.@epochs 20 Flux.train!(objective, ps, data, opt, cb = throttled_cb) -``` - -Calling `Flux.stop()` in a callback will exit the training loop early. - -```julia -cb = function () - accuracy() > 0.9 && Flux.stop() +for epoch in 1:20 + @info "Epoch $epoch" + Flux.train!(objective, ps, data, opt, cb = throttled_cb) end ``` diff --git a/docs/src/training/training.md b/docs/src/training/training.md index aba255af47..3070494188 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -65,14 +65,14 @@ It is also important that every `update!` step receives a newly gradient compute as this will be change whenever the model's parameters are changed, and for each new data point. !!! compat "Implicit gradients" - Flux ≤ 0.13 used Zygote's "implicit" mode, in which `gradient` takes a zero-argument function. + Flux ≤ 0.14 used Zygote's "implicit" mode, in which `gradient` takes a zero-argument function. It looks like this: ``` pars = Flux.params(model) grad = gradient(() -> loss(model(input), label), pars) ``` Here `pars::Params` and `grad::Grads` are two dictionary-like structures. - Support for this will be removed from Flux 0.14, and these blue (teal?) boxes + Support for this will be removed from Flux 0.15, and these blue (teal?) boxes explain what needs to change. ## Loss Functions @@ -90,7 +90,7 @@ like [`mse`](@ref Flux.Losses.mse) for mean-squared error or [`crossentropy`](@r are available from the [`Flux.Losses`](../models/losses.md) module. !!! compat "Implicit-style loss functions" - Flux ≤ 0.13 needed a loss function which closed over a reference to the model, + Flux ≤ 0.14 needed a loss function which closed over a reference to the model, instead of being a pure function. Thus in old code you may see something like ``` loss(x, y) = sum((model(x) .- y).^2) @@ -211,7 +211,7 @@ Or explicitly writing the anonymous function which this `do` block creates, !!! compat "Implicit-style `train!`" This is a new method of `train!`, which takes the result of `setup` as its 4th argument. The 1st argument is a function which accepts the model itself. - Flux versions ≤ 0.13 provided a method of `train!` for "implicit" parameters, + Flux versions ≤ 0.14 provided a method of `train!` for "implicit" parameters, which works like this: ``` train!((x,y) -> loss(model(x), y), Flux.params(model), train_set, Adam()) @@ -342,7 +342,7 @@ for epoch in 1:1000 end ``` -!!! compat "Flux ≤ 0.13" +!!! compat "Flux ≤ 0.14" With the old "implicit" optimiser, `opt = Adam(0.1)`, the equivalent was to directly mutate the `Adam` struct, `opt.eta = 0.001`. @@ -374,7 +374,7 @@ train!(loss, bimodel, data, opt_state) Flux.thaw!(opt_state) ``` -!!! compat "Flux ≤ 0.13" +!!! compat "Flux ≤ 0.14" The earlier "implicit" equivalent was to pass to `gradient` an object referencing only part of the model, such as `Flux.params(bimodel.layers.enc)`. @@ -383,7 +383,7 @@ Flux.thaw!(opt_state) Flux used to handle gradients, training, and optimisation rules quite differently. The new style described above is called "explicit" by Zygote, and the old style "implicit". -Flux 0.13 is the transitional version which supports both. +Flux 0.13 and 0.14 are the transitional versions which support both. The blue-green boxes above describe the changes. For more details on training in the implicit style, see [Flux 0.13.6 documentation](https://fluxml.ai/Flux.jl/v0.13.6/training/training/). diff --git a/docs/src/training/zygote.md b/docs/src/training/zygote.md index f25d151bb8..385e7dde7b 100644 --- a/docs/src/training/zygote.md +++ b/docs/src/training/zygote.md @@ -18,10 +18,10 @@ Zygote.hessian_reverse Zygote.diaghessian ``` -## Implicit style (Flux ≤ 0.13) +## Implicit style (Flux ≤ 0.14) Flux used to use what Zygote calls "implicit" gradients, [described here](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) in its documentation. -However, support for this will be removed from Flux 0.14. +However, support for this will be removed from Flux 0.15. !!! compat "Training" The blue-green boxes in the [training section](@ref man-training) describe diff --git a/docs/src/tutorials/2021-01-26-mlp.md b/docs/src/tutorials/2021-01-26-mlp.md index 7b2b530518..7f29543345 100644 --- a/docs/src/tutorials/2021-01-26-mlp.md +++ b/docs/src/tutorials/2021-01-26-mlp.md @@ -7,7 +7,7 @@ To run this example, we need the following packages: ```julia using Flux, Statistics using Flux.Data: DataLoader -using Flux: onehotbatch, onecold, logitcrossentropy, throttle, @epochs, params +using Flux: onehotbatch, onecold, logitcrossentropy, throttle, params using Base.Iterators: repeated using CUDA using MLDatasets @@ -138,8 +138,11 @@ function train(; kws...) ## Training evalcb = () -> @show(loss_all(train_data, m)) opt = Adam(args.rate) - - @epochs args.epochs Flux.train!(loss, params(m), train_data, opt, cb = evalcb) + + for epoch in 1:args.epochs + @info "Epoch $epoch" + Flux.train!(loss, params(m), train_data, opt, cb = evalcb) + end @show accuracy(train_data, m) @@ -153,7 +156,7 @@ end * **Initializes the model parameters:** Creates the `args` object that contains the defult values for training our model. * **Loads the train and test data:** Calls the function `getdata` we defined above. * **Constructs the model:** Builds the model and loads the train and test data sets, and our model onto the GPU (if available). -* **Trains the model:** Defines the *callback* function `evalcb` to show the value of the `loss_all` function during the training process. Then, it sets [Adam](@ref Flux.Optimise.Adam) as the optimiser for training out model. Finally, it runs the training process with the macro `@epochs` for `10` epochs (as defined in the `args` object) and shows the `accuracy` value for the train and test data. +* **Trains the model:** Defines the *callback* function `evalcb` to show the value of the `loss_all` function during the training process. Then, it sets [Adam](@ref Flux.Optimise.Adam) as the optimiser for training out model. Finally, it runs the training process for `10` epochs (as defined in the `args` object) and shows the `accuracy` value for the train and test data. To see the full version of this example, see [Simple multi-layer perceptron - model-zoo](https://github.com/FluxML/model-zoo/blob/master/vision/mlp_mnist/mlp_mnist.jl). diff --git a/docs/src/utilities.md b/docs/src/utilities.md index e23a4d1f70..bc1200124e 100644 --- a/docs/src/utilities.md +++ b/docs/src/utilities.md @@ -49,7 +49,6 @@ These functions call: ```@docs Flux.rng_from_array -Flux.default_rng_value Flux.nfan ``` diff --git a/src/Flux.jl b/src/Flux.jl index 132231c9ad..d522b91e78 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -10,7 +10,7 @@ using MacroTools: @forward using MLUtils import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions using Optimisers: freeze!, thaw!, adjust! - +using Random: default_rng using Zygote, ChainRulesCore using Zygote: Params, @adjoint, gradient, pullback using Zygote.ForwardDiff: value @@ -32,8 +32,6 @@ export Chain, Dense, Embedding, Maxout, SkipConnection, Parallel, PairwiseFusion include("optimise/Optimise.jl") using .Optimise -using .Optimise: @epochs -using .Optimise: skip export Descent, Adam, Momentum, Nesterov, RMSProp, AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, OAdam, AdamW, RAdam, AdaBelief, InvDecay, ExpDecay, diff --git a/src/deprecations.jl b/src/deprecations.jl index b796c498d7..4703a1202f 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -1,19 +1,3 @@ -# v0.12 deprecations - -function ones(dims...) - Base.depwarn("Flux.ones(size...) is deprecated, please use Flux.ones32(size...) or Base.ones(Float32, size...)", :ones, force=true) - Base.ones(Float32, dims...) -end -ones(T::Type, dims...) = Base.ones(T, dims...) - -function zeros(dims...) - Base.depwarn("Flux.zeros(size...) is deprecated, please use Flux.zeros32(size...) or Base.zeros(Float32, size...)", :zeros, force=true) - Base.zeros(Float32, dims...) -end -zeros(T::Type, dims...) = Base.zeros(T, dims...) - -ones32(::Type, dims...) = throw(ArgumentError("Flux.ones32 is always Float32, use Base.ones to specify the element type")) -zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32, use Base.zeros to specify the element type")) # v0.13 deprecations @@ -59,7 +43,7 @@ function loadparams!(m, xs) end # Channel notation: Changed to match Conv, but very softly deprecated! -# Perhaps change to @deprecate for v0.14, but there is no plan to remove these. +# Perhaps change to @deprecate for v0.15, but there is no plan to remove these. Dense(in::Integer, out::Integer, σ = identity; kw...) = Dense(in => out, σ; kw...) Bilinear(in1::Integer, in2::Integer, out::Integer, σ = identity; kw...) = @@ -86,7 +70,7 @@ Base.@deprecate_binding Data Flux false "Sub-module Flux.Data has been removed. @deprecate paramtype(T,m) _paramtype(T,m) false # internal method, renamed to make this clear -@deprecate rng_from_array() default_rng_value() +@deprecate rng_from_array() Random.default_rng() function istraining() Base.depwarn("Flux.istraining() is deprecated, use NNlib.within_gradient(x) instead", :istraining) @@ -216,13 +200,17 @@ ChainRulesCore.@non_differentiable _greek_ascii_depwarn(::Any...) # v0.14 deprecations +@deprecate default_rng_value() Random.default_rng() + + +# v0.15 deprecations -# Enable these when 0.14 is released, and delete const ClipGrad = Optimise.ClipValue etc: +# Enable these when 0.15 is released, and delete const ClipGrad = Optimise.ClipValue etc: # Base.@deprecate_binding Optimiser OptimiserChain # Base.@deprecate_binding ClipValue ClipGrad # train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError( -# """On Flux 0.14, `train!` no longer accepts implicit `Zygote.Params`. +# """On Flux 0.15, `train!` no longer accepts implicit `Zygote.Params`. # Instead of `train!(loss_xy, Flux.params(model), data, Adam())` # it now needs `opt = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt)` # where `loss_mxy` accepts the model as its first argument. diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 62661765de..d38be00df3 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -71,9 +71,9 @@ mutable struct Dropout{F<:Real,D,R<:AbstractRNG} active::Union{Bool, Nothing} rng::R end -Dropout(p::Real, dims, active) = Dropout(p, dims, active, default_rng_value()) +Dropout(p::Real, dims, active) = Dropout(p, dims, active, default_rng()) -function Dropout(p::Real; dims=:, active::Union{Bool,Nothing} = nothing, rng = default_rng_value()) +function Dropout(p::Real; dims=:, active::Union{Bool,Nothing} = nothing, rng = default_rng()) 0 ≤ p ≤ 1 || throw(ArgumentError("Dropout expects 0 ≤ p ≤ 1, got p = $p")) Dropout(p, dims, active, rng) end @@ -125,8 +125,8 @@ mutable struct AlphaDropout{F,R<:AbstractRNG} rng::R end -AlphaDropout(p, active) = AlphaDropout(p, active, default_rng_value()) -function AlphaDropout(p; rng = default_rng_value(), active::Union{Bool,Nothing} = nothing) +AlphaDropout(p, active) = AlphaDropout(p, active, default_rng()) +function AlphaDropout(p; rng = default_rng(), active::Union{Bool,Nothing} = nothing) 0 ≤ p ≤ 1 || throw(ArgumentError("AlphaDropout expects 0 ≤ p ≤ 1, got p = $p")) AlphaDropout(p, active, rng) end @@ -455,10 +455,12 @@ function Base.show(io::IO, l::InstanceNorm) end """ - GroupNorm(channels::Integer, G::Integer, λ=identity; - initβ=zeros32, initγ=ones32, - affine=true, track_stats=false, - eps=1f-5, momentum=0.1f0) + GroupNorm(channels::Int, G::Int, λ = identity; + initβ = zeros32, + initγ = ones32, + affine = true, + eps = 1f-5, + momentum = 0.1f0) [Group Normalization](https://arxiv.org/abs/1803.08494) layer. @@ -476,8 +478,6 @@ For `WHCN` images it's the usual channel dimension. If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias `β` and scale `γ` parameters. -If `track_stats=true`, accumulates mean and var statistics in training phase -that will be used to renormalize the input in test phase. # Examples ```jldoctest @@ -496,13 +496,13 @@ julia> isapprox(std(y[:, :, 3:4, 2]), 1, atol=0.1) && std(xs[:, :, 3:4, 2]) != s true ``` """ -mutable struct GroupNorm{F,V,N,W} +mutable struct GroupNorm{F,V,N} G::Int # number of groups λ::F # activation function β::V # bias γ::V # scale - μ::W # moving mean - σ²::W # moving std + μ::Nothing # moving mean + σ²::Nothing # moving std ϵ::N momentum::N affine::Bool @@ -516,20 +516,18 @@ trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;) function GroupNorm(chs::Int, G::Int, λ=identity; initβ=zeros32, initγ=ones32, - affine::Bool=true, track_stats::Bool=false, active::Union{Bool,Nothing}=nothing, + affine::Bool=true, active::Union{Bool,Nothing}=nothing, eps::Real=1f-5, momentum::Real=0.1f0, ϵ=nothing) - if track_stats - Base.depwarn("`track_stats=true` will be removed from GroupNorm in Flux 0.14. The default value is `track_stats=false`, which will work as before.", :GroupNorm) - end ε = _greek_ascii_depwarn(ϵ => eps, :GroupNorm, "ϵ" => "eps") chs % G == 0 || error("The number of groups ($(G)) must divide the number of channels ($chs)") β = affine ? initβ(chs) : nothing γ = affine ? initγ(chs) : nothing - μ = track_stats ? zeros32(G) : nothing - σ² = track_stats ? ones32(G) : nothing + μ = nothing + σ² = nothing + track_stats = false return GroupNorm(G, λ, β, γ, diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 48f660ffdb..3ca01e93fa 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -5,7 +5,7 @@ using LinearAlgebra export train!, update!, Descent, Adam, Momentum, Nesterov, RMSProp, AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW,RAdam, OAdam, AdaBelief, - InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser, + InvDecay, ExpDecay, WeightDecay, Optimiser, ClipValue, ClipNorm include("optimisers.jl") diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 0149490059..9da9b1472f 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -566,7 +566,7 @@ that will be fed into the next, and this is finally applied to the parameter as usual. !!! note - This will be replaced by `Optimisers.OptimiserChain` in Flux 0.14. + This will be replaced by `Optimisers.OptimiserChain` in Flux 0.15. """ mutable struct Optimiser <: AbstractOptimiser os::Vector{Any} @@ -704,7 +704,7 @@ end Clip gradients when their absolute value exceeds `thresh`. !!! note - This will be replaced by `Optimisers.ClipGrad` in Flux 0.14. + This will be replaced by `Optimisers.ClipGrad` in Flux 0.15. """ mutable struct ClipValue{T} <: AbstractOptimiser thresh::T diff --git a/src/optimise/train.jl b/src/optimise/train.jl index cf47013064..883a7210c4 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -16,7 +16,7 @@ As a result, the parameters are mutated and the optimiser's internal state may c The gradient could be mutated as well. !!! compat "Deprecated" - This method for implicit `Params` (and `AbstractOptimiser`) will be removed from Flux 0.14. + This method for implicit `Params` (and `AbstractOptimiser`) will be removed from Flux 0.15. The explicit method `update!(opt, model, grad)` from Optimisers.jl will remain. """ function update!(opt::AbstractOptimiser, x::AbstractArray, x̄) @@ -37,54 +37,6 @@ call(f, xs...) = f(xs...) runall(f) = f runall(fs::AbstractVector) = () -> foreach(call, fs) -struct SkipException <: Exception end - -""" - skip() - -Call `Flux.skip()` in a callback to indicate when a callback condition is met. -This will trigger the train loop to skip the current data point and not update with the calculated gradient. - -!!! note - `Flux.skip()` will be removed from Flux 0.14 - -# Examples -```julia -cb = function () - loss() > 1e7 && Flux.skip() -end -``` -""" -function skip() - Base.depwarn("""Flux.skip() will be removed from Flux 0.14. - and should be replaced with `continue` in an ordinary `for` loop.""", :skip) - throw(SkipException()) -end - - -struct StopException <: Exception end - -""" - stop() - -Call `Flux.stop()` in a callback to indicate when a callback condition is met. -This will trigger the train loop to stop and exit. - -!!! note - `Flux.stop()` will be removed from Flux 0.14. It should be replaced with `break` in an ordinary `for` loop. - -# Examples -```julia -cb = function () - accuracy() > 0.9 && Flux.stop() -end -``` -""" -function stop() - Base.depwarn("""Flux.stop() will be removed from Flux 0.14. - It should be replaced with `break` in an ordinary `for` loop.""", :stop) - throw(StopException()) -end batchmemaybe(x) = tuple(x) batchmemaybe(x::Tuple) = x @@ -96,7 +48,7 @@ Uses a `loss` function and training `data` to improve the model's parameters according to a particular optimisation rule `opt`. !!! compat "Deprecated" - This method with implicit `Params` will be removed from Flux 0.14. + This method with implicit `Params` will be removed from Flux 0.15. It should be replaced with the explicit method `train!(loss, model, data, opt)`. For each `d in data`, first the gradient of the `loss` is computed like this: @@ -118,7 +70,7 @@ Different optimisers can be combined using [`Flux.Optimise.Optimiser`](@ref Flux This training loop iterates through `data` once. It will stop with a `DomainError` if the loss is `NaN` or infinite. -You can use [`@epochs`](@ref) to do this several times, or +You can use use `train!` inside a for loop to do this several times, or use for instance `Itertools.ncycle` to make a longer `data` iterator. ## Callbacks @@ -128,8 +80,6 @@ For example, this will print "training" every 10 seconds (using [`Flux.throttle` ``` train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10)) ``` - -The callback can call [`Flux.stop`](@ref) to interrupt the training loop. Multiple callbacks can be passed to `cb` as array. """ @@ -138,51 +88,15 @@ function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ()) itrsz = Base.IteratorSize(typeof(data)) n = (itrsz == Base.HasLength()) || (itrsz == Base.HasShape{1}()) ? length(data) : 0 @withprogress for (i, d) in enumerate(data) - try - l, gs = withgradient(ps) do - loss(batchmemaybe(d)...) - end - if !isfinite(l) - throw(DomainError("Loss is $l on data item $i, stopping training")) - end - update!(opt, ps, gs) - cb() - catch ex - if ex isa StopException - break - elseif ex isa SkipException - continue - else - rethrow(ex) - end + l, gs = withgradient(ps) do + loss(batchmemaybe(d)...) + end + if !isfinite(l) + throw(DomainError("Loss is $l on data item $i, stopping training")) end + update!(opt, ps, gs) + cb() + @logprogress iszero(n) ? nothing : i / n end end - -""" - @epochs N body - -Run `body` `N` times. Mainly useful for quickly doing multiple epochs of -training in a REPL. - -!!! note - The macro `@epochs` will be removed from Flux 0.14. Please just write an ordinary `for` loop. - -# Examples -```julia -julia> Flux.@epochs 2 println("hello") -[ Info: Epoch 1 -hello -[ Info: Epoch 2 -hello -``` -""" -macro epochs(n, ex) - Base.depwarn("""The macro `@epochs` will be removed from Flux 0.14. - As an alternative, you can write a simple `for i in 1:epochs` loop.""", Symbol("@epochs"), force=true) - :(@progress for i = 1:$(esc(n)) - @info "Epoch $i" - $(esc(ex)) - end) -end diff --git a/src/utils.jl b/src/utils.jl index 00d8df0cba..082d9dcb1c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -34,36 +34,20 @@ ofeltype(x, y) = convert(float(eltype(x)), y) epseltype(x) = eps(float(eltype(x))) """ - rng_from_array([x]) + rng_from_array(x) Create an instance of the RNG most appropriate for `x`. The current defaults are: -- `x isa CuArray`: `CUDA.default_rng()`, else: -- `x isa AbstractArray`, or no `x` provided: - - Julia version is < 1.7: `Random.GLOBAL_RNG` - - Julia version is >= 1.7: `Random.default_rng()` +- `x isa CuArray`: `CUDA.default_rng()` +- `x isa AbstractArray`: `Random.default_rng() """ -rng_from_array(::AbstractArray) = default_rng_value() +rng_from_array(::AbstractArray) = Random.default_rng() @non_differentiable rng_from_array(::Any) -if VERSION >= v"1.7" - default_rng_value() = Random.default_rng() -else - default_rng_value() = Random.GLOBAL_RNG -end - -""" - default_rng_value() - -Create an instance of the default RNG depending on Julia's version. -- Julia version is < 1.7: `Random.GLOBAL_RNG` -- Julia version is >= 1.7: `Random.default_rng()` -""" -default_rng_value """ - glorot_uniform([rng = default_rng_value()], size...; gain = 1) -> Array + glorot_uniform([rng], size...; gain = 1) -> Array glorot_uniform([rng]; kw...) -> Function Return an `Array{Float32}` of the given `size` containing random numbers drawn from a uniform @@ -102,13 +86,13 @@ function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1) scale = Float32(gain) * sqrt(24.0f0 / sum(nfan(dims...))) (rand(rng, Float32, dims...) .- 0.5f0) .* scale end -glorot_uniform(dims::Integer...; kw...) = glorot_uniform(default_rng_value(), dims...; kw...) -glorot_uniform(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims...; kwargs...) -> glorot_uniform(rng, dims...; init_kwargs..., kwargs...) +glorot_uniform(dims::Integer...; kw...) = glorot_uniform(default_rng(), dims...; kw...) +glorot_uniform(rng::AbstractRNG=default_rng(); init_kwargs...) = (dims...; kwargs...) -> glorot_uniform(rng, dims...; init_kwargs..., kwargs...) ChainRulesCore.@non_differentiable glorot_uniform(::Any...) """ - glorot_normal([rng = default_rng_value(), size...; gain = 1) -> Array + glorot_normal([rng], size...; gain = 1) -> Array glorot_normal([rng]; kw...) -> Function Return an `Array{Float32}` of the given `size` containing random numbers drawn from a normal @@ -145,13 +129,13 @@ function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1) std = Float32(gain) * sqrt(2.0f0 / sum(nfan(dims...))) randn(rng, Float32, dims...) .* std end -glorot_normal(dims::Integer...; kwargs...) = glorot_normal(default_rng_value(), dims...; kwargs...) -glorot_normal(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims...; kwargs...) -> glorot_normal(rng, dims...; init_kwargs..., kwargs...) +glorot_normal(dims::Integer...; kwargs...) = glorot_normal(default_rng(), dims...; kwargs...) +glorot_normal(rng::AbstractRNG=default_rng(); init_kwargs...) = (dims...; kwargs...) -> glorot_normal(rng, dims...; init_kwargs..., kwargs...) ChainRulesCore.@non_differentiable glorot_normal(::Any...) """ - kaiming_uniform([rng = default_rng_value()], size...; gain = √2) -> Array + kaiming_uniform([rng], size...; gain = √2) -> Array kaiming_uniform([rng]; kw...) -> Function Return an `Array{Float32}` of the given `size` containing random numbers drawn from a uniform distribution @@ -180,13 +164,13 @@ function kaiming_uniform(rng::AbstractRNG, dims::Integer...; gain::Real = √2) return (rand(rng, Float32, dims...) .- 0.5f0) .* 2bound end -kaiming_uniform(dims::Integer...; kwargs...) = kaiming_uniform(default_rng_value(), dims...; kwargs...) -kaiming_uniform(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...) +kaiming_uniform(dims::Integer...; kwargs...) = kaiming_uniform(default_rng(), dims...; kwargs...) +kaiming_uniform(rng::AbstractRNG=default_rng(); init_kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...) ChainRulesCore.@non_differentiable kaiming_uniform(::Any...) """ - kaiming_normal([rng = default_rng_value()], size...; gain = √2) -> Array + kaiming_normal([rng], size...; gain = √2) -> Array kaiming_normal([rng]; kw...) -> Function Return an `Array{Float32}` of the given `size` containing random numbers taken from a normal @@ -217,13 +201,13 @@ function kaiming_normal(rng::AbstractRNG, dims::Integer...; gain::Real = √2f0) return randn(rng, Float32, dims...) .* std end -kaiming_normal(dims::Integer...; kwargs...) = kaiming_normal(default_rng_value(), dims...; kwargs...) +kaiming_normal(dims::Integer...; kwargs...) = kaiming_normal(default_rng(), dims...; kwargs...) kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...) ChainRulesCore.@non_differentiable kaiming_normal(::Any...) """ - truncated_normal([rng = default_rng_value()], size...; mean = 0, std = 1, lo = -2, hi = 2) -> Array + truncated_normal([rng], size...; mean = 0, std = 1, lo = -2, hi = 2) -> Array truncated_normal([rng]; kw...) -> Function Return an `Array{Float32}` of the given `size` where each element is drawn from a truncated normal distribution. @@ -263,13 +247,13 @@ function truncated_normal(rng::AbstractRNG, dims::Integer...; mean = 0, std = 1, return xs end -truncated_normal(dims::Integer...; kwargs...) = truncated_normal(default_rng_value(), dims...; kwargs...) -truncated_normal(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims...; kwargs...) -> truncated_normal(rng, dims...; init_kwargs..., kwargs...) +truncated_normal(dims::Integer...; kwargs...) = truncated_normal(default_rng(), dims...; kwargs...) +truncated_normal(rng::AbstractRNG=default_rng(); init_kwargs...) = (dims...; kwargs...) -> truncated_normal(rng, dims...; init_kwargs..., kwargs...) ChainRulesCore.@non_differentiable truncated_normal(::Any...) """ - orthogonal([rng = default_rng_value()], size...; gain = 1) -> Array + orthogonal([rng], size...; gain = 1) -> Array orthogonal([rng]; kw...) -> Function Return an `Array{Float32}` of the given `size` which is a (semi) orthogonal matrix, as described in [1]. @@ -324,13 +308,13 @@ function orthogonal(rng::AbstractRNG, d1::Integer, ds::Integer...; kwargs...) return reshape(orthogonal(rng, rows, cols; kwargs...), dims) end -orthogonal(dims::Integer...; kwargs...) = orthogonal(default_rng_value(), dims...; kwargs...) -orthogonal(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims::Integer...; kwargs...) -> orthogonal(rng, dims...; init_kwargs..., kwargs...) +orthogonal(dims::Integer...; kwargs...) = orthogonal(default_rng(), dims...; kwargs...) +orthogonal(rng::AbstractRNG=default_rng(); init_kwargs...) = (dims::Integer...; kwargs...) -> orthogonal(rng, dims...; init_kwargs..., kwargs...) ChainRulesCore.@non_differentiable orthogonal(::Any...) """ - sparse_init([rng = default_rng_value()], rows, cols; sparsity, std = 0.01) -> Array + sparse_init([rng], rows, cols; sparsity, std = 0.01) -> Array sparse_init([rng]; kw...) -> Function Return a `Matrix{Float32}` of size `rows, cols` where each column contains a fixed fraction of @@ -372,8 +356,8 @@ function sparse_init(rng::AbstractRNG, dims::Integer...; sparsity, std = 0.01) return mapslices(shuffle, sparse_array, dims=1) end -sparse_init(dims::Integer...; kwargs...) = sparse_init(default_rng_value(), dims...; kwargs...) -sparse_init(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; init_kwargs..., kwargs...) +sparse_init(dims::Integer...; kwargs...) = sparse_init(default_rng(), dims...; kwargs...) +sparse_init(rng::AbstractRNG=default_rng(); init_kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; init_kwargs..., kwargs...) ChainRulesCore.@non_differentiable sparse_init(::Any...) @@ -463,7 +447,7 @@ end # For consistency, it accepts an RNG, but ignores it: identity_init(::AbstractRNG, dims::Integer...; kwargs...) = identity_init(dims...; kwargs...) -identity_init(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (args...;kwargs...) -> identity_init(rng, args...; init_kwargs..., kwargs...) +identity_init(rng::AbstractRNG=default_rng(); init_kwargs...) = (args...;kwargs...) -> identity_init(rng, args...; init_kwargs..., kwargs...) ChainRulesCore.@non_differentiable identity_init(::Any...) diff --git a/test/layers/conv.jl b/test/layers/conv.jl index c83b2c18d3..5b4b80d918 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -299,7 +299,5 @@ end c3 = ConvTranspose((3,), 2=>4, relu) @test c3(x) isa Array{Float32, 3} - if VERSION >= v"1.8" - @test (@inferred c3(x); true) # fails on 1.6 - end + @test (@inferred c3(x); true) end diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 917ca20a17..35f11a4adc 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -69,11 +69,7 @@ evalwgrad(f, x...) = pullback(f, x...)[1] # CPU RNGs map onto CPU ok if isempty(rng_kwargs) - if VERSION >= v"1.7" - @test cpu(m).rng isa Random.TaskLocalRNG - else - @test cpu(m).rng isa Random._GLOBAL_RNG - end + @test cpu(m).rng isa Random.TaskLocalRNG else @test cpu(m).rng === only(values(rng_kwargs)) end @@ -118,11 +114,7 @@ end # CPU RNGs map onto CPU ok if isempty(rng_kwargs) - if VERSION >= v"1.7" - @test cpu(m).rng isa Random.TaskLocalRNG - else - @test cpu(m).rng isa Random._GLOBAL_RNG - end + @test cpu(m).rng isa Random.TaskLocalRNG else @test cpu(m).rng === only(values(rng_kwargs)) end @@ -388,7 +380,7 @@ end # begin tests squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions - let m = GroupNorm(4,2, track_stats=true), sizes = (3,4,2), + let m = GroupNorm(4,2), sizes = (3,4,2), x = reshape(collect(1:prod(sizes)), sizes) @test length(Flux.params(m)) == 2 @@ -396,47 +388,19 @@ end @test m.β == [0, 0, 0, 0] # initβ(32) @test m.γ == [1, 1, 1, 1] # initγ(32) - y = evalwgrad(m, x) - - #julia> x - #[:, :, 1] = - # 1.0 4.0 7.0 10.0 - # 2.0 5.0 8.0 11.0 - # 3.0 6.0 9.0 12.0 - # - #[:, :, 2] = - # 13.0 16.0 19.0 22.0 - # 14.0 17.0 20.0 23.0 - # 15.0 18.0 21.0 24.0 - # - # μ will be - # (1. + 2. + 3. + 4. + 5. + 6.) / 6 = 3.5 - # (7. + 8. + 9. + 10. + 11. + 12.) / 6 = 9.5 - # - # (13. + 14. + 15. + 16. + 17. + 18.) / 6 = 15.5 - # (19. + 20. + 21. + 22. + 23. + 24.) / 6 = 21.5 - # - # μ = - # 3.5 15.5 - # 9.5 21.5 - # - # ∴ update rule with momentum: - # (1. - .1) * 0 + .1 * (3.5 + 15.5) / 2 = 0.95 - # (1. - .1) * 0 + .1 * (9.5 + 21.5) / 2 = 1.55 - @test m.μ ≈ [0.95, 1.55] - n = prod(size(x)) ÷ m.G ÷ size(x)[end] - corr = n / (n-1) - z = reshape(x,3,2,2,2) - σ² = var(z, dims=(1,2), corrected=false) - @test m.σ² ≈ 0.1*corr*vec(mean(σ², dims=4)) .+ 0.9 * 1 + ŷ = evalwgrad(m, x) + + @test m.μ === nothing + @test m.σ² === nothing + ŷ = m(x) + y = [-1.4638476 0.29276943 -1.4638476 0.29276943; -0.87830865 0.87830853 -0.8783088 0.8783083; -0.29276967 1.4638474 -0.2927699 1.4638472;;; -1.4638476 0.29276943 -1.4638472 0.29276943; -0.8783083 0.8783083 -0.8783083 0.8783083; -0.29276943 1.4638472 -0.29276943 1.4638472] - y = m(x) - out = (z .- reshape(m.μ, 1,1,2,1)) ./ sqrt.(reshape(m.σ², 1,1,2,1) .+ 1f-5) - @test y ≈ reshape(out, size(x)) atol=1.0e-5 + @test ŷ ≈ y atol=1.0e-5 end # with activation function - let m = GroupNorm(4,2, sigmoid, track_stats=true), sizes = (3, 4, 2), + let m = GroupNorm(4,2, sigmoid), sizes = (3, 4, 2), x = reshape(collect(1:prod(sizes)), sizes) + x = Float32.(x) μ_affine_shape = ones(Int,length(sizes) + 1) μ_affine_shape[end-1] = 2 # Number of groups @@ -449,28 +413,18 @@ end og_shape = size(x) - y = m(x) - x_ = reshape(x,affine_shape...) - out = reshape(sigmoid.((x_ .- reshape(m.μ,μ_affine_shape...)) ./ sqrt.(reshape(m.σ²,μ_affine_shape...) .+ m.ϵ)),og_shape) - @test y ≈ out atol=1e-7 + ŷ = m(x) + y = [0.18787955 0.57267404 0.18787955 0.57267404; 0.2935284 0.70647156 0.29352835 0.70647156; 0.42732593 0.81212044 0.42732587 0.8121204;;; 0.18787955 0.57267404 0.1878796 0.57267404; 0.29352847 0.70647156 0.29352847 0.70647156; 0.42732602 0.8121204 0.42732602 0.8121204] + @test ŷ ≈ y atol=1e-7 end - let m = trainmode!(GroupNorm(2,2, track_stats=true)), sizes = (2, 4, 1, 2, 3), + let m = trainmode!(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) y = reshape(m(y), sizes...) @test m(x) == y end - # check that μ, σ², and the output are the correct size for higher rank tensors - let m = GroupNorm(4,2, track_stats=true), sizes = (5, 5, 3, 4, 4, 6), - x = Float32.(reshape(collect(1:prod(sizes)), sizes)) - y = evalwgrad(m, x) - @test size(m.μ) == (m.G,) - @test size(m.σ²) == (m.G,) - @test size(y) == sizes - end - # show that group norm is the same as instance norm when the group size is the same as the number of channels let IN = trainmode!(InstanceNorm(4; affine=true)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,5), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 8e959955c2..7df8b0d4c2 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -58,17 +58,10 @@ end @test primal[1] ≈ e - if VERSION < v"1.7" - @test ∇Wi ≈ grads[:Wi] - @test ∇Wh ≈ grads[:Wh] - @test ∇b ≈ grads[:b] - @test_broken ∇state0 ≈ grads[:state0] - else - @test_broken ∇Wi ≈ grads[:Wi] - @test_broken ∇Wh ≈ grads[:Wh] - @test_broken ∇b ≈ grads[:b] - @test_broken ∇state0 ≈ grads[:state0] - end + @test_broken ∇Wi ≈ grads[:Wi] + @test_broken ∇Wh ≈ grads[:Wh] + @test_broken ∇b ≈ grads[:b] + @test_broken ∇state0 ≈ grads[:state0] end # Ref FluxML/Flux.jl#1209 1D input diff --git a/test/optimise.jl b/test/optimise.jl index e09ba06b4d..c79ce7f5e8 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -46,34 +46,6 @@ end end @testset "Training Loop" begin - i = 0 - l = 1 - Flux.train!( - () -> (sleep(0.1); Flux.skip(); i+=1), - Params([]), - Iterators.repeated((), 10), - Descent() - ) - - @test i==0 #all skipped - - Flux.train!( - () -> (sleep(0.1); i==8 && Flux.skip(); i+=1), - Params([]), - Iterators.repeated((), 10), - Descent() - ) - - @test i==8 #skip after i hit 8 - - i = 0 - Flux.train!(() -> (sleep(0.1); i += 1; l), - Params([]), - Iterators.repeated((), 100), - Descent(), - cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1)) - - @test 3 < i < 50 # Test multiple callbacks x = 0