diff --git a/Project.toml b/Project.toml index 5a0c30def..910f4d1d8 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "0.1.0" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" Cassette = "7057c7e9-c182-5462-911a-8362d720325c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" @@ -19,4 +20,4 @@ ReactantNNlibExt = "NNlib" [compat] CEnum = "0.4" Enzyme = "0.11, 0.12" -Reactant_jll = "0.0.2" +Reactant_jll = "0.0.3" diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 4f01abb3c..ebad2db70 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -189,6 +189,15 @@ cc_library( "@stablehlo//:chlo_ops", "@xla//xla/pjrt/cpu:cpu_client", "@xla//xla/pjrt/gpu:se_gpu_pjrt_client", + + + "@xla//xla/stream_executor/cuda:all_runtime", + "@xla//xla/stream_executor/cuda:cuda_driver", + "@xla//xla/stream_executor/cuda:cuda_platform", + "@xla//xla/stream_executor/cuda:cudnn_plugin", + "@xla//xla/stream_executor/cuda:cufft_plugin", + "@xla//xla/stream_executor:cuda_platform", + "@xla//xla/pjrt:status_casters", "@xla//xla/python/ifrt:ifrt", "@xla//xla/python/pjrt_ifrt:xla_ifrt", diff --git a/src/XLA.jl b/src/XLA.jl index e1b33ef02..59e035d44 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -50,14 +50,17 @@ function CPUClient(asynchronous=true, node_id=0, num_nodes=1) global cpuclientcount @assert cpuclientcount[] == 0 cpuclientcount[]+=1 - client = @ccall MLIR.API.mlir_c.MakeCPUClient(asynchronous::UInt8, node_id::Cint, num_nodes::Cint)::Ptr{Cvoid} + + f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeCPUClient") + client = ccall(f, Ptr{Cvoid}, (UInt8, Cint, Cint), asynchronous, node_id, num_nodes) return Client(client) end -function GPUClient(node_id=0, num_nodes=1, platform="cuda") +function GPUClient(node_id=0, num_nodes=1, platform="gpu") allowed_devices = [-1] GC.@preserve allowed_devices begin - client = @ccall MLIR.API.mlir_c.MakeGPUClient(node_id::Cint, num_nodes::Cint, pointer(allowed_devices)::Ptr{Cvoid}, length(allowed_devices)::Cint, platform::Cstring)::Ptr{Cvoid} + f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeGPUClient") + client = ccall(f, Ptr{Cvoid}, (Cint, Cint, Ptr{Cvoid}, Cint, Cstring), node_id, num_nodes, pointer(allowed_devices), length(allowed_devices), platform) end return Client(client) end @@ -65,10 +68,15 @@ end const backends = Dict{String, Client}() const default_backend = Ref{Client}() const default_device_idx = Ref{Int}(0) +using Reactant_jll +using Libdl function __init__() - @ccall MLIR.API.mlir_c.InitializeLogs()::Cvoid + initLogs = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "InitializeLogs") + ccall(initLogs, Cvoid, ()) cpu = CPUClient() backends["cpu"] = cpu + gpu = GPUClient() + backends["gpu"] = gpu default_backend[] = cpu end diff --git a/src/mlir/MLIR.jl b/src/mlir/MLIR.jl index dbe7b79a6..7c7b4cdfc 100644 --- a/src/mlir/MLIR.jl +++ b/src/mlir/MLIR.jl @@ -5,7 +5,7 @@ using CEnum using Preferences using Reactant_jll -const mlir_c = Reactant_jll.libReactantExtra +const mlir_c = Reactant_jll.libReactantExtra_handle # MLIR C API let