Skip to content

Commit

Permalink
fix CI on 1.6 and MacOS
Browse files Browse the repository at this point in the history
  • Loading branch information
ToucheSir committed Jan 3, 2023
1 parent d950245 commit 3a52deb
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 18 deletions.
22 changes: 11 additions & 11 deletions src/normalization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ _apply_scale_bias(x, scale, bias) = x .* scale .+ bias
Shared code path for all built-in norm functions.
`μ` and `σ²` should be calculated on the fly using [`NNlib.norm_stats`](@ref),
or extracted from an existing collection such as [`NNlib.RunningStats`](@ref).
`μ` and `σ²` should be calculated on the fly using [`norm_stats`](@ref),
or extracted from an existing collection such as [`RunningStats`](@ref).
`bias` and `scale` are consistent with cuDNN and Flux.Scale.
We opt for `scale` over `weight` to avoid confusion with dense layers.
If the size of the statistics and affine parameters differ,
Expand All @@ -64,7 +64,7 @@ Contains running mean and variance estimates for stateful norm functions.
If the parameters are mutable, they will be updated in-place.
Otherwise, they will be replaced wholesale.
See also [`NNlib.update_running_stats!`](@ref).
See also [`update_running_stats!`](@ref).
"""
mutable struct RunningStats{M <: AbstractArray, V <: AbstractArray, MT <: Real}
mean::M
Expand Down Expand Up @@ -114,10 +114,10 @@ end
reduce_dims) where {N}
Performs a moving average update for layers with tracked statistics.
`μ` and `σ²` are the sample mean and variance, most likely from [`NNlib.norm_stats`](@ref).
`reduce_dims` should also match the `dims` argument of [`NNlib.norm_stats`](@ref).
`μ` and `σ²` are the sample mean and variance, most likely from [`norm_stats`](@ref).
`reduce_dims` should also match the `dims` argument of [`norm_stats`](@ref).
See also [`NNlib.RunningStats`](@ref).
See also [`RunningStats`](@ref).
"""
function update_running_stats!(stats::RunningStats, x, μ, σ², reduce_dims::Dims)
V = eltype(σ²)
Expand Down Expand Up @@ -153,7 +153,7 @@ Normalizes `x` along the first `S` dimensions.
For an additional learned affine transform, provide a `S`-dimensional `scale` and `bias`.
See also [`NNlib.batchnorm`](@ref), [`NNlib.instancenorm`](@ref), and [`NNlib.groupnorm`](@ref).
See also [`batchnorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref).
# Examples
Expand Down Expand Up @@ -190,14 +190,14 @@ Functional [Batch Normalization](https://arxiv.org/abs/1502.03167) operation.
Normalizes `x` along each ``D_1×...×D_{N-2}×1×D_N`` input slice,
where `N-1` is the "channel" (or "feature", for 2D inputs) dimension.
Provide a [`NNlib.RunningStats`](@ref) to fix a estimated mean and variance.
Provide a [`RunningStats`](@ref) to fix a estimated mean and variance.
`batchnorm` will renormalize the input using these statistics during inference,
and update them using batch-level statistics when training.
To override this behaviour, manually set a value for `training`.
If specified, `scale` and `bias` will be applied as an additional learned affine transform.
See also [`NNlib.layernorm`](@ref), [`NNlib.instancenorm`](@ref), and [`NNlib.groupnorm`](@ref).
See also [`layernorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref).
"""
function batchnorm(x::AbstractArray{<:Any, N},
running_stats::Union{RunningStats, Nothing} = nothing,
Expand Down Expand Up @@ -232,7 +232,7 @@ To override this behaviour, manually set a value for `training`.
If specified, `scale` and `bias` will be applied as an additional learned affine transform.
See also [`NNlib.layernorm`](@ref), [`NNlib.batchnorm`](@ref), and [`NNlib.groupnorm`](@ref).
See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`groupnorm`](@ref).
"""
function instancenorm(x::AbstractArray{<:Any, N},
running_stats::Union{RunningStats, Nothing} = nothing,
Expand Down Expand Up @@ -266,7 +266,7 @@ The number of channels must be an integer multiple of the number of groups.
If specified, `scale` and `bias` will be applied as an additional learned affine transform.
See also [`NNlib.layernorm`](@ref), [`NNlib.batchnorm`](@ref), and [`NNlib.instancenorm`](@ref).
See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`instancenorm`](@ref).
# Examples
Expand Down
16 changes: 9 additions & 7 deletions test/normalization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ end

# Group/InstanceNorm dimensions
let W = 128, C = 2, N = 2, shape = (W, W, 1, 1)
x = [randn_sample(shape, 1, 1);;; randn_sample(shape, 2, 2);;;;
randn_sample(shape, 3, 3);;; randn_sample(shape, 4, 4)]
# Tile to W x W x 2 x 2
x = cat(cat(randn_sample(shape, 1, 1), randn_sample(shape, 2, 2); dims = 3),
cat(randn_sample(shape, 3, 3), randn_sample(shape, 4, 4); dims = 3);
dims = 4)
μ, σ² = NNlib.norm_stats(x, (1, 2))
@test vec(μ)1:(C * N) rtol=0.05
@test vec(σ²)abs2.(1:(C * N)) rtol=0.05
Expand All @@ -60,7 +62,9 @@ end
(running_stats, true, y_ns, y_ns, dx_ns),
(running_stats, false, meanvar, meanvar, NoTangent()),
]
@test NNlib.maybe_norm_stats(stats, x, dims, !training) == y
= NNlib.maybe_norm_stats(stats, x, dims, !training)
@test ŷ[1]y[1] rtol=1e-5
@test ŷ[2]y[2] rtol=1e-5
ŷ, back = rrule(NNlib.maybe_norm_stats, stats, x, dims, !training)
@test== y_ad
@test back(meanvar) == (NoTangent(), NoTangent(), dx, NoTangent(), NoTangent())
Expand Down Expand Up @@ -170,8 +174,7 @@ end
@testset for use_stats in (true, false)
stats = use_stats ? NNlib.RunningStats(zeros(2), ones(2), 0.1) : nothing
y, back = Zygote.pullback(NNlib.instancenorm, x, stats, scale, bias, 1e-5)
@test y[-1.22474 -1.22474; 0.0 0.0; 1.22474 1.22474;;;
-1.22474 -1.22474; 0.0 0.0; 1.22474 1.22474] rtol=1e-5
@test yrepeat([-1.22474, 0.0, 1.22474], 1, 2, 2) rtol=1e-5

expected_mean, expected_var = [0.5, 0.8], [1.0, 1.0]
if use_stats
Expand All @@ -197,8 +200,7 @@ end
end

dx, dstats, dscale, dbias, _ = back(fill!(similar(y), 1))
@test dx[3.6742 3.6742; 1.22474 1.22474; -1.22474 -1.22474;;;
3.6742 3.6742; 1.22474 1.22474; -1.22474 -1.22474] rtol=1e-5
@test dxrepeat([3.6742, 1.22474, -1.22474], 1, 2, 2) rtol=1e-5
@test dscale == zeros(2)
@test dbias == fill(6.0, 2)
@test dstats === nothing
Expand Down

0 comments on commit 3a52deb

Please sign in to comment.