From 597ec6fefa6c6d7190d886d0bcd3145b539a5362 Mon Sep 17 00:00:00 2001 From: Paul Novotny Date: Sun, 28 Jul 2024 09:14:38 -0400 Subject: [PATCH] Fix ConvTranspose symmetric non-constant padding The ConvTranspose was not able to handle symmetric non-constant padding (ie, `pad=(1, 0)` for 2D ConvTranspose). Constant padding (ie `pad=1` for 2D ConvTranspose) and assymetric non-constant padding (ie, `pad=(1, 0, 2, 3)`) worked correctly. This commit fixes symmetric non-constant padding and adds unit tests to ensure it produces the same output size as an equivalent fully expanded padding. Fixes #2424 --- ext/FluxAMDGPUExt/conv.jl | 10 +++++++++- src/layers/conv.jl | 11 ++++++++++- test/ext_amdgpu/basic.jl | 15 +++++++++++++++ test/layers/conv.jl | 13 +++++++++++++ 4 files changed, 47 insertions(+), 2 deletions(-) diff --git a/ext/FluxAMDGPUExt/conv.jl b/ext/FluxAMDGPUExt/conv.jl index 681a38db6d..392f50d9cd 100644 --- a/ext/FluxAMDGPUExt/conv.jl +++ b/ext/FluxAMDGPUExt/conv.jl @@ -5,8 +5,16 @@ function Flux.conv_dims(c::Conv, x::T) where T <: ROCArray end function Flux.conv_transpose_dims(c::ConvTranspose, x::T) where T <: ROCArray + # Calculate combined pad in each dimension + nd = ndims(x) - 2 + if length(c.pad) == nd + # Handle symmetric non-constant padding + combined_pad = ntuple(i -> 2 * c.pad[i], nd) + else + combined_pad = ntuple(i -> c.pad[2i-1] + c.pad[2i], nd) + end + # Calculate size of "input", from ∇conv_data()'s perspective... - combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end]) I = (size(x)[1:end - 2] .- 1) .* c.stride .+ 1 .+ (size(c.weight)[1:end - 2] .- 1) .* c.dilation .- combined_pad .+ c.outpad C_in = size(c.weight)[end - 1] * c.groups diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 338bb89725..8ba07b95a8 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -321,12 +321,21 @@ end @layer ConvTranspose function conv_transpose_dims(c::ConvTranspose, x::AbstractArray) + # Calculate combined pad in each dimension + nd = ndims(x) - 2 + if length(c.pad) == nd + # Handle symmetric non-constant padding + combined_pad = ntuple(i -> 2 * c.pad[i], nd) + else + combined_pad = ntuple(i -> c.pad[2i-1] + c.pad[2i], nd) + end + # Calculate size of "input", from ∇conv_data()'s perspective... calc_dim(xsz, wsz, stride, dilation, pad, outpad) = (xsz - 1) * stride + 1 + (wsz - 1) * dilation - pad + outpad - combined_pad = ntuple(i -> c.pad[2i-1] + c.pad[2i], length(c.pad) ÷ 2) I = map(calc_dim, size(x)[1:end-2], size(c.weight)[1:end-2], c.stride, c.dilation, combined_pad, c.outpad) C_in = size(c.weight)[end-1] * c.groups batch_size = size(x)[end] + # Create DenseConvDims() that looks like the corresponding conv() w_size = size(c.weight) return DenseConvDims((I..., C_in, batch_size), w_size; diff --git a/test/ext_amdgpu/basic.jl b/test/ext_amdgpu/basic.jl index 86b1cccf37..831b577d48 100644 --- a/test/ext_amdgpu/basic.jl +++ b/test/ext_amdgpu/basic.jl @@ -46,6 +46,21 @@ end end end +@testset "Convolution with symmetric non-constant padding" begin + for conv_type in (Conv, ConvTranspose), nd in 1:3 + kernel = tuple(fill(2, nd)...) + x = rand(Float32, fill(10, nd)..., 3, 5) |> gpu + + pad = ntuple(i -> i, nd) + m = conv_type(kernel, 3 => 4, pad=pad) |> f32 |> gpu + + expanded_pad = ntuple(i -> pad[(i - 1) ÷ 2 + 1], 2 * nd) + m_expanded = conv_type(kernel, 3 => 4, pad=expanded_pad) |> f32 |> gpu + + @test size(m(x)) == size(m_expanded(x)) + end +end + @testset "ConvTranspose output padding" begin x = randn(Float32, 10, 11, 3, 2) m = ConvTranspose((3, 5), 3=>6, stride=3, outpad=(1, 0)) diff --git a/test/layers/conv.jl b/test/layers/conv.jl index be26786495..2e75a1e39d 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -257,6 +257,19 @@ end end end +@testset "$ltype $(nd)D symmetric non-constant padding" for ltype in (Conv, ConvTranspose, DepthwiseConv, CrossCor), nd in (1, 2, 3) + kernel = ntuple(Returns(3), nd) + data = ones(Float32, (kernel .+ 5)..., 1,1) + + pad = ntuple(i -> i, nd) + l = ltype(kernel, 1=>1, pad=pad) + + expanded_pad = ntuple(i -> pad[(i - 1) ÷ 2 + 1], 2 * nd) + l_expanded = ltype(kernel, 1=>1, pad=expanded_pad) + + @test size(l(data)) == size(l_expanded(data)) +end + @testset "$ltype SamePad kernelsize $k" for ltype in (Conv, ConvTranspose, DepthwiseConv, CrossCor), k in ( (1,), (2,), (3,), (4,5), (6,7,8)) data = ones(Float32, (k .+ 3)..., 1,1) l = ltype(k, 1=>1, pad=SamePad())