From 4d27e39c51efa780d514615c2008659fcbd8ab1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Sat, 14 Dec 2024 18:52:47 +0100 Subject: [PATCH] review --- src/Precompile.jl | 4 +++- src/Reactant.jl | 13 +++++++++++-- src/XLA.jl | 11 +++++------ 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/Precompile.jl b/src/Precompile.jl index 5acb30067..acd8d4a8e 100644 --- a/src/Precompile.jl +++ b/src/Precompile.jl @@ -5,9 +5,11 @@ using PrecompileTools: @setup_workload, @compile_workload return end @compile_workload begin - Reactant.__init__() + initialize_dialect() cpu = XLA.CPUClient() x = Reactant.ConcreteRArray(randn(Float64, 2, 2); client=cpu) @code_hlo optimize = false sum(x) + XLA.free_client(cpu) + deinitialize_dialect() end end diff --git a/src/Reactant.jl b/src/Reactant.jl index 5f8d1623a..8424dc519 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -114,14 +114,23 @@ include("Compiler.jl") using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace -const registry = Ref{MLIR.IR.DialectRegistry}() -function __init__() +const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}() + +function initialize_dialect() registry[] = MLIR.IR.DialectRegistry() @ccall MLIR.API.mlir_c.InitializeRegistryAndPasses( registry[]::MLIR.API.MlirDialectRegistry )::Cvoid end +function deinitialize_dialect() + return registry[] = nothing +end + +function __init__() + return initialize_dialect() +end + function set_default_backend(backend::XLA.Client) return XLA.default_backend[] = backend end diff --git a/src/XLA.jl b/src/XLA.jl index bf4e311af..4000a17c1 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -7,15 +7,14 @@ mutable struct Client function Client(client::Ptr{Cvoid}) @assert client != C_NULL - client = new(client) - #TODO: Client are also constructed from MLIR.API.mlir_c.BufferToClient so the pointer cannot be free when Client is cleaned - #finalizer(client) do client - # @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid - #end - return client + return new(client) end end +@inline function free_client(client::Client) + @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid +end + function to_row_major(x::Array{T,N}) where {T,N} return permutedims(x, reverse(Base.OneTo(N))) end