diff --git a/Project.toml b/Project.toml index 5320962..fca921d 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,7 @@ KrylovPreconditionersAMDGPUExt = "AMDGPU" KrylovPreconditionersCUDAExt = "CUDA" [compat] -AMDGPU = "0.8.2" +AMDGPU = "0.8.3" Adapt = "3, 4" CUDA = "5.1.1" KernelAbstractions = "0.9" diff --git a/ext/AMDGPU/operators.jl b/ext/AMDGPU/operators.jl index c5a7f77..74d2df5 100644 --- a/ext/AMDGPU/operators.jl +++ b/ext/AMDGPU/operators.jl @@ -1,3 +1,5 @@ +using AMDGPU.HIP + mutable struct AMD_KrylovOperator{T} <: AbstractKrylovOperator{T} type::Type{T} m::Int @@ -26,9 +28,20 @@ for (SparseMatrixType, BlasType) in ((:(ROCSparseMatrixCSR{T}), :BlasFloat), descY = rocSPARSE.ROCDenseVectorDescriptor(T, m) algo = rocSPARSE.rocSPARSE.rocsparse_spmv_alg_default buffer_size = Ref{Csize_t}() - rocSPARSE.rocsparse_spmv(rocSPARSE.handle(), transa, alpha, descA, descX, + if HIP.runtime_version() ≥ v"6-" + rocSPARSE.rocsparse_spmv(rocSPARSE.handle(), transa, alpha, descA, descX, + beta, descY, T, algo, rocSPARSE.rocsparse_spmv_stage_buffer_size, + buffer_size, C_NULL) + else + rocSPARSE.rocsparse_spmv(rocSPARSE.handle(), transa, alpha, descA, descX, beta, descY, T, algo, buffer_size, C_NULL) + end buffer = ROCVector{UInt8}(undef, buffer_size[]) + if HIP.runtime_version() ≥ v"6-" + rocSPARSE.rocsparse_spmv(rocSPARSE.handle(), transa, alpha, descA, descX, + beta, descY, T, algo, rocSPARSE.rocsparse_spmv_stage_preprocess, + buffer_size, buffer) + end return AMD_KrylovOperator{T}(T, m, n, nrhs, transa, descA, buffer_size, buffer) else descX = rocSPARSE.ROCDenseMatrixDescriptor(T, n, nrhs) @@ -62,8 +75,14 @@ function LinearAlgebra.mul!(y::ROCVector{T}, A::AMD_KrylovOperator{T}, x::ROCVec algo = rocSPARSE.rocsparse_spmv_alg_default alpha = Ref{T}(one(T)) beta = Ref{T}(zero(T)) - rocSPARSE.rocsparse_spmv(rocSPARSE.handle(), A.transa, alpha, A.descA, descX, - beta, descY, T, algo, A.buffer_size, A.buffer) + if HIP.runtime_version() ≥ v"6-" + rocSPARSE.rocsparse_spmv(rocSPARSE.handle(), A.transa, alpha, A.descA, descX, + beta, descY, T, algo, rocSPARSE.rocsparse_spmv_stage_compute, + A.buffer_size, A.buffer) + else + rocSPARSE.rocsparse_spmv(rocSPARSE.handle(), A.transa, alpha, A.descA, descX, + beta, descY, T, algo, A.buffer_size, A.buffer) + end end function LinearAlgebra.mul!(Y::ROCMatrix{T}, A::AMD_KrylovOperator{T}, X::ROCMatrix{T}) where T <: BlasFloat