Skip to content

Commit

Permalink
fix: dispatches
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal authored and wsmoses committed Dec 29, 2024
1 parent 0e6d376 commit ccc61d4
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions src/stdlibs/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,24 @@ using ..TracedUtils: TracedUtils, get_mlir_data, materialize_traced_array, set_m
using LinearAlgebra

# Various Wrapper Arrays defined in LinearAlgebra
function materialize_traced_array(
function TracedUtils.materialize_traced_array(
x::Transpose{TracedRNumber{T},TracedRArray{T,N}}
) where {T,N}
px = parent(x)
A = ndims(px) == 1 ? reshape(px, :, 1) : px
return permutedims(A, (2, 1))
end

function materialize_traced_array(
function TracedUtils.materialize_traced_array(
x::Adjoint{TracedRNumber{T},TracedRArray{T,N}}
) where {T,N}
return conj(materialize_traced_array(transpose(parent(x))))
end

function materialize_traced_array(
x::LinearAlgebra.Diagonal{TracedRNumber{T},TracedRArray{T,1}}
function TracedUtils.materialize_traced_array(
x::Diagonal{TracedRNumber{T},TracedRArray{T,1}}
) where {T}
return LinearAlgebra.diagm(parent(x))
return diagm(parent(x))
end

function TracedUtils.materialize_traced_array(x::Tridiagonal{T,TracedRArray{T,1}}) where {T}
Expand All @@ -42,7 +42,7 @@ for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE"))
uAT = Symbol(:Unit, AT)
@eval begin
function TracedUtils.materialize_traced_array(
x::$(AT){T,TracedRArray{T,2}}
x::$(AT){TracedRNumber{T},TracedRArray{T,2}}
) where {T}
m, n = size(x)
row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1)
Expand All @@ -52,7 +52,7 @@ for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE"))
end

function TracedUtils.materialize_traced_array(
x::$(uAT){T,TracedRArray{T,2}}
x::$(uAT){TracedRNumber{T},TracedRArray{T,2}}
) where {T}
m, n = size(x)
row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1)
Expand All @@ -64,7 +64,9 @@ for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE"))
end
end

function TracedUtils.materialize_traced_array(x::Symmetric{T,TracedRArray{T,2}}) where {T}
function TracedUtils.materialize_traced_array(
x::Symmetric{TracedRNumber{T},TracedRArray{T,2}}
) where {T}
m, n = size(x)
row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1)
col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2)
Expand Down Expand Up @@ -107,7 +109,9 @@ function TracedUtils.set_mlir_data!(
return x
end

function TracedUtils.set_mlir_data!(x::Diagonal{TracedRNumber{T},TracedRArray{T,1}}, data) where {T}
function TracedUtils.set_mlir_data!(
x::Diagonal{TracedRNumber{T},TracedRArray{T,1}}, data
) where {T}
parent(x).mlir_data = diag(TracedRArray{T}(data)).mlir_data
return x
end
Expand All @@ -119,7 +123,7 @@ for (AT, dcomp, ocomp) in (
(:UnitUpperTriangular, "LT", "GE"),
)
@eval function TracedUtils.set_mlir_data!(
x::LinearAlgebra.$(AT){T,TracedRArray{T,2}}, data
x::$(AT){TracedRNumber{T},TracedRArray{T,2}}, data
) where {T}
tdata = TracedRArray{T}(data)
z = zero(tdata)
Expand All @@ -137,17 +141,19 @@ for (AT, dcomp, ocomp) in (
end

function TracedUtils.set_mlir_data!(
x::LinearAlgebra.Symmetric{T,TracedRArray{T,2}}, data
x::Symmetric{TracedRNumber{T},TracedRArray{T,2}}, data
) where {T}
if x.uplo == 'L'
set_mlir_data!(LinearAlgebra.LowerTriangular(parent(x)), data)
set_mlir_data!(LowerTriangular(parent(x)), data)
else
set_mlir_data!(LinearAlgebra.UpperTriangular(parent(x)), data)
set_mlir_data!(UpperTriangular(parent(x)), data)
end
return x
end

function TracedUtils.set_mlir_data!(x::Tridiagonal{T,TracedRArray{T,1}}, data) where {T}
function TracedUtils.set_mlir_data!(
x::Tridiagonal{TracedRNumber{T},TracedRArray{T,1}}, data
) where {T}
tdata = TracedRArray{T}(data)
set_mlir_data!(x.dl, diag(tdata, -1).mlir_data)
set_mlir_data!(x.d, diag(tdata, 0).mlir_data)
Expand Down

0 comments on commit ccc61d4

Please sign in to comment.