Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ConvTranspose symmetric non-constant padding #2463

Merged
merged 1 commit into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion ext/FluxAMDGPUExt/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,16 @@ function Flux.conv_dims(c::Conv, x::T) where T <: ROCArray
end

function Flux.conv_transpose_dims(c::ConvTranspose, x::T) where T <: ROCArray
# Calculate combined pad in each dimension
nd = ndims(x) - 2
if length(c.pad) == nd
# Handle symmetric non-constant padding
combined_pad = ntuple(i -> 2 * c.pad[i], nd)
else
combined_pad = ntuple(i -> c.pad[2i-1] + c.pad[2i], nd)
end

# Calculate size of "input", from ∇conv_data()'s perspective...
combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end])
I = (size(x)[1:end - 2] .- 1) .* c.stride .+ 1 .+
(size(c.weight)[1:end - 2] .- 1) .* c.dilation .- combined_pad .+ c.outpad
C_in = size(c.weight)[end - 1] * c.groups
Expand Down
11 changes: 10 additions & 1 deletion src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,21 @@ end
@layer ConvTranspose

function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
# Calculate combined pad in each dimension
nd = ndims(x) - 2
if length(c.pad) == nd
# Handle symmetric non-constant padding
combined_pad = ntuple(i -> 2 * c.pad[i], nd)
else
combined_pad = ntuple(i -> c.pad[2i-1] + c.pad[2i], nd)
end

# Calculate size of "input", from ∇conv_data()'s perspective...
calc_dim(xsz, wsz, stride, dilation, pad, outpad) = (xsz - 1) * stride + 1 + (wsz - 1) * dilation - pad + outpad
combined_pad = ntuple(i -> c.pad[2i-1] + c.pad[2i], length(c.pad) ÷ 2)
I = map(calc_dim, size(x)[1:end-2], size(c.weight)[1:end-2], c.stride, c.dilation, combined_pad, c.outpad)
C_in = size(c.weight)[end-1] * c.groups
batch_size = size(x)[end]

# Create DenseConvDims() that looks like the corresponding conv()
w_size = size(c.weight)
return DenseConvDims((I..., C_in, batch_size), w_size;
Expand Down
15 changes: 15 additions & 0 deletions test/ext_amdgpu/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ end
end
end

@testset "Convolution with symmetric non-constant padding" begin
for conv_type in (Conv, ConvTranspose), nd in 1:3
kernel = tuple(fill(2, nd)...)
x = rand(Float32, fill(10, nd)..., 3, 5) |> gpu

pad = ntuple(i -> i, nd)
m = conv_type(kernel, 3 => 4, pad=pad) |> f32 |> gpu

expanded_pad = ntuple(i -> pad[(i - 1) ÷ 2 + 1], 2 * nd)
m_expanded = conv_type(kernel, 3 => 4, pad=expanded_pad) |> f32 |> gpu

@test size(m(x)) == size(m_expanded(x))
end
end

@testset "ConvTranspose output padding" begin
x = randn(Float32, 10, 11, 3, 2)
m = ConvTranspose((3, 5), 3=>6, stride=3, outpad=(1, 0))
Expand Down
13 changes: 13 additions & 0 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,19 @@ end
end
end

@testset "$ltype $(nd)D symmetric non-constant padding" for ltype in (Conv, ConvTranspose, DepthwiseConv, CrossCor), nd in (1, 2, 3)
kernel = ntuple(Returns(3), nd)
data = ones(Float32, (kernel .+ 5)..., 1,1)

pad = ntuple(i -> i, nd)
l = ltype(kernel, 1=>1, pad=pad)

expanded_pad = ntuple(i -> pad[(i - 1) ÷ 2 + 1], 2 * nd)
l_expanded = ltype(kernel, 1=>1, pad=expanded_pad)

@test size(l(data)) == size(l_expanded(data))
end

@testset "$ltype SamePad kernelsize $k" for ltype in (Conv, ConvTranspose, DepthwiseConv, CrossCor), k in ( (1,), (2,), (3,), (4,5), (6,7,8))
data = ones(Float32, (k .+ 3)..., 1,1)
l = ltype(k, 1=>1, pad=SamePad())
Expand Down
Loading