Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix: instance norm gradients with enzyme
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 31, 2024
1 parent 73b8961 commit 59145df
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 9 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b"
VectorizedStatistics = "3b853605-1c98-4422-8364-4bd93ee0529e"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Expand Down Expand Up @@ -77,7 +76,6 @@ Statistics = "1.10"
Test = "1.10"
Tracker = "0.2.34"
UnrolledUtilities = "0.1.2"
VectorizedStatistics = "0.5.10"
Zygote = "0.6.70"
cuDNN = "1.3"
julia = "1.10"
Expand Down
1 change: 0 additions & 1 deletion src/LuxLib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ using StaticArraysCore: StaticArraysCore, StaticArray, StaticVector
using Statistics: Statistics, mean, var
using SLEEFPirates: SLEEFPirates
using UnrolledUtilities: unrolled_any, unrolled_all, unrolled_filter, unrolled_mapreduce
using VectorizedStatistics: vmean, vvar

@reexport using NNlib

Expand Down
2 changes: 1 addition & 1 deletion src/api/instancenorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ end
end

function _test_valid_instancenorm_arguments(::AbstractArray{T, N}) where {T, N}
N > 2 || throw(ArgumentError("`ndims(x) = $(N)` must be at least 2."))
N > 2 || throw(ArgumentError("`ndims(x) = $(N)` must be at least > 2."))
return nothing
end

Expand Down
4 changes: 0 additions & 4 deletions src/impl/fast_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,13 @@
# VectorizedStatistics.jl, we can will specialize the CPU dispatches to use them.
fast_mean(x::AbstractArray; dims=:) = fast_mean(internal_operation_mode(x), x; dims)
fast_mean(opmode, x::AbstractArray; dims=:) = mean(x; dims)
fast_mean(::LoopedArrayOp, x::AbstractArray; dims=:) = vmean(x; dims, multithreaded=true)

function fast_var(x::AbstractArray; mean=nothing, dims=:, corrected=true)
return fast_var(internal_operation_mode(x), x; mean, dims, corrected)
end
function fast_var(opmode, x::AbstractArray; mean=nothing, dims=:, corrected=true)
return var(x; mean, dims, corrected)
end
function fast_var(::LoopedArrayOp, x::AbstractArray; mean=nothing, dims=:, corrected=true)
return vvar(x; mean, dims, corrected, multithreaded=true)
end

function fast_mean_var(x::AbstractArray; dims=:, corrected=true)
return fast_mean_var(internal_operation_mode(x), x; dims, corrected)
Expand Down
5 changes: 4 additions & 1 deletion src/impl/normalization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ function __update_statistics(opmode, rμ, rσ², μ, σ², m1, m2)
__update_statistics!(opmode, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, 1 - m1)
return rμ2, rσ²2
end

CRC.@non_differentiable __update_statistics(::Any...)

function __update_statistics!(::LoopedArrayOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3)
@tturbo for I in indices((rμ2, rσ²2))
rμ2[I] = m3 * rμ[I] + m1 * μ[I]
Expand All @@ -37,7 +40,7 @@ end
@inbounds rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I]
end

CRC.@non_differentiable __update_statistics(::Any...)
EnzymeRules.inactive(::typeof(__update_statistics!), ::Any...) = nothing

function _update_normalization_statistics(
x::AbstractArray{T, N}, rμ::AbstractArray{<:Number, N},
Expand Down

0 comments on commit 59145df

Please sign in to comment.