Skip to content

Commit

Permalink
Update spmv for AMD GPUs
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Feb 29, 2024
1 parent f3142b4 commit f44d83c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ KrylovPreconditionersAMDGPUExt = "AMDGPU"
KrylovPreconditionersCUDAExt = "CUDA"

[compat]
AMDGPU = "0.8.2"
AMDGPU = "0.8.3"
Adapt = "3, 4"
CUDA = "5.1.1"
KernelAbstractions = "0.9"
Expand Down
25 changes: 22 additions & 3 deletions ext/AMDGPU/operators.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using AMDGPU.HIP

mutable struct AMD_KrylovOperator{T} <: AbstractKrylovOperator{T}
type::Type{T}
m::Int
Expand Down Expand Up @@ -26,9 +28,20 @@ for (SparseMatrixType, BlasType) in ((:(ROCSparseMatrixCSR{T}), :BlasFloat),
descY = rocSPARSE.ROCDenseVectorDescriptor(T, m)
algo = rocSPARSE.rocSPARSE.rocsparse_spmv_alg_default
buffer_size = Ref{Csize_t}()
rocSPARSE.rocsparse_spmv(rocSPARSE.handle(), transa, alpha, descA, descX,
if HIP.runtime_version() v"6-"
rocSPARSE.rocsparse_spmv(rocSPARSE.handle(), transa, alpha, descA, descX,
beta, descY, T, algo, rocSPARSE.rocsparse_spmv_stage_buffer_size,
buffer_size, C_NULL)
else
rocSPARSE.rocsparse_spmv(rocSPARSE.handle(), transa, alpha, descA, descX,
beta, descY, T, algo, buffer_size, C_NULL)
end
buffer = ROCVector{UInt8}(undef, buffer_size[])
if HIP.runtime_version() v"6-"
rocSPARSE.rocsparse_spmv(rocSPARSE.handle(), transa, alpha, descA, descX,
beta, descY, T, algo, rocSPARSE.rocsparse_spmv_stage_preprocess,
buffer_size, buffer)
end
return AMD_KrylovOperator{T}(T, m, n, nrhs, transa, descA, buffer_size, buffer)
else
descX = rocSPARSE.ROCDenseMatrixDescriptor(T, n, nrhs)
Expand Down Expand Up @@ -62,8 +75,14 @@ function LinearAlgebra.mul!(y::ROCVector{T}, A::AMD_KrylovOperator{T}, x::ROCVec
algo = rocSPARSE.rocsparse_spmv_alg_default
alpha = Ref{T}(one(T))
beta = Ref{T}(zero(T))
rocSPARSE.rocsparse_spmv(rocSPARSE.handle(), A.transa, alpha, A.descA, descX,
beta, descY, T, algo, A.buffer_size, A.buffer)
if HIP.runtime_version() v"6-"
rocSPARSE.rocsparse_spmv(rocSPARSE.handle(), A.transa, alpha, A.descA, descX,
beta, descY, T, algo, rocSPARSE.rocsparse_spmv_stage_compute,
A.buffer_size, A.buffer)
else
rocSPARSE.rocsparse_spmv(rocSPARSE.handle(), A.transa, alpha, A.descA, descX,
beta, descY, T, algo, A.buffer_size, A.buffer)
end
end

function LinearAlgebra.mul!(Y::ROCMatrix{T}, A::AMD_KrylovOperator{T}, X::ROCMatrix{T}) where T <: BlasFloat
Expand Down

0 comments on commit f44d83c

Please sign in to comment.