From 597bcd730334bdceefb753979157e3f4579f966c Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 24 Sep 2023 22:45:06 -0500 Subject: [PATCH] Add batchnorm derivatives --- ext/NNlibCUDACUDNNExt/batchnorm.jl | 147 +++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) diff --git a/ext/NNlibCUDACUDNNExt/batchnorm.jl b/ext/NNlibCUDACUDNNExt/batchnorm.jl index 2c38f009e..2c83d92c2 100644 --- a/ext/NNlibCUDACUDNNExt/batchnorm.jl +++ b/ext/NNlibCUDACUDNNExt/batchnorm.jl @@ -3,6 +3,8 @@ using cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward, cudnnBatchNormalizationForwardTraining import NNlib: batchnorm, ∇batchnorm +using EnzymeCore + # TODO: replace with new cudnn normalization interface # https://github.com/JuliaGPU/CUDA.jl/blob/master/lib/cudnn/normalization.jl @@ -153,3 +155,148 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr scalingParameter(T, alpha), scalingParameter(T, beta), scalingParameter(T, dalpha), scalingParameter(T, dbeta), xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar) end + + + +function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(cudnnBNForward!)}, ::Type{RT}, + y::OutType, + g, + b, + x, + running_mean, running_var, momentum::EnzymeCore.Const{<:Real}; kws...) where {OutType, RT} + + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated + func.val(y.val, b.val, x.val, running_mean.val, running_var.val, momentum.val; kws...) + end + + primal = if EnzymeCore.EnzymeRules.needs_primal(config) + y.val + else + nothing + end + shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) + y.dval + else + nothing + end + + cache_g = nothing + cache_x = nothing + cache_running_mean = nothing + cache_running_var = nothing + + if !(typeof(y) <: EnzymeCore.Const) + if !(typeof(x) <: EnzymeCore.Const) + || !(typeof(g) <: EnzymeCore.Const) + || !(typeof(b) <: EnzymeCore.Const) + + if EnzymeCore.EnzymeRules.overwritten(config)[3] + cache_g = copy(g.val) + end + if EnzymeCore.EnzymeRules.overwritten(config)[5] + cache_x = copy(x.val) + end + if EnzymeCore.EnzymeRules.overwritten(config)[6] + cache_running_mean = copy(running_mean.val) + end + if EnzymeCore.EnzymeRules.overwritten(config)[7] + cache_running_var = copy(running_var.val) + end + + end + end + + cache = (cache_g, cache_x, cache_running_mean, cache_running_var) + + return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(cudnnBNForward!)}, ::Type{RT}, + cache, + y::OutType, g, b, x, running_mean, running_var, momentum::EnzymeCore.Const{<:Real}; kws...) where {OutType, RT} + + cache_g, cache_x, cache_running_mean, cache_running_var = cache + + if !(typeof(y) <: EnzymeCore.Const) + if !(typeof(x) <: EnzymeCore.Const) + || !(typeof(g) <: EnzymeCore.Const) + || !(typeof(b) <: EnzymeCore.Const) + + if EnzymeCore.EnzymeRules.overwritten(config)[3] + cache_g = g.val + end + if EnzymeCore.EnzymeRules.overwritten(config)[5] + cache_x = x.val + end + if EnzymeCore.EnzymeRules.overwritten(config)[6] + cache_running_mean = running_mean.val + end + if EnzymeCore.EnzymeRules.overwritten(config)[7] + cache_running_var = running_var.val + end + + end + end + + dys = y.dval + dgs = (typeof(g) <: EnzymeCore.Const) ? dys : g.dval + dbs = (typeof(b) <: EnzymeCore.Const) ? dbs : b.dval + dxs = (typeof(x) <: EnzymeCore.Const) ? dxs : x.dval + + if EnzymeCore.EnzymeRules.width(config) == 1 + dys = (dys,) + dxs = (dxs,) + dgs = (dgs,) + dbs = (dbs,) + end + + for (dy, dx, dg, db) in zip(dys, dxs, dgs, dbs) + if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val + + if !((typeof(x) <: EnzymeCore.Const) || dx === x.val) + || !((typeof(g) <: EnzymeCore.Const) || dg === g.val) + || !((typeof(b) <: EnzymeCore.Const) || db === b.val) + + # dx values + alpha = T(1) + beta = T(1) + + # dx = alpha * newVal + beta old(dx) + # if x is constant, we can use zero for both + # otherwise we want to do dx += newVal, aka alpha=beta=1 + if x <: EnzymeCore.Const + alpha = T(0) + beta = T(0) + dx = similar(x.val) + end + + # dg / db values + alpha = T(1) + beta = T(1) + + if g <: EnzymeCore.Const && b <: EnzymeCore.Const + dalpha = T(0) + dbeta = T(0) + end + + if g <: EnzymeCore.Const + dg = similar(g.val) + end + + if b <: EnzymeCore.Const + db = similar(b.val) + end + + cudnnBNBackward!(dg, cache_g, db, dx, cache_x, dy, + cache_running_mean, cache_running_var, + momentum.val; alpha, beta, dalpha, dbeta; kw...) + + end + + dy .= 0 + + end + end + + return (nothing, nothing, nothing, nothing, nothing, nothing, nothing) +end