diff --git a/src/KrylovPreconditioners.jl b/src/KrylovPreconditioners.jl index 4c90bd2..30bc8a3 100644 --- a/src/KrylovPreconditioners.jl +++ b/src/KrylovPreconditioners.jl @@ -5,6 +5,7 @@ using AMDGPU, AMDGPU.rocSPARSE using CUDA, CUDA.CUSPARSE using LinearAlgebra: checksquare, BlasReal, BlasFloat +import LinearAlgebra: ldiv! # Preconditioners include("ic0.jl") diff --git a/src/ic0.jl b/src/ic0.jl index 517fa00..9e783d4 100644 --- a/src/ic0.jl +++ b/src/ic0.jl @@ -1,3 +1,5 @@ +export ic0 + mutable struct NVIDIA_IC0{SM,DM} P::SM z::DM @@ -17,13 +19,13 @@ end for ArrayType in (:(CuVector{T}), :(CuMatrix{T})) @eval begin - function ldiv!(y::$ArrayType, ic::NVIDIA_IC0{<:$ArrayType,CuSparseMatrixCSR{T,Cint}}, x::$ArrayType) where T <: BlasFloat + function ldiv!(y::$ArrayType, ic::NVIDIA_IC0{CuSparseMatrixCSR{T,Cint},<:$ArrayType}, x::$ArrayType) where T <: BlasFloat ldiv!(ic.z, LowerTriangular(ic.P), x) # Forward substitution with L ldiv!(y, LowerTriangular(ic.P)', ic.z) # Backward substitution with Lᴴ return y end - function ldiv!(y::$ArrayType, ic::NVIDIA_IC0{<:$ArrayType,CuSparseMatrixCSC{T,Cint}}, x::$ArrayType) where T <: BlasFloat + function ldiv!(y::$ArrayType, ic::NVIDIA_IC0{CuSparseMatrixCSC{T,Cint},<:$ArrayType}, x::$ArrayType) where T <: BlasFloat ldiv!(ic.z, UpperTriangular(ic.P)', x) # Forward substitution with L ldiv!(y, UpperTriangular(ic.P), ic.z) # Backward substitution with Lᴴ return y diff --git a/src/ilu0.jl b/src/ilu0.jl index b7f2c23..8eeb759 100644 --- a/src/ilu0.jl +++ b/src/ilu0.jl @@ -1,3 +1,5 @@ +export ilu0 + mutable struct NVIDIA_ILU0{SM,DM} P::SM z::DM @@ -17,13 +19,13 @@ end for ArrayType in (:(CuVector{T}), :(CuMatrix{T})) @eval begin - function ldiv!(y::$ArrayType, ilu::NVIDIA_ILU0{<:$ArrayType,CuSparseMatrixCSR{T,Cint}}, x::$ArrayType) where T <: BlasFloat + function ldiv!(y::$ArrayType, ilu::NVIDIA_ILU0{CuSparseMatrixCSR{T,Cint},<:$ArrayType}, x::$ArrayType) where T <: BlasFloat ldiv!(ilu.z, UnitLowerTriangular(ilu.P), x) # Forward substitution with L ldiv!(y, UpperTriangular(ilu.P), ilu.z) # Backward substitution with U return y end - function ldiv!(y::$ArrayType, ilu::NVIDIA_ILU0{<:$ArrayType,CuSparseMatrixCSC{T,Cint}}, x::$ArrayType) where T <: BlasReal + function ldiv!(y::$ArrayType, ilu::NVIDIA_ILU0{CuSparseMatrixCSC{T,Cint},<:$ArrayType}, x::$ArrayType) where T <: BlasReal ldiv!(ilu.z, LowerTriangular(ilu.P), x) # Forward substitution with L ldiv!(y, UnitUpperTriangular(ilu.P), ilu.z) # Backward substitution with U return y