Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

precompile first interpretation #353

Merged
merged 22 commits into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
Expand Down Expand Up @@ -56,6 +57,7 @@ GPUArraysCore = "0.1.6, 0.2"
LinearAlgebra = "1.10"
NNlib = "0.9.26"
OrderedCollections = "1"
PrecompileTools = "1"
Preferences = "1.4"
PythonCall = "0.9"
Random = "1.10"
Expand Down
23 changes: 16 additions & 7 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -334,12 +334,18 @@ const cuFunc = Ref{UInt}(0)
const cuModule = Ref{UInt}(0)

function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false)
fnwrapped,
func2, traced_result, result, seen_args, ret, linear_args, in_tys,
linear_results = MLIR.IR.mmodule!(mod) do
MLIR.IR.block!(MLIR.IR.body(mod)) do
return Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
end
# Explicitly don't use block! to avoid creating a closure, which creates
# both compile-time and relocatability issues

MLIR.IR.activate!(mod)
MLIR.IR.activate!(MLIR.IR.body(mod))
fnwrapped, func2, traced_result, result, seen_args, ret, linear_args, in_tys,
linear_results =
try
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
finally
MLIR.IR.deactivate!(MLIR.IR.body(mod))
MLIR.IR.deactivate!(mod)
end

concrete_seen = OrderedIdDict()
Expand Down Expand Up @@ -828,7 +834,8 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false)
ctx = MLIR.IR.Context(Reactant.registry[], false)
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid

return MLIR.IR.context!(ctx) do
MLIR.IR.activate!(ctx)
return try
# compile function to MLIR module
mod = MLIR.IR.Module(MLIR.IR.Location())
linear_args, linear_results, preserved_args, seen_args, concrete_result, isclosure = compile_mlir!(
Expand All @@ -851,6 +858,8 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false)
return exec,
linear_args, linear_results, preserved_args, seen_args, concrete_result,
isclosure
finally
MLIR.IR.deactivate!(ctx)
end
end

Expand Down
12 changes: 6 additions & 6 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
function ConcreteRNumber{T}(
data::T2; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing
data::T2; client::XLA.Client=XLA.default_backend[], idx::Int=XLA.default_device_idx[], device::Union{Nothing, XLA.Device}=nothing
) where {T<:Number,T2<:Number}
data = convert(T, data)
crarray = ConcreteRArray(fill(data); client, idx, device)
return ConcreteRNumber{T}(crarray.data)
end
function ConcreteRNumber(
data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing
data::T; client::XLA.Client=XLA.default_backend[], idx::Int=XLA.default_device_idx[], device::Union{Nothing, XLA.Device}=nothing
) where {T<:Number}
crarray = ConcreteRArray(fill(data); client, idx, device)
return ConcreteRNumber{T}(crarray.data)
Expand Down Expand Up @@ -37,7 +37,7 @@ end
Base.convert(::Type{T}, x::ConcreteRNumber) where {T<:Number} = convert(T, to_number(x))

function ConcreteRArray(
data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing
data::T; client::XLA.Client=XLA.default_backend[], idx::Int=XLA.default_device_idx[], device::Union{Nothing, XLA.Device}=nothing
) where {T<:Number}
Base.depwarn(
"ConcreteRArray(data::Number) is deprecated, use ConcreteRNumber(data) instead",
Expand All @@ -52,9 +52,9 @@ Adapt.adapt_storage(::Type{T}, x::AbstractArray) where {T<:ConcreteRArray} = T(x

function ConcreteRArray(
data::Array{T,N};
client=XLA.default_backend[],
idx=XLA.default_device_idx[],
device=nothing,
client::XLA.Client=XLA.default_backend[],
idx::Int=XLA.default_device_idx[],
device::Union{Nothing, XLA.Device}=nothing,
) where {T,N}
device = device === nothing ? XLA.ClientGetDevice(client, idx) : device
return ConcreteRArray{T,N}(
Expand Down
8 changes: 4 additions & 4 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ end
ReactantCacheToken(),
REACTANT_METHOD_TABLE,
world,
true, #=forward_rules=#
true, #=reverse_rules=#
false, #=forward_rules=#
false, #=reverse_rules=#
false, #=broadcast_rewrite=#
set_reactant_abi,
)
Expand All @@ -80,8 +80,8 @@ else
REACTANT_CACHE,
REACTANT_METHOD_TABLE,
world,
true, #=forward_rules=#
true, #=forward_rules=#
false, #=forward_rules=#
false, #=forward_rules=#
false, #=broadcast_rewrite=#
set_reactant_abi,
)
Expand Down
66 changes: 66 additions & 0 deletions src/Precompile.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
using PrecompileTools: @setup_workload, @compile_workload

function infer_sig(sig)
interp = ReactantInterpreter()

min_world = Ref{UInt}(typemin(UInt))
max_world = Ref{UInt}(typemax(UInt))

lookup_result = Reactant.lookup_world(
sig, interp.world, Core.Compiler.method_table(interp), min_world, max_world
)
match = lookup_result::Core.MethodMatch
# look up the method and code instance
mi = ccall(
:jl_specializations_get_linfo,
Ref{Core.MethodInstance},
(Any, Any, Any),
match.method,
match.spec_types,
match.sparams,
)

@static if VERSION < v"1.11"
# For older Julia versions, we vendor in some of the code to prevent
# having to build the MethodInstance twice.
result = CC.InferenceResult(mi, CC.typeinf_lattice(interp))
frame = CC.InferenceState(result, :no, interp)
@assert !isnothing(frame)
CC.typeinf(interp, frame)
ir = CC.run_passes(frame.src, CC.OptimizationState(frame, interp), result, nothing)
rt = CC.widenconst(CC.ignorelimited(result.result))
else
ir, rt = CC.typeinf_ircode(interp, mi, nothing)
end
end

@setup_workload begin
initialize_dialect()
client = XLA.CPUClient(; checkcount=false)
@compile_workload begin
# Precompilation on 1.10 hits an apparent bug: https://github.com/JuliaLang/julia/issues/56947
@static if VERSION < v"1.11"
else
# infer_sig(Tuple{typeof(Base.sum), Reactant.TracedRArray{Float64, 2}})
# infer_sig(Tuple{typeof(Base.sin), Reactant.TracedRNumber{Float64}})
x = ConcreteRNumber(2.0; client)
Reactant.compile(sin, (x,); client)

y = ConcreteRArray([2.0]; client)
Reactant.compile(Base.sum, (y,); client)
end
end
glou-nes marked this conversation as resolved.
Show resolved Hide resolved
XLA.free_client(client)
client.client = C_NULL
deinitialize_dialect()
# Opaque closures capture the worldage of their compilation and thus are not relocatable
# Therefore we explicitly purge all OC's we have created here
for v in oc_capture_vec
if v isa Base.RefValue
p = Ptr{Ptr{Cvoid}}(pointer_from_objref(v))
Base.atomic_pointerset(p, C_NULL, :monotonic)
else
empty!(v)
end
end
end
15 changes: 13 additions & 2 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,23 @@ end
using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile
export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace

const registry = Ref{MLIR.IR.DialectRegistry}()
function __init__()
const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}()

function initialize_dialect()
registry[] = MLIR.IR.DialectRegistry()
@ccall MLIR.API.mlir_c.InitializeRegistryAndPasses(
registry[]::MLIR.API.MlirDialectRegistry
)::Cvoid
end

function deinitialize_dialect()
return registry[] = nothing
end

function __init__()
return initialize_dialect()
end

function set_default_backend(backend::XLA.Client)
return XLA.default_backend[] = backend
end
Expand All @@ -244,4 +253,6 @@ function set_default_backend(backend::String)
return set_default_backend(XLA.backends[backend])
end

include("Precompile.jl")

end # module
24 changes: 14 additions & 10 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,9 @@ function make_mlir_fn(

N = length(args)
seen_args = OrderedIdDict()
traced_args = ntuple(N) do i
return Reactant.make_tracer(
traced_args = Vector{Any}(undef, N)
for i in 1:N
@inbounds traced_args[i] = Reactant.make_tracer(
seen_args,
args[i],
(:args, i),
Expand Down Expand Up @@ -166,7 +167,10 @@ function make_mlir_fn(

@assert MLIR.IR._has_block()

result = MLIR.IR.block!(fnbody) do
# Explicitly don't use block! to avoid creating a closure, which creates
# both compile-time and relocatability issues
MLIR.IR.activate!(fnbody)
result = try
for (i, arg) in enumerate(linear_args)
if construct_function_without_args
arg.mlir_data = args[i].mlir_data
Expand All @@ -177,12 +181,9 @@ function make_mlir_fn(
end
end

# TODO fix it for kwargs
#if concretein
Reactant.call_with_reactant(f, traced_args...)
#else
# f(traced_args...)
#end
finally
MLIR.IR.deactivate!(fnbody)
end

seen_results = OrderedIdDict()
Expand Down Expand Up @@ -215,7 +216,8 @@ function make_mlir_fn(

out_tys = [transpose_ty(Ops.mlir_type(arg)) for arg in linear_results]

ret = MLIR.IR.block!(fnbody) do
MLIR.IR.activate!(fnbody)
ret = try
vals = MLIR.IR.Value[]
for res in linear_results
col_maj = if res isa MissingTracedValue
Expand All @@ -230,7 +232,9 @@ function make_mlir_fn(
!no_args_in_result && @assert length(vals) == length(linear_results)

dialect = getfield(MLIR.Dialects, return_dialect)
return dialect.return_(vals)
dialect.return_(vals)
finally
MLIR.IR.deactivate!(fnbody)
end

name2 = name
Expand Down
19 changes: 10 additions & 9 deletions src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ mutable struct Client
client::Ptr{Cvoid}

function Client(client::Ptr{Cvoid})
@assert client != C_NULL
return new(client)
#@assert client != C_NULL
#finalizer(new(client)) do client
# @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid
#end
end
end

@inline function free_client(client::Client)
@ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid
end

function to_row_major(x::Array{T,N}) where {T,N}
return permutedims(x, reverse(Base.OneTo(N)))
end
Expand Down Expand Up @@ -42,11 +43,11 @@ SetLogLevel(x) = @ccall MLIR.API.mlir_c.SetLogLevel(x::Cint)::Cvoid

const cpuclientcount = Ref(0)
# TODO synchronization when async is not working because `future` in `ConcreteRArray` is always `nothing`
function CPUClient(asynchronous=false, node_id=0, num_nodes=1)
global cpuclientcount
@assert cpuclientcount[] == 0
cpuclientcount[] += 1
glou-nes marked this conversation as resolved.
Show resolved Hide resolved

function CPUClient(asynchronous=false, node_id=0, num_nodes=1; checkcount=true)
if checkcount
@assert cpuclientcount[] == 0
cpuclientcount[] += 1
end
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeCPUClient")
client = ccall(f, Ptr{Cvoid}, (UInt, Cint, Cint), asynchronous, node_id, num_nodes)
#client = @ccall MLIR.API.mlir_c.MakeCPUClient(asynchronous::UInt8, node_id::Cint, num_nodes::Cint)::Ptr{Cvoid}
Expand Down
5 changes: 3 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ function make_oc_dict(
)::Core.OpaqueClosure where {FT}
key = f
if haskey(oc_captures, key)
return oc_captures[key]
oc = oc_captures[key]
oc
else
ores = ccall(
:jl_new_opaque_closure_from_code_info,
Expand Down Expand Up @@ -527,7 +528,7 @@ function call_with_reactant_generator(
# octup = Tuple{method.sig.parameters[2:end]...}
octup = Tuple{tys[2:end]...}
ocva = false

# jl_new_opaque_closure forcibly executes in the current world... This means that we won't get the right
# inner code during compilation without special handling (i.e. call_in_world_total).
# Opaque closures also require taking the function argument. We can work around the latter
Expand Down
Loading