Skip to content

Commit

Permalink
Fix ConvTranspose symmetric non-constant padding
Browse files Browse the repository at this point in the history
The ConvTranspose was not able to handle symmetric non-constant
padding (ie, `pad=(1, 0)` for 2D ConvTranspose). Constant padding (ie
`pad=1` for 2D ConvTranspose) and assymetric non-constant padding (ie,
`pad=(1, 0, 2, 3)`) worked correctly. This commit fixes symmetric
non-constant padding and adds unit tests to ensure it produces the same
output size as an equivalent fully expanded padding.

Fixes #2424
  • Loading branch information
paulnovo committed Jun 22, 2024
1 parent 2f19e68 commit 2af2862
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 2 deletions.
10 changes: 9 additions & 1 deletion ext/FluxAMDGPUExt/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,16 @@ function Flux.conv_dims(c::Conv, x::T) where T <: ROCArray
end

function Flux.conv_transpose_dims(c::ConvTranspose, x::T) where T <: ROCArray
# Calculate combined pad in each dimension
nd = ndims(x) - 2
if length(c.pad) == nd

Check warning on line 10 in ext/FluxAMDGPUExt/conv.jl

View check run for this annotation

Codecov / codecov/patch

ext/FluxAMDGPUExt/conv.jl#L9-L10

Added lines #L9 - L10 were not covered by tests
# Handle symmetric non-constant padding
combined_pad = ntuple(i -> 2 * c.pad[i], nd)

Check warning on line 12 in ext/FluxAMDGPUExt/conv.jl

View check run for this annotation

Codecov / codecov/patch

ext/FluxAMDGPUExt/conv.jl#L12

Added line #L12 was not covered by tests
else
combined_pad = ntuple(i -> c.pad[2i-1] + c.pad[2i], nd)

Check warning on line 14 in ext/FluxAMDGPUExt/conv.jl

View check run for this annotation

Codecov / codecov/patch

ext/FluxAMDGPUExt/conv.jl#L14

Added line #L14 was not covered by tests
end

# Calculate size of "input", from ∇conv_data()'s perspective...
combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end])
I = (size(x)[1:end - 2] .- 1) .* c.stride .+ 1 .+
(size(c.weight)[1:end - 2] .- 1) .* c.dilation .- combined_pad
C_in = size(c.weight)[end - 1] * c.groups
Expand Down
11 changes: 10 additions & 1 deletion src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,21 @@ end
@layer ConvTranspose

function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
# Calculate combined pad in each dimension
nd = ndims(x) - 2
if length(c.pad) == nd
# Handle symmetric non-constant padding
combined_pad = ntuple(i -> 2 * c.pad[i], nd)
else
combined_pad = ntuple(i -> c.pad[2i-1] + c.pad[2i], nd)
end

# Calculate size of "input", from ∇conv_data()'s perspective...
calc_dim(xsz, wsz, stride, dilation, pad) = (xsz - 1) * stride + 1 + (wsz - 1) * dilation - pad
combined_pad = ntuple(i -> c.pad[2i-1] + c.pad[2i], length(c.pad) ÷ 2)
I = map(calc_dim, size(x)[1:end-2], size(c.weight)[1:end-2], c.stride, c.dilation, combined_pad)
C_in = size(c.weight)[end-1] * c.groups
batch_size = size(x)[end]

# Create DenseConvDims() that looks like the corresponding conv()
w_size = size(c.weight)
return DenseConvDims((I..., C_in, batch_size), w_size;
Expand Down
15 changes: 15 additions & 0 deletions test/ext_amdgpu/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ end
end
end

@testset "Convolution with symmetric non-constant padding" begin
for conv_type in (Conv, ConvTranspose), nd in 1:3
kernel = tuple(fill(2, nd)...)
x = rand(Float32, fill(10, nd)..., 3, 5) |> gpu

pad = ntuple(i -> i, nd)
m = conv_type(kernel, 3 => 4, pad=pad) |> f32 |> gpu

expanded_pad = ntuple(i -> pad[(i - 1) ÷ 2 + 1], 2 * nd)
m_expanded = conv_type(kernel, 3 => 4, pad=expanded_pad) |> f32 |> gpu

@test size(m(x)) == size(m_expanded(x))
end
end

@testset "Chain(Conv)" begin
m = Chain(Conv((3, 3), 3 => 3)) |> f32
x = rand(Float32, 10, 10, 3, 2)
Expand Down
13 changes: 13 additions & 0 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,19 @@ end
end
end

@testset "$ltype $(nd)D symmetric non-constant padding" for ltype in (Conv, ConvTranspose, DepthwiseConv, CrossCor), nd in (1, 2, 3)
kernel = ntuple(Returns(3), nd)
data = ones(Float32, (kernel .+ 5)..., 1,1)

pad = ntuple(i -> i, nd)
l = ltype(kernel, 1=>1, pad=pad)

expanded_pad = ntuple(i -> pad[(i - 1) ÷ 2 + 1], 2 * nd)
l_expanded = ltype(kernel, 1=>1, pad=expanded_pad)

@test size(l(data)) == size(l_expanded(data))
end

@testset "$ltype SamePad kernelsize $k" for ltype in (Conv, ConvTranspose, DepthwiseConv, CrossCor), k in ( (1,), (2,), (3,), (4,5), (6,7,8))
data = ones(Float32, (k .+ 3)..., 1,1)
l = ltype(k, 1=>1, pad=SamePad())
Expand Down

0 comments on commit 2af2862

Please sign in to comment.