Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 7, 2025
1 parent f32a2f9 commit 6dd132b
Show file tree
Hide file tree
Showing 9 changed files with 597 additions and 1,921 deletions.
13 changes: 13 additions & 0 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,11 +432,24 @@ extern "C" MlirModule ConvertLLVMStrToMLIR(const char *lmod, MlirContext cctx) {
SMDiagnostic Err;
auto llvmModule =
llvm::parseIR(llvm::MemoryBufferRef(lmod, "conversion"), Err, Context);
if (!llvmModule) {
std::string err_str;
llvm::raw_string_ostream err_stream(err_str);
Err.print(/*ProgName=*/"LLVMToMLIR", err_stream);
err_stream.flush();
if (ReactantThrowError) {
ReactantThrowError(err_str.c_str());
return wrap((mlir::ModuleOp)nullptr);
}
}
mlir::MLIRContext &context = *unwrap(cctx);
auto res = mlir::translateLLVMIRToModule(std::move(llvmModule), &context,
/*emitExpensiveWarnings*/ false,
/*dropDICompositeElements*/ false)
.release();
if (!res) {
ReactantThrowError("Could not translate LLVM IR to MLIR Module");
}
return wrap(res);
}

Expand Down
2 changes: 1 addition & 1 deletion deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ http_archive(
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
)

ENZYMEXLA_COMMIT = "92074225b9546e332042b915c76b7561f7fa038d"
ENZYMEXLA_COMMIT = "d601991a87b010023f85ec9f11fa2eb827bf1b90"
ENZYMEXLA_SHA256 = ""

http_archive(
Expand Down
1 change: 1 addition & 0 deletions deps/ReactantExtra/make-bindings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ for file in [
"Nvvm.jl",
"Gpu.jl",
"Affine.jl",
"MosaicTPU.jl"
]
build_file(joinpath(src_dir, "mlir", "Dialects", file))
end
Expand Down
25 changes: 14 additions & 11 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ function compile(job)
# TODO: on 1.9, this actually creates a context. cache those.
entry = GPUCompiler.JuliaContext() do ctx
mod, meta = GPUCompiler.compile(
:llvm, job; optimize=false, cleanup=false, validate=false
:llvm, job; optimize=false, cleanup=false, validate=false, libraries=false
)

entryname = LLVM.name(meta.entry)
Expand Down Expand Up @@ -322,8 +322,6 @@ function compile(job)
LLVM.strip_debuginfo!(mod)
modstr = string(mod)

println(modstr)

# This is a bit weird since we're taking a module from julia's llvm into reactant's llvm version
# it is probably safer to reparse a string using the right llvm module api, so we will do that.

Expand All @@ -332,14 +330,14 @@ function compile(job)
modstr::Cstring, MLIR.IR.context()::MLIR.API.MlirContext
)::MLIR.API.MlirModule
)
@assert mmod != C_NULL

linkRes = @ccall MLIR.API.mlir_c.LinkInModule(
MLIR.IR.mmodule()::MLIR.API.MlirModule,
mmod::MLIR.API.MlirModule,
entryname::Cstring,
)::MLIR.API.MlirOperation

entry = MLIR.IR.Operation(linkRes)
String(Reactant.TracedUtils.get_attribute_by_name(linkRes, "sym_name"))
end

Expand Down Expand Up @@ -401,11 +399,16 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(

sym_name = String(gensym("call_$fname"))
mod = MLIR.IR.mmodule()
CConv=MLIR.IR.Attribute(MLIR.API.mlirLLVMCConvAttrGet(ctx, MLIR.API.MlirLLVMCConvPTX_Kernel))
voidty = MLIR.IR.Type(MLIR.API.mlirLLVMVoidTypeGet(ctx))
wrapftype = MLIR.IR.Type(MLIR.API.mlirLLVMFunctionTypeGet(voidty, length(wrapper_tys), wrapper_tys, false))
wrapfunc = MLIR.IR.block!(MLIR.IR.body(mod)) do
return MLIR.Dialects.func.func_(;
return MLIR.Dialects.llvm.func(;
sym_name,
function_type=MLIR.IR.FunctionType(wrapper_tys, []),
body=MLIR.IR.Region()
sym_visibility=MLIR.IR.Attribute("private"),
function_type=wrapftype,
body=MLIR.IR.Region(),
CConv
)
end
wrapbody = MLIR.IR.Block(wrapper_tys, [MLIR.IR.Location() for _ in wrapper_tys])
Expand All @@ -416,6 +419,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(

symtab = MLIR.IR.SymbolTable(MLIR.IR.Operation(mod))
gpufunc = MLIR.IR.lookup(symtab, fname)
MLIR.IR.attr!(gpufunc, "CConv", MLIR.IR.Attribute(MLIR.API.mlirLLVMCConvAttrGet(ctx, MLIR.API.MlirLLVMCConvC)))
gpu_function_type = MLIR.IR.Type(Reactant.TracedUtils.get_attribute_by_name(gpufunc, "function_type"))


Expand Down Expand Up @@ -467,13 +471,12 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
end

MLIR.IR.block!(wrapbody) do
MLIR.Dialects.func.call(wrapargs; result_0=MLIR.IR.Type[], callee=wrapfn)
MLIR.Dialects.func.return_(MLIR.IR.Value[])
MLIR.Dialects.llvm.call(wrapargs, MLIR.IR.Value[]; callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)), op_bundle_sizes=MLIR.IR.Attribute(Int32[]))
MLIR.Dialects.llvm.return_(nothing)
end

output_operand_aliases = MLIR.IR.Attribute(aliases)


blk_operands = MLIR.IR.Value[]
for idx in
(blockdim.x, blockdim.y, blockdim.z, threaddim.x, threaddim.y, threaddim.z, shmem)
Expand All @@ -488,7 +491,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
blk_operands...,
mlir_args;
result_0=restys,
fn=sym_name,
fn=MLIR.IR.FlatSymbolRefAttribute(sym_name),
output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases)
)
for (i, res) in enumerate(rarrays)
Expand Down
2 changes: 1 addition & 1 deletion src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
if isdefined(Reactant_jll, :ptxas_path)
toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))]
end
kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])}"
kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])},symbol-dce"

opt_passes = optimization_passes(; no_nan)

Expand Down
27 changes: 18 additions & 9 deletions src/mlir/Dialects/Llvm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ import ...API



function ashr(lhs::Value, rhs::Value; res=nothing::Union{Nothing, IR.Type}, location=Location())
function ashr(lhs::Value, rhs::Value; res=nothing::Union{Nothing, IR.Type}, isExact=nothing, location=Location())
op_ty_results = IR.Type[]
operands = Value[lhs, rhs, ]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(res) && push!(op_ty_results, res)
!isnothing(isExact) && push!(attributes, namedattribute("isExact", isExact))

create_operation(
"llvm.ashr", location;
Expand Down Expand Up @@ -1155,13 +1156,14 @@ function func(; sym_name, sym_visibility=nothing, function_type, linkage=nothing
end


function lshr(lhs::Value, rhs::Value; res=nothing::Union{Nothing, IR.Type}, location=Location())
function lshr(lhs::Value, rhs::Value; res=nothing::Union{Nothing, IR.Type}, isExact=nothing, location=Location())
op_ty_results = IR.Type[]
operands = Value[lhs, rhs, ]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(res) && push!(op_ty_results, res)
!isnothing(isExact) && push!(attributes, namedattribute("isExact", isExact))

create_operation(
"llvm.lshr", location;
Expand Down Expand Up @@ -1245,7 +1247,7 @@ Examples:
See the following link for more details:
https://llvm.org/docs/LangRef.html#load-instruction
"""
function load(addr::Value; res::IR.Type, alignment=nothing, volatile_=nothing, nontemporal=nothing, invariant=nothing, ordering=nothing, syncscope=nothing, access_groups=nothing, alias_scopes=nothing, noalias_scopes=nothing, tbaa=nothing, location=Location())
function load(addr::Value; res::IR.Type, alignment=nothing, volatile_=nothing, nontemporal=nothing, invariant=nothing, invariantGroup=nothing, ordering=nothing, syncscope=nothing, access_groups=nothing, alias_scopes=nothing, noalias_scopes=nothing, tbaa=nothing, location=Location())
op_ty_results = IR.Type[res, ]
operands = Value[addr, ]
owned_regions = Region[]
Expand All @@ -1255,6 +1257,7 @@ function load(addr::Value; res::IR.Type, alignment=nothing, volatile_=nothing, n
!isnothing(volatile_) && push!(attributes, namedattribute("volatile_", volatile_))
!isnothing(nontemporal) && push!(attributes, namedattribute("nontemporal", nontemporal))
!isnothing(invariant) && push!(attributes, namedattribute("invariant", invariant))
!isnothing(invariantGroup) && push!(attributes, namedattribute("invariantGroup", invariantGroup))
!isnothing(ordering) && push!(attributes, namedattribute("ordering", ordering))
!isnothing(syncscope) && push!(attributes, namedattribute("syncscope", syncscope))
!isnothing(access_groups) && push!(attributes, namedattribute("access_groups", access_groups))
Expand Down Expand Up @@ -1318,13 +1321,14 @@ function mlir_none(; res=nothing::Union{Nothing, IR.Type}, location=Location())
end


function or(lhs::Value, rhs::Value; res=nothing::Union{Nothing, IR.Type}, location=Location())
function or(lhs::Value, rhs::Value; res=nothing::Union{Nothing, IR.Type}, isDisjoint=nothing, location=Location())
op_ty_results = IR.Type[]
operands = Value[lhs, rhs, ]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(res) && push!(op_ty_results, res)
!isnothing(isDisjoint) && push!(attributes, namedattribute("isDisjoint", isDisjoint))

create_operation(
"llvm.or", location;
Expand Down Expand Up @@ -1415,13 +1419,14 @@ function return_(arg=nothing::Union{Nothing, Value}; location=Location())
end


function sdiv(lhs::Value, rhs::Value; res=nothing::Union{Nothing, IR.Type}, location=Location())
function sdiv(lhs::Value, rhs::Value; res=nothing::Union{Nothing, IR.Type}, isExact=nothing, location=Location())
op_ty_results = IR.Type[]
operands = Value[lhs, rhs, ]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(res) && push!(op_ty_results, res)
!isnothing(isExact) && push!(attributes, namedattribute("isExact", isExact))

create_operation(
"llvm.sdiv", location;
Expand Down Expand Up @@ -1557,7 +1562,7 @@ llvm.store %val, %ptr atomic monotonic {alignment = 8 : i64}
See the following link for more details:
https://llvm.org/docs/LangRef.html#store-instruction
"""
function store(value::Value, addr::Value; alignment=nothing, volatile_=nothing, nontemporal=nothing, ordering=nothing, syncscope=nothing, access_groups=nothing, alias_scopes=nothing, noalias_scopes=nothing, tbaa=nothing, location=Location())
function store(value::Value, addr::Value; alignment=nothing, volatile_=nothing, nontemporal=nothing, invariantGroup=nothing, ordering=nothing, syncscope=nothing, access_groups=nothing, alias_scopes=nothing, noalias_scopes=nothing, tbaa=nothing, location=Location())
op_ty_results = IR.Type[]
operands = Value[value, addr, ]
owned_regions = Region[]
Expand All @@ -1566,6 +1571,7 @@ function store(value::Value, addr::Value; alignment=nothing, volatile_=nothing,
!isnothing(alignment) && push!(attributes, namedattribute("alignment", alignment))
!isnothing(volatile_) && push!(attributes, namedattribute("volatile_", volatile_))
!isnothing(nontemporal) && push!(attributes, namedattribute("nontemporal", nontemporal))
!isnothing(invariantGroup) && push!(attributes, namedattribute("invariantGroup", invariantGroup))
!isnothing(ordering) && push!(attributes, namedattribute("ordering", ordering))
!isnothing(syncscope) && push!(attributes, namedattribute("syncscope", syncscope))
!isnothing(access_groups) && push!(attributes, namedattribute("access_groups", access_groups))
Expand Down Expand Up @@ -1634,13 +1640,14 @@ function trunc(arg::Value; res::IR.Type, location=Location())
end


function udiv(lhs::Value, rhs::Value; res=nothing::Union{Nothing, IR.Type}, location=Location())
function udiv(lhs::Value, rhs::Value; res=nothing::Union{Nothing, IR.Type}, isExact=nothing, location=Location())
op_ty_results = IR.Type[]
operands = Value[lhs, rhs, ]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(res) && push!(op_ty_results, res)
!isnothing(isExact) && push!(attributes, namedattribute("isExact", isExact))

create_operation(
"llvm.udiv", location;
Expand All @@ -1651,12 +1658,13 @@ function udiv(lhs::Value, rhs::Value; res=nothing::Union{Nothing, IR.Type}, loca
end


function uitofp(arg::Value; res::IR.Type, location=Location())
function uitofp(arg::Value; res::IR.Type, nonNeg=nothing, location=Location())
op_ty_results = IR.Type[res, ]
operands = Value[arg, ]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(nonNeg) && push!(attributes, namedattribute("nonNeg", nonNeg))

create_operation(
"llvm.uitofp", location;
Expand Down Expand Up @@ -1763,12 +1771,13 @@ function xor(lhs::Value, rhs::Value; res=nothing::Union{Nothing, IR.Type}, locat
end


function zext(arg::Value; res::IR.Type, location=Location())
function zext(arg::Value; res::IR.Type, nonNeg=nothing, location=Location())
op_ty_results = IR.Type[res, ]
operands = Value[arg, ]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(nonNeg) && push!(attributes, namedattribute("nonNeg", nonNeg))

create_operation(
"llvm.zext", location;
Expand Down
Loading

0 comments on commit 6dd132b

Please sign in to comment.