Skip to content

Commit

Permalink
reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
glou-nes committed Dec 14, 2024
1 parent 3161fd9 commit e301a71
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 11 deletions.
6 changes: 4 additions & 2 deletions src/precompile.jl → src/Precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ using PrecompileTools: @setup_workload, @compile_workload
@static if haskey(ENV, "REACTANT_TEST_GROUP")
return
end
@info "enable precompilation" gethostname() Base.active_project()
@compile_workload begin
Reactant.__init__()
initialize_dialect()
cpu = XLA.CPUClient()
x = Reactant.ConcreteRArray(randn(Float64, 2, 2); client=cpu)
@code_hlo optimize = false sum(x)
XLA.free_client(cpu)
deinitialize_dialect()
end
XLA.cpuclientcount[] = 0
end
15 changes: 12 additions & 3 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,23 @@ include("Compiler.jl")
using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile
export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace

const registry = Ref{MLIR.IR.DialectRegistry}()
function __init__()
const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}()

function initialize_dialect()
registry[] = MLIR.IR.DialectRegistry()
@ccall MLIR.API.mlir_c.InitializeRegistryAndPasses(
registry[]::MLIR.API.MlirDialectRegistry
)::Cvoid
end

function deinitialize_dialect()
return registry[] = nothing
end

function __init__()
return initialize_dialect()
end

function set_default_backend(backend::XLA.Client)
return XLA.default_backend[] = backend
end
Expand All @@ -130,6 +139,6 @@ function set_default_backend(backend::String)
return set_default_backend(XLA.backends[backend])
end

include("precompile.jl")
include("Precompile.jl")

end # module
14 changes: 8 additions & 6 deletions src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@ mutable struct Client

function Client(client::Ptr{Cvoid})
@assert client != C_NULL
client = new(client)
#TODO: Client are also constructed from MLIR.API.mlir_c.BufferToClient so the pointer cannot be free when Client is cleaned
#finalizer(client) do client
# @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid
#end
return client
return new(client)
end
end

@inline function free_client(client::Client)
@ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid
end

function to_row_major(x::Array{T,N}) where {T,N}
return permutedims(x, reverse(Base.OneTo(N)))
end
Expand All @@ -42,8 +41,11 @@ end

SetLogLevel(x) = @ccall MLIR.API.mlir_c.SetLogLevel(x::Cint)::Cvoid

global cpuclientcount = Ref(0)
# TODO synchronization when async is not working because `future` in `ConcreteRArray` is always `nothing`
function CPUClient(asynchronous=false, node_id=0, num_nodes=1)
@assert cpuclientcount[] == 0
cpuclientcount[] += 1
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeCPUClient")
client = ccall(f, Ptr{Cvoid}, (UInt, Cint, Cint), asynchronous, node_id, num_nodes)
#client = @ccall MLIR.API.mlir_c.MakeCPUClient(asynchronous::UInt8, node_id::Cint, num_nodes::Cint)::Ptr{Cvoid}
Expand Down

0 comments on commit e301a71

Please sign in to comment.