Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
William Moses committed Dec 6, 2024
1 parent 95d5921 commit be52876
Showing 1 changed file with 27 additions and 17 deletions.
44 changes: 27 additions & 17 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,13 @@ function compile(job)
mod, meta = CUDA.GPUCompiler.compile(:llvm, job)
string(mod)
end
println(string(modstr))
@show job
@show job.params
@show job.source
kernel = LLVMFunc{F,tt}(f, modstr)
return modstr
#=
# check if we'll need the device runtime
undefined_fs = filter(collect(functions(meta.ir))) do f
isdeclaration(f) && !LLVM.isintrinsic(f)
isdeclaration(f) && !CUDA.LLVM.isintrinsic(f)
end
intrinsic_fns = ["vprintf", "malloc", "free", "__assertfail",
"__nvvm_reflect" #= TODO: should have been optimized away =#]
needs_cudadevrt = !isempty(setdiff(LLVM.name.(undefined_fs), intrinsic_fns))
needs_cudadevrt = !isempty(setdiff(CUDA.LLVM.name.(undefined_fs), intrinsic_fns))

# prepare invocations of CUDA compiler tools
ptxas_opts = String[]
Expand All @@ -59,7 +52,7 @@ function compile(job)
arch = "sm_$(cap.major)$(cap.minor)"

# validate use of parameter memory
argtypes = filter([KernelState, job.source.specTypes.parameters...]) do dt
argtypes = filter([CUDA.KernelState, job.source.specTypes.parameters...]) do dt
!isghosttype(dt) && !Core.Compiler.isconstType(dt)
end
param_usage = sum(sizeof, argtypes)
Expand Down Expand Up @@ -120,7 +113,7 @@ function compile(job)
"--output-file", ptxas_output,
ptx_input
])
proc, log = run_and_collect(`$(ptxas()) $ptxas_opts`)
proc, log = CUDA.run_and_collect(`$(ptxas()) $ptxas_opts`)
log = strip(log)
if !success(proc)
reason = proc.termsignal > 0 ? "ptxas received signal $(proc.termsignal)" :
Expand All @@ -139,8 +132,7 @@ function compile(job)
@debug "PTX compiler log:\n" * log
end
rm(ptx_input)
=#
#=

# link device libraries, if necessary
#
# this requires relocatable device code, which prevents certain optimizations and
Expand Down Expand Up @@ -180,8 +172,12 @@ function compile(job)
image = read(ptxas_output)
rm(ptxas_output)
end
=#
return (image, entry=LLVM.name(meta.entry))

println(string(modstr))
@show job
@show job.source
@show job.config
LLVMFunc{F,job.source.specTypes}(f, modstr, image, LLVM.name(meta.entry))
end

# link into an executable kernel
Expand All @@ -193,10 +189,24 @@ end
struct LLVMFunc{F,tt}
f::F
mod::String
image
entry::String
end

function (func::LLVMFunc{F,tt})(args...) where{F, tt}

function (func::LLVMFunc{F,tt})(args...; blocks::CUDA.CuDim=1, threads::CUDA.CuDim=1,
shmem::Integer=0) where{F, tt}
blockdim = CUDA.CuDim3(blocks)
threaddim = CUDA.CuDim3(threads)

@show args

# void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque,
# size_t opaque_len, XlaCustomCallStatus* status) {

CUDA.cuLaunchKernel(f,
blockdim.x, blockdim.y, blockdim.z,
threaddim.x, threaddim.y, threaddim.z,
shmem, stream, kernelParams, C_NULL)
end

# cache of compilation caches, per context
Expand Down

0 comments on commit be52876

Please sign in to comment.