Skip to content

Commit

Permalink
feat: generalize diagm
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 12, 2024
1 parent a9a8b2a commit 995d957
Showing 1 changed file with 25 additions and 34 deletions.
59 changes: 25 additions & 34 deletions src/linear_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,7 @@ end
function materialize_traced_array(
x::LinearAlgebra.Tridiagonal{T,TracedRArray{T,1}}
) where {T}
scatter_indices = vcat(
diagonal_indices_zero_indexed(size(x, 1), size(x, 2), -1),
diagonal_indices_zero_indexed(size(x, 1), size(x, 2), 0),
diagonal_indices_zero_indexed(size(x, 1), size(x, 2), 1),
)
scatter_indices = Ops.constant(scatter_indices)

updates = TracedRArray{T,1}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.concatenate(
[x.dl.mlir_data, x.d.mlir_data, x.du.mlir_data]; dimension=0
),
1,
),
(size(scatter_indices, 1),),
)

return simple_scatter_op(size(x), scatter_indices, updates)
return LinearAlgebra.diagm(-1 => x.dl, 0 => x.d, 1 => x.du)
end

for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE"))
Expand Down Expand Up @@ -251,13 +233,23 @@ function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T}
return TracedRArray{T,1}((), res, (diag_length,))
end

function LinearAlgebra.diagm(v::AnyTracedRArray{T,1}) where {T}
return LinearAlgebra.diagm(length(v), length(v), v)
end
function LinearAlgebra.diagm(m::Integer, n::Integer, v::AnyTracedRArray{T,1}) where {T}
m, n = LinearAlgebra.diagm_size((m, n), 0 => v) # size check
indices = Ops.constant(diagonal_indices_zero_indexed(m, n, 0)[1:length(v), :])
return simple_scatter_op((m, n), indices, materialize_traced_array(v))
function LinearAlgebra._diagm(
shape, kv::Pair{<:Integer,<:AnyTracedRArray{T,1}}...
) where {T}
m, n = LinearAlgebra.diagm_size(shape, kv...)
scatter_indices = Matrix{Int64}[]
concat_inputs = MLIR.IR.Value[]
for (k, v) in kv
push!(scatter_indices, diagonal_indices_zero_indexed(m, n, k)[1:length(v), :])
push!(concat_inputs, get_mlir_data(v))
end
scatter_indices = Ops.constant(reduce(vcat, scatter_indices))
values = TracedRArray{T,1}(
(),
MLIR.IR.result(MLIR.Dialects.stablehlo.concatenate(concat_inputs; dimension=0), 1),
(size(scatter_indices, 1),),
)
return simple_scatter_op((m, n), scatter_indices, values)
end

# Common Utilities
Expand Down Expand Up @@ -309,15 +301,14 @@ function simple_scatter_op(
return TracedRArray{T,2}((), res, shape)
end

# The cartesian version doesn't exist in julia 1.10
## The cartesian version doesn't exist in julia 1.10
function diagonal_indices_zero_indexed(m::Integer, n::Integer, k::Integer=0)
Cstart = CartesianIndex(1 + max(0, -k), 1 + max(0, k))
Cstep = CartesianIndex(1, 1)
res = StepRangeLen(Cstart, Cstep, max(0, k <= 0 ? min(m + k, n) : min(m, n - k)))
indices = Matrix{Int}(undef, (length(res), 2))
for (i, idx) in enumerate(res)
indices[i, 1] = idx[1] - 1
indices[i, 2] = idx[2] - 1
idx1, idx2 = 1 + max(0, -k), 1 + max(0, k)
L = max(0, k 0 ? min(m + k, n) : min(m, n - k))
indices = Matrix{Int}(undef, (L, 2))
for i in axes(indices, 1)
indices[i, 1] = idx1 + i - 2
indices[i, 2] = idx2 + i - 2
end
return indices
end

0 comments on commit 995d957

Please sign in to comment.