Skip to content

Commit

Permalink
wqtmp
Browse files Browse the repository at this point in the history
  • Loading branch information
William Moses committed Dec 5, 2024
1 parent 1209491 commit 95d5921
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,19 @@ end
const _kernel_instances = Dict{Any, Any}()



# compile to executable machine code
function compile(job)
# lower to PTX
# TODO: on 1.9, this actually creates a context. cache those.
modstr = JuliaContext() do ctx
mod, meta = GPUCompiler.compile(:llvm, job)
modstr = CUDA.GPUCompiler.JuliaContext() do ctx
mod, meta = CUDA.GPUCompiler.compile(:llvm, job)
string(mod)
end
println(string(modstr))
@show job
@show job.params
@show job.source
kernel = LLVMFunc{F,tt}(f, modstr)
return modstr
#=
# check if we'll need the device runtime
Expand Down Expand Up @@ -187,12 +191,23 @@ function link(job, compiled)
end

struct LLVMFunc{F,tt}
f::F
mod::String
f::F
mod::String
end

function (func::LLVMFunc{F,tt})(args...) where{F, tt}


end

# cache of compilation caches, per context
const _compiler_caches = Dict{MLIR.IR.Context, Dict{Any, LLVMFunc}}();
function compiler_cache(ctx::MLIR.IR.Context)
cache = get(_compiler_caches, ctx, nothing)
if cache === nothing
cache = Dict{Any, LLVMFunc}()
_compiler_caches[ctx] = cache
end
return cache
end

function recufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
Expand All @@ -202,20 +217,17 @@ function recufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}

Base.@lock CUDA.cufunction_lock begin
# compile the function
cache = CUDA.compiler_cache(cuda.context)
cache = compiler_cache(MLIR.IR.context())
source = CUDA.methodinstance(F, tt)
config = CUDA.compiler_config(cuda.device; kwargs...)::CUDA.CUDACompilerConfig
fun = CUDA.GPUCompiler.cached_compilation(cache, source, config, compile, link)

@show fun
println(string(fun))
#@show fun.mod
# create a callable object that captures the function instance. we don't need to think
# about world age here, as GPUCompiler already does and will return a different object
key = (objectid(source))
kernel = get(_kernel_instances, key, nothing)
if kernel === nothing
kernel = LLVMFunc{F,tt}(f, fun)
_kernel_instances[key] = kernel
end
return kernel::LLVMFunc{F,tt}
Expand Down

0 comments on commit 95d5921

Please sign in to comment.