Skip to content

Commit

Permalink
Add device to ext
Browse files Browse the repository at this point in the history
  • Loading branch information
luraess committed Feb 7, 2024
1 parent b85fcd5 commit 25c5729
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
12 changes: 11 additions & 1 deletion ext/ChmyAMDGPUExt/ChmyAMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
module ChmyAMDGPUExt

using AMDGPU, KernelAbstractions, Chmy
using AMDGPU, AMDGPU.ROCKernels, KernelAbstractions, Chmy

import Chmy.Architectures: heuristic_groupsize, set_device!, get_device

Base.unsafe_wrap(::ROCBackend, ptr::Ptr, dims) = unsafe_wrap(ROCArray, ptr, dims; lock=false)

Chmy.pointertype(::ROCBackend, T::DataType) = Ptr{T}

set_device!(dev::HIPDevice) = AMDGPU.device!(dev)

get_device(::ROCBackend, id::Integer) = HIPDevice(id)

heuristic_groupsize(::HIPDevice, ::Val{1}) = (256, )
heuristic_groupsize(::HIPDevice, ::Val{2}) = (128, 2, )
heuristic_groupsize(::HIPDevice, ::Val{3}) = (128, 2, 1, )

end
12 changes: 11 additions & 1 deletion ext/ChmyCUDAExt/ChmyCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
module ChmyCUDAExt

using CUDA, KernelAbstractions, Chmy
using CUDA, CUDA.CUDAKernels, KernelAbstractions, Chmy

import Chmy.Architectures: heuristic_groupsize, set_device!, get_device

Base.unsafe_wrap(::CUDABackend, ptr::CuPtr, dims) = unsafe_wrap(CuArray, ptr, dims)

Chmy.pointertype(::CUDABackend, T::DataType) = CuPtr{T}

set_device!(dev::CuDevice) = CUDA.device!(dev)

get_device(::CUDABackend, id::Integer) = CuDevice(id - 1)

heuristic_groupsize(::CuDevice, ::Val{1}) = (256,)
heuristic_groupsize(::CuDevice, ::Val{2}) = (32, 8)
heuristic_groupsize(::CuDevice, ::Val{3}) = (32, 8, 1)

end

0 comments on commit 25c5729

Please sign in to comment.