Skip to content

Commit

Permalink
Non-diff shape handling in norm layers
Browse files Browse the repository at this point in the history
This reduces some latency when using Zygote.
  • Loading branch information
ToucheSir committed Dec 19, 2023
1 parent f4b4761 commit 790eb84
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ function _norm_layer_forward(
l, x::AbstractArray{T, N}; reduce_dims, affine_shape,
) where {T, N}
if !_isactive(l, x) && l.track_stats # testmode with tracked stats
stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
stats_shape = ChainRulesCore.ignore_derivatives() do
ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
end
μ = reshape(l.μ, stats_shape)
σ² = reshape(l.σ², stats_shape)
else # trainmode or testmode without tracked stats
Expand Down Expand Up @@ -347,7 +349,9 @@ trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;)
function (BN::BatchNorm)(x::AbstractArray{T,N}) where {T,N}
_size_check(BN, x, N-1 => BN.chs)
reduce_dims = [1:N-2; N]
affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
affine_shape = ChainRulesCore.ignore_derivatives() do
ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
end
return _norm_layer_forward(BN, x; reduce_dims, affine_shape)
end

Expand Down Expand Up @@ -439,7 +443,9 @@ trainable(in::InstanceNorm) = hasaffine(in) ? (β = in.β, γ = in.γ) : (;)
function (l::InstanceNorm)(x::AbstractArray{T,N}) where {T,N}
_size_check(l, x, N-1 => l.chs)
reduce_dims = 1:N-2
affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
affine_shape = ChainRulesCore.ignore_derivatives() do
ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
end
return _norm_layer_forward(l, x; reduce_dims, affine_shape)
end

Expand All @@ -456,10 +462,10 @@ end

"""
GroupNorm(channels::Int, G::Int, λ = identity;
initβ = zeros32,
initβ = zeros32,
initγ = ones32,
affine = true,
eps = 1f-5,
affine = true,
eps = 1f-5,
momentum = 0.1f0)
[Group Normalization](https://arxiv.org/abs/1803.08494) layer.
Expand Down Expand Up @@ -538,12 +544,14 @@ function GroupNorm(chs::Int, G::Int, λ=identity;
end

function (gn::GroupNorm)(x::AbstractArray)
_size_check(gn, x, ndims(x)-1 => gn.chs)
_size_check(gn, x, ndims(x)-1 => gn.chs)
sz = size(x)
x2 = reshape(x, sz[1:end-2]..., sz[end-1]÷gn.G, gn.G, sz[end])
N = ndims(x2) # == ndims(x)+1
reduce_dims = 1:N-2
affine_shape = ntuple(i -> i (N-1, N-2) ? size(x2, i) : 1, N)
affine_shape = ChainRulesCore.ignore_derivatives() do
ntuple(i -> i (N-1, N-2) ? size(x2, i) : 1, N)
end
x3 = _norm_layer_forward(gn, x2; reduce_dims, affine_shape)
return reshape(x3, sz)
end
Expand Down

0 comments on commit 790eb84

Please sign in to comment.