Skip to content

Commit

Permalink
review
Browse files Browse the repository at this point in the history
  • Loading branch information
glou-nes committed Dec 14, 2024
1 parent 5878d9f commit 4d27e39
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
4 changes: 3 additions & 1 deletion src/Precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ using PrecompileTools: @setup_workload, @compile_workload
return
end
@compile_workload begin
Reactant.__init__()
initialize_dialect()
cpu = XLA.CPUClient()
x = Reactant.ConcreteRArray(randn(Float64, 2, 2); client=cpu)
@code_hlo optimize = false sum(x)
XLA.free_client(cpu)
deinitialize_dialect()
end
end
13 changes: 11 additions & 2 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,23 @@ include("Compiler.jl")
using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile
export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace

const registry = Ref{MLIR.IR.DialectRegistry}()
function __init__()
const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}()

function initialize_dialect()
registry[] = MLIR.IR.DialectRegistry()
@ccall MLIR.API.mlir_c.InitializeRegistryAndPasses(
registry[]::MLIR.API.MlirDialectRegistry
)::Cvoid
end

function deinitialize_dialect()
return registry[] = nothing
end

function __init__()
return initialize_dialect()
end

function set_default_backend(backend::XLA.Client)
return XLA.default_backend[] = backend
end
Expand Down
11 changes: 5 additions & 6 deletions src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@ mutable struct Client

function Client(client::Ptr{Cvoid})
@assert client != C_NULL
client = new(client)
#TODO: Client are also constructed from MLIR.API.mlir_c.BufferToClient so the pointer cannot be free when Client is cleaned
#finalizer(client) do client
# @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid
#end
return client
return new(client)
end
end

@inline function free_client(client::Client)
@ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid
end

function to_row_major(x::Array{T,N}) where {T,N}
return permutedims(x, reverse(Base.OneTo(N)))
end
Expand Down

0 comments on commit 4d27e39

Please sign in to comment.