Skip to content

Commit

Permalink
feat: support scalars
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 24, 2025
1 parent 10b5683 commit ca8780b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
19 changes: 19 additions & 0 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,25 @@ function broadcast_in_dim(
return TracedRArray{T,Int64(length(result_size))}((), res, Tuple(result_size))
end

function broadcast_in_dim(
x::TracedRNumber{T},
dims::Vector{Int},
result_size::Vector{Int};
location=mlir_stacktrace("broadcast_in_dim", @__FILE__, @__LINE__),
) where {T}
@assert length(dims) == 0

res = MLIR.IR.result(
stablehlo.broadcast_in_dim(
x.mlir_data;
result_0=MLIR.IR.TensorType(result_size, MLIR.IR.Type(T)),
broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims .- 1),
location,
),
)
return TracedRArray{T,Int64(length(result_size))}((), res, Tuple(result_size))
end

@noinline function sort(
xs::TracedRArray...;
comparator,
Expand Down
16 changes: 12 additions & 4 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,28 @@ function get_ancestor_indices(
) where {T,N,M}
@assert length(indices) == N "Expected $N indices, got $(length(indices))"
if any(is_traced, indices)
# XXX: scalars are not supported
final_size = Vector{Int64}(undef, N)
ddims = Int64[]
for (i, idx) in enumerate(indices)
@assert ndims(idx) == 1 "Unsupported feature. Please file an issue."
@assert ndims(idx) == 1 || ndims(idx) == 0 "Unsupported feature. Please file an issue."
ndims(idx) == 0 && push!(ddims, i)
final_size[i] = length(idx)
end
@show Base.strides(x)
linear_indices = mapreduce(+, enumerate(indices)) do (i, idx)
Base.stride(x, i) .* (Ops.broadcast_in_dim(idx, Int64[i], final_size) .- 1) .+ 1
bcasted_idxs = Ops.broadcast_in_dim(
idx, ndims(idx) == 0 ? Int64[] : Int64[i], final_size
)
Base.stride(x, i) .* (bcasted_idxs .- 1) .+ 1
end
parent_linear_indices_all = collect(LinearIndices(size(parent(x))))
parent_linear_indices = TracedUtils.promote_to(
TracedRArray{Int64,ndims(parent_linear_indices_all)}, parent_linear_indices_all
)[linear_indices]
isempty(ddims) || (
parent_linear_indices = materialize_traced_array(
dropdims(parent_linear_indices; dims=Tuple(ddims))
)
)
return (parent_linear_indices,)
else
# Have this as a separate code-path since we can generate non-dynamic indexing
Expand Down

0 comments on commit ca8780b

Please sign in to comment.