Skip to content

Commit

Permalink
hotfix
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Oct 20, 2023
1 parent 6cf2f3d commit fe2bc2d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/KrylovPreconditioners.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using AMDGPU, AMDGPU.rocSPARSE
using CUDA, CUDA.CUSPARSE

using LinearAlgebra: checksquare, BlasReal, BlasFloat
import LinearAlgebra: ldiv!

# Preconditioners
include("ic0.jl")
Expand Down
6 changes: 4 additions & 2 deletions src/ic0.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
export ic0

mutable struct NVIDIA_IC0{SM,DM}
P::SM
z::DM
Expand All @@ -17,13 +19,13 @@ end

for ArrayType in (:(CuVector{T}), :(CuMatrix{T}))
@eval begin
function ldiv!(y::$ArrayType, ic::NVIDIA_IC0{<:$ArrayType,CuSparseMatrixCSR{T,Cint}}, x::$ArrayType) where T <: BlasFloat
function ldiv!(y::$ArrayType, ic::NVIDIA_IC0{CuSparseMatrixCSR{T,Cint},<:$ArrayType}, x::$ArrayType) where T <: BlasFloat
ldiv!(ic.z, LowerTriangular(ic.P), x) # Forward substitution with L
ldiv!(y, LowerTriangular(ic.P)', ic.z) # Backward substitution with Lᴴ
return y
end

function ldiv!(y::$ArrayType, ic::NVIDIA_IC0{<:$ArrayType,CuSparseMatrixCSC{T,Cint}}, x::$ArrayType) where T <: BlasFloat
function ldiv!(y::$ArrayType, ic::NVIDIA_IC0{CuSparseMatrixCSC{T,Cint},<:$ArrayType}, x::$ArrayType) where T <: BlasFloat
ldiv!(ic.z, UpperTriangular(ic.P)', x) # Forward substitution with L
ldiv!(y, UpperTriangular(ic.P), ic.z) # Backward substitution with Lᴴ
return y
Expand Down
6 changes: 4 additions & 2 deletions src/ilu0.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
export ilu0

mutable struct NVIDIA_ILU0{SM,DM}
P::SM
z::DM
Expand All @@ -17,13 +19,13 @@ end

for ArrayType in (:(CuVector{T}), :(CuMatrix{T}))
@eval begin
function ldiv!(y::$ArrayType, ilu::NVIDIA_ILU0{<:$ArrayType,CuSparseMatrixCSR{T,Cint}}, x::$ArrayType) where T <: BlasFloat
function ldiv!(y::$ArrayType, ilu::NVIDIA_ILU0{CuSparseMatrixCSR{T,Cint},<:$ArrayType}, x::$ArrayType) where T <: BlasFloat
ldiv!(ilu.z, UnitLowerTriangular(ilu.P), x) # Forward substitution with L
ldiv!(y, UpperTriangular(ilu.P), ilu.z) # Backward substitution with U
return y
end

function ldiv!(y::$ArrayType, ilu::NVIDIA_ILU0{<:$ArrayType,CuSparseMatrixCSC{T,Cint}}, x::$ArrayType) where T <: BlasReal
function ldiv!(y::$ArrayType, ilu::NVIDIA_ILU0{CuSparseMatrixCSC{T,Cint},<:$ArrayType}, x::$ArrayType) where T <: BlasReal
ldiv!(ilu.z, LowerTriangular(ilu.P), x) # Forward substitution with L
ldiv!(y, UnitUpperTriangular(ilu.P), ilu.z) # Backward substitution with U
return y
Expand Down

0 comments on commit fe2bc2d

Please sign in to comment.