Skip to content

Commit

Permalink
Add KrylovOperator and TriangularOperator for oneAPI.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Apr 4, 2024
1 parent 45988dc commit 6e548be
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 61 deletions.
2 changes: 2 additions & 0 deletions ext/KrylovPreconditionersoneAPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module KrylovPreconditionersoneAPIExt
using LinearAlgebra
using SparseArrays
using oneAPI
using oneAPI: global_queue, sycl_queue, context, device
using oneAPI.oneMKL
using LinearAlgebra: checksquare, BlasReal, BlasFloat
import LinearAlgebra: ldiv!, mul!
Expand All @@ -13,5 +14,6 @@ using KernelAbstractions
const KA = KernelAbstractions

include("oneAPI/blockjacobi.jl")
include("oneAPI/operators.jl")

end
95 changes: 95 additions & 0 deletions ext/oneAPI/operators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
mutable struct INTEL_KrylovOperator{T} <: AbstractKrylovOperator{T}
type::Type{T}
m::Int
n::Int
nrhs::Int
transa::Char
matrix::oneSparseMatrixCSR{T}
end

eltype(A::INTEL_KrylovOperator{T}) where T = T
size(A::INTEL_KrylovOperator) = (A.m, A.n)

for (SparseMatrixType, BlasType) in ((:(oneSparseMatrixCSR{T}), :BlasFloat),)
@eval begin
function KP.KrylovOperator(A::$SparseMatrixType; nrhs::Int=1, transa::Char='N') where T <: $BlasType
m,n = size(A)
if nrhs == 1
oneMKL.sparse_optimize_gemv!(transa, A)
end
# sparse_optimize_gemm! is only available with oneAPI 2024.1.0
return INTEL_KrylovOperator{T}(T, m, n, nrhs, transa, A)
end

function KP.update!(A::INTEL_KrylovOperator{T}, B::$SparseMatrixType) where T <: $BlasFloat
error("The update of an INTEL_KrylovOperator is not supported.")
end
end
end

function LinearAlgebra.mul!(y::oneVector{T}, A::INTEL_KrylovOperator{T}, x::oneVector{T}) where T <: BlasFloat
(length(y) != A.m) && throw(DimensionMismatch("length(y) != A.m"))
(length(x) != A.n) && throw(DimensionMismatch("length(x) != A.n"))
(A.nrhs == 1) || throw(DimensionMismatch("A.nrhs != 1"))
alpha = one(T)
beta = zero(T)
oneMKL.sparse_gemv!(A.transa, alpha, A.matrix, x, beta, y)
end

function LinearAlgebra.mul!(Y::oneMatrix{T}, A::INTEL_KrylovOperator{T}, X::oneMatrix{T}) where T <: BlasFloat
mY, nY = size(Y)
mX, nX = size(X)
(mY != A.m) && throw(DimensionMismatch("mY != A.m"))
(mX != A.n) && throw(DimensionMismatch("mX != A.n"))
(nY == nX == A.nrhs) || throw(DimensionMismatch("nY != A.nrhs or nX != A.nrhs"))
alpha = one(T)
beta = zero(T)
oneMKL.sparse_gemm!(A.transa, 'N', alpha, A.matrix, X, beta, Y)
end

mutable struct INTEL_TriangularOperator{T} <: AbstractTriangularOperator{T}
type::Type{T}
m::Int
n::Int
nrhs::Int
uplo::Char
diag::Char
transa::Char
matrix::oneSparseMatrixCSR{T}
end

eltype(A::INTEL_TriangularOperator{T}) where T = T
size(A::INTEL_TriangularOperator) = (A.m, A.n)

for (SparseMatrixType, BlasType) in ((:(oneSparseMatrixCSR{T}), :BlasFloat),)
@eval begin
function KP.TriangularOperator(A::$SparseMatrixType, uplo::Char, diag::Char; nrhs::Int=1, transa::Char='N') where T <: $BlasType
m,n = size(A)
if nrhs == 1
oneMKL.sparse_optimize_trsv!(uplo, transa, diag, A)
end
# sparse_optimize_trsm! is only available with oneAPI 2024.1.0
return INTEL_TriangularOperator{T}(T, m, n, nrhs, uplo, diag, transa, A)
end

function KP.update!(A::INTEL_TriangularOperator{T}, B::$SparseMatrixType) where T <: $BlasFloat
return error("The update of an INTEL_TriangularOperator is not supported.")
end
end
end

function LinearAlgebra.ldiv!(y::oneVector{T}, A::INTEL_TriangularOperator{T}, x::oneVector{T}) where T <: BlasFloat
(length(y) != A.m) && throw(DimensionMismatch("length(y) != A.m"))
(length(x) != A.n) && throw(DimensionMismatch("length(x) != A.n"))
(A.nrhs == 1) || throw(DimensionMismatch("A.nrhs != 1"))
oneMKL.sparse_trsv!(A.uplo, A.transa, A.diag, A.matrix, x, y)
end

function LinearAlgebra.ldiv!(Y::oneMatrix{T}, A::INTEL_TriangularOperator{T}, X::oneMatrix{T}) where T <: BlasFloat
mY, nY = size(Y)
mX, nX = size(X)
(mY != A.m) && throw(DimensionMismatch("mY != A.m"))
(mX != A.n) && throw(DimensionMismatch("mX != A.n"))
(nY == nX == A.nrhs) || throw(DimensionMismatch("nY != A.nrhs or nX != A.nrhs"))
error("The routine sparse_trsm! is only available with oneAPI 2024.1.0")
end
130 changes: 69 additions & 61 deletions test/gpu/gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,41 +90,45 @@ function test_operator(FC, V, DM, SM)
mul!(y_gpu, opA_gpu, x_gpu)
@test collect(y_gpu) y_cpu
end
for j = 1:5
y_cpu = rand(FC, m)
x_cpu = rand(FC, n)
A_cpu2 = A_cpu + j*I
mul!(y_cpu, A_cpu2, x_cpu)
y_gpu = V(y_cpu)
x_gpu = V(x_cpu)
A_gpu2 = SM(A_cpu2)
update!(opA_gpu, A_gpu2)
mul!(y_gpu, opA_gpu, x_gpu)
@test collect(y_gpu) y_cpu
if V.body.name.name != :oneArray
for j = 1:5
y_cpu = rand(FC, m)
x_cpu = rand(FC, n)
A_cpu2 = A_cpu + j*I
mul!(y_cpu, A_cpu2, x_cpu)
y_gpu = V(y_cpu)
x_gpu = V(x_cpu)
A_gpu2 = SM(A_cpu2)
update!(opA_gpu, A_gpu2)
mul!(y_gpu, opA_gpu, x_gpu)
@test collect(y_gpu) y_cpu
end
end

nrhs = 3
opA_gpu = KrylovOperator(A_gpu; nrhs)
for i = 1:5
Y_cpu = rand(FC, m, nrhs)
X_cpu = rand(FC, n, nrhs)
mul!(Y_cpu, A_cpu, X_cpu)
Y_gpu = DM(Y_cpu)
X_gpu = DM(X_cpu)
mul!(Y_gpu, opA_gpu, X_gpu)
@test collect(Y_gpu) Y_cpu
end
for j = 1:5
Y_cpu = rand(FC, m, nrhs)
X_cpu = rand(FC, n, nrhs)
A_cpu2 = A_cpu + j*I
mul!(Y_cpu, A_cpu2, X_cpu)
Y_gpu = DM(Y_cpu)
X_gpu = DM(X_cpu)
A_gpu2 = SM(A_cpu2)
update!(opA_gpu, A_gpu2)
mul!(Y_gpu, opA_gpu, X_gpu)
@test collect(Y_gpu) Y_cpu
if V.body.name.name != :oneArray
nrhs = 3
opA_gpu = KrylovOperator(A_gpu; nrhs)
for i = 1:5
Y_cpu = rand(FC, m, nrhs)
X_cpu = rand(FC, n, nrhs)
mul!(Y_cpu, A_cpu, X_cpu)
Y_gpu = DM(Y_cpu)
X_gpu = DM(X_cpu)
mul!(Y_gpu, opA_gpu, X_gpu)
@test collect(Y_gpu) Y_cpu
end
for j = 1:5
Y_cpu = rand(FC, m, nrhs)
X_cpu = rand(FC, n, nrhs)
A_cpu2 = A_cpu + j*I
mul!(Y_cpu, A_cpu2, X_cpu)
Y_gpu = DM(Y_cpu)
X_gpu = DM(X_cpu)
A_gpu2 = SM(A_cpu2)
update!(opA_gpu, A_gpu2)
mul!(Y_gpu, opA_gpu, X_gpu)
@test collect(Y_gpu) Y_cpu
end
end
end

Expand Down Expand Up @@ -152,43 +156,47 @@ function test_triangular(FC, V, DM, SM)
ldiv!(y_gpu, opA_gpu, x_gpu)
@test collect(y_gpu) y_cpu
end
for j = 1:5
y_cpu = rand(FC, n)
x_cpu = rand(FC, n)
A_cpu2 = A_cpu + j*tril(A_cpu,-1) + j*triu(A_cpu,1)
ldiv!(y_cpu, triangle(A_cpu2), x_cpu)
y_gpu = V(y_cpu)
x_gpu = V(x_cpu)
A_gpu2 = SM(A_cpu2)
update!(opA_gpu, A_gpu2)
ldiv!(y_gpu, opA_gpu, x_gpu)
@test collect(y_gpu) y_cpu
if V.body.name.name != :oneArray
for j = 1:5
y_cpu = rand(FC, n)
x_cpu = rand(FC, n)
A_cpu2 = A_cpu + j*tril(A_cpu,-1) + j*triu(A_cpu,1)
ldiv!(y_cpu, triangle(A_cpu2), x_cpu)
y_gpu = V(y_cpu)
x_gpu = V(x_cpu)
A_gpu2 = SM(A_cpu2)
update!(opA_gpu, A_gpu2)
ldiv!(y_gpu, opA_gpu, x_gpu)
@test collect(y_gpu) y_cpu
end
end

nrhs = 3
opA_gpu = TriangularOperator(A_gpu, uplo, diag; nrhs)
for i = 1:5
Y_cpu = rand(FC, n, nrhs)
X_cpu = rand(FC, n, nrhs)
ldiv!(Y_cpu, triangle(A_cpu), X_cpu)
Y_gpu = DM(Y_cpu)
X_gpu = DM(X_cpu)
ldiv!(Y_gpu, opA_gpu, X_gpu)
@test collect(Y_gpu) Y_cpu
end
if V.body.name.name != :CuArray
for j = 1:5
if V.body.name.name != :oneArray
nrhs = 3
opA_gpu = TriangularOperator(A_gpu, uplo, diag; nrhs)
for i = 1:5
Y_cpu = rand(FC, n, nrhs)
X_cpu = rand(FC, n, nrhs)
A_cpu2 = A_cpu + j*tril(A_cpu,-1) + j*triu(A_cpu,1)
ldiv!(Y_cpu, triangle(A_cpu2), X_cpu)
ldiv!(Y_cpu, triangle(A_cpu), X_cpu)
Y_gpu = DM(Y_cpu)
X_gpu = DM(X_cpu)
A_gpu2 = SM(A_cpu2)
update!(opA_gpu, A_gpu2)
ldiv!(Y_gpu, opA_gpu, X_gpu)
@test collect(Y_gpu) Y_cpu
end
if V.body.name.name != :CuArray
for j = 1:5
Y_cpu = rand(FC, n, nrhs)
X_cpu = rand(FC, n, nrhs)
A_cpu2 = A_cpu + j*tril(A_cpu,-1) + j*triu(A_cpu,1)
ldiv!(Y_cpu, triangle(A_cpu2), X_cpu)
Y_gpu = DM(Y_cpu)
X_gpu = DM(X_cpu)
A_gpu2 = SM(A_cpu2)
update!(opA_gpu, A_gpu2)
ldiv!(Y_gpu, opA_gpu, X_gpu)
@test collect(Y_gpu) Y_cpu
end
end
end
end
end
Expand Down
12 changes: 12 additions & 0 deletions test/gpu/intel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@ include("gpu.jl")
@test oneAPI.functional()
oneAPI.allowscalar(false)

@testset "KrylovOperator" begin
@testset "oneSparseMatrixCSR -- $FC" for FC in (Float64, ComplexF64)
test_operator(FC, oneVector{FC}, oneMatrix{FC}, oneSparseMatrixCSR)
end
end

@testset "TriangularOperator" begin
@testset "oneSparseMatrixCSR -- $FC" for FC in (Float64, ComplexF64)
test_triangular(FC, oneVector{FC}, oneMatrix{FC}, oneSparseMatrixCSR)
end
end

@testset "Block Jacobi preconditioner" begin
test_block_jacobi(oneAPIBackend(), oneArray, oneSparseMatrixCSR)
end
Expand Down

0 comments on commit 6e548be

Please sign in to comment.