Skip to content

Commit

Permalink
feat: more indexing support (#608)
Browse files Browse the repository at this point in the history
* feat: overload LinearAlgebra.kron

* test: kron

* feat: more indexing support

* refactor: move tests around a bit

* fix: cleanup implementation and add tests
  • Loading branch information
avik-pal authored Jan 24, 2025
1 parent 534bea3 commit 2118ee2
Show file tree
Hide file tree
Showing 6 changed files with 301 additions and 274 deletions.
34 changes: 8 additions & 26 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,7 @@ function Base.getindex(a::TracedRArray{T,N}, indices::CartesianIndex{N}) where {
end

function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
indices = map(enumerate(indices)) do (idx, i)
i isa Colon && return 1:size(a, idx)
i isa CartesianIndex && return Tuple(i)
i isa AbstractArray{<:Bool} && return findall(i)
return i
end
indices = TracedUtils.normalize_indices(a, indices...)

use_gather_getindex = false
for idxs in indices
Expand All @@ -168,7 +163,7 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
use_gather_getindex = true
break
end
contiguous = all(isone, diff(idxs))
contiguous = all(isone, diff(vec(idxs)))
if typeof(contiguous) <: Bool && !contiguous
use_gather_getindex = true
break
Expand All @@ -181,19 +176,11 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
if any(i -> unwrapped_eltype(i) <: Bool, indices)
error("Boolean indexing with TracedRArrays isn't fully supported yet.")
end
idxs = map(indices) do i
i isa Number && return fill(i, 1)
return i
end
indices_list = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{Int,1}), idxs)
indices_list = generate_index_list(indices_list...)
res = Ops.gather_getindex(a, indices_list)
res = Ops.reshape(res, length.(idxs)...)
ddims = findall(indices) do idx
return idx isa Integer || idx isa TracedRNumber{<:Integer}
end
isempty(ddims) || return materialize_traced_array(dropdims(res; dims=Tuple(ddims)))
return res
indices, integer_indices, result_size, _ = TracedUtils.traced_indices(indices...)
res = Ops.gather_getindex(a, generate_index_list(indices...))
isempty(integer_indices) ||
(res = materialize_traced_array(dropdims(res; dims=integer_indices)))
return Ops.reshape(res, result_size)
end

start_indices = map(indices) do i
Expand Down Expand Up @@ -233,12 +220,7 @@ maybe_assert_scalar_setindexing(args...) = nothing
function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N}
maybe_assert_scalar_setindexing(a, indices...)

indices = map(enumerate(indices)) do (idx, i)
i isa Colon && return 1:size(a, idx)
i isa CartesianIndex && return Tuple(i)
i isa AbstractArray{<:Bool} && return findall(i)
return i
end
indices = TracedUtils.normalize_indices(a, indices...)

use_scatter_setindex = false
for idxs in indices
Expand Down
62 changes: 42 additions & 20 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,30 +64,26 @@ function get_ancestor_indices(
x::WrappedReshapedArray{TracedRNumber{T},N,TracedRArray{T,M}}, indices...
) where {T,N,M}
@assert length(indices) == N "Expected $N indices, got $(length(indices))"
indices = normalize_indices(x, indices...)
if any(is_traced, indices)
final_size = Vector{Int64}(undef, N)
ddims = Int64[]
for (i, idx) in enumerate(indices)
@assert ndims(idx) == 1 || ndims(idx) == 0 "Unsupported feature. Please file an issue."
ndims(idx) == 0 && push!(ddims, i)
final_size[i] = length(idx)
end
indices, integer_indices, result_size, flattened_size = traced_indices(indices...)
linear_indices = mapreduce(+, enumerate(indices)) do (i, idx)
bcasted_idxs = Ops.broadcast_in_dim(
idx, ndims(idx) == 0 ? Int64[] : Int64[i], final_size
idx, ndims(idx) == 0 ? Int64[] : Int64[i], flattened_size
)
Base.stride(x, i) .* (bcasted_idxs .- 1)
end
linear_indices = linear_indices .+ 1
parent_linear_indices_all = collect(LinearIndices(size(parent(x))))
parent_linear_indices = TracedUtils.promote_to(
parent_linear_indices = promote_to(
TracedRArray{Int64,ndims(parent_linear_indices_all)}, parent_linear_indices_all
)[linear_indices]
isempty(ddims) || (
isempty(integer_indices) || (
parent_linear_indices = materialize_traced_array(
dropdims(parent_linear_indices; dims=Tuple(ddims))
dropdims(parent_linear_indices; dims=integer_indices)
)
)
parent_linear_indices = Ops.reshape(parent_linear_indices, result_size)
return (parent_linear_indices,)
else
# Have this as a separate code-path since we can generate non-dynamic indexing
Expand All @@ -106,7 +102,7 @@ function set_mlir_data!(
end

function set_mlir_data!(x::AnyTracedRArray{T}, data) where {T}
ancestor_indices = TracedUtils.get_ancestor_indices(x, axes(x)...)
ancestor_indices = get_ancestor_indices(x, axes(x)...)
setindex!(Reactant.ancestor(x), TracedRArray{T}(data), ancestor_indices...)
return x
end
Expand Down Expand Up @@ -317,7 +313,7 @@ elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitive} = x
struct TypeCast{T<:ReactantPrimitive} <: Function end

function (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2}
return TracedUtils.promote_to(TracedRNumber{T}, x)
return promote_to(TracedRNumber{T}, x)
end

function elem_apply(::Type{T}, x::TracedRArray) where {T<:ReactantPrimitive}
Expand Down Expand Up @@ -434,7 +430,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
batch_inputs = MLIR.IR.Value[]

for a in linear_args
idx, path = TracedUtils.get_argidx(a)
idx, path = get_argidx(a)
if idx == 1 && fnwrap
push_val!(batch_inputs, f, path[3:end])
else
Expand All @@ -455,20 +451,20 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
residx = 1

for a in linear_results
if TracedUtils.has_residx(a)
path = TracedUtils.get_residx(a)
TracedUtils.set!(result, path[2:end], MLIR.IR.result(res, residx))
if has_residx(a)
path = get_residx(a)
set!(result, path[2:end], MLIR.IR.result(res, residx))
residx += 1
else
idx, path = TracedUtils.get_argidx(a)
idx, path = get_argidx(a)
if idx == 1 && fnwrap
TracedUtils.set!(f, path[3:end], MLIR.IR.result(res, residx))
set!(f, path[3:end], MLIR.IR.result(res, residx))
residx += 1
else
if fnwrap
idx -= 1
end
TracedUtils.set!(args[idx], path[3:end], MLIR.IR.result(res, residx))
set!(args[idx], path[3:end], MLIR.IR.result(res, residx))
residx += 1
end
end
Expand Down Expand Up @@ -523,4 +519,30 @@ end
return Ops.broadcast_in_dim(x, collect(Int64, 1:ndims(x)), collect(Int64, rsize))
end

function normalize_indices(a::AbstractArray, indices...)
return map(enumerate(indices)) do (i, idx)
idx isa Colon && return collect(Int64, 1:size(a, i))
idx isa CartesianIndex && return Tuple(idx)
idx isa AbstractArray{Bool} && return findall(idx)
return idx
end
end

function traced_indices(indices...)
integer_indices = Int64[]
result_size = Int64[]
flattened_size = Int64[length(idx) for idx in indices]
new_indices = map(enumerate(indices)) do (i, idx)
if idx isa Number
push!(integer_indices, i)
idx isa TracedRNumber && return idx
return promote_to(TracedRNumber{Int}, idx)
end
append!(result_size, [size(idx)...])
idx isa TracedRArray && return materialize_traced_array(vec(idx))
return promote_to(TracedRArray{Int,1}, vec(idx))
end
return new_indices, Tuple(integer_indices), result_size, flattened_size
end

end
Loading

0 comments on commit 2118ee2

Please sign in to comment.