diff --git a/Project.toml b/Project.toml index cbcac208d..d40aa095d 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433" @@ -56,6 +57,7 @@ GPUArraysCore = "0.1.6, 0.2" LinearAlgebra = "1.10" NNlib = "0.9.26" OrderedCollections = "1" +PrecompileTools = "1" Preferences = "1.4" PythonCall = "0.9" Random = "1.10" diff --git a/src/Compiler.jl b/src/Compiler.jl index ec2e96dfe..8c081cce2 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -334,12 +334,18 @@ const cuFunc = Ref{UInt}(0) const cuModule = Ref{UInt}(0) function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false) - fnwrapped, - func2, traced_result, result, seen_args, ret, linear_args, in_tys, - linear_results = MLIR.IR.mmodule!(mod) do - MLIR.IR.block!(MLIR.IR.body(mod)) do - return Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true) - end + # Explicitly don't use block! to avoid creating a closure, which creates + # both compile-time and relocatability issues + + MLIR.IR.activate!(mod) + MLIR.IR.activate!(MLIR.IR.body(mod)) + fnwrapped, func2, traced_result, result, seen_args, ret, linear_args, in_tys, + linear_results = + try + Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true) + finally + MLIR.IR.deactivate!(MLIR.IR.body(mod)) + MLIR.IR.deactivate!(mod) end concrete_seen = OrderedIdDict() @@ -828,7 +834,8 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false) ctx = MLIR.IR.Context(Reactant.registry[], false) @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid - return MLIR.IR.context!(ctx) do + MLIR.IR.activate!(ctx) + return try # compile function to MLIR module mod = MLIR.IR.Module(MLIR.IR.Location()) linear_args, linear_results, preserved_args, seen_args, concrete_result, isclosure = compile_mlir!( @@ -851,6 +858,8 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false) return exec, linear_args, linear_results, preserved_args, seen_args, concrete_result, isclosure + finally + MLIR.IR.deactivate!(ctx) end end diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 2c3561047..73979b4ca 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -1,12 +1,12 @@ function ConcreteRNumber{T}( - data::T2; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing + data::T2; client::XLA.Client=XLA.default_backend[], idx::Int=XLA.default_device_idx[], device::Union{Nothing, XLA.Device}=nothing ) where {T<:Number,T2<:Number} data = convert(T, data) crarray = ConcreteRArray(fill(data); client, idx, device) return ConcreteRNumber{T}(crarray.data) end function ConcreteRNumber( - data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing + data::T; client::XLA.Client=XLA.default_backend[], idx::Int=XLA.default_device_idx[], device::Union{Nothing, XLA.Device}=nothing ) where {T<:Number} crarray = ConcreteRArray(fill(data); client, idx, device) return ConcreteRNumber{T}(crarray.data) @@ -37,7 +37,7 @@ end Base.convert(::Type{T}, x::ConcreteRNumber) where {T<:Number} = convert(T, to_number(x)) function ConcreteRArray( - data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing + data::T; client::XLA.Client=XLA.default_backend[], idx::Int=XLA.default_device_idx[], device::Union{Nothing, XLA.Device}=nothing ) where {T<:Number} Base.depwarn( "ConcreteRArray(data::Number) is deprecated, use ConcreteRNumber(data) instead", @@ -52,9 +52,9 @@ Adapt.adapt_storage(::Type{T}, x::AbstractArray) where {T<:ConcreteRArray} = T(x function ConcreteRArray( data::Array{T,N}; - client=XLA.default_backend[], - idx=XLA.default_device_idx[], - device=nothing, + client::XLA.Client=XLA.default_backend[], + idx::Int=XLA.default_device_idx[], + device::Union{Nothing, XLA.Device}=nothing, ) where {T,N} device = device === nothing ? XLA.ClientGetDevice(client, idx) : device return ConcreteRArray{T,N}( diff --git a/src/Interpreter.jl b/src/Interpreter.jl index af7ead9da..d68041d2a 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -64,8 +64,8 @@ end ReactantCacheToken(), REACTANT_METHOD_TABLE, world, - true, #=forward_rules=# - true, #=reverse_rules=# + false, #=forward_rules=# + false, #=reverse_rules=# false, #=broadcast_rewrite=# set_reactant_abi, ) @@ -80,8 +80,8 @@ else REACTANT_CACHE, REACTANT_METHOD_TABLE, world, - true, #=forward_rules=# - true, #=forward_rules=# + false, #=forward_rules=# + false, #=forward_rules=# false, #=broadcast_rewrite=# set_reactant_abi, ) diff --git a/src/Precompile.jl b/src/Precompile.jl new file mode 100644 index 000000000..4684287b7 --- /dev/null +++ b/src/Precompile.jl @@ -0,0 +1,66 @@ +using PrecompileTools: @setup_workload, @compile_workload + +function infer_sig(sig) + interp = ReactantInterpreter() + + min_world = Ref{UInt}(typemin(UInt)) + max_world = Ref{UInt}(typemax(UInt)) + + lookup_result = Reactant.lookup_world( + sig, interp.world, Core.Compiler.method_table(interp), min_world, max_world + ) + match = lookup_result::Core.MethodMatch + # look up the method and code instance + mi = ccall( + :jl_specializations_get_linfo, + Ref{Core.MethodInstance}, + (Any, Any, Any), + match.method, + match.spec_types, + match.sparams, + ) + + @static if VERSION < v"1.11" + # For older Julia versions, we vendor in some of the code to prevent + # having to build the MethodInstance twice. + result = CC.InferenceResult(mi, CC.typeinf_lattice(interp)) + frame = CC.InferenceState(result, :no, interp) + @assert !isnothing(frame) + CC.typeinf(interp, frame) + ir = CC.run_passes(frame.src, CC.OptimizationState(frame, interp), result, nothing) + rt = CC.widenconst(CC.ignorelimited(result.result)) + else + ir, rt = CC.typeinf_ircode(interp, mi, nothing) + end +end + +@setup_workload begin + initialize_dialect() + client = XLA.CPUClient(; checkcount=false) + @compile_workload begin + # Precompilation on 1.10 hits an apparent bug: https://github.com/JuliaLang/julia/issues/56947 + @static if VERSION < v"1.11" + else + # infer_sig(Tuple{typeof(Base.sum), Reactant.TracedRArray{Float64, 2}}) + # infer_sig(Tuple{typeof(Base.sin), Reactant.TracedRNumber{Float64}}) + x = ConcreteRNumber(2.0; client) + Reactant.compile(sin, (x,); client) + + y = ConcreteRArray([2.0]; client) + Reactant.compile(Base.sum, (y,); client) + end + end + XLA.free_client(client) + client.client = C_NULL + deinitialize_dialect() + # Opaque closures capture the worldage of their compilation and thus are not relocatable + # Therefore we explicitly purge all OC's we have created here + for v in oc_capture_vec + if v isa Base.RefValue + p = Ptr{Ptr{Cvoid}}(pointer_from_objref(v)) + Base.atomic_pointerset(p, C_NULL, :monotonic) + else + empty!(v) + end + end +end diff --git a/src/Reactant.jl b/src/Reactant.jl index 039a717a9..3102c7531 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -228,14 +228,23 @@ end using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace -const registry = Ref{MLIR.IR.DialectRegistry}() -function __init__() +const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}() + +function initialize_dialect() registry[] = MLIR.IR.DialectRegistry() @ccall MLIR.API.mlir_c.InitializeRegistryAndPasses( registry[]::MLIR.API.MlirDialectRegistry )::Cvoid end +function deinitialize_dialect() + return registry[] = nothing +end + +function __init__() + return initialize_dialect() +end + function set_default_backend(backend::XLA.Client) return XLA.default_backend[] = backend end @@ -244,4 +253,6 @@ function set_default_backend(backend::String) return set_default_backend(XLA.backends[backend]) end +include("Precompile.jl") + end # module diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index cd7f8623a..6bd29764b 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -115,8 +115,9 @@ function make_mlir_fn( N = length(args) seen_args = OrderedIdDict() - traced_args = ntuple(N) do i - return Reactant.make_tracer( + traced_args = Vector{Any}(undef, N) + for i in 1:N + @inbounds traced_args[i] = Reactant.make_tracer( seen_args, args[i], (:args, i), @@ -166,7 +167,10 @@ function make_mlir_fn( @assert MLIR.IR._has_block() - result = MLIR.IR.block!(fnbody) do + # Explicitly don't use block! to avoid creating a closure, which creates + # both compile-time and relocatability issues + MLIR.IR.activate!(fnbody) + result = try for (i, arg) in enumerate(linear_args) if construct_function_without_args arg.mlir_data = args[i].mlir_data @@ -177,12 +181,9 @@ function make_mlir_fn( end end - # TODO fix it for kwargs - #if concretein Reactant.call_with_reactant(f, traced_args...) - #else - # f(traced_args...) - #end + finally + MLIR.IR.deactivate!(fnbody) end seen_results = OrderedIdDict() @@ -215,7 +216,8 @@ function make_mlir_fn( out_tys = [transpose_ty(Ops.mlir_type(arg)) for arg in linear_results] - ret = MLIR.IR.block!(fnbody) do + MLIR.IR.activate!(fnbody) + ret = try vals = MLIR.IR.Value[] for res in linear_results col_maj = if res isa MissingTracedValue @@ -230,7 +232,9 @@ function make_mlir_fn( !no_args_in_result && @assert length(vals) == length(linear_results) dialect = getfield(MLIR.Dialects, return_dialect) - return dialect.return_(vals) + dialect.return_(vals) + finally + MLIR.IR.deactivate!(fnbody) end name2 = name diff --git a/src/XLA.jl b/src/XLA.jl index 6255737e4..c7a526f2b 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -6,14 +6,15 @@ mutable struct Client client::Ptr{Cvoid} function Client(client::Ptr{Cvoid}) + @assert client != C_NULL return new(client) - #@assert client != C_NULL - #finalizer(new(client)) do client - # @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid - #end end end +@inline function free_client(client::Client) + @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid +end + function to_row_major(x::Array{T,N}) where {T,N} return permutedims(x, reverse(Base.OneTo(N))) end @@ -42,11 +43,11 @@ SetLogLevel(x) = @ccall MLIR.API.mlir_c.SetLogLevel(x::Cint)::Cvoid const cpuclientcount = Ref(0) # TODO synchronization when async is not working because `future` in `ConcreteRArray` is always `nothing` -function CPUClient(asynchronous=false, node_id=0, num_nodes=1) - global cpuclientcount - @assert cpuclientcount[] == 0 - cpuclientcount[] += 1 - +function CPUClient(asynchronous=false, node_id=0, num_nodes=1; checkcount=true) + if checkcount + @assert cpuclientcount[] == 0 + cpuclientcount[] += 1 + end f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeCPUClient") client = ccall(f, Ptr{Cvoid}, (UInt, Cint, Cint), asynchronous, node_id, num_nodes) #client = @ccall MLIR.API.mlir_c.MakeCPUClient(asynchronous::UInt8, node_id::Cint, num_nodes::Cint)::Ptr{Cvoid} diff --git a/src/utils.jl b/src/utils.jl index b91ad47ef..8ea259103 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -245,7 +245,8 @@ function make_oc_dict( )::Core.OpaqueClosure where {FT} key = f if haskey(oc_captures, key) - return oc_captures[key] + oc = oc_captures[key] + oc else ores = ccall( :jl_new_opaque_closure_from_code_info, @@ -527,7 +528,7 @@ function call_with_reactant_generator( # octup = Tuple{method.sig.parameters[2:end]...} octup = Tuple{tys[2:end]...} ocva = false - + # jl_new_opaque_closure forcibly executes in the current world... This means that we won't get the right # inner code during compilation without special handling (i.e. call_in_world_total). # Opaque closures also require taking the function argument. We can work around the latter