Skip to content

Commit

Permalink
Update XLA.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored May 16, 2024
1 parent 43fcd9f commit 5377b28
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ function GPUClient(node_id=0, num_nodes=1, platform="gpu")
#allowed_devices = [-1]
# GC.@preserve allowed_devices begin
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeGPUClient")
client = ccall(f, Ptr{Cvoid}, (Cint, Cint, Ptr{Cvoid}, Cint, Cstring), node_id, num_nodes, C_NULL, 0, platform)
refstr = Ref{Cstring}()
client = ccall(f, Ptr{Cvoid}, (Cint, Cint, Ptr{Cvoid}, Cint, Cstring, Ptr{Cstring}), node_id, num_nodes, C_NULL, 0, platform, refstr)
if client == C_NULL
throw(AssertionError(refstr[]))
end
# end
return Client(client)
end
Expand All @@ -77,8 +81,12 @@ function __init__()
cpu = CPUClient()
backends["cpu"] = cpu
@static if !Sys.isapple()
gpu = GPUClient()
backends["gpu"] = gpu
try
gpu = GPUClient()
backends["gpu"] = gpu
catch e
println(stdout, e)
end
end
default_backend[] = cpu
end
Expand Down

0 comments on commit 5377b28

Please sign in to comment.