Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 7, 2025
1 parent eb7729b commit f32a2f9
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
17 changes: 16 additions & 1 deletion deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,8 @@ cc_library(
"@com_google_absl//absl/log:initialize",
"@com_google_absl//absl/log:globals",
"@llvm-project//mlir:CAPIIRObjects",
"@llvm-project//mlir:CAPILLVMObjects",
"@llvm-project//mlir:CAPILLVMObjects",
"@jax//jaxlib/mosaic:tpu_dialect_capi",
] + select({
"@xla//xla/tsl:is_cuda_enabled_and_oss":[
"@xla//xla/stream_executor/cuda:all_runtime",
Expand Down Expand Up @@ -682,6 +683,20 @@ gentbl_cc_library(
tblgen = "//:mlir-jl-tblgen",
)

gentbl_cc_library(
name = "MosaicTPUJLIncGen",
tbl_outs = [(
["--generator=jl-op-defs", "--disable-module-wrap=0"],
"MosaicTPU.jl"
)
],
td_file = "@jax//jaxlib/mosaic:dialect/tpu/tpu.td",
deps = [
"@jax//jaxlib/mosaic:tpu_td_files",
],
tblgen = "//:mlir-jl-tblgen",
)

gentbl_cc_library(
name = "StableHLOJLIncGen",
tbl_outs = [(
Expand Down
6 changes: 3 additions & 3 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(

wrapper_tys = MLIR.IR.Type[]
ctx = MLIR.IR.context()
cullvm_ty = MLIR.API.mlirLLVMArrayTypeGet(MLIR.API.mlirLLVMPointerTypeGet(ctx, 1), 1)
cullvm_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.API.mlirLLVMPointerTypeGet(ctx, 1), 1))
for (i, a) in Tuple{Int, Any}[(0, func.f), enumerate(args)...]
if sizeof(a) == 0
continue
Expand All @@ -399,7 +399,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
# Per below we assume we can inline all other types directly in
end

sym_name = gensym("call_$fname")
sym_name = String(gensym("call_$fname"))
mod = MLIR.IR.mmodule()
wrapfunc = MLIR.IR.block!(MLIR.IR.body(mod)) do
return MLIR.Dialects.func.func_(;
Expand All @@ -414,7 +414,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
wrapargs = MLIR.IR.Value[]
argidx = 1

symtab = MLIR.IR.SymbolTable(mod)
symtab = MLIR.IR.SymbolTable(MLIR.IR.Operation(mod))
gpufunc = MLIR.IR.lookup(symtab, fname)
gpu_function_type = MLIR.IR.Type(Reactant.TracedUtils.get_attribute_by_name(gpufunc, "function_type"))

Expand Down

0 comments on commit f32a2f9

Please sign in to comment.