Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dims keyword to gather #448

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions src/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,28 @@ or multiple `dst` columns.

See [`gather`](@ref) for an allocating version.
"""
function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
dims = scatter_dims(src, dst, idx)
colons = ntuple(i -> Colon(), dims)
# function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
# dims = scatter_dims(src, dst, idx)
# colons = ntuple(i -> Colon(), dims)
# for k in CartesianIndices(idx)
# _view(dst, colons, k) .= _view(src, colons, idx[k])
# end
# return dst
# end

"""
dst[:, ... , k, :,...] .= src[:, ... , idx[k]..., :,...]
"""
function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray; dims = nothing)
nbefore, nafter = scatter_dims(src, dst, idx, dims)
colbefore = ntuple(i -> Colon(), nbefore)
colafter = ntuple(i -> Colon(), nafter)
for k in CartesianIndices(idx)
_view(dst, colons, k) .= _view(src, colons, idx[k])
_view(dst, colbefore, k, colafter) .= _view(src, colbefore, idx[k], colafter)
end
return dst
end


"""
NNlib.gather(src, idx) -> dst

Expand Down
40 changes: 28 additions & 12 deletions src/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,41 @@ typelength(::Type{CartesianIndex{M}}) where M = M
Performs dimensional consistency checks and return the
dimensionality of the scattered objects.
"""

function scatter_dims(X::AbstractArray{Tx,Nx},
Y::AbstractArray{Ty,Ny},
idx::AbstractArray{Tidx,Nidx}) where {Tx,Ty,Tidx,Nx,Ny,Nidx}
M = typelength(Tidx)
dims = scatter_dims(Nx, Ny, M, Nidx)
size(X)[1:dims] == size(Y)[1:dims] || throw(ArgumentError("Incompatible input shapes."))
size(Y)[dims+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes."))
return dims
idx::AbstractArray{Tidx,Nidx},
dims::Union{Nothing, Integer} = nothing) where {Tx,Ty,Tidx,Nx,Ny,Nidx}
nsrcin = typelength(Tidx)
ndstin = Nidx
nbefore, nafter = scatter_dims(Nx, Ny, nsrcin, ndstin, dims)
size(Y)[1:nbefore] == size(X)[1:nbefore] || throw(ArgumentError("Incompatible input shapes."))
size(Y)[nbefore+1:nbefore+ndstin] == size(idx) || throw(ArgumentError("Incompatible input shapes."))
size(Y)[nbefore+ndstin+1:end] == size(X)[nbefore+nsrcin+1:end] || throw(ArgumentError("Incompatible input shapes."))
return nbefore, nafter
end

function scatter_dims(Nx, Ny, M, Nidx)
@assert Nx - M == Ny - Nidx "Incompatible input shapes of (dst, src, idx) = ($Nx, $Ny, $Nidx)."
dims = Nx - M
dims < 0 && throw(ArgumentError("dims must be non-negative but got dims=$dims."))
return dims
function scatter_dims(Nx, Ny, nsrcin, ndstin, dims = nothing)
@assert Nx - nsrcin == Ny - ndstin "Incompatible input shapes of (dst, src, idx, Tidx) = ($Nx, $Ny, $ndstin, $nsrcin)."

if dims === nothing
nbefore = Nx - nsrcin
nbefore < 0 && throw(ArgumentError("nbefore must be non-negative but got $nbefore."))
nafter = 0
return nbefore, nafter
else
nbefore = dims - 1
nafter = Ny - ndstin - nbefore
nbefore < 0 && throw(ArgumentError("nbefore must be non-negative but got $nbefore."))
nafter < 0 && throw(ArgumentError("nafter must be non-negative but got $nafter."))
return nbefore, nafter
end
end

_view(X, colons, k) = view(X, colons..., k...)
_view(X, colons, k::Union{Integer, CartesianIndex}) = view(X, colons..., k)
_view(X, colbefore, k, colafter) = view(X, colbefore..., k..., colafter...)
_view(X, colbefore, k::Union{Integer, CartesianIndex}, colafter) = view(X, colbefore..., k, colafter...)

"""
NNlib.scatter!(op, dst, src, idx)
Expand Down Expand Up @@ -78,7 +94,7 @@ function scatter!(op::OP, dst::AbstractArray, src::AbstractArray, idx::AbstractA
src_v = _view(src, colons, k)
dst_v .= (op).(dst_v, src_v)
end
dst
return dst
end

function scatter!(op::typeof(mean), dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
Expand Down
15 changes: 15 additions & 0 deletions test/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,18 @@ end
gradtest(xs -> gather!(dst, xs, index), src)
gradtest(xs -> gather(xs, index), src)
end

using NNlib: gather!, gather

@testset "gather! dims" begin

src = reshape([1:15;], 3, 5)
index = [1, 3, 2, 2]
dst = zeros(Int, 4, 5)
gather!(dst, src, index, dims=1)

@test dst == [1 4 7 10 13
3 6 9 12 15
2 5 8 11 14
2 5 8 11 14]
end