Skip to content

Commit

Permalink
Merge branch 'main' into fix-type-unstability
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing authored Jul 22, 2024
2 parents 9de9d40 + 822057c commit 7a3f77f
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 32 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ authors = [
"Sergio Sánchez Ramírez <[email protected]>",
"Paul Berg <[email protected]>",
]
version = "0.1.7"
version = "0.1.8"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Expand Down
20 changes: 0 additions & 20 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,26 +237,6 @@ extern "C" void FutureAwait(FutureType* Future) {
Future->Await();
}

extern "C" void RunPassPipeline(const char* pass_pipeline, MlirModule cmod) {
mlir::ModuleOp mod = cast<ModuleOp>(unwrap(cmod));

mlir::PassManager pm(mod.getContext());

std::string error_message;
llvm::raw_string_ostream error_stream(error_message);
error_stream << "Failed to parse pipeline\n";
mlir::LogicalResult result =
mlir::parsePassPipeline(pass_pipeline, pm, error_stream);
if (mlir::failed(result)) {
llvm::errs() << error_message << "\n";
exit(1);
}
if (!mlir::succeeded(pm.run(mod))) {
llvm::errs() << "Pipeline failed" << "\n";
exit(1);
}
}

extern "C" void XLAExecute(xla::PjRtLoadedExecutable* exec, int num_args, PjRtBuffer** op_args, uint8_t* is_arg_donatable, int num_results, PjRtBuffer** op_results, uint8_t *futures, FutureType** future_results) {
std::vector<std::vector<PjRtBuffer*>> argument_handles;
argument_handles.emplace_back(op_args, op_args + num_args);
Expand Down
1 change: 0 additions & 1 deletion deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,6 @@ cc_library(
"-Wl,-exported_symbol,_FreeFuture",
"-Wl,-exported_symbol,_FutureIsReady",
"-Wl,-exported_symbol,_FutureAwait",
"-Wl,-exported_symbol,_RunPassPipeline",
"-Wl,-exported_symbol,_XLAExecute",
"-Wl,-exported_symbol,_RegisterDialects",
"-Wl,-exported_symbol,_InitializeRegistryAndPasses",
Expand Down
7 changes: 0 additions & 7 deletions src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,6 @@ module XLA

import ...MLIR

function RunPassPipeline(pass_pipeline, mod::MLIR.IR.Module)
GC.@preserve pass_pipeline mod begin
@ccall MLIR.API.mlir_c.RunPassPipeline(
pass_pipeline::Cstring, mod.module_::MLIR.API.MlirModule
)::Cvoid
end
end
mutable struct Client
client::Ptr{Cvoid}

Expand Down
6 changes: 5 additions & 1 deletion src/mlir/IR/Pass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ end
Run the provided `passManager` on the given `module`.
"""
function run!(pm::PassManager, mod::Module)
status = LogicalResult(API.mlirPassManagerRun(pm, mod))
status = LogicalResult(@static if isdefined(API, :mlirPassManagerRunOnOp)
API.mlirPassManagerRunOnOp(pm, Operation(mod))
else
API.mlirPassManagerRun(pm, mod)
end)
if isfailure(status)
throw("failed to run pass manager on module")
end
Expand Down
2 changes: 1 addition & 1 deletion src/mlir/IR/SymbolTable.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct SymbolTable
mutable struct SymbolTable
st::API.MlirSymbolTable

function SymbolTable(st)
Expand Down
16 changes: 15 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,23 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa
return MLIR.Dialects.func.return_(vals)
end

name2 = name

tab = MLIR.IR.SymbolTable(MLIR.IR.Operation(mod))
for i in 0:10000
name2 = if i == 0
name
else
name * string(i)
end
if MLIR.IR.mlirIsNull(MLIR.API.mlirSymbolTableLookup(tab, name2))
break
end
end

func2 = MLIR.IR.block!(MLIR.IR.body(mod)) do
return MLIR.Dialects.func.func_(;
sym_name=name,
sym_name=name2,
function_type=MLIR.IR.FunctionType(in_tys, out_tys),
body=MLIR.IR.Region(),
sym_visibility,
Expand Down

0 comments on commit 7a3f77f

Please sign in to comment.