Skip to content

Commit

Permalink
Fix libtpu auto download (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Sep 2, 2024
1 parent 65e1c2a commit ebb3661
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
9 changes: 3 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
name = "Reactant"
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
authors = [
"William Moses <[email protected]>",
"Valentin Churavy <[email protected]>",
"Sergio Sánchez Ramírez <[email protected]>",
"Paul Berg <[email protected]>",
]
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>"]
version = "0.1.9"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
Scratch = "6c6a2e73-6563-6170-7368-637461726353"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
13 changes: 12 additions & 1 deletion src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ function GPUClient(node_id=0, num_nodes=1, platform="gpu")
return Client(client)
end


function TPUClient(tpu_path::String)
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeTPUClient")
refstr = Ref{Cstring}()
Expand All @@ -88,17 +89,27 @@ const default_backend = Ref{Client}()
const default_device_idx = Ref{Int}(0)
using Reactant_jll
using Libdl
using Scratch, Downloads
function __init__()
initLogs = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "InitializeLogs")
ccall(initLogs, Cvoid, ())
cpu = CPUClient()
backends["cpu"] = cpu
default_backend[] = cpu

@static if !Sys.isapple()
if isfile("/usr/lib/libtpu.so")
dataset_dir = @get_scratch!("libtpu")
if !isfile(dataset_dir*"/libtpu.so")
Downloads.download("https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20240829-py3-none-any.whl", dataset_dir*"/tpu.zip")
run(`unzip -qq $(dataset_dir*"/tpu.zip") -d $(dataset_dir)/tmp`)
run(`mv $(dataset_dir)/tmp/libtpu/libtpu.so $(dataset_dir)/libtpu.so`)
rm(dataset_dir*"/tmp", recursive=true)
rm(dataset_dir*"/tpu.zip", recursive=true)
end
try
tpu = TPUClient(
"/home/wmoses/.local/lib/python3.8/site-packages/libtpu/libtpu.so"
dataset_dir*"/libtpu.so"
)
backends["tpu"] = tpu
default_backend[] = tpu
Expand Down

0 comments on commit ebb3661

Please sign in to comment.