diff --git a/src/normalization.jl b/src/normalization.jl index 9b8011e4a..14862b554 100644 --- a/src/normalization.jl +++ b/src/normalization.jl @@ -39,8 +39,8 @@ _apply_scale_bias(x, scale, bias) = x .* scale .+ bias Shared code path for all built-in norm functions. -`μ` and `σ²` should be calculated on the fly using [`NNlib.norm_stats`](@ref), -or extracted from an existing collection such as [`NNlib.RunningStats`](@ref). +`μ` and `σ²` should be calculated on the fly using [`norm_stats`](@ref), +or extracted from an existing collection such as [`RunningStats`](@ref). `bias` and `scale` are consistent with cuDNN and Flux.Scale. We opt for `scale` over `weight` to avoid confusion with dense layers. If the size of the statistics and affine parameters differ, @@ -64,7 +64,7 @@ Contains running mean and variance estimates for stateful norm functions. If the parameters are mutable, they will be updated in-place. Otherwise, they will be replaced wholesale. -See also [`NNlib.update_running_stats!`](@ref). +See also [`update_running_stats!`](@ref). """ mutable struct RunningStats{M <: AbstractArray, V <: AbstractArray, MT <: Real} mean::M @@ -114,10 +114,10 @@ end reduce_dims) where {N} Performs a moving average update for layers with tracked statistics. -`μ` and `σ²` are the sample mean and variance, most likely from [`NNlib.norm_stats`](@ref). -`reduce_dims` should also match the `dims` argument of [`NNlib.norm_stats`](@ref). +`μ` and `σ²` are the sample mean and variance, most likely from [`norm_stats`](@ref). +`reduce_dims` should also match the `dims` argument of [`norm_stats`](@ref). -See also [`NNlib.RunningStats`](@ref). +See also [`RunningStats`](@ref). """ function update_running_stats!(stats::RunningStats, x, μ, σ², reduce_dims::Dims) V = eltype(σ²) @@ -153,7 +153,7 @@ Normalizes `x` along the first `S` dimensions. For an additional learned affine transform, provide a `S`-dimensional `scale` and `bias`. -See also [`NNlib.batchnorm`](@ref), [`NNlib.instancenorm`](@ref), and [`NNlib.groupnorm`](@ref). +See also [`batchnorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref). # Examples @@ -190,14 +190,14 @@ Functional [Batch Normalization](https://arxiv.org/abs/1502.03167) operation. Normalizes `x` along each ``D_1×...×D_{N-2}×1×D_N`` input slice, where `N-1` is the "channel" (or "feature", for 2D inputs) dimension. -Provide a [`NNlib.RunningStats`](@ref) to fix a estimated mean and variance. +Provide a [`RunningStats`](@ref) to fix a estimated mean and variance. `batchnorm` will renormalize the input using these statistics during inference, and update them using batch-level statistics when training. To override this behaviour, manually set a value for `training`. If specified, `scale` and `bias` will be applied as an additional learned affine transform. -See also [`NNlib.layernorm`](@ref), [`NNlib.instancenorm`](@ref), and [`NNlib.groupnorm`](@ref). +See also [`layernorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref). """ function batchnorm(x::AbstractArray{<:Any, N}, running_stats::Union{RunningStats, Nothing} = nothing, @@ -232,7 +232,7 @@ To override this behaviour, manually set a value for `training`. If specified, `scale` and `bias` will be applied as an additional learned affine transform. -See also [`NNlib.layernorm`](@ref), [`NNlib.batchnorm`](@ref), and [`NNlib.groupnorm`](@ref). +See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`groupnorm`](@ref). """ function instancenorm(x::AbstractArray{<:Any, N}, running_stats::Union{RunningStats, Nothing} = nothing, @@ -266,7 +266,7 @@ The number of channels must be an integer multiple of the number of groups. If specified, `scale` and `bias` will be applied as an additional learned affine transform. -See also [`NNlib.layernorm`](@ref), [`NNlib.batchnorm`](@ref), and [`NNlib.instancenorm`](@ref). +See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`instancenorm`](@ref). # Examples diff --git a/test/normalization.jl b/test/normalization.jl index 128610297..b38b1ca71 100644 --- a/test/normalization.jl +++ b/test/normalization.jl @@ -35,8 +35,10 @@ end # Group/InstanceNorm dimensions let W = 128, C = 2, N = 2, shape = (W, W, 1, 1) - x = [randn_sample(shape, 1, 1);;; randn_sample(shape, 2, 2);;;; - randn_sample(shape, 3, 3);;; randn_sample(shape, 4, 4)] + # Tile to W x W x 2 x 2 + x = cat(cat(randn_sample(shape, 1, 1), randn_sample(shape, 2, 2); dims = 3), + cat(randn_sample(shape, 3, 3), randn_sample(shape, 4, 4); dims = 3); + dims = 4) μ, σ² = NNlib.norm_stats(x, (1, 2)) @test vec(μ)≈1:(C * N) rtol=0.05 @test vec(σ²)≈abs2.(1:(C * N)) rtol=0.05 @@ -60,7 +62,9 @@ end (running_stats, true, y_ns, y_ns, dx_ns), (running_stats, false, meanvar, meanvar, NoTangent()), ] - @test NNlib.maybe_norm_stats(stats, x, dims, !training) == y + ŷ = NNlib.maybe_norm_stats(stats, x, dims, !training) + @test ŷ[1]≈y[1] rtol=1e-5 + @test ŷ[2]≈y[2] rtol=1e-5 ŷ, back = rrule(NNlib.maybe_norm_stats, stats, x, dims, !training) @test ŷ == y_ad @test back(meanvar) == (NoTangent(), NoTangent(), dx, NoTangent(), NoTangent()) @@ -170,8 +174,7 @@ end @testset for use_stats in (true, false) stats = use_stats ? NNlib.RunningStats(zeros(2), ones(2), 0.1) : nothing y, back = Zygote.pullback(NNlib.instancenorm, x, stats, scale, bias, 1e-5) - @test y≈[-1.22474 -1.22474; 0.0 0.0; 1.22474 1.22474;;; - -1.22474 -1.22474; 0.0 0.0; 1.22474 1.22474] rtol=1e-5 + @test y≈repeat([-1.22474, 0.0, 1.22474], 1, 2, 2) rtol=1e-5 expected_mean, expected_var = [0.5, 0.8], [1.0, 1.0] if use_stats @@ -197,8 +200,7 @@ end end dx, dstats, dscale, dbias, _ = back(fill!(similar(y), 1)) - @test dx≈[3.6742 3.6742; 1.22474 1.22474; -1.22474 -1.22474;;; - 3.6742 3.6742; 1.22474 1.22474; -1.22474 -1.22474] rtol=1e-5 + @test dx≈repeat([3.6742, 1.22474, -1.22474], 1, 2, 2) rtol=1e-5 @test dscale == zeros(2) @test dbias == fill(6.0, 2) @test dstats === nothing