diff --git a/Project.toml b/Project.toml index 0c00aa53..25aac329 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "NNlib" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.26" +version = "0.9.27" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/conv.jl b/src/conv.jl index fead2ee2..f4050bdb 100644 --- a/src/conv.jl +++ b/src/conv.jl @@ -343,12 +343,12 @@ for conv in [:conv, :depthwiseconv] conv_pullback, ∇conv_data_pullback = Symbol.([conv, ∇conv_data], :_pullback) @eval function rrule(::typeof($conv), x, w, cdims; kw...) - function $conv_pullback(Δ) - Δ = colmajor(Δ) + function $conv_pullback(Δraw) + Δ = colmajor(unthunk(Δraw)) return ( NoTangent(), - @thunk($∇conv_data(unthunk(Δ), w, cdims, kw...)), - @thunk($∇conv_filter(x, unthunk(Δ), cdims, kw...)), + @thunk($∇conv_data(Δ, w, cdims, kw...)), + @thunk($∇conv_filter(x, Δ, cdims, kw...)), NoTangent(), ) end @@ -356,12 +356,12 @@ for conv in [:conv, :depthwiseconv] end @eval function rrule(::typeof($∇conv_data), x, w, cdims; kw...) - function $∇conv_data_pullback(Δ) - Δ = colmajor(Δ) + function $∇conv_data_pullback(Δraw) + Δ = colmajor(unthunk(Δraw)) return ( NoTangent(), - @thunk($conv(unthunk(Δ), w, cdims, kw...)), - @thunk($∇conv_filter(unthunk(Δ), x, cdims, kw...)), + @thunk($conv(Δ, w, cdims, kw...)), + @thunk($∇conv_filter(Δ, x, cdims, kw...)), NoTangent(), ) end