Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use KernelAbstractions for fold/unfold #596

Merged
merged 2 commits into from
Jul 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion ext/NNlibCUDAExt/NNlibCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ include("activations.jl")
include("batchedadjtrans.jl")
include("batchedmul.jl")
include("ctc.jl")
include("fold.jl")
include("scatter.jl")
include("utils.jl")

Expand Down
111 changes: 0 additions & 111 deletions ext/NNlibCUDAExt/fold.jl

This file was deleted.

4 changes: 2 additions & 2 deletions ext/NNlibFFTWExt/stft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function NNlib.stft(x;
ids = [
row + hop_length * col
for row in 1:n_fft, col in 0:(n_frames - 1)]
x = x[ids, ntuple(_ -> Colon(), ndims(x) - 1)...]
x = @inbounds x[ids, ntuple(_ -> Colon(), ndims(x) - 1)...]
end

region = 1
Expand Down Expand Up @@ -113,7 +113,7 @@ function NNlib.istft(y;
# In case of batched input, reshaped it (n_fft, n_frames, batch) -> (:, batch).
nd = ntuple(_ -> Colon(), ndims(x) - 2)
ndims(x) == 3 && (x = reshape(x, (:, size(x, 3)));)
x = x[ids, nd...]
x = @inbounds x[ids, nd...]

# Trim padding.
left = center ? (n_fft ÷ 2 + 1) : 1
Expand Down
2 changes: 1 addition & 1 deletion src/audio/spectrogram.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function spectrogram(waveform;
window_normalized && (spec = spec .* inv(norm(window));)

if power > 0
p = real(eltype(spec)(power))
p = eltype(waveform)(power)
spec = abs.(spec).^p
end
return spec
Expand Down
135 changes: 119 additions & 16 deletions src/fold.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

"""
unfold(x, kernel_size; stride = 1, pad = 0, dilation = 0, flipped = true)

Expand All @@ -7,10 +6,10 @@
of kernel)*input_channels`. The number of sliding windows will match those of
convolution (`conv`) with the same kernel_size and arguments. Note that
by default `conv` flips the spatial dimensions of its kernel (default
`flipped=false`), whereas `unfold` does not (default `flipped=true`).
Uses `NNlib.im2col!` as backend.
`flipped=false`), whereas `unfold` does not (default `flipped=true`).
Uses `NNlib.im2col!` as backend.

See also [`fold`](@ref), the adjoint/transpose operator
See also [`fold`](@ref), the adjoint/transpose operator
and a potential inverse of `unfold`.

# Example
Expand All @@ -23,7 +22,7 @@

julia> kws = (pad=1, stride=2, flipped=true); # use same args for conv and unfold

julia> z = NNlib.unfold(x, size(w); kws...)
julia> z = NNlib.unfold(x, size(w); kws...)
4×3×1 Array{Int64, 3}:
[:, :, 1] =
0 100 2
Expand Down Expand Up @@ -61,8 +60,8 @@

The adjoint/transpose operator of `unfold`. It accumulates sliding windows from
the output of `unfold` into a container tensor of size `output_size`. An inverse
to `unfold` may be obtained (in some cases) by using `fold` and accounting for scaling issues
with a divisor (see example). Uses `NNlib.col2im!` as backend.
to `unfold` may be obtained (in some cases) by using `fold` and accounting for scaling issues
with a divisor (see example). Uses `NNlib.col2im!` as backend.

See also [`unfold`](@ref).

Expand Down Expand Up @@ -101,7 +100,7 @@
2.0
1.0

julia> z ./ divisor
julia> z ./ divisor
7×1×1 Array{Float64, 3}:
[:, :, 1] =
100.0
Expand Down Expand Up @@ -133,30 +132,30 @@
end

function fold(y::AbstractArray{T, 3}, output_size::NTuple, cdims::DenseConvDims) where {T}
x = similar(y, output_size)
x = similar(y, output_size)
return fold!(x, y, cdims)
end

# N < 5 -dimension in-place versions
# N < 5 -dimension in-place versions
function unfold!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, N}, cdims::DenseConvDims) where {yT, xT, N}
unfold!(
y,
insert_singleton_spatial_dimension(x, 5-N),
insert_singleton_spatial_dimension(cdims, 5-N),
y,
insert_singleton_spatial_dimension(x, 5-N),
insert_singleton_spatial_dimension(cdims, 5-N),
)
return y
end

function fold!(x::AbstractArray{xT, N}, y::AbstractArray{yT, 3}, cdims::DenseConvDims) where {yT, xT, N}
fold!(
insert_singleton_spatial_dimension(x, 5-N),
insert_singleton_spatial_dimension(x, 5-N),
y,
insert_singleton_spatial_dimension(cdims, 5-N),
insert_singleton_spatial_dimension(cdims, 5-N),
)
return x
end

# 5-dimension in-place versions
# 5-dimension in-place versions
function unfold!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, 5}, cdims::DenseConvDims) where {yT, xT}
@threads for batch_idx in 1:size(x, 5)
y_slice = view(y, :, :, batch_idx)
Expand All @@ -173,6 +172,110 @@
return x
end

@kernel function unfold_kernel!(

Check warning on line 175 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L175

Added line #L175 was not covered by tests
col::AbstractArray{T}, x, col_size,
input_size, output_size, kernel_size,
flipkernel, stride, pad_lo, dilation, max_idx,
) where T
index = @index(Global)

Check warning on line 180 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L180

Added line #L180 was not covered by tests

@inbounds if index ≤ max_idx
i, kw, kh, kd, c, b = CartesianIndices(col_size)[index].I # col indices
w, h, d = CartesianIndices(output_size)[i].I # x indices

Check warning on line 184 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L182-L184

Added lines #L182 - L184 were not covered by tests

# project
w, h, d = @. ((w, h, d) - 1) * stride - pad_lo + 1 + ((kw, kh, kd) - 1) * dilation

Check warning on line 187 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L187

Added line #L187 was not covered by tests

if !flipkernel
kw, kh, kd = kernel_size .- (kw, kh, kd) .+ 1

Check warning on line 190 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L189-L190

Added lines #L189 - L190 were not covered by tests
end

# check out of bounds
if !all(checkindex.(Bool, UnitRange.(1, input_size), (w, h, d)))
col[i, kw, kh, kd, c, b] = T(0)

Check warning on line 195 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L194-L195

Added lines #L194 - L195 were not covered by tests
else
xval::T = x[w, h, d, c, b]
col[i, kw, kh, kd, c, b] = xval

Check warning on line 198 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L197-L198

Added lines #L197 - L198 were not covered by tests
end
end
end

@kernel function fold_kernel!(

Check warning on line 203 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L203

Added line #L203 was not covered by tests
x::AbstractArray{T}, col, col_size,
input_size, output_size, kernel_size,
flipkernel, stride, pad_lo, dilation, max_idx,
) where T
index = @index(Global)

Check warning on line 208 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L208

Added line #L208 was not covered by tests

@inbounds if index ≤ max_idx
i, kw, kh, kd, c, b = CartesianIndices(col_size)[index].I # col indices
w, h, d = CartesianIndices(output_size)[i].I # x indices

Check warning on line 212 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L210-L212

Added lines #L210 - L212 were not covered by tests

# project
w, h, d = @. ((w, h, d) - 1) * stride - pad_lo + 1 + ((kw, kh, kd) - 1) * dilation

Check warning on line 215 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L215

Added line #L215 was not covered by tests

# check out of bounds
if all(checkindex.(Bool, UnitRange.(1, input_size), (w, h, d)))
if !flipkernel
kw, kh, kd = kernel_size .- (kw, kh, kd) .+ 1

Check warning on line 220 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L218-L220

Added lines #L218 - L220 were not covered by tests
end

cval::T = col[i, kw, kh, kd, c, b]
@atomic x[w, h, d, c, b] += cval

Check warning on line 224 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L223-L224

Added lines #L223 - L224 were not covered by tests
end
end
end

function unfold!(
col::AnyGPUArray{cT,3}, x::AnyGPUArray{xT,5}, cdims::DenseConvDims,
) where {cT, xT}
spatial_dims(cdims) != 3 && throw(DimensionMismatch(
"unfold!() only accepts 3d convoluitional inputs"))

C_in = channels_in(cdims)
ker_size = kernel_size(cdims)
pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims)
pad_lo = (pad_w_lo, pad_h_lo, pad_d_lo)

out_size = output_size(cdims)
col_reshaped = reshape(col, (prod(out_size), ker_size..., C_in, :))

max_idx = prod(size(col))
unfold_kernel!(get_backend(x))(
col_reshaped, x, size(col_reshaped),
input_size(cdims), out_size, ker_size,
flipkernel(cdims), stride(cdims), pad_lo, dilation(cdims), max_idx;
ndrange=max_idx)
return col
end

function fold!(
x::AnyGPUArray{xT,5}, col::AnyGPUArray{cT,3}, cdims::DenseConvDims,
) where {xT, cT}
spatial_dims(cdims) != 3 && throw(DimensionMismatch(
"fold!() only accepts 3d convoluitional inputs"))

# going to accumulate into x
fill!(x, xT(0))

C_in = channels_in(cdims)
ker_size = kernel_size(cdims)
pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims)
pad_lo = (pad_w_lo, pad_h_lo, pad_d_lo)
out_size = output_size(cdims)

col_reshaped = reshape(col, (prod(out_size), ker_size..., C_in, :))

max_idx = prod(size(col))
fold_kernel!(get_backend(x))(
x, col_reshaped, size(col_reshaped),
input_size(cdims), out_size, ker_size,
flipkernel(cdims), stride(cdims), pad_lo, dilation(cdims), max_idx;
ndrange=max_idx)

return x
end

# reverse diff rules
function rrule(::typeof(unfold), x, cdims::DenseConvDims; kw...)
function unfold_pullback(Δ)
Expand Down
40 changes: 0 additions & 40 deletions test/fold.jl

This file was deleted.

Loading
Loading