From d39a8837b1ca8c2e33272092f68de040c85cf0c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Mon, 9 Dec 2024 17:04:21 +0100 Subject: [PATCH 01/21] precompile first try --- Project.toml | 6 ++++-- src/Reactant.jl | 2 ++ src/precompile.jl | 9 +++++++++ 3 files changed, 15 insertions(+), 2 deletions(-) create mode 100644 src/precompile.jl diff --git a/Project.toml b/Project.toml index d5e57ef82..356c9d55d 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433" @@ -30,8 +31,8 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df" -[sources.ReactantCore] -path = "lib/ReactantCore" +[sources] +ReactantCore = {path = "lib/ReactantCore"} [extensions] ReactantAbstractFFTsExt = "AbstractFFTs" @@ -57,6 +58,7 @@ GPUArraysCore = "0.1.6, 0.2" LinearAlgebra = "1.10" NNlib = "0.9.26" OrderedCollections = "1" +PrecompileTools = "1.2.1" Preferences = "1.4" PythonCall = "0.9" Random = "1.10" diff --git a/src/Reactant.jl b/src/Reactant.jl index 039a717a9..b2daddf1b 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -244,4 +244,6 @@ function set_default_backend(backend::String) return set_default_backend(XLA.backends[backend]) end +include("precompile.jl") + end # module diff --git a/src/precompile.jl b/src/precompile.jl new file mode 100644 index 000000000..da631b3f4 --- /dev/null +++ b/src/precompile.jl @@ -0,0 +1,9 @@ +using PrecompileTools: @setup_workload, @compile_workload + +@setup_workload begin + @compile_workload begin + __init__() + x = Reactant.ConcreteRArray(randn(Float64, 2, 2)) + @jit sum(x) + end +end From 702c3b1e8bfded2f9b1f1bf1722aaf14f92b11ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Mon, 9 Dec 2024 20:41:41 +0100 Subject: [PATCH 02/21] add __init__ & assert remove --- src/XLA.jl | 2 +- src/precompile.jl | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/XLA.jl b/src/XLA.jl index 6255737e4..9533bb004 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -44,7 +44,7 @@ const cpuclientcount = Ref(0) # TODO synchronization when async is not working because `future` in `ConcreteRArray` is always `nothing` function CPUClient(asynchronous=false, node_id=0, num_nodes=1) global cpuclientcount - @assert cpuclientcount[] == 0 + #@assert cpuclientcount[] == 0 cpuclientcount[] += 1 f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeCPUClient") diff --git a/src/precompile.jl b/src/precompile.jl index da631b3f4..532c03361 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -1,8 +1,9 @@ using PrecompileTools: @setup_workload, @compile_workload @setup_workload begin + Reactant.__init__() + XLA.__init__() @compile_workload begin - __init__() x = Reactant.ConcreteRArray(randn(Float64, 2, 2)) @jit sum(x) end From 6a0dc9bbec53b9e4a98ad276037320146aa328da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Tue, 10 Dec 2024 13:26:04 +0100 Subject: [PATCH 03/21] remove counter & fix precompilation --- Project.toml | 2 +- src/XLA.jl | 16 ++++++---------- src/precompile.jl | 13 +++++++++---- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 356c9d55d..0dd5a31a8 100644 --- a/Project.toml +++ b/Project.toml @@ -58,7 +58,7 @@ GPUArraysCore = "0.1.6, 0.2" LinearAlgebra = "1.10" NNlib = "0.9.26" OrderedCollections = "1" -PrecompileTools = "1.2.1" +PrecompileTools = "1" Preferences = "1.4" PythonCall = "0.9" Random = "1.10" diff --git a/src/XLA.jl b/src/XLA.jl index 9533bb004..c9940c904 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -6,11 +6,12 @@ mutable struct Client client::Ptr{Cvoid} function Client(client::Ptr{Cvoid}) - return new(client) - #@assert client != C_NULL - #finalizer(new(client)) do client - # @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid - #end + @assert client != C_NULL + client = new(client) + finalizer(client) do client + @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid + end + return client end end @@ -40,13 +41,8 @@ end SetLogLevel(x) = @ccall MLIR.API.mlir_c.SetLogLevel(x::Cint)::Cvoid -const cpuclientcount = Ref(0) # TODO synchronization when async is not working because `future` in `ConcreteRArray` is always `nothing` function CPUClient(asynchronous=false, node_id=0, num_nodes=1) - global cpuclientcount - #@assert cpuclientcount[] == 0 - cpuclientcount[] += 1 - f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeCPUClient") client = ccall(f, Ptr{Cvoid}, (UInt, Cint, Cint), asynchronous, node_id, num_nodes) #client = @ccall MLIR.API.mlir_c.MakeCPUClient(asynchronous::UInt8, node_id::Cint, num_nodes::Cint)::Ptr{Cvoid} diff --git a/src/precompile.jl b/src/precompile.jl index 532c03361..dc09d21c1 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -1,10 +1,15 @@ using PrecompileTools: @setup_workload, @compile_workload @setup_workload begin - Reactant.__init__() - XLA.__init__() @compile_workload begin - x = Reactant.ConcreteRArray(randn(Float64, 2, 2)) - @jit sum(x) + Reactant.__init__() + cpu = XLA.CPUClient() + x = Reactant.ConcreteRArray(randn(Float64, 2, 2); client=cpu) + @code_hlo optimize = false sum(x) + end + + @compile_workload begin + interp = Reactant.ReactantInterpreter() + Base.code_ircode(sum, (Reactant.TracedRArray{Float64,2},); interp) end end From 9ebf3da39e814cd5e54ec9aa8268fc6a42c7b487 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Tue, 10 Dec 2024 15:12:34 +0100 Subject: [PATCH 04/21] keep only one workload --- src/precompile.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/precompile.jl b/src/precompile.jl index dc09d21c1..83d709626 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -7,9 +7,4 @@ using PrecompileTools: @setup_workload, @compile_workload x = Reactant.ConcreteRArray(randn(Float64, 2, 2); client=cpu) @code_hlo optimize = false sum(x) end - - @compile_workload begin - interp = Reactant.ReactantInterpreter() - Base.code_ircode(sum, (Reactant.TracedRArray{Float64,2},); interp) - end end From 9d2c0d26dcd757092e89152b985c60ffda4bba73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Tue, 10 Dec 2024 20:53:41 +0100 Subject: [PATCH 05/21] try to fix CI --- src/precompile.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/precompile.jl b/src/precompile.jl index 83d709626..9aba6f293 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -2,9 +2,7 @@ using PrecompileTools: @setup_workload, @compile_workload @setup_workload begin @compile_workload begin - Reactant.__init__() - cpu = XLA.CPUClient() - x = Reactant.ConcreteRArray(randn(Float64, 2, 2); client=cpu) - @code_hlo optimize = false sum(x) + interp = Reactant.ReactantInterpreter() + Base.code_ircode(sum, (Reactant.TracedRArray{Float64,2},); interp) end end From 67ca30467518157192c66e73e77404ccc2a36435 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Tue, 10 Dec 2024 21:22:37 +0100 Subject: [PATCH 06/21] forgotten init --- src/precompile.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/precompile.jl b/src/precompile.jl index 9aba6f293..b3cf02e73 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -2,6 +2,7 @@ using PrecompileTools: @setup_workload, @compile_workload @setup_workload begin @compile_workload begin + Reactant.__init__() interp = Reactant.ReactantInterpreter() Base.code_ircode(sum, (Reactant.TracedRArray{Float64,2},); interp) end From 6ade5e70134200a40a7ad9244d850a78831061f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Thu, 12 Dec 2024 19:51:48 +0100 Subject: [PATCH 07/21] compact --- Project.toml | 4 ++-- src/Reactant.jl | 5 ++++- src/XLA.jl | 7 ++++--- src/precompile.jl | 5 +++-- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 0dd5a31a8..a568d6397 100644 --- a/Project.toml +++ b/Project.toml @@ -31,8 +31,8 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df" -[sources] -ReactantCore = {path = "lib/ReactantCore"} +[sources.ReactantCore] +path = "lib/ReactantCore" [extensions] ReactantAbstractFFTsExt = "AbstractFFTs" diff --git a/src/Reactant.jl b/src/Reactant.jl index b2daddf1b..c9bff26dd 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -244,6 +244,9 @@ function set_default_backend(backend::String) return set_default_backend(XLA.backends[backend]) end -include("precompile.jl") +@static if !haskey(ENV, "REACTANT_TEST_GROUP") + @info "enable precompilation" gethostname() Base.active_project() + include("precompile.jl") +end end # module diff --git a/src/XLA.jl b/src/XLA.jl index c9940c904..7d9822c54 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -8,9 +8,10 @@ mutable struct Client function Client(client::Ptr{Cvoid}) @assert client != C_NULL client = new(client) - finalizer(client) do client - @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid - end + #TODO: Client are also constructed from MLIR.API.mlir_c.BufferToClient so the pointer cannot be free when Client is cleaned + #finalizer(client) do client + # @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid + #end return client end end diff --git a/src/precompile.jl b/src/precompile.jl index b3cf02e73..83d709626 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -3,7 +3,8 @@ using PrecompileTools: @setup_workload, @compile_workload @setup_workload begin @compile_workload begin Reactant.__init__() - interp = Reactant.ReactantInterpreter() - Base.code_ircode(sum, (Reactant.TracedRArray{Float64,2},); interp) + cpu = XLA.CPUClient() + x = Reactant.ConcreteRArray(randn(Float64, 2, 2); client=cpu) + @code_hlo optimize = false sum(x) end end From 795269482ac0dd177d3a57a9075ee7e61ba6ff18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Thu, 12 Dec 2024 20:46:15 +0100 Subject: [PATCH 08/21] apply CI check inside setup_workload --- src/Reactant.jl | 5 +---- src/precompile.jl | 4 ++++ 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/Reactant.jl b/src/Reactant.jl index c9bff26dd..b2daddf1b 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -244,9 +244,6 @@ function set_default_backend(backend::String) return set_default_backend(XLA.backends[backend]) end -@static if !haskey(ENV, "REACTANT_TEST_GROUP") - @info "enable precompilation" gethostname() Base.active_project() - include("precompile.jl") -end +include("precompile.jl") end # module diff --git a/src/precompile.jl b/src/precompile.jl index 83d709626..dbd364374 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -1,6 +1,10 @@ using PrecompileTools: @setup_workload, @compile_workload @setup_workload begin + @static if haskey(ENV, "REACTANT_TEST_GROUP") + return + end + @info "enable precompilation" gethostname() Base.active_project() @compile_workload begin Reactant.__init__() cpu = XLA.CPUClient() From 72395334e29d063fec15f607d48f4feffdb1d17d Mon Sep 17 00:00:00 2001 From: glounes <130663280+glou-nes@users.noreply.github.com> Date: Sat, 14 Dec 2024 19:20:01 +0100 Subject: [PATCH 09/21] reviews --- src/{precompile.jl => Precompile.jl} | 6 ++++-- src/Reactant.jl | 15 ++++++++++++--- src/XLA.jl | 14 ++++++++------ 3 files changed, 24 insertions(+), 11 deletions(-) rename src/{precompile.jl => Precompile.jl} (74%) diff --git a/src/precompile.jl b/src/Precompile.jl similarity index 74% rename from src/precompile.jl rename to src/Precompile.jl index dbd364374..c3c86f55d 100644 --- a/src/precompile.jl +++ b/src/Precompile.jl @@ -4,11 +4,13 @@ using PrecompileTools: @setup_workload, @compile_workload @static if haskey(ENV, "REACTANT_TEST_GROUP") return end - @info "enable precompilation" gethostname() Base.active_project() @compile_workload begin - Reactant.__init__() + initialize_dialect() cpu = XLA.CPUClient() x = Reactant.ConcreteRArray(randn(Float64, 2, 2); client=cpu) @code_hlo optimize = false sum(x) + XLA.free_client(cpu) + deinitialize_dialect() end + XLA.cpuclientcount[] = 0 end diff --git a/src/Reactant.jl b/src/Reactant.jl index b2daddf1b..3102c7531 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -228,14 +228,23 @@ end using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace -const registry = Ref{MLIR.IR.DialectRegistry}() -function __init__() +const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}() + +function initialize_dialect() registry[] = MLIR.IR.DialectRegistry() @ccall MLIR.API.mlir_c.InitializeRegistryAndPasses( registry[]::MLIR.API.MlirDialectRegistry )::Cvoid end +function deinitialize_dialect() + return registry[] = nothing +end + +function __init__() + return initialize_dialect() +end + function set_default_backend(backend::XLA.Client) return XLA.default_backend[] = backend end @@ -244,6 +253,6 @@ function set_default_backend(backend::String) return set_default_backend(XLA.backends[backend]) end -include("precompile.jl") +include("Precompile.jl") end # module diff --git a/src/XLA.jl b/src/XLA.jl index 7d9822c54..01174e81c 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -7,15 +7,14 @@ mutable struct Client function Client(client::Ptr{Cvoid}) @assert client != C_NULL - client = new(client) - #TODO: Client are also constructed from MLIR.API.mlir_c.BufferToClient so the pointer cannot be free when Client is cleaned - #finalizer(client) do client - # @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid - #end - return client + return new(client) end end +@inline function free_client(client::Client) + @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid +end + function to_row_major(x::Array{T,N}) where {T,N} return permutedims(x, reverse(Base.OneTo(N))) end @@ -42,8 +41,11 @@ end SetLogLevel(x) = @ccall MLIR.API.mlir_c.SetLogLevel(x::Cint)::Cvoid +global cpuclientcount = Ref(0) # TODO synchronization when async is not working because `future` in `ConcreteRArray` is always `nothing` function CPUClient(asynchronous=false, node_id=0, num_nodes=1) + @assert cpuclientcount[] == 0 + cpuclientcount[] += 1 f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeCPUClient") client = ccall(f, Ptr{Cvoid}, (UInt, Cint, Cint), asynchronous, node_id, num_nodes) #client = @ccall MLIR.API.mlir_c.MakeCPUClient(asynchronous::UInt8, node_id::Cint, num_nodes::Cint)::Ptr{Cvoid} From 36c03cc55ac16d114640944e38f44c8838557a72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Sun, 15 Dec 2024 02:08:19 +0100 Subject: [PATCH 10/21] typo --- src/XLA.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/XLA.jl b/src/XLA.jl index 01174e81c..6d4b00197 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -41,7 +41,7 @@ end SetLogLevel(x) = @ccall MLIR.API.mlir_c.SetLogLevel(x::Cint)::Cvoid -global cpuclientcount = Ref(0) +const cpuclientcount = Ref(0) # TODO synchronization when async is not working because `future` in `ConcreteRArray` is always `nothing` function CPUClient(asynchronous=false, node_id=0, num_nodes=1) @assert cpuclientcount[] == 0 From 8c883bd548964952b90c2c99683261a1a8ae99e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Sun, 15 Dec 2024 02:09:23 +0100 Subject: [PATCH 11/21] remove CI ENV hack --- src/Precompile.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/Precompile.jl b/src/Precompile.jl index c3c86f55d..8fe9a5e6d 100644 --- a/src/Precompile.jl +++ b/src/Precompile.jl @@ -1,9 +1,6 @@ using PrecompileTools: @setup_workload, @compile_workload @setup_workload begin - @static if haskey(ENV, "REACTANT_TEST_GROUP") - return - end @compile_workload begin initialize_dialect() cpu = XLA.CPUClient() From 6c4d00cff83ab07944ae2f43d9582547c43705bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Sun, 15 Dec 2024 03:47:45 +0100 Subject: [PATCH 12/21] Revert "remove CI ENV hack" This reverts commit a688556ec469f6bdb0a5b3ec53ede9c8b6bc4ac3. --- src/Precompile.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/Precompile.jl b/src/Precompile.jl index 8fe9a5e6d..c3c86f55d 100644 --- a/src/Precompile.jl +++ b/src/Precompile.jl @@ -1,6 +1,9 @@ using PrecompileTools: @setup_workload, @compile_workload @setup_workload begin + @static if haskey(ENV, "REACTANT_TEST_GROUP") + return + end @compile_workload begin initialize_dialect() cpu = XLA.CPUClient() From 9d82b9f84c2316e873701c860bd3097282eb8fc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Sun, 15 Dec 2024 21:17:03 +0100 Subject: [PATCH 13/21] test CI --- src/Precompile.jl | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/Precompile.jl b/src/Precompile.jl index c3c86f55d..a125f0ca5 100644 --- a/src/Precompile.jl +++ b/src/Precompile.jl @@ -1,16 +1,13 @@ using PrecompileTools: @setup_workload, @compile_workload @setup_workload begin - @static if haskey(ENV, "REACTANT_TEST_GROUP") - return - end + initialize_dialect() + cpu = XLA.CPUClient() @compile_workload begin - initialize_dialect() - cpu = XLA.CPUClient() x = Reactant.ConcreteRArray(randn(Float64, 2, 2); client=cpu) @code_hlo optimize = false sum(x) - XLA.free_client(cpu) - deinitialize_dialect() end + XLA.free_client(cpu) + deinitialize_dialect() XLA.cpuclientcount[] = 0 end From d4c71bb2099217b980643b7d5306ff98bcc120bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Sun, 15 Dec 2024 22:05:00 +0100 Subject: [PATCH 14/21] test CI 2 --- src/Precompile.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/Precompile.jl b/src/Precompile.jl index a125f0ca5..3cbb964e2 100644 --- a/src/Precompile.jl +++ b/src/Precompile.jl @@ -2,12 +2,9 @@ using PrecompileTools: @setup_workload, @compile_workload @setup_workload begin initialize_dialect() - cpu = XLA.CPUClient() @compile_workload begin - x = Reactant.ConcreteRArray(randn(Float64, 2, 2); client=cpu) - @code_hlo optimize = false sum(x) + interp = Reactant.ReactantInterpreter() + Base.code_ircode(sum, (Reactant.TracedRArray{Float64,2},); interp) end - XLA.free_client(cpu) deinitialize_dialect() - XLA.cpuclientcount[] = 0 end From 4031e9cff412daae24c6cc1925ef1c67d944eff6 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 3 Jan 2025 16:12:04 -0500 Subject: [PATCH 15/21] de initialize opaque closure cache --- src/Precompile.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/Precompile.jl b/src/Precompile.jl index 3cbb964e2..4323eba00 100644 --- a/src/Precompile.jl +++ b/src/Precompile.jl @@ -7,4 +7,12 @@ using PrecompileTools: @setup_workload, @compile_workload Base.code_ircode(sum, (Reactant.TracedRArray{Float64,2},); interp) end deinitialize_dialect() + for v in oc_capture_vec + if v isa Base.RefValue + p = Ptr{Ptr{Cvoid}}(pointer_from_objref(r)) + Base.atomic_pointerset(p, C_NULL, :monotonic) + else + empty!(v) + end + end end From 77fe08255db29a56b3f9290cfd160d3ae8fbb4f7 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 3 Jan 2025 17:47:26 -0500 Subject: [PATCH 16/21] ongoing --- src/Compiler.jl | 23 ++++++++++++++++------- src/Precompile.jl | 11 ++++++++--- src/TracedUtils.jl | 24 ++++++++++++++---------- src/XLA.jl | 8 +++++--- src/utils.jl | 8 ++++++++ 5 files changed, 51 insertions(+), 23 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index ec2e96dfe..8c081cce2 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -334,12 +334,18 @@ const cuFunc = Ref{UInt}(0) const cuModule = Ref{UInt}(0) function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false) - fnwrapped, - func2, traced_result, result, seen_args, ret, linear_args, in_tys, - linear_results = MLIR.IR.mmodule!(mod) do - MLIR.IR.block!(MLIR.IR.body(mod)) do - return Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true) - end + # Explicitly don't use block! to avoid creating a closure, which creates + # both compile-time and relocatability issues + + MLIR.IR.activate!(mod) + MLIR.IR.activate!(MLIR.IR.body(mod)) + fnwrapped, func2, traced_result, result, seen_args, ret, linear_args, in_tys, + linear_results = + try + Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true) + finally + MLIR.IR.deactivate!(MLIR.IR.body(mod)) + MLIR.IR.deactivate!(mod) end concrete_seen = OrderedIdDict() @@ -828,7 +834,8 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false) ctx = MLIR.IR.Context(Reactant.registry[], false) @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid - return MLIR.IR.context!(ctx) do + MLIR.IR.activate!(ctx) + return try # compile function to MLIR module mod = MLIR.IR.Module(MLIR.IR.Location()) linear_args, linear_results, preserved_args, seen_args, concrete_result, isclosure = compile_mlir!( @@ -851,6 +858,8 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false) return exec, linear_args, linear_results, preserved_args, seen_args, concrete_result, isclosure + finally + MLIR.IR.deactivate!(ctx) end end diff --git a/src/Precompile.jl b/src/Precompile.jl index 4323eba00..9d6a0d601 100644 --- a/src/Precompile.jl +++ b/src/Precompile.jl @@ -2,14 +2,19 @@ using PrecompileTools: @setup_workload, @compile_workload @setup_workload begin initialize_dialect() + client = XLA.CPUClient(; checkcount=false) @compile_workload begin - interp = Reactant.ReactantInterpreter() - Base.code_ircode(sum, (Reactant.TracedRArray{Float64,2},); interp) + x = ConcreteRNumber(2.0; client) + Reactant.compile(sin, (x,); client) + # interp = Reactant.ReactantInterpreter() + # Base.code_ircode(Base.sin, (Reactant.TracedRNumber{Float64},); interp) end + XLA.free_client(client) + client.client = C_NULL deinitialize_dialect() for v in oc_capture_vec if v isa Base.RefValue - p = Ptr{Ptr{Cvoid}}(pointer_from_objref(r)) + p = Ptr{Ptr{Cvoid}}(pointer_from_objref(v)) Base.atomic_pointerset(p, C_NULL, :monotonic) else empty!(v) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index cd7f8623a..6bd29764b 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -115,8 +115,9 @@ function make_mlir_fn( N = length(args) seen_args = OrderedIdDict() - traced_args = ntuple(N) do i - return Reactant.make_tracer( + traced_args = Vector{Any}(undef, N) + for i in 1:N + @inbounds traced_args[i] = Reactant.make_tracer( seen_args, args[i], (:args, i), @@ -166,7 +167,10 @@ function make_mlir_fn( @assert MLIR.IR._has_block() - result = MLIR.IR.block!(fnbody) do + # Explicitly don't use block! to avoid creating a closure, which creates + # both compile-time and relocatability issues + MLIR.IR.activate!(fnbody) + result = try for (i, arg) in enumerate(linear_args) if construct_function_without_args arg.mlir_data = args[i].mlir_data @@ -177,12 +181,9 @@ function make_mlir_fn( end end - # TODO fix it for kwargs - #if concretein Reactant.call_with_reactant(f, traced_args...) - #else - # f(traced_args...) - #end + finally + MLIR.IR.deactivate!(fnbody) end seen_results = OrderedIdDict() @@ -215,7 +216,8 @@ function make_mlir_fn( out_tys = [transpose_ty(Ops.mlir_type(arg)) for arg in linear_results] - ret = MLIR.IR.block!(fnbody) do + MLIR.IR.activate!(fnbody) + ret = try vals = MLIR.IR.Value[] for res in linear_results col_maj = if res isa MissingTracedValue @@ -230,7 +232,9 @@ function make_mlir_fn( !no_args_in_result && @assert length(vals) == length(linear_results) dialect = getfield(MLIR.Dialects, return_dialect) - return dialect.return_(vals) + dialect.return_(vals) + finally + MLIR.IR.deactivate!(fnbody) end name2 = name diff --git a/src/XLA.jl b/src/XLA.jl index 6d4b00197..c7a526f2b 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -43,9 +43,11 @@ SetLogLevel(x) = @ccall MLIR.API.mlir_c.SetLogLevel(x::Cint)::Cvoid const cpuclientcount = Ref(0) # TODO synchronization when async is not working because `future` in `ConcreteRArray` is always `nothing` -function CPUClient(asynchronous=false, node_id=0, num_nodes=1) - @assert cpuclientcount[] == 0 - cpuclientcount[] += 1 +function CPUClient(asynchronous=false, node_id=0, num_nodes=1; checkcount=true) + if checkcount + @assert cpuclientcount[] == 0 + cpuclientcount[] += 1 + end f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeCPUClient") client = ccall(f, Ptr{Cvoid}, (UInt, Cint, Cint), asynchronous, node_id, num_nodes) #client = @ccall MLIR.API.mlir_c.MakeCPUClient(asynchronous::UInt8, node_id::Cint, num_nodes::Cint)::Ptr{Cvoid} diff --git a/src/utils.jl b/src/utils.jl index b91ad47ef..de9e19f33 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -527,6 +527,10 @@ function call_with_reactant_generator( # octup = Tuple{method.sig.parameters[2:end]...} octup = Tuple{tys[2:end]...} ocva = false + + # safe_print("src.relocatability", src.relocatability) + # We explicitly embed the global cache here, so it is definitionally not relocatable + # src.relocatability = 0 # jl_new_opaque_closure forcibly executes in the current world... This means that we won't get the right # inner code during compilation without special handling (i.e. call_in_world_total). @@ -572,6 +576,10 @@ function call_with_reactant_generator( code_info.ssavaluetypes = length(overdubbed_code) code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code + # We explicitly embed the global cache here, so it is definitionally not relocatable + # safe_print("code_info.relocatability", code_info.relocatability) + # code_info.relocatability = 0 + if DEBUG_INTERP[] safe_print("code_info", code_info) end From 598860ab4175737115ef02514488003407333880 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 3 Jan 2025 18:16:23 -0500 Subject: [PATCH 17/21] more attempts --- src/Precompile.jl | 44 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/src/Precompile.jl b/src/Precompile.jl index 9d6a0d601..4c9e425f8 100644 --- a/src/Precompile.jl +++ b/src/Precompile.jl @@ -1,16 +1,46 @@ using PrecompileTools: @setup_workload, @compile_workload +function infer_sig(sig) + interp = ReactantInterpreter() + + min_world = Ref{UInt}(typemin(UInt)) + max_world = Ref{UInt}(typemax(UInt)) + + lookup_result = Reactant.lookup_world( + sig, interp.world, Core.Compiler.method_table(interp), min_world, max_world + ) + match = lookup_result::Core.MethodMatch + # look up the method and code instance + mi = ccall( + :jl_specializations_get_linfo, + Ref{Core.MethodInstance}, + (Any, Any, Any), + match.method, + match.spec_types, + match.sparams, + ) + + @static if VERSION < v"1.11" + # For older Julia versions, we vendor in some of the code to prevent + # having to build the MethodInstance twice. + result = CC.InferenceResult(mi, CC.typeinf_lattice(interp)) + frame = CC.InferenceState(result, :no, interp) + @assert !isnothing(frame) + CC.typeinf(interp, frame) + ir = CC.run_passes(frame.src, CC.OptimizationState(frame, interp), result, nothing) + rt = CC.widenconst(CC.ignorelimited(result.result)) + else + ir, rt = CC.typeinf_ircode(interp, mi, nothing) + end +end + @setup_workload begin initialize_dialect() - client = XLA.CPUClient(; checkcount=false) @compile_workload begin - x = ConcreteRNumber(2.0; client) - Reactant.compile(sin, (x,); client) - # interp = Reactant.ReactantInterpreter() - # Base.code_ircode(Base.sin, (Reactant.TracedRNumber{Float64},); interp) + # infer_sig(Tuple{typeof(Base.sum), Reactant.TracedRArray{Float64, 2}}) + + infer_sig(Tuple{typeof(Base.sin), Reactant.TracedRNumber{Float64}}) end - XLA.free_client(client) - client.client = C_NULL deinitialize_dialect() for v in oc_capture_vec if v isa Base.RefValue From 26d0c0affe233718d628ad15c7697d5be2752882 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 3 Jan 2025 22:57:35 -0500 Subject: [PATCH 18/21] fix --- src/Precompile.jl | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/src/Precompile.jl b/src/Precompile.jl index 4c9e425f8..d18dda59b 100644 --- a/src/Precompile.jl +++ b/src/Precompile.jl @@ -36,18 +36,29 @@ end @setup_workload begin initialize_dialect() + client = XLA.CPUClient(; checkcount=false) @compile_workload begin - # infer_sig(Tuple{typeof(Base.sum), Reactant.TracedRArray{Float64, 2}}) - - infer_sig(Tuple{typeof(Base.sin), Reactant.TracedRNumber{Float64}}) - end - deinitialize_dialect() - for v in oc_capture_vec - if v isa Base.RefValue - p = Ptr{Ptr{Cvoid}}(pointer_from_objref(v)) - Base.atomic_pointerset(p, C_NULL, :monotonic) + # Precompilation on 1.10 hits an apparent bug: https://github.com/JuliaLang/julia/issues/56947 + @static if VERSION < v"1.10" else - empty!(v) + # infer_sig(Tuple{typeof(Base.sum), Reactant.TracedRArray{Float64, 2}}) + # infer_sig(Tuple{typeof(Base.sin), Reactant.TracedRNumber{Float64}}) + x = ConcreteRNumber(2.0; client) + Reactant.compile(sin, (x,); client) + + y = ConcreteRArray([2.0]; client) + Reactant.compile(Base.sum, (y,); client) end end + XLA.free_client(client) + client.client = C_NULL + deinitialize_dialect() + # for v in oc_capture_vec + # if v isa Base.RefValue + # p = Ptr{Ptr{Cvoid}}(pointer_from_objref(v)) + # Base.atomic_pointerset(p, C_NULL, :monotonic) + # else + # empty!(v) + # end + # end end From adc2622197dc24b58eef80620407fd6e3174f6be Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 3 Jan 2025 23:06:48 -0500 Subject: [PATCH 19/21] fix --- src/Precompile.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/Precompile.jl b/src/Precompile.jl index d18dda59b..e7a02c722 100644 --- a/src/Precompile.jl +++ b/src/Precompile.jl @@ -34,21 +34,20 @@ function infer_sig(sig) end end +@static if VERSION < v"1.10" +else @setup_workload begin initialize_dialect() client = XLA.CPUClient(; checkcount=false) @compile_workload begin # Precompilation on 1.10 hits an apparent bug: https://github.com/JuliaLang/julia/issues/56947 - @static if VERSION < v"1.10" - else - # infer_sig(Tuple{typeof(Base.sum), Reactant.TracedRArray{Float64, 2}}) - # infer_sig(Tuple{typeof(Base.sin), Reactant.TracedRNumber{Float64}}) - x = ConcreteRNumber(2.0; client) - Reactant.compile(sin, (x,); client) - - y = ConcreteRArray([2.0]; client) - Reactant.compile(Base.sum, (y,); client) - end + # infer_sig(Tuple{typeof(Base.sum), Reactant.TracedRArray{Float64, 2}}) + # infer_sig(Tuple{typeof(Base.sin), Reactant.TracedRNumber{Float64}}) + x = ConcreteRNumber(2.0; client) + Reactant.compile(sin, (x,); client) + + y = ConcreteRArray([2.0]; client) + Reactant.compile(Base.sum, (y,); client) end XLA.free_client(client) client.client = C_NULL @@ -62,3 +61,4 @@ end # end # end end +end From 03ff85e18f7c0133678369dd82ded73e83e116ab Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 3 Jan 2025 23:26:23 -0500 Subject: [PATCH 20/21] fix --- src/ConcreteRArray.jl | 12 ++++++------ src/Precompile.jl | 38 ++++++++++++++++++++------------------ src/utils.jl | 13 +++---------- 3 files changed, 29 insertions(+), 34 deletions(-) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 2c3561047..73979b4ca 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -1,12 +1,12 @@ function ConcreteRNumber{T}( - data::T2; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing + data::T2; client::XLA.Client=XLA.default_backend[], idx::Int=XLA.default_device_idx[], device::Union{Nothing, XLA.Device}=nothing ) where {T<:Number,T2<:Number} data = convert(T, data) crarray = ConcreteRArray(fill(data); client, idx, device) return ConcreteRNumber{T}(crarray.data) end function ConcreteRNumber( - data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing + data::T; client::XLA.Client=XLA.default_backend[], idx::Int=XLA.default_device_idx[], device::Union{Nothing, XLA.Device}=nothing ) where {T<:Number} crarray = ConcreteRArray(fill(data); client, idx, device) return ConcreteRNumber{T}(crarray.data) @@ -37,7 +37,7 @@ end Base.convert(::Type{T}, x::ConcreteRNumber) where {T<:Number} = convert(T, to_number(x)) function ConcreteRArray( - data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing + data::T; client::XLA.Client=XLA.default_backend[], idx::Int=XLA.default_device_idx[], device::Union{Nothing, XLA.Device}=nothing ) where {T<:Number} Base.depwarn( "ConcreteRArray(data::Number) is deprecated, use ConcreteRNumber(data) instead", @@ -52,9 +52,9 @@ Adapt.adapt_storage(::Type{T}, x::AbstractArray) where {T<:ConcreteRArray} = T(x function ConcreteRArray( data::Array{T,N}; - client=XLA.default_backend[], - idx=XLA.default_device_idx[], - device=nothing, + client::XLA.Client=XLA.default_backend[], + idx::Int=XLA.default_device_idx[], + device::Union{Nothing, XLA.Device}=nothing, ) where {T,N} device = device === nothing ? XLA.ClientGetDevice(client, idx) : device return ConcreteRArray{T,N}( diff --git a/src/Precompile.jl b/src/Precompile.jl index e7a02c722..4684287b7 100644 --- a/src/Precompile.jl +++ b/src/Precompile.jl @@ -34,31 +34,33 @@ function infer_sig(sig) end end -@static if VERSION < v"1.10" -else @setup_workload begin initialize_dialect() client = XLA.CPUClient(; checkcount=false) @compile_workload begin # Precompilation on 1.10 hits an apparent bug: https://github.com/JuliaLang/julia/issues/56947 - # infer_sig(Tuple{typeof(Base.sum), Reactant.TracedRArray{Float64, 2}}) - # infer_sig(Tuple{typeof(Base.sin), Reactant.TracedRNumber{Float64}}) - x = ConcreteRNumber(2.0; client) - Reactant.compile(sin, (x,); client) - - y = ConcreteRArray([2.0]; client) - Reactant.compile(Base.sum, (y,); client) + @static if VERSION < v"1.11" + else + # infer_sig(Tuple{typeof(Base.sum), Reactant.TracedRArray{Float64, 2}}) + # infer_sig(Tuple{typeof(Base.sin), Reactant.TracedRNumber{Float64}}) + x = ConcreteRNumber(2.0; client) + Reactant.compile(sin, (x,); client) + + y = ConcreteRArray([2.0]; client) + Reactant.compile(Base.sum, (y,); client) + end end XLA.free_client(client) client.client = C_NULL deinitialize_dialect() - # for v in oc_capture_vec - # if v isa Base.RefValue - # p = Ptr{Ptr{Cvoid}}(pointer_from_objref(v)) - # Base.atomic_pointerset(p, C_NULL, :monotonic) - # else - # empty!(v) - # end - # end -end + # Opaque closures capture the worldage of their compilation and thus are not relocatable + # Therefore we explicitly purge all OC's we have created here + for v in oc_capture_vec + if v isa Base.RefValue + p = Ptr{Ptr{Cvoid}}(pointer_from_objref(v)) + Base.atomic_pointerset(p, C_NULL, :monotonic) + else + empty!(v) + end + end end diff --git a/src/utils.jl b/src/utils.jl index de9e19f33..8ea259103 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -245,7 +245,8 @@ function make_oc_dict( )::Core.OpaqueClosure where {FT} key = f if haskey(oc_captures, key) - return oc_captures[key] + oc = oc_captures[key] + oc else ores = ccall( :jl_new_opaque_closure_from_code_info, @@ -527,11 +528,7 @@ function call_with_reactant_generator( # octup = Tuple{method.sig.parameters[2:end]...} octup = Tuple{tys[2:end]...} ocva = false - - # safe_print("src.relocatability", src.relocatability) - # We explicitly embed the global cache here, so it is definitionally not relocatable - # src.relocatability = 0 - + # jl_new_opaque_closure forcibly executes in the current world... This means that we won't get the right # inner code during compilation without special handling (i.e. call_in_world_total). # Opaque closures also require taking the function argument. We can work around the latter @@ -576,10 +573,6 @@ function call_with_reactant_generator( code_info.ssavaluetypes = length(overdubbed_code) code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code - # We explicitly embed the global cache here, so it is definitionally not relocatable - # safe_print("code_info.relocatability", code_info.relocatability) - # code_info.relocatability = 0 - if DEBUG_INTERP[] safe_print("code_info", code_info) end From 7bdcaab3580ca33618f01652b232f0c3a7058f81 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 3 Jan 2025 23:36:40 -0500 Subject: [PATCH 21/21] Disable rules for now as unused --- src/Interpreter.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Interpreter.jl b/src/Interpreter.jl index af7ead9da..d68041d2a 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -64,8 +64,8 @@ end ReactantCacheToken(), REACTANT_METHOD_TABLE, world, - true, #=forward_rules=# - true, #=reverse_rules=# + false, #=forward_rules=# + false, #=reverse_rules=# false, #=broadcast_rewrite=# set_reactant_abi, ) @@ -80,8 +80,8 @@ else REACTANT_CACHE, REACTANT_METHOD_TABLE, world, - true, #=forward_rules=# - true, #=forward_rules=# + false, #=forward_rules=# + false, #=forward_rules=# false, #=broadcast_rewrite=# set_reactant_abi, )