Skip to content

Commit

Permalink
Add batchnorm derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Oct 8, 2023
1 parent 256a4fb commit 597bcd7
Showing 1 changed file with 147 additions and 0 deletions.
147 changes: 147 additions & 0 deletions ext/NNlibCUDACUDNNExt/batchnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ using cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward,
cudnnBatchNormalizationForwardTraining
import NNlib: batchnorm, ∇batchnorm

using EnzymeCore

# TODO: replace with new cudnn normalization interface
# https://github.com/JuliaGPU/CUDA.jl/blob/master/lib/cudnn/normalization.jl

Expand Down Expand Up @@ -153,3 +155,148 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr
scalingParameter(T, alpha), scalingParameter(T, beta), scalingParameter(T, dalpha), scalingParameter(T, dbeta),
xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar)
end



function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(cudnnBNForward!)}, ::Type{RT},
y::OutType,
g,
b,
x,
running_mean, running_var, momentum::EnzymeCore.Const{<:Real}; kws...) where {OutType, RT}

if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
func.val(y.val, b.val, x.val, running_mean.val, running_var.val, momentum.val; kws...)
end

primal = if EnzymeCore.EnzymeRules.needs_primal(config)
y.val
else
nothing
end
shadow = if EnzymeCore.EnzymeRules.needs_shadow(config)
y.dval
else
nothing
end

cache_g = nothing
cache_x = nothing
cache_running_mean = nothing
cache_running_var = nothing

if !(typeof(y) <: EnzymeCore.Const)
if !(typeof(x) <: EnzymeCore.Const)
|| !(typeof(g) <: EnzymeCore.Const)
|| !(typeof(b) <: EnzymeCore.Const)

if EnzymeCore.EnzymeRules.overwritten(config)[3]
cache_g = copy(g.val)
end
if EnzymeCore.EnzymeRules.overwritten(config)[5]
cache_x = copy(x.val)
end
if EnzymeCore.EnzymeRules.overwritten(config)[6]
cache_running_mean = copy(running_mean.val)
end
if EnzymeCore.EnzymeRules.overwritten(config)[7]
cache_running_var = copy(running_var.val)
end

end
end

cache = (cache_g, cache_x, cache_running_mean, cache_running_var)

return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache)
end

function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(cudnnBNForward!)}, ::Type{RT},
cache,
y::OutType, g, b, x, running_mean, running_var, momentum::EnzymeCore.Const{<:Real}; kws...) where {OutType, RT}

cache_g, cache_x, cache_running_mean, cache_running_var = cache

if !(typeof(y) <: EnzymeCore.Const)
if !(typeof(x) <: EnzymeCore.Const)
|| !(typeof(g) <: EnzymeCore.Const)
|| !(typeof(b) <: EnzymeCore.Const)

if EnzymeCore.EnzymeRules.overwritten(config)[3]
cache_g = g.val
end
if EnzymeCore.EnzymeRules.overwritten(config)[5]
cache_x = x.val
end
if EnzymeCore.EnzymeRules.overwritten(config)[6]
cache_running_mean = running_mean.val
end
if EnzymeCore.EnzymeRules.overwritten(config)[7]
cache_running_var = running_var.val
end

end
end

dys = y.dval
dgs = (typeof(g) <: EnzymeCore.Const) ? dys : g.dval
dbs = (typeof(b) <: EnzymeCore.Const) ? dbs : b.dval
dxs = (typeof(x) <: EnzymeCore.Const) ? dxs : x.dval

if EnzymeCore.EnzymeRules.width(config) == 1
dys = (dys,)
dxs = (dxs,)
dgs = (dgs,)
dbs = (dbs,)
end

for (dy, dx, dg, db) in zip(dys, dxs, dgs, dbs)
if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val

if !((typeof(x) <: EnzymeCore.Const) || dx === x.val)
|| !((typeof(g) <: EnzymeCore.Const) || dg === g.val)
|| !((typeof(b) <: EnzymeCore.Const) || db === b.val)

# dx values
alpha = T(1)
beta = T(1)

# dx = alpha * newVal + beta old(dx)
# if x is constant, we can use zero for both
# otherwise we want to do dx += newVal, aka alpha=beta=1
if x <: EnzymeCore.Const
alpha = T(0)
beta = T(0)
dx = similar(x.val)
end

# dg / db values
alpha = T(1)
beta = T(1)

if g <: EnzymeCore.Const && b <: EnzymeCore.Const
dalpha = T(0)
dbeta = T(0)
end

if g <: EnzymeCore.Const
dg = similar(g.val)
end

if b <: EnzymeCore.Const
db = similar(b.val)
end

cudnnBNBackward!(dg, cache_g, db, dx, cache_x, dy,
cache_running_mean, cache_running_var,
momentum.val; alpha, beta, dalpha, dbeta; kw...)

end

dy .= 0

end
end

return (nothing, nothing, nothing, nothing, nothing, nothing, nothing)
end

0 comments on commit 597bcd7

Please sign in to comment.