Skip to content

Commit

Permalink
link against cuda platform
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed May 3, 2024
1 parent 81a889d commit 3157b2f
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 6 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.1.0"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"

Expand All @@ -19,4 +20,4 @@ ReactantNNlibExt = "NNlib"
[compat]
CEnum = "0.4"
Enzyme = "0.11, 0.12"
Reactant_jll = "0.0.2"
Reactant_jll = "0.0.3"
9 changes: 9 additions & 0 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,15 @@ cc_library(
"@stablehlo//:chlo_ops",
"@xla//xla/pjrt/cpu:cpu_client",
"@xla//xla/pjrt/gpu:se_gpu_pjrt_client",


"@xla//xla/stream_executor/cuda:all_runtime",
"@xla//xla/stream_executor/cuda:cuda_driver",
"@xla//xla/stream_executor/cuda:cuda_platform",
"@xla//xla/stream_executor/cuda:cudnn_plugin",
"@xla//xla/stream_executor/cuda:cufft_plugin",
"@xla//xla/stream_executor:cuda_platform",

"@xla//xla/pjrt:status_casters",
"@xla//xla/python/ifrt:ifrt",
"@xla//xla/python/pjrt_ifrt:xla_ifrt",
Expand Down
16 changes: 12 additions & 4 deletions src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,25 +50,33 @@ function CPUClient(asynchronous=true, node_id=0, num_nodes=1)
global cpuclientcount
@assert cpuclientcount[] == 0
cpuclientcount[]+=1
client = @ccall MLIR.API.mlir_c.MakeCPUClient(asynchronous::UInt8, node_id::Cint, num_nodes::Cint)::Ptr{Cvoid}

f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeCPUClient")
client = ccall(f, Ptr{Cvoid}, (UInt8, Cint, Cint), asynchronous, node_id, num_nodes)
return Client(client)
end

function GPUClient(node_id=0, num_nodes=1, platform="cuda")
function GPUClient(node_id=0, num_nodes=1, platform="gpu")
allowed_devices = [-1]
GC.@preserve allowed_devices begin
client = @ccall MLIR.API.mlir_c.MakeGPUClient(node_id::Cint, num_nodes::Cint, pointer(allowed_devices)::Ptr{Cvoid}, length(allowed_devices)::Cint, platform::Cstring)::Ptr{Cvoid}
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeGPUClient")
client = ccall(f, Ptr{Cvoid}, (Cint, Cint, Ptr{Cvoid}, Cint, Cstring), node_id, num_nodes, pointer(allowed_devices), length(allowed_devices), platform)
end
return Client(client)
end

const backends = Dict{String, Client}()
const default_backend = Ref{Client}()
const default_device_idx = Ref{Int}(0)
using Reactant_jll
using Libdl
function __init__()
@ccall MLIR.API.mlir_c.InitializeLogs()::Cvoid
initLogs = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "InitializeLogs")
ccall(initLogs, Cvoid, ())
cpu = CPUClient()
backends["cpu"] = cpu
gpu = GPUClient()
backends["gpu"] = gpu
default_backend[] = cpu
end

Expand Down
2 changes: 1 addition & 1 deletion src/mlir/MLIR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using CEnum
using Preferences
using Reactant_jll

const mlir_c = Reactant_jll.libReactantExtra
const mlir_c = Reactant_jll.libReactantExtra_handle

# MLIR C API
let
Expand Down

0 comments on commit 3157b2f

Please sign in to comment.