diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index c3c324d2b..605915b94 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -7,13 +7,15 @@ function __init__() for (jlop, hloop) in ((:(NNlib.tanh), :tanh), (:(NNlib.tanh_fast), :tanh)) @eval begin if $jlop != Base.tanh && $jlop != Base.FastMath.tanh_fast - function Reactant.elem_apply(::typeof($jlop), - lhs::Reactant.TracedRArray{ElType,Shape,N}) where {ElType, - Shape, - N} - return Reactant.TracedRArray{ElType,Shape,N}((), - Reactant.MLIR.IR.result(Reactant.MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), - 1)) + function Reactant.elem_apply( + ::typeof($jlop), lhs::Reactant.TracedRArray{ElType,Shape,N} + ) where {ElType,Shape,N} + return Reactant.TracedRArray{ElType,Shape,N}( + (), + Reactant.MLIR.IR.result( + Reactant.MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1 + ), + ) end end end @@ -21,8 +23,9 @@ function __init__() end # TODO handle non finite cases -function NNlib.softmax!(out::Reactant.TracedRArray{T,Shape,N}, x::AbstractArray; - dims=1) where {T,Shape,N} +function NNlib.softmax!( + out::Reactant.TracedRArray{T,Shape,N}, x::AbstractArray; dims=1 +) where {T,Shape,N} max_ = NNlib.fast_maximum(x; dims) #if all(isfinite, max_) @fastmath out .= exp.(x .- max_) diff --git a/src/Reactant.jl b/src/Reactant.jl index d8bfe6e52..2389f083c 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -11,19 +11,19 @@ abstract type RArray{ElType,Shape,N} <: AbstractArray{ElType,N} end @inline Base.ndims(::RArray{ElType,Shape,N}) where {ElType,Shape,N} = N @inline Base.ndims(::Type{<:RArray{ElType,Shape,N}}) where {ElType,Shape,N} = N -@inline mlir_type(::RArray{ElType,Shape,N}) where {ElType,Shape,N} = MLIR.IR.TensorType(Shape, - MLIR.IR.Type(ElType)) +@inline mlir_type(::RArray{ElType,Shape,N}) where {ElType,Shape,N} = + MLIR.IR.TensorType(Shape, MLIR.IR.Type(ElType)) -struct XLAArray{ElType,Shape,N} <: RArray{ElType,Shape,N} -end +struct XLAArray{ElType,Shape,N} <: RArray{ElType,Shape,N} end mutable struct ConcreteRArray{ElType,Shape,N} <: RArray{ElType,Shape,N} data::XLA.AsyncBuffer # data::XLAArray{ElType, Shape, N} end -function Base.convert(::Type{T}, - X::ConcreteRArray{ElType,Shape,N}) where {T<:Array,ElType,Shape,N} +function Base.convert( + ::Type{T}, X::ConcreteRArray{ElType,Shape,N} +) where {T<:Array,ElType,Shape,N} data = Array{ElType,N}(undef, Shape...) XLA.await(X.data) buf = X.data.buffer @@ -52,15 +52,16 @@ function Base.isapprox(x, y::ConcreteRArray{ElType,(),0}; kwargs...) where {ElTy return Base.isapprox(to_float(x), y; kwargs...) end -function Base.isapprox(x::ConcreteRArray{ElType,(),0}, y::ConcreteRArray{ElType2,(),0}; - kwargs...) where {ElType,ElType2} +function Base.isapprox( + x::ConcreteRArray{ElType,(),0}, y::ConcreteRArray{ElType2,(),0}; kwargs... +) where {ElType,ElType2} return Base.isapprox(to_float(x), y; kwargs...) end function Base.print_array(io::IO, X::ConcreteRArray) if X.data == XLA.AsyncEmptyBuffer println(io, "") - return + return nothing end return Base.print_array(io, convert(Array, X)) end @@ -68,13 +69,14 @@ end function Base.show(io::IO, X::ConcreteRArray) if X.data == XLA.AsyncEmptyBuffer println(io, "") - return + return nothing end return Base.show(io, convert(Array, X)) end -@inline function Base.getindex(a::ConcreteRArray{ElType,Shape}, - args::Vararg{Int,N}) where {ElType,Shape,N} +@inline function Base.getindex( + a::ConcreteRArray{ElType,Shape}, args::Vararg{Int,N} +) where {ElType,Shape,N} if a.data == XLA.AsyncEmptyBuffer throw("Cannot getindex from empty buffer") end @@ -98,13 +100,13 @@ end return convert(Array, a)[args...] end -@inline function ConcreteRArray(data::Array{ElType,N}; client=XLA.default_backend[], - idx=XLA.default_device_idx[]) where {ElType,N} +@inline function ConcreteRArray( + data::Array{ElType,N}; client=XLA.default_backend[], idx=XLA.default_device_idx[] +) where {ElType,N} device = XLA.ClientGetDevice(client, idx) - return ConcreteRArray{ElType,size(data),N}(XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, - data, - device), - nothing)) + return ConcreteRArray{ElType,size(data),N}( + XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, data, device), nothing) + ) # ConcreteRArray{ElType, size(data), N}(XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, XLA.to_row_major(data), device), nothing)) end @@ -117,10 +119,9 @@ end mutable struct TracedRArray{ElType,Shape,N} <: RArray{ElType,Shape,N} paths::Tuple mlir_data::Union{Nothing,MLIR.IR.Value} - function TracedRArray{ElType,Shape,N}(paths::Tuple, - mlir_data::Union{Nothing,MLIR.IR.Value}) where {ElType, - Shape, - N} + function TracedRArray{ElType,Shape,N}( + paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value} + ) where {ElType,Shape,N} if mlir_data !== nothing @assert size(MLIR.IR.type(mlir_data)) == Shape end @@ -130,26 +131,25 @@ end using Enzyme -@inline function Enzyme.Compiler.active_reg_inner(::Type{TracedRArray{ElType,Shape,N}}, - seen::ST, world::Union{Nothing,UInt}, - ::Val{justActive}=Val(false), - ::Val{UnionSret}=Val(false))::Enzyme.Compiler.ActivityState where {ST, - ElType, - Shape, - N, - justActive, - UnionSret} - if Enzyme.Compiler.active_reg_inner(ElType, seen, world, Val(justActive), - Val(UnionSret)) == Enzyme.Compiler.AnyState +@inline function Enzyme.Compiler.active_reg_inner( + ::Type{TracedRArray{ElType,Shape,N}}, + seen::ST, + world::Union{Nothing,UInt}, + ::Val{justActive}=Val(false), + ::Val{UnionSret}=Val(false), +)::Enzyme.Compiler.ActivityState where {ST,ElType,Shape,N,justActive,UnionSret} + if Enzyme.Compiler.active_reg_inner( + ElType, seen, world, Val(justActive), Val(UnionSret) + ) == Enzyme.Compiler.AnyState return Enzyme.Compiler.AnyState else return Enzyme.Compiler.DupState end end -@inline function Enzyme.make_zero(::Type{RT}, seen::IdDict, prev::RT, - ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, - RT<:RArray} +@inline function Enzyme.make_zero( + ::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false) +)::RT where {copy_if_inactive,RT<:RArray} if haskey(seen, prev) return seen[prev] end @@ -169,8 +169,9 @@ end return res end -function Base.promote_rule(A::Type{TracedRArray{T,Shape,N}}, - B::Type{TracedRArray{S,Shape,N}}) where {T,S,Shape,N} +function Base.promote_rule( + A::Type{TracedRArray{T,Shape,N}}, B::Type{TracedRArray{S,Shape,N}} +) where {T,S,Shape,N} return TracedRArray{Base.promote_type(T, S),Shape,N} end @@ -197,13 +198,13 @@ using Enzyme TracedSetPath = 5 end -@inline is_concrete_tuple(x::T2) where {T2} = (x <: Tuple) && !(x === Tuple) && - !(x isa UnionAll) +@inline is_concrete_tuple(x::T2) where {T2} = + (x <: Tuple) && !(x === Tuple) && !(x isa UnionAll) @inline function traced_type(val::Type{T}, seen::ST, ::Val{mode}) where {ST,T,mode} if T <: ConcreteRArray if mode == ConcreteToTraced - @inline base_typet(TV::TT) where {TT<:UnionAll} = UnionAll(TV.var, - base_typet(TV.body)) + @inline base_typet(TV::TT) where {TT<:UnionAll} = + UnionAll(TV.var, base_typet(TV.body)) @inline base_typet(TV::TT) where {TT<:DataType} = TracedRArray{TV.parameters...} return base_typet(T) elseif mode == TracedToConcrete @@ -216,9 +217,10 @@ end if mode == ConcreteToTraced throw("TracedRArray $T cannot be traced") elseif mode == TracedToConcrete - @inline base_typec(TV::TT) where {TT<:UnionAll} = UnionAll(TV.var, - base_typec(TV.body)) - @inline base_typec(TV::TT) where {TT<:DataType} = ConcreteRArray{TV.parameters...} + @inline base_typec(TV::TT) where {TT<:UnionAll} = + UnionAll(TV.var, base_typec(TV.body)) + @inline base_typec(TV::TT) where {TT<:DataType} = + ConcreteRArray{TV.parameters...} return base_typec(T) elseif mode == TracedTrack || mode == TracedSetPath return T @@ -286,8 +288,9 @@ end if mode == ArrayToConcrete && eltype(T) <: AbstractFloat return (ConcreteRArray{eltype(T),Shape,ndims(T)} where {Shape}) else - return Array{traced_type(Enzyme.Compiler.ptreltype(T), seen, Val(mode)), - ndims(T)} + return Array{ + traced_type(Enzyme.Compiler.ptreltype(T), seen, Val(mode)),ndims(T) + } end end @@ -496,8 +499,9 @@ end return y end -@inline function make_tracer(seen::IdDict, prev::ConcreteRArray{ElType,Shape,N}, path, mode, - data) where {ElType,Shape,N} +@inline function make_tracer( + seen::IdDict, prev::ConcreteRArray{ElType,Shape,N}, path, mode, data +) where {ElType,Shape,N} if mode == ArrayToConcrete return prev end @@ -513,8 +517,9 @@ end return res end -@inline function make_tracer(seen::IdDict, prev::TracedRArray{ElType,Shape,N}, path, mode, - data) where {ElType,Shape,N} +@inline function make_tracer( + seen::IdDict, prev::TracedRArray{ElType,Shape,N}, path, mode, data +) where {ElType,Shape,N} if mode == ConcreteToTraced throw("Cannot trace existing trace type") end @@ -546,14 +551,17 @@ end throw("Cannot Unknown trace mode $mode") end -@inline function make_tracer(seen::IdDict, prev::RT, path, mode, - data) where {RT<:AbstractFloat} +@inline function make_tracer( + seen::IdDict, prev::RT, path, mode, data +) where {RT<:AbstractFloat} return prev end @inline function make_tracer(seen::IdDict, prev::Complex{RT}, path, mode, data) where {RT} - return Complex(make_tracer(seen, prev.re, append_path(path, :re), mode, data), - make_tracer(seen, prev.im, append_path(path, :im), mode, data)) + return Complex( + make_tracer(seen, prev.re, append_path(path, :re), mode, data), + make_tracer(seen, prev.im, append_path(path, :im), mode, data), + ) end @inline function make_tracer(seen::IdDict, prev::RT, path, mode, data) where {RT<:Array} @@ -585,19 +593,24 @@ end end @inline function make_tracer(seen::IdDict, prev::RT, path, mode, data) where {RT<:Tuple} - return ((make_tracer(seen, v, append_path(path, i), mode, data) for (i, v) in - enumerate(prev))...,) + return ( + ( + make_tracer(seen, v, append_path(path, i), mode, data) for + (i, v) in enumerate(prev) + )..., + ) end -@inline function make_tracer(seen::IdDict, prev::NamedTuple{A,RT}, path, mode, - data) where {A,RT} - return NamedTuple{A,traced_type(RT, (), Val(mode))}(((make_tracer(seen, - Base.getfield(prev, - name), - append_path(path, - name), - mode, data) for name in - A)...,)) +@inline function make_tracer( + seen::IdDict, prev::NamedTuple{A,RT}, path, mode, data +) where {A,RT} + return NamedTuple{A,traced_type(RT, (), Val(mode))}(( + ( + make_tracer( + seen, Base.getfield(prev, name), append_path(path, name), mode, data + ) for name in A + )..., + )) end @inline function make_tracer(seen::IdDict, prev::Core.Box, path, mode, data) @@ -615,8 +628,9 @@ end return res end -function generate_jlfunc(concrete_result, client, mod, Nargs, linear_args, linear_results, - preserved_args) +function generate_jlfunc( + concrete_result, client, mod, Nargs, linear_args, linear_results, preserved_args +) args = ntuple(Val(Nargs)) do i Base.@_inline_meta return Symbol("arg_$i") @@ -731,9 +745,12 @@ function generate_jlfunc(concrete_result, client, mod, Nargs, linear_args, linea quote $(arg_syncs...) GC.@preserve $(topres...) begin - linearized_results = XLA.ExecutableCall($exec, ($(linearized_args...),), - $donated_args_set, - Val($(length(linear_results)))) + linearized_results = XLA.ExecutableCall( + $exec, + ($(linearized_args...),), + $donated_args_set, + Val($(length(linear_results))), + ) end end end @@ -742,7 +759,7 @@ function generate_jlfunc(concrete_result, client, mod, Nargs, linear_args, linea function create_result(tocopy::T, resname::Symbol, path) where {T} if T <: ConcreteRArray push!(concrete_result_maker, :($resname = $T($(result_stores[path])))) - return + return nothing end if T <: Tuple elems = Symbol[] @@ -751,10 +768,13 @@ function generate_jlfunc(concrete_result, client, mod, Nargs, linear_args, linea create_result(v, sym, (path..., i)) push!(elems, sym) end - push!(concrete_result_maker, quote - $resname = ($(elems...),) - end) - return + push!( + concrete_result_maker, + quote + $resname = ($(elems...),) + end, + ) + return nothing end if T <: Array elems = Symbol[] @@ -763,18 +783,21 @@ function generate_jlfunc(concrete_result, client, mod, Nargs, linear_args, linea create_result(v, sym, (path..., i)) push!(elems, sym) end - push!(concrete_result_maker, quote - $resname = $(eltype(T))[$(elems...)] - end) - return + push!( + concrete_result_maker, + quote + $resname = $(eltype(T))[$(elems...)] + end, + ) + return nothing end if T <: Int || T <: AbstractFloat || T <: AbstractString || T <: Nothing push!(concrete_result_maker, :($resname = $tocopy)) - return + return nothing end if T <: Symbol push!(concrete_result_maker, :($resname = $(QuoteNode(tocopy)))) - return + return nothing end if isstructtype(T) elems = Symbol[] @@ -784,13 +807,16 @@ function generate_jlfunc(concrete_result, client, mod, Nargs, linear_args, linea create_result(getfield(tocopy, i), sym, (path..., i)) push!(elems, sym) end - push!(concrete_result_maker, - quote - flds = Any[$(elems...)] - $resname = ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), $T, - flds, $nf) - end) - return + push!( + concrete_result_maker, + quote + flds = Any[$(elems...)] + $resname = ccall( + :jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), $T, flds, $nf + ) + end, + ) + return nothing end return error("canot copy $T") @@ -813,7 +839,9 @@ end const registry = Ref{MLIR.IR.DialectRegistry}() function __init__() registry[] = MLIR.IR.DialectRegistry() - @ccall MLIR.API.mlir_c.InitializeRegistryAndPasses(registry[]::MLIR.API.MlirDialectRegistry)::Cvoid + @ccall MLIR.API.mlir_c.InitializeRegistryAndPasses( + registry[]::MLIR.API.MlirDialectRegistry + )::Cvoid end const opt_passes = """ @@ -976,8 +1004,9 @@ pad_dot_general<1>(1); enzyme-hlo-remove-transform """ -function compile(f::FTy, args::VAT; pipeline_options="", - client=nothing) where {FTy,VAT<:Tuple} +function compile( + f::FTy, args::VAT; pipeline_options="", client=nothing +) where {FTy,VAT<:Tuple} N = length(args) ctx = MLIR.IR.Context() Base.append!(registry[]; context=ctx) @@ -985,18 +1014,16 @@ function compile(f::FTy, args::VAT; pipeline_options="", MLIR.IR.context!(ctx) do mod = MLIR.IR.Module(MLIR.IR.Location()) MLIR.IR.mmodule!(mod) do - fnwrapped, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn(mod, - f, - args, - (), - "main", - true) + fnwrapped, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn( + mod, f, args, (), "main", true + ) @assert !fnwrapped concrete_seen = IdDict() - concrete_result = make_tracer(concrete_seen, traced_result, ("result",), - TracedToConcrete, nothing) #=data=# + concrete_result = make_tracer( + concrete_seen, traced_result, ("result",), TracedToConcrete, nothing + ) #=data=# if client === nothing if length(linear_args) > 0 @@ -1012,9 +1039,12 @@ function compile(f::FTy, args::VAT; pipeline_options="", end end - XLA.RunPassPipeline(opt_passes * - ",enzyme,arith-raise{stablehlo=true},canonicalize, remove-unnecessary-enzyme-ops, enzyme-simplify-math," * - opt_passes, mod) + XLA.RunPassPipeline( + opt_passes * + ",enzyme,arith-raise{stablehlo=true},canonicalize, remove-unnecessary-enzyme-ops, enzyme-simplify-math," * + opt_passes, + mod, + ) preserved_args = Tuple{TracedRArray,Int}[] results = [MLIR.IR.operand(ret, i) for i in 1:MLIR.IR.noperands(ret)] @@ -1037,10 +1067,11 @@ function compile(f::FTy, args::VAT; pipeline_options="", out_tys2 = [MLIR.IR.type(a) for a in nresults] - func3 = MLIR.Dialects.func.func_(; sym_name="main", - function_type=MLIR.IR.FunctionType(in_tys, - out_tys2), - body=MLIR.IR.Region()) + func3 = MLIR.Dialects.func.func_(; + sym_name="main", + function_type=MLIR.IR.FunctionType(in_tys, out_tys2), + body=MLIR.IR.Region(), + ) MLIR.API.mlirRegionTakeBody(MLIR.IR.region(func3, 1), MLIR.IR.region(func2, 1)) push!(MLIR.IR.body(mod), func3) @@ -1050,8 +1081,15 @@ function compile(f::FTy, args::VAT; pipeline_options="", # println(string(mod)) - return generate_jlfunc(concrete_result, client, mod, N, linear_args, - linear_results2, preserved_args) + return generate_jlfunc( + concrete_result, + client, + mod, + N, + linear_args, + linear_results2, + preserved_args, + ) end end end diff --git a/src/XLA.jl b/src/XLA.jl index f721e7f4a..9f556bc8f 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -4,8 +4,9 @@ import ...MLIR function RunPassPipeline(pass_pipeline, mod::MLIR.IR.Module) GC.@preserve pass_pipeline mod begin - @ccall MLIR.API.mlir_c.RunPassPipeline(pass_pipeline::Cstring, - mod.module_::MLIR.API.MlirModule)::Cvoid + @ccall MLIR.API.mlir_c.RunPassPipeline( + pass_pipeline::Cstring, mod.module_::MLIR.API.MlirModule + )::Cvoid end end mutable struct Client @@ -61,8 +62,17 @@ function GPUClient(node_id=0, num_nodes=1, platform="gpu") # GC.@preserve allowed_devices begin f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeGPUClient") refstr = Ref{Cstring}() - client = ccall(f, Ptr{Cvoid}, (Cint, Cint, Ptr{Cvoid}, Cint, Cstring, Ptr{Cstring}), - node_id, num_nodes, C_NULL, 0, platform, refstr) + client = ccall( + f, + Ptr{Cvoid}, + (Cint, Cint, Ptr{Cvoid}, Cint, Cstring, Ptr{Cstring}), + node_id, + num_nodes, + C_NULL, + 0, + platform, + refstr, + ) if client == C_NULL throw(AssertionError(refstr[])) end @@ -142,12 +152,16 @@ end function device(buffer::Buffer) GC.@preserve buffer begin - return Device(@ccall MLIR.API.mlir_c.BufferToDevice(buffer.buffer::Ptr{Cvoid})::Ptr{Cvoid}) + return Device( + @ccall MLIR.API.mlir_c.BufferToDevice(buffer.buffer::Ptr{Cvoid})::Ptr{Cvoid} + ) end end function client(buffer::Buffer) GC.@preserve buffer begin - return Client(@ccall MLIR.API.mlir_c.BufferToClient(buffer.buffer::Ptr{Cvoid})::Ptr{Cvoid}) + return Client( + @ccall MLIR.API.mlir_c.BufferToClient(buffer.buffer::Ptr{Cvoid})::Ptr{Cvoid} + ) end end function device(buffer::AsyncBuffer) @@ -158,7 +172,9 @@ function client(buffer::AsyncBuffer) end function client(device::Device) GC.@preserve device begin - return Client(@ccall MLIR.API.mlir_c.DeviceToClient(device.device::Ptr{Cvoid})::Ptr{Cvoid}) + return Client( + @ccall MLIR.API.mlir_c.DeviceToClient(device.device::Ptr{Cvoid})::Ptr{Cvoid} + ) end end @@ -167,11 +183,14 @@ function ArrayFromHostBuffer(client::Client, array::Array{T,N}, device) where {T dtype = MLIR.IR.Type(T) sizear = Int64[s for s in reverse(size(array))] GC.@preserve array sizear begin - @ccall MLIR.API.mlir_c.ArrayFromHostBuffer(client.client::Ptr{Cvoid}, - pointer(array)::Ptr{T}, - dtype::MLIR.API.MlirType, N::Csize_t, - pointer(sizear)::Ptr{Int64}, - device.device::Ptr{Cvoid})::Ptr{Cvoid} + @ccall MLIR.API.mlir_c.ArrayFromHostBuffer( + client.client::Ptr{Cvoid}, + pointer(array)::Ptr{T}, + dtype::MLIR.API.MlirType, + N::Csize_t, + pointer(sizear)::Ptr{Int64}, + device.device::Ptr{Cvoid}, + )::Ptr{Cvoid} end end return Buffer(buffer) @@ -179,8 +198,9 @@ end function BufferToHost(buffer::Buffer, data) GC.@preserve buffer begin - @ccall MLIR.API.mlir_c.BufferToHost(buffer.buffer::Ptr{Cvoid}, - data::Ptr{Cvoid})::Cvoid + @ccall MLIR.API.mlir_c.BufferToHost( + buffer.buffer::Ptr{Cvoid}, data::Ptr{Cvoid} + )::Cvoid end end @@ -191,8 +211,11 @@ end function CopyBufferToDevice(buffer::Buffer, device::Device) GC.@preserve buffer device begin - Buffer(@ccall MLIR.API.mlir_c.CopyBufferToDevice(buffer.buffer::Ptr{Cvoid}, - device.device::Ptr{Cvoid})::Ptr{Cvoid}) + Buffer( + @ccall MLIR.API.mlir_c.CopyBufferToDevice( + buffer.buffer::Ptr{Cvoid}, device.device::Ptr{Cvoid} + )::Ptr{Cvoid} + ) end end @@ -227,37 +250,45 @@ entry: return res end -@generated function ExecutableCall(exec::LoadedExecutable, inputs::NTuple{N,Ptr{Cvoid}}, - donated_args::NTuple{N,UInt8}, - ::Val{n_outs}) where {N,n_outs} +@generated function ExecutableCall( + exec::LoadedExecutable, + inputs::NTuple{N,Ptr{Cvoid}}, + donated_args::NTuple{N,UInt8}, + ::Val{n_outs}, +) where {N,n_outs} sym0 = dlsym(Reactant_jll.libReactantExtra_handle, "XLAExecute") xla_execute_fn = reinterpret(UInt, sym0) ir = execute_ir(N, n_outs, xla_execute_fn) results = [] for i in 1:n_outs - push!(results, - :(AsyncBuffer(Buffer(outputs[$i]), future ? Future(future_res[$i]) : nothing))) + push!( + results, + :(AsyncBuffer(Buffer(outputs[$i]), future ? Future(future_res[$i]) : nothing)), + ) end return quote Base.@_inline_meta exec = exec.exec GC.@preserve exec begin - outputs, future_res, future = Base.llvmcall(($ir, "f"), - Tuple{NTuple{n_outs,Ptr{Cvoid}}, - NTuple{n_outs,Ptr{Cvoid}}, - Bool}, - Tuple{Ptr{Cvoid}, - NTuple{N,Ptr{Cvoid}}, - NTuple{N,UInt8}}, - exec, inputs, donated_args) + outputs, future_res, future = Base.llvmcall( + ($ir, "f"), + Tuple{NTuple{n_outs,Ptr{Cvoid}},NTuple{n_outs,Ptr{Cvoid}},Bool}, + Tuple{Ptr{Cvoid},NTuple{N,Ptr{Cvoid}},NTuple{N,UInt8}}, + exec, + inputs, + donated_args, + ) end return ($(results...),) end end -@inline function ExecutableCall0(exec::LoadedExecutable, inputs::NTuple{N,Ptr{Cvoid}}, - donated_args::NTuple{N,UInt8}, - ::Val{n_outs}) where {N,n_outs} +@inline function ExecutableCall0( + exec::LoadedExecutable, + inputs::NTuple{N,Ptr{Cvoid}}, + donated_args::NTuple{N,UInt8}, + ::Val{n_outs}, +) where {N,n_outs} outputs = Ref{NTuple{n_outs,Ptr{Cvoid}}}() future_res = Ref{NTuple{n_outs,Ptr{Cvoid}}}() futures = Ref{UInt8}(0) @@ -265,15 +296,16 @@ end inputs = Base.RefValue(inputs) donated_args = Base.RefValue(donated_args) GC.@preserve inputs donated_args outputs futures future_res begin - @ccall MLIR.API.mlir_c.XLAExecute(exec.exec::Ptr{Cvoid}, N::Cint, - inputs::Ptr{Cvoid}, donated_args::Ptr{UInt8}, - n_outs::Cint, - Base.unsafe_convert(Ptr{Cvoid}, - outputs)::Ptr{Cvoid}, - Base.unsafe_convert(Ptr{UInt8}, - futures)::Ptr{UInt8}, - Base.unsafe_convert(Ptr{Cvoid}, - future_res)::Ptr{Cvoid})::Cvoid + @ccall MLIR.API.mlir_c.XLAExecute( + exec.exec::Ptr{Cvoid}, + N::Cint, + inputs::Ptr{Cvoid}, + donated_args::Ptr{UInt8}, + n_outs::Cint, + Base.unsafe_convert(Ptr{Cvoid}, outputs)::Ptr{Cvoid}, + Base.unsafe_convert(Ptr{UInt8}, futures)::Ptr{UInt8}, + Base.unsafe_convert(Ptr{Cvoid}, future_res)::Ptr{Cvoid}, + )::Cvoid end outputs = outputs[] @@ -288,8 +320,11 @@ end function Compile(client::Client, mod::MLIR.IR.Module) GC.@preserve client mod begin - executable = LoadedExecutable(@ccall MLIR.API.mlir_c.ClientCompile(client.client::Ptr{Cvoid}, - mod.module_::MLIR.API.MlirModule)::Ptr{Cvoid}) + executable = LoadedExecutable( + @ccall MLIR.API.mlir_c.ClientCompile( + client.client::Ptr{Cvoid}, mod.module_::MLIR.API.MlirModule + )::Ptr{Cvoid} + ) end end @@ -301,7 +336,9 @@ end function ClientNumAddressableDevices(client::Client) GC.@preserve client begin - return @ccall MLIR.API.mlir_c.ClientNumAddressableDevices(client.client::Ptr{Cvoid})::Cint + return @ccall MLIR.API.mlir_c.ClientNumAddressableDevices( + client.client::Ptr{Cvoid} + )::Cint end end @@ -313,15 +350,21 @@ end function ClientGetDevice(client::Client, idx) GC.@preserve client begin - return Device(@ccall MLIR.API.mlir_c.ClientGetDevice(client.client::Ptr{Cvoid}, - idx::Cint)::Ptr{Cvoid}) + return Device( + @ccall MLIR.API.mlir_c.ClientGetDevice( + client.client::Ptr{Cvoid}, idx::Cint + )::Ptr{Cvoid} + ) end end function ClientGetAddressableDevice(client::Client, idx) GC.@preserve client begin - return Device(@ccall MLIR.API.mlir_c.ClientGetAddressableDevice(client.client::Ptr{Cvoid}, - idx::Cint)::Ptr{Cvoid}) + return Device( + @ccall MLIR.API.mlir_c.ClientGetAddressableDevice( + client.client::Ptr{Cvoid}, idx::Cint + )::Ptr{Cvoid} + ) end end diff --git a/src/mlir/IR/AffineExpr.jl b/src/mlir/IR/AffineExpr.jl index db36eb977..7a0f4931d 100644 --- a/src/mlir/IR/AffineExpr.jl +++ b/src/mlir/IR/AffineExpr.jl @@ -56,8 +56,8 @@ ismultipleof(expr::AffineExpr, factor) = API.mlirAffineExprIsMultipleOf(expr, fa Checks whether the given affine expression involves AffineDimExpr 'position'. """ -isfunctionofdimexpr(expr::AffineExpr, position) = API.mlirAffineExprIsFunctionOfDim(expr, - position) +isfunctionofdimexpr(expr::AffineExpr, position) = + API.mlirAffineExprIsFunctionOfDim(expr, position) """ isdimexpr(affineExpr) @@ -71,8 +71,8 @@ isdimexpr(expr::AffineExpr) = API.mlirAffineExprIsADim(expr) Creates an affine dimension expression with 'position' in the context. """ -AffineDimensionExpr(position; context::Context=context()) = AffineExpr(API.mlirAffineDimExprGet(context, - position)) +AffineDimensionExpr(position; context::Context=context()) = + AffineExpr(API.mlirAffineDimExprGet(context, position)) """ issymbolexpr(affineExpr) @@ -86,8 +86,8 @@ issymbolexpr(expr::AffineExpr) = API.mlirAffineExprIsASymbol(expr) Creates an affine symbol expression with 'position' in the context. """ -SymbolExpr(position; context::Context=context()) = AffineExpr(API.mlirAffineSymbolExprGet(context, - position)) +SymbolExpr(position; context::Context=context()) = + AffineExpr(API.mlirAffineSymbolExprGet(context, position)) """ position(affineExpr) @@ -100,7 +100,11 @@ function position(expr::AffineExpr) elseif issymbolexpr(expr) API.mlirAffineSymbolExprGetPosition(expr) else - throw(ArgumentError("The given affine expression is not a affine dimension expression or affine symbol expression")) + throw( + ArgumentError( + "The given affine expression is not a affine dimension expression or affine symbol expression", + ), + ) end end @@ -116,8 +120,8 @@ isconstantexpr(expr::AffineExpr) = API.mlirAffineExprIsAConstant(expr) Creates an affine constant expression with 'constant' in the context. """ -ConstantExpr(constant; context::Context=context()) = AffineExpr(API.mlirAffineConstantExprGet(context, - constant)) +ConstantExpr(constant; context::Context=context()) = + AffineExpr(API.mlirAffineConstantExprGet(context, constant)) """ value(affineExpr) @@ -185,8 +189,8 @@ isfloordiv(expr::AffineExpr) = API.mlirAffineExprIsAFloorDiv(expr) Creates an affine floordiv expression with 'lhs' and 'rhs'. """ -Base.div(lhs::AffineExpr, rhs::AffineExpr) = AffineExpr(API.mlirAffineFloorDivExprGet(lhs, - rhs)) +Base.div(lhs::AffineExpr, rhs::AffineExpr) = + AffineExpr(API.mlirAffineFloorDivExprGet(lhs, rhs)) Base.fld(lhs::AffineExpr, rhs::AffineExpr) = div(lhs, rhs) """ @@ -201,8 +205,8 @@ isceildiv(expr::AffineExpr) = API.mlirAffineExprIsACeilDiv(expr) Creates an affine ceildiv expression with 'lhs' and 'rhs'. """ -Base.cld(lhs::AffineExpr, rhs::AffineExpr) = AffineExpr(API.mlirAffineCeilDivExprGet(lhs, - rhs)) +Base.cld(lhs::AffineExpr, rhs::AffineExpr) = + AffineExpr(API.mlirAffineCeilDivExprGet(lhs, rhs)) """ isbinary(affineExpr) diff --git a/src/mlir/IR/AffineMap.jl b/src/mlir/IR/AffineMap.jl index 5a4b31d58..a70e6f722 100644 --- a/src/mlir/IR/AffineMap.jl +++ b/src/mlir/IR/AffineMap.jl @@ -44,9 +44,8 @@ context(map::AffineMap) = API.mlirAffineMapGetContext(map) Creates a zero result affine map of the given dimensions and symbols in the context. The affine map is owned by the context. """ -AffineMap(ndims, nsymbols; context::Context=context()) = AffineMap(API.mlirAffineMapZeroResultGet(context, - ndims, - nsymbols)) +AffineMap(ndims, nsymbols; context::Context=context()) = + AffineMap(API.mlirAffineMapZeroResultGet(context, ndims, nsymbols)) """ AffineMap(ndims, nsymbols, affineExprs; context=context()) @@ -54,27 +53,24 @@ AffineMap(ndims, nsymbols; context::Context=context()) = AffineMap(API.mlirAffin Creates an affine map with results defined by the given list of affine expressions. The map resulting map also has the requested number of input dimensions and symbols, regardless of them being used in the results. """ -AffineMap(ndims, nsymbols, exprs::Vector{AffineExpr}; context::Context=context()) = AffineMap(API.mlirAffineMapGet(context, - ndims, - nsymbols, - length(exprs), - pointer(exprs))) +AffineMap(ndims, nsymbols, exprs::Vector{AffineExpr}; context::Context=context()) = + AffineMap(API.mlirAffineMapGet(context, ndims, nsymbols, length(exprs), pointer(exprs))) """ ConstantAffineMap(val; context=context()) Creates a single constant result affine map in the context. The affine map is owned by the context. """ -ConstantAffineMap(val; context::Context=context()) = AffineMap(API.mlirAffineMapConstantGet(context, - val)) +ConstantAffineMap(val; context::Context=context()) = + AffineMap(API.mlirAffineMapConstantGet(context, val)) """ IdentityAffineMap(ndims; context=context()) Creates an affine map with 'ndims' identity in the context. The affine map is owned by the context. """ -IdentityAffineMap(ndims; context::Context=context()) = AffineMap(API.mlirAffineMapMultiDimIdentityGet(context, - ndims)) +IdentityAffineMap(ndims; context::Context=context()) = + AffineMap(API.mlirAffineMapMultiDimIdentityGet(context, ndims)) """ MinorIdentityAffineMap(ndims, nresults; context=context()) @@ -98,8 +94,9 @@ The affine map is owned by the context. function PermutationAffineMap(permutation; context::Context=context()) @assert Base.isperm(permutation) "$permutation must be a valid permutation" zero_perm = permutation .- 1 - return AffineMap(API.mlirAffineMapPermutationGet(context, length(zero_perm), - pointer(zero_perm))) + return AffineMap( + API.mlirAffineMapPermutationGet(context, length(zero_perm), pointer(zero_perm)) + ) end """ @@ -194,9 +191,8 @@ Base.isperm(map::AffineMap) = API.mlirAffineMapIsPermutation(map) Returns the affine map consisting of the `positions` subset. """ -submap(map::AffineMap, pos::Vector{Int}) = AffineMap(API.mlirAffineMapGetSubMap(map, - length(pos), - pointer(pos))) +submap(map::AffineMap, pos::Vector{Int}) = + AffineMap(API.mlirAffineMapGetSubMap(map, length(pos), pointer(pos))) """ majorsubmap(affineMap, nresults) @@ -205,8 +201,8 @@ Returns the affine map consisting of the most major `nresults` results. Returns the null AffineMap if the `nresults` is equal to zero. Returns the `affineMap` if `nresults` is greater or equals to number of results of the given affine map. """ -majorsubmap(map::AffineMap, nresults) = AffineMap(API.mlirAffineMapGetMajorSubMap(map, - nresults)) +majorsubmap(map::AffineMap, nresults) = + AffineMap(API.mlirAffineMapGetMajorSubMap(map, nresults)) """ minorsubmap(affineMap, nresults) @@ -214,19 +210,19 @@ majorsubmap(map::AffineMap, nresults) = AffineMap(API.mlirAffineMapGetMajorSubMa Returns the affine map consisting of the most minor `nresults` results. Returns the null AffineMap if the `nresults` is equal to zero. Returns the `affineMap` if `nresults` is greater or equals to number of results of the given affine map. """ -minorsubmap(map::AffineMap, nresults) = AffineMap(API.mlirAffineMapGetMinorSubMap(map, - nresults)) +minorsubmap(map::AffineMap, nresults) = + AffineMap(API.mlirAffineMapGetMinorSubMap(map, nresults)) """ mlirAffineMapReplace(affineMap, expression => replacement, numResultDims, numResultSyms) Apply `AffineExpr::replace(map)` to each of the results and return a new new AffineMap with the new results and the specified number of dims and symbols. """ -Base.replace(map::AffineMap, old_new::Pair{AffineExpr,AffineExpr}, nresultdims, nresultsyms) = AffineMap(API.mlirAffineMapReplace(map, - old_new.first, - old_new.second, - nresultdims, - nresultsyms)) +Base.replace( + map::AffineMap, old_new::Pair{AffineExpr,AffineExpr}, nresultdims, nresultsyms +) = AffineMap( + API.mlirAffineMapReplace(map, old_new.first, old_new.second, nresultdims, nresultsyms), +) """ simplify(affineMaps, size, result, populateResult) @@ -300,24 +296,25 @@ macro affinemap(expr) known_binops = [:+, :-, :*, :÷, :%, :fld, :cld] - affine_exprs = Expr(:vect, - map(rhs.args) do ex - walk(ex) do v - if v isa Integer - Expr(:call, ConstantExpr, Int64(v)) - elseif Meta.isexpr(v, :call) - v - elseif v isa Symbol - if v in dims || v in syms || v in known_binops - v - else - error("unknown item $v") - end - else - v - end - end - end...) + affine_exprs = Expr( + :vect, map(rhs.args) do ex + walk(ex) do v + if v isa Integer + Expr(:call, ConstantExpr, Int64(v)) + elseif Meta.isexpr(v, :call) + v + elseif v isa Symbol + if v in dims || v in syms || v in known_binops + v + else + error("unknown item $v") + end + else + v + end + end + end... + ) quote $(dimexprs...) diff --git a/src/mlir/IR/Attribute.jl b/src/mlir/IR/Attribute.jl index b9d86a656..f5ac3d664 100644 --- a/src/mlir/IR/Attribute.jl +++ b/src/mlir/IR/Attribute.jl @@ -16,8 +16,8 @@ Base.convert(::Core.Type{API.MlirAttribute}, attribute::Attribute) = attribute.a Parses an attribute. The attribute is owned by the context. """ -Base.parse(::Core.Type{Attribute}, str; context::Context=context()) = Attribute(API.mlirAttributeParseGet(context, - str)) +Base.parse(::Core.Type{Attribute}, str; context::Context=context()) = + Attribute(API.mlirAttributeParseGet(context, str)) """ ==(a1, a2) @@ -80,9 +80,8 @@ isarray(attr::Attribute) = API.mlirAttributeIsAArray(attr) Creates an array element containing the given list of elements in the given context. """ -Attribute(attrs::Vector{Attribute}; context::Context=context()) = Attribute(API.mlirArrayAttrGet(context, - length(attrs), - pointer(attrs))) +Attribute(attrs::Vector{Attribute}; context::Context=context()) = + Attribute(API.mlirArrayAttrGet(context, length(attrs), pointer(attrs))) """ isdict(attr) @@ -114,8 +113,9 @@ isfloat(attr::Attribute) = API.mlirAttributeIsAFloat(attr) Creates a floating point attribute in the given context with the given double value and double-precision FP semantics. If `check=true`, emits appropriate diagnostics on illegal arguments. """ -function Attribute(f::T; context::Context=context(), location::Location=Location(), - check::Bool=false) where {T<:AbstractFloat} +function Attribute( + f::T; context::Context=context(), location::Location=Location(), check::Bool=false +) where {T<:AbstractFloat} if check Attribute(API.mlirFloatAttrDoubleGetChecked(location, Type(T), Float64(f))) else @@ -145,8 +145,8 @@ isinteger(attr::Attribute) = API.mlirAttributeIsAInteger(attr) Creates an integer attribute of the given type with the given integer value. """ -Attribute(i::T, type=Type(T)) where {T<:Integer} = Attribute(API.mlirIntegerAttrGet(type, - Int64(i))) +Attribute(i::T, type=Type(T)) where {T<:Integer} = + Attribute(API.mlirIntegerAttrGet(type, Int64(i))) """ Int64(attr) @@ -214,11 +214,8 @@ isopaque(attr::Attribute) = API.mlirAttributeIsAOpaque(attr) Creates an opaque attribute in the given context associated with the dialect identified by its namespace. The attribute contains opaque byte data of the specified length (data need not be null-terminated). """ -OpaqueAttribute(namespace, data, type; context::Context=context) = Attribute(API.mlirOpaqueAttrGet(context, - namespace, - length(data), - data, - type)) +OpaqueAttribute(namespace, data, type; context::Context=context) = + Attribute(API.mlirOpaqueAttrGet(context, namespace, length(data), data, type)) """ mlirOpaqueAttrGetDialectNamespace(attr) @@ -252,8 +249,8 @@ isstring(attr::Attribute) = API.mlirAttributeIsAString(attr) Creates a string attribute in the given context containing the given string. """ -Attribute(str::AbstractString; context::Context=context()) = Attribute(API.mlirStringAttrGet(context, - str)) +Attribute(str::AbstractString; context::Context=context()) = + Attribute(API.mlirStringAttrGet(context, str)) """ Attribute(type, str) @@ -287,10 +284,11 @@ issymbolref(attr::Attribute) = API.mlirAttributeIsASymbolRef(attr) Creates a symbol reference attribute in the given context referencing a symbol identified by the given string inside a list of nested references. Each of the references in the list must not be nested. """ -SymbolRefAttribute(symbol::String, references::Vector{Attribute}; context::Context=context()) = Attribute(API.mlirSymbolRefAttrGet(context, - symbol, - length(references), - pointer(references))) +SymbolRefAttribute( + symbol::String, references::Vector{Attribute}; context::Context=context() +) = Attribute( + API.mlirSymbolRefAttrGet(context, symbol, length(references), pointer(references)) +) """ rootref(attr) @@ -334,8 +332,8 @@ isflatsymbolref(attr::Attribute) = API.mlirAttributeIsAFlatSymbolRef(attr) Creates a flat symbol reference attribute in the given context referencing a symbol identified by the given string. """ -FlatSymbolRefAttribute(symbol::String; context::Context=context()) = Attribute(API.mlirFlatSymbolRefAttrGet(context, - symbol)) +FlatSymbolRefAttribute(symbol::String; context::Context=context()) = + Attribute(API.mlirFlatSymbolRefAttrGet(context, symbol)) """ flatsymbol(attr) @@ -408,8 +406,9 @@ Creates a dense elements attribute with the given Shaped type and elements in th """ function DenseElementsAttribute(shaped_type::Type, elements::AbstractArray) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return Attribute(API.mlirDenseElementsAttrGet(shaped_type, length(elements), - pointer(elements))) + return Attribute( + API.mlirDenseElementsAttrGet(shaped_type, length(elements), pointer(elements)) + ) end # TODO mlirDenseElementsAttrRawBufferGet @@ -481,76 +480,88 @@ Creates a dense elements attribute with the given shaped type from elements of a """ function DenseElementsAttribute(values::AbstractVector{Bool}) shaped_type = TensorType(size(values), Type(Bool)) - return Attribute(API.mlirDenseElementsAttrBoolGet(shaped_type, length(values), - pointer(values))) + return Attribute( + API.mlirDenseElementsAttrBoolGet(shaped_type, length(values), pointer(values)) + ) end function DenseElementsAttribute(values::AbstractArray{UInt8}) shaped_type = TensorType(size(values), Type(UInt8)) - return Attribute(API.mlirDenseElementsAttrUInt8Get(shaped_type, length(values), - pointer(values))) + return Attribute( + API.mlirDenseElementsAttrUInt8Get(shaped_type, length(values), pointer(values)) + ) end function DenseElementsAttribute(values::AbstractArray{Int8}) shaped_type = TensorType(size(values), Type(Int8)) - return Attribute(API.mlirDenseElementsAttrInt8Get(shaped_type, length(values), - pointer(values))) + return Attribute( + API.mlirDenseElementsAttrInt8Get(shaped_type, length(values), pointer(values)) + ) end function DenseElementsAttribute(values::AbstractArray{UInt16}) shaped_type = TensorType(size(values), Type(UInt16)) - return Attribute(API.mlirDenseElementsAttrUInt16Get(shaped_type, length(values), - pointer(values))) + return Attribute( + API.mlirDenseElementsAttrUInt16Get(shaped_type, length(values), pointer(values)) + ) end function DenseElementsAttribute(values::AbstractArray{Int16}) shaped_type = TensorType(size(values), Type(Int16)) - return Attribute(API.mlirDenseElementsAttrInt16Get(shaped_type, length(values), - pointer(values))) + return Attribute( + API.mlirDenseElementsAttrInt16Get(shaped_type, length(values), pointer(values)) + ) end function DenseElementsAttribute(values::AbstractArray{UInt32}) shaped_type = TensorType(size(values), Type(UInt32)) - return Attribute(API.mlirDenseElementsAttrUInt32Get(shaped_type, length(values), - pointer(values))) + return Attribute( + API.mlirDenseElementsAttrUInt32Get(shaped_type, length(values), pointer(values)) + ) end function DenseElementsAttribute(values::AbstractArray{Int32}) shaped_type = TensorType(size(values), Type(Int32)) - return Attribute(API.mlirDenseElementsAttrInt32Get(shaped_type, length(values), - pointer(values))) + return Attribute( + API.mlirDenseElementsAttrInt32Get(shaped_type, length(values), pointer(values)) + ) end function DenseElementsAttribute(values::AbstractArray{UInt64}) shaped_type = TensorType(size(values), Type(UInt64)) - return Attribute(API.mlirDenseElementsAttrUInt64Get(shaped_type, length(values), - pointer(values))) + return Attribute( + API.mlirDenseElementsAttrUInt64Get(shaped_type, length(values), pointer(values)) + ) end function DenseElementsAttribute(values::AbstractArray{Int64}) shaped_type = TensorType(size(values), Type(Int64)) - return Attribute(API.mlirDenseElementsAttrInt64Get(shaped_type, length(values), - pointer(values))) + return Attribute( + API.mlirDenseElementsAttrInt64Get(shaped_type, length(values), pointer(values)) + ) end function DenseElementsAttribute(values::AbstractArray{Float32}) shaped_type = TensorType(size(values), Type(Float32)) - return Attribute(API.mlirDenseElementsAttrFloatGet(shaped_type, length(values), - pointer(values))) + return Attribute( + API.mlirDenseElementsAttrFloatGet(shaped_type, length(values), pointer(values)) + ) end function DenseElementsAttribute(values::AbstractArray{Float64}) shaped_type = TensorType(size(values), Type(Float64)) - return Attribute(API.mlirDenseElementsAttrDoubleGet(shaped_type, length(values), - pointer(values))) + return Attribute( + API.mlirDenseElementsAttrDoubleGet(shaped_type, length(values), pointer(values)) + ) end # TODO mlirDenseElementsAttrBFloat16Get function DenseElementsAttribute(values::AbstractArray{Float16}) shaped_type = TensorType(size(values), Type(Float16)) - return Attribute(API.mlirDenseElementsAttrFloat16Get(shaped_type, length(values), - pointer(values))) + return Attribute( + API.mlirDenseElementsAttrFloat16Get(shaped_type, length(values), pointer(values)) + ) end """ @@ -561,8 +572,9 @@ Creates a dense elements attribute with the given shaped type from string elemen function DenseElementsAttribute(values::AbstractArray{String}) # TODO may fail because `Type(String)` is not defined shaped_type = TensorType(size(values), Type(String)) - return Attribute(API.mlirDenseElementsAttrStringGet(shaped_type, length(values), - pointer(values))) + return Attribute( + API.mlirDenseElementsAttrStringGet(shaped_type, length(values), pointer(values)) + ) end """ @@ -608,13 +620,20 @@ issparseelements(attr::Attribute) = API.mlirAttributeIsASparseElements(attr) """ function isdensearray end -@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Bool}) = API.mlirAttributeIsADenseBoolArray(attr) -@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Int8}) = API.mlirAttributeIsADenseI8Array(attr) -@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Int16}) = API.mlirAttributeIsADenseI16Array(attr) -@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Int32}) = API.mlirAttributeIsADenseI32Array(attr) -@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Int64}) = API.mlirAttributeIsADenseI64Array(attr) -@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Float32}) = API.mlirAttributeIsADenseF32Array(attr) -@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Float64}) = API.mlirAttributeIsADenseF64Array(attr) +@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Bool}) = + API.mlirAttributeIsADenseBoolArray(attr) +@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Int8}) = + API.mlirAttributeIsADenseI8Array(attr) +@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Int16}) = + API.mlirAttributeIsADenseI16Array(attr) +@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Int32}) = + API.mlirAttributeIsADenseI32Array(attr) +@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Int64}) = + API.mlirAttributeIsADenseI64Array(attr) +@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Float32}) = + API.mlirAttributeIsADenseF32Array(attr) +@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Float64}) = + API.mlirAttributeIsADenseF64Array(attr) @llvmversioned min = v"16" """ DenseArrayAttribute(array; context=context()) @@ -623,27 +642,27 @@ function isdensearray end """ function DenseArrayAttribute end -@llvmversioned min = v"16" DenseArrayAttribute(values::AbstractArray{Bool}; context::Context=context()) = Attribute(API.mlirDenseBoolArrayGet(context, - length(values), - pointer(values))) -@llvmversioned min = v"16" DenseArrayAttribute(values::AbstractArray{Int8}; context::Context=context()) = Attribute(API.mlirDenseI8ArrayGet(context, - length(values), - pointer(values))) -@llvmversioned min = v"16" DenseArrayAttribute(values::AbstractArray{Int16}; context::Context=context()) = Attribute(API.mlirDenseI16ArrayGet(context, - length(values), - pointer(values))) -@llvmversioned min = v"16" DenseArrayAttribute(values::AbstractArray{Int32}; context::Context=context()) = Attribute(API.mlirDenseI32ArrayGet(context, - length(values), - pointer(values))) -@llvmversioned min = v"16" DenseArrayAttribute(values::AbstractArray{Int64}; context::Context=context()) = Attribute(API.mlirDenseI64ArrayGet(context, - length(values), - pointer(values))) -@llvmversioned min = v"16" DenseArrayAttribute(values::AbstractArray{Float32}; context::Context=context()) = Attribute(API.mlirDenseF32ArrayGet(context, - length(values), - pointer(values))) -@llvmversioned min = v"16" DenseArrayAttribute(values::AbstractArray{Float64}; context::Context=context()) = Attribute(API.mlirDenseF64ArrayGet(context, - length(values), - pointer(values))) +@llvmversioned min = v"16" DenseArrayAttribute( + values::AbstractArray{Bool}; context::Context=context() +) = Attribute(API.mlirDenseBoolArrayGet(context, length(values), pointer(values))) +@llvmversioned min = v"16" DenseArrayAttribute( + values::AbstractArray{Int8}; context::Context=context() +) = Attribute(API.mlirDenseI8ArrayGet(context, length(values), pointer(values))) +@llvmversioned min = v"16" DenseArrayAttribute( + values::AbstractArray{Int16}; context::Context=context() +) = Attribute(API.mlirDenseI16ArrayGet(context, length(values), pointer(values))) +@llvmversioned min = v"16" DenseArrayAttribute( + values::AbstractArray{Int32}; context::Context=context() +) = Attribute(API.mlirDenseI32ArrayGet(context, length(values), pointer(values))) +@llvmversioned min = v"16" DenseArrayAttribute( + values::AbstractArray{Int64}; context::Context=context() +) = Attribute(API.mlirDenseI64ArrayGet(context, length(values), pointer(values))) +@llvmversioned min = v"16" DenseArrayAttribute( + values::AbstractArray{Float32}; context::Context=context() +) = Attribute(API.mlirDenseF32ArrayGet(context, length(values), pointer(values))) +@llvmversioned min = v"16" DenseArrayAttribute( + values::AbstractArray{Float64}; context::Context=context() +) = Attribute(API.mlirDenseF64ArrayGet(context, length(values), pointer(values))) @llvmversioned min = v"16" Attribute(values::AbstractArray) = DenseArrayAttribute(values) @@ -655,8 +674,9 @@ function Base.length(attr::Attribute) elseif iselements(attr) API.mlirElementsAttrGetNumElements(attr) else - _isdensearray = any(T -> isdensearray(attr, T), - [Bool, Int8, Int16, Int32, Int64, Float32, Float64]) + _isdensearray = any( + T -> isdensearray(attr, T), [Bool, Int8, Int16, Int32, Int64, Float32, Float64] + ) if _isdensearray API.mlirDenseBoolArrayGetNumElements(attr) end diff --git a/src/mlir/IR/Block.jl b/src/mlir/IR/Block.jl index 6e3752441..5b16edd00 100644 --- a/src/mlir/IR/Block.jl +++ b/src/mlir/IR/Block.jl @@ -82,9 +82,8 @@ end Appends an argument of the specified type to the block. Returns the newly added argument. """ -push_argument!(block::Block, type; location::Location=Location()) = Value(API.mlirBlockAddArgument(block, - type, - location)) +push_argument!(block::Block, type; location::Location=Location()) = + Value(API.mlirBlockAddArgument(block, type, location)) """ first_op(block) @@ -175,7 +174,7 @@ function activate!(blk::Block) return Block[] end Base.push!(stack, blk) - return + return nothing end function deactivate!(blk::Block) diff --git a/src/mlir/IR/Context.jl b/src/mlir/IR/Context.jl index 9160de9f7..ddd772496 100644 --- a/src/mlir/IR/Context.jl +++ b/src/mlir/IR/Context.jl @@ -39,7 +39,7 @@ function activate!(ctx::Context) return Context[] end Base.push!(stack, ctx) - return + return nothing end function deactivate!(ctx::Context) diff --git a/src/mlir/IR/ExecutionEngine.jl b/src/mlir/IR/ExecutionEngine.jl index 7fb2baecb..76395e964 100644 --- a/src/mlir/IR/ExecutionEngine.jl +++ b/src/mlir/IR/ExecutionEngine.jl @@ -18,10 +18,17 @@ LLVM passes at `optLevel` are run before code generation. The number and array of paths corresponding to shared libraries that will be loaded are specified via `numPaths` and `sharedLibPaths` respectively. TODO: figure out other options. """ -function ExecutionEngine(mod::Module, optLevel::Int, sharedlibs::Vector{String}=String[], - enableObjectDump::Bool=false) - return ExecutionEngine(API.mlirExecutionEngineCreate(mod, optLevel, length(sharedlibs), - sharedlibs, enableObjectDump)) +function ExecutionEngine( + mod::Module, + optLevel::Int, + sharedlibs::Vector{String}=String[], + enableObjectDump::Bool=false, +) + return ExecutionEngine( + API.mlirExecutionEngineCreate( + mod, optLevel, length(sharedlibs), sharedlibs, enableObjectDump + ), + ) end Base.convert(::Core.Type{API.MlirExecutionEngine}, engine::ExecutionEngine) = engine.engine @@ -34,8 +41,11 @@ Base.convert(::Core.Type{API.MlirExecutionEngine}, engine::ExecutionEngine) = en Lookup a native function in the execution engine by name, returns nullptr if the name can't be looked-up. """ function lookup(jit::ExecutionEngine, name::String; packed::Bool=false) - fn = packed ? API.mlirExecutionEngineLookupPacked(jit, name) : - API.mlirExecutionEngineLookup(jit, name) + fn = if packed + API.mlirExecutionEngineLookupPacked(jit, name) + else + API.mlirExecutionEngineLookup(jit, name) + end return fn == C_NULL ? nothing : fn end @@ -46,5 +56,5 @@ end Dump as an object in `fileName`. """ -Base.write(filename::String, jit::ExecutionEngine) = API.mlirExecutionEngineDumpToObjectFile(jit, - filename) +Base.write(filename::String, jit::ExecutionEngine) = + API.mlirExecutionEngineDumpToObjectFile(jit, filename) diff --git a/src/mlir/IR/IR.jl b/src/mlir/IR/IR.jl index 8cfcd1211..c66242422 100644 --- a/src/mlir/IR/IR.jl +++ b/src/mlir/IR/IR.jl @@ -7,8 +7,19 @@ using ..API export Attribute, Block, Context, Dialect, Location, Operation, Region, Value export activate!, deactivate!, dispose!, enable_multithreading!, context! export context, type, type!, location, typeid, block, dialect -export nattrs, attr, attr!, rmattr!, nregions, region, nresults, result, noperands, operand, - operand!, nsuccessors, successor +export nattrs, + attr, + attr!, + rmattr!, + nregions, + region, + nresults, + result, + noperands, + operand, + operand!, + nsuccessors, + successor export BlockIterator, RegionIterator, OperationIterator export @affinemap @@ -32,7 +43,7 @@ macro llvmversioned(pred, expr) version = eval(version) if predname == :min && VersionNumber(19) >= version || - predname == :max && VersionNumber(19) <= version + predname == :max && VersionNumber(19) <= version esc(expr) else esc(:(nothing)) @@ -73,8 +84,9 @@ function Base.cconvert(::Core.Type{API.MlirStringRef}, s::AbstractString) end # Directly create `MlirStringRef` instead of adding an extra ccall. -function Base.unsafe_convert(::Core.Type{API.MlirStringRef}, - s::Union{Symbol,String,AbstractVector{UInt8}}) +function Base.unsafe_convert( + ::Core.Type{API.MlirStringRef}, s::Union{Symbol,String,AbstractVector{UInt8}} +) p = Base.unsafe_convert(Ptr{Cchar}, s) return API.MlirStringRef(p, sizeof(s)) end diff --git a/src/mlir/IR/Identifier.jl b/src/mlir/IR/Identifier.jl index d09ca81c4..2f58836b6 100644 --- a/src/mlir/IR/Identifier.jl +++ b/src/mlir/IR/Identifier.jl @@ -7,8 +7,8 @@ end Gets an identifier with the given string value. """ -Identifier(str::String; context::Context=context()) = Identifier(API.mlirIdentifierGet(context, - str)) +Identifier(str::String; context::Context=context()) = + Identifier(API.mlirIdentifierGet(context, str)) Base.convert(::Core.Type{API.MlirIdentifier}, id::Identifier) = id.identifier diff --git a/src/mlir/IR/IntegerSet.jl b/src/mlir/IR/IntegerSet.jl index 7fc10c0ac..e7f89cbc1 100644 --- a/src/mlir/IR/IntegerSet.jl +++ b/src/mlir/IR/IntegerSet.jl @@ -12,9 +12,8 @@ end Gets or creates a new canonically empty integer set with the give number of dimensions and symbols in the given context. """ -IntegerSet(ndims, nsymbols; context::Context=context()) = IntegerSet(API.mlirIntegerSetEmptyGet(context, - ndims, - nsymbols)) +IntegerSet(ndims, nsymbols; context::Context=context()) = + IntegerSet(API.mlirIntegerSetEmptyGet(context, ndims, nsymbols)) """ IntegerSet(ndims, nsymbols, constraints, eqflags; context=context()) @@ -23,12 +22,16 @@ Gets or creates a new integer set in the given context. The set is defined by a list of affine constraints, with the given number of input dimensions and symbols, which are treated as either equalities (eqflags is 1) or inequalities (eqflags is 0). Both `constraints` and `eqflags` need to be arrays of the same length. """ -IntegerSet(ndims, nsymbols, constraints, eqflags; context::Context=context()) = IntegerSet(API.mlirIntegerSetGet(context, - ndims, - nsymbols, - length(constraints), - pointer(constraints), - pointer(eqflags))) +IntegerSet(ndims, nsymbols, constraints, eqflags; context::Context=context()) = IntegerSet( + API.mlirIntegerSetGet( + context, + ndims, + nsymbols, + length(constraints), + pointer(constraints), + pointer(eqflags), + ), +) """ mlirIntegerSetReplaceGet(set, dimReplacements, symbolReplacements, numResultDims, numResultSymbols) @@ -37,11 +40,15 @@ Gets or creates a new integer set in which the values and dimensions of the give `dimReplacements` and `symbolReplacements` are expected to point to at least as many consecutive expressions as the given set has dimensions and symbols, respectively. The new set will have `numResultDims` and `numResultSymbols` dimensions and symbols, respectively. """ -Base.replace(set::IntegerSet, dim_replacements, symbol_replacements) = IntegerSet(API.mlirIntegerSetReplaceGet(set, - dim_replacements, - symbol_replacements, - length(dim_replacements), - length(symbol_replacements))) +Base.replace(set::IntegerSet, dim_replacements, symbol_replacements) = IntegerSet( + API.mlirIntegerSetReplaceGet( + set, + dim_replacements, + symbol_replacements, + length(dim_replacements), + length(symbol_replacements), + ), +) Base.convert(::Core.Type{API.MlirIntegerSet}, set::IntegerSet) = set.set diff --git a/src/mlir/IR/Location.jl b/src/mlir/IR/Location.jl index 32e848b78..88f403006 100644 --- a/src/mlir/IR/Location.jl +++ b/src/mlir/IR/Location.jl @@ -23,8 +23,9 @@ end # TODO rename to merge? function fuse(locations::Vector{Location}, metadata; context::Context=context()) - return Location(API.mlirLocationFusedGet(context, length(locations), pointer(locations), - metadata)) + return Location( + API.mlirLocationFusedGet(context, length(locations), pointer(locations), metadata) + ) end Base.convert(::Core.Type{API.MlirLocation}, location::Location) = location.location diff --git a/src/mlir/IR/Module.jl b/src/mlir/IR/Module.jl index 702a03bb4..fc10d0593 100644 --- a/src/mlir/IR/Module.jl +++ b/src/mlir/IR/Module.jl @@ -23,8 +23,8 @@ Base.convert(::Core.Type{API.MlirModule}, module_::Module) = module_.module_ Parses a module from the string and transfers ownership to the caller. """ -Base.parse(::Core.Type{Module}, module_; context::Context=context()) = Module(API.mlirModuleCreateParse(context, - module_)) +Base.parse(::Core.Type{Module}, module_; context::Context=context()) = + Module(API.mlirModuleCreateParse(context, module_)) macro mlir_str(code) quote @@ -66,7 +66,7 @@ function activate!(blk::Module) return Module[] end Base.push!(stack, blk) - return + return nothing end function deactivate!(blk::Module) diff --git a/src/mlir/IR/Operation.jl b/src/mlir/IR/Operation.jl index 2fc7938f1..21b3f0b76 100644 --- a/src/mlir/IR/Operation.jl +++ b/src/mlir/IR/Operation.jl @@ -65,8 +65,8 @@ block(operation::Operation) = Block(API.mlirOperationGetBlock(operation), false) Gets the operation that owns this operation, returning null if the operation is not owned. """ -parent_op(operation::Operation) = Operation(API.mlirOperationGetParentOperation(operation), - false) +parent_op(operation::Operation) = + Operation(API.mlirOperationGetParentOperation(operation), false) """ rmfromparent(op) @@ -207,8 +207,8 @@ end Removes an attribute by name. Returns false if the attribute was not found and true if removed. """ -rmattr!(operation::Operation, name) = API.mlirOperationRemoveAttributeByName(operation, - name) +rmattr!(operation::Operation, name) = + API.mlirOperationRemoveAttributeByName(operation, name) function lose_ownership!(operation::Operation) @assert operation.owned @@ -267,16 +267,19 @@ end Returns whether the given fully-qualified operation (i.e. 'dialect.operation') is registered with the context. This will return true if the dialect is loaded and the operation is registered within the dialect. """ -is_registered(opname; context::Context=context()) = API.mlirContextIsRegisteredOperation(context, - opname) - -function create_operation(name, loc; - results=nothing, - operands=nothing, - owned_regions=nothing, - successors=nothing, - attributes=nothing, - result_inference=isnothing(results)) +is_registered(opname; context::Context=context()) = + API.mlirContextIsRegisteredOperation(context, opname) + +function create_operation( + name, + loc; + results=nothing, + operands=nothing, + owned_regions=nothing, + successors=nothing, + attributes=nothing, + result_inference=isnothing(results), +) GC.@preserve name loc begin state = Ref(API.mlirOperationStateGet(name, loc)) if !isnothing(results) @@ -292,16 +295,15 @@ function create_operation(name, loc; lose_ownership!.(owned_regions) GC.@preserve owned_regions begin mlir_regions = Base.unsafe_convert.(API.MlirRegion, owned_regions) - API.mlirOperationStateAddOwnedRegions(state, length(mlir_regions), - mlir_regions) + API.mlirOperationStateAddOwnedRegions( + state, length(mlir_regions), mlir_regions + ) end end if !isnothing(successors) GC.@preserve successors begin mlir_blocks = Base.unsafe_convert.(API.MlirBlock, successors) - API.mlirOperationStateAddSuccessors(state, - length(mlir_blocks), - mlir_blocks) + API.mlirOperationStateAddSuccessors(state, length(mlir_blocks), mlir_blocks) end end if !isnothing(attributes) diff --git a/src/mlir/IR/Pass.jl b/src/mlir/IR/Pass.jl index 5c5faf5c4..e6ef13ac7 100644 --- a/src/mlir/IR/Pass.jl +++ b/src/mlir/IR/Pass.jl @@ -30,8 +30,8 @@ PassManager(; context::Context=context()) = PassManager(API.mlirPassManagerCreat Create a new top-level PassManager anchored on `anchorOp`. """ -PassManager(anchor_op::Operation; context::Context=context()) = PassManager(API.mlirPassManagerCreateOnOperation(context, - anchor_op)) +PassManager(anchor_op::Operation; context::Context=context()) = + PassManager(API.mlirPassManagerCreateOnOperation(context, anchor_op)) Base.convert(::Core.Type{API.MlirPassManager}, pass::PassManager) = pass.pass @@ -83,8 +83,8 @@ end Cast a top-level `PassManager` to a generic `OpPassManager`. """ -OpPassManager(pm::PassManager) = OpPassManager(API.mlirPassManagerGetAsOpPassManager(pm), - pm) +OpPassManager(pm::PassManager) = + OpPassManager(API.mlirPassManagerGetAsOpPassManager(pm), pm) """ OpPassManager(passManager, operationName) @@ -92,18 +92,16 @@ OpPassManager(pm::PassManager) = OpPassManager(API.mlirPassManagerGetAsOpPassMan Nest an `OpPassManager` under the top-level PassManager, the nested passmanager will only run on operations matching the provided name. The returned `OpPassManager` will be destroyed when the parent is destroyed. To further nest more `OpPassManager` under the newly returned one, see `mlirOpPassManagerNest` below. """ -OpPassManager(pm::PassManager, opname) = OpPassManager(API.mlirPassManagerGetNestedUnder(pm, - opname), - pm) +OpPassManager(pm::PassManager, opname) = + OpPassManager(API.mlirPassManagerGetNestedUnder(pm, opname), pm) """ OpPassManager(opPassManager, operationName) Nest an `OpPassManager` under the provided `OpPassManager`, the nested passmanager will only run on operations matching the provided name. The returned `OpPassManager` will be destroyed when the parent is destroyed. """ -OpPassManager(opm::OpPassManager, opname) = OpPassManager(API.mlirOpPassManagerGetNestedUnder(opm, - opname), - opm.pass) +OpPassManager(opm::OpPassManager, opname) = + OpPassManager(API.mlirOpPassManagerGetNestedUnder(opm, opname), opm.pass) Base.convert(::Core.Type{API.MlirOpPassManager}, op_pass::OpPassManager) = op_pass.op_pass @@ -151,15 +149,15 @@ end Parse a textual MLIR pass pipeline and add it to the provided `OpPassManager`. """ function Base.parse(opm::OpPassManager, pipeline::String) - result = LogicalResult(if true - io = IOBuffer() - c_print_callback = @cfunction(print_callback, Cvoid, - (API.MlirStringRef, Any)) - API.mlirParsePassPipeline(opm, pipeline, c_print_callback, - Ref(io)) - else - API.mlirParsePassPipeline(opm, pipeline) - end) + result = LogicalResult( + if true + io = IOBuffer() + c_print_callback = @cfunction(print_callback, Cvoid, (API.MlirStringRef, Any)) + API.mlirParsePassPipeline(opm, pipeline, c_print_callback, Ref(io)) + else + API.mlirParsePassPipeline(opm, pipeline) + end, + ) if isfailure(result) throw(AddPipelineException(String(take!(io)))) @@ -176,8 +174,9 @@ function add_pipeline!(op_pass::OpPassManager, pipeline) @static if isdefined(API, :mlirOpPassManagerAddPipeline) io = IOBuffer() c_print_callback = @cfunction(print_callback, Cvoid, (API.MlirStringRef, Any)) - result = LogicalResult(API.mlirOpPassManagerAddPipeline(op_pass, pipeline, - c_print_callback, Ref(io))) + result = LogicalResult( + API.mlirOpPassManagerAddPipeline(op_pass, pipeline, c_print_callback, Ref(io)) + ) if isfailure(result) exc = AddPipelineException(String(take!(io))) throw(exc) @@ -236,26 +235,36 @@ end function create_external_pass!(oppass::OpPassManager, args...) return create_external_pass!(oppass.pass, args...) end - function create_external_pass!(manager, pass, name, argument, - description, opname=opname(pass), - dependent_dialects=API.MlirDialectHandle[]) + function create_external_pass!( + manager, + pass, + name, + argument, + description, + opname=opname(pass), + dependent_dialects=API.MlirDialectHandle[], + ) passid = TypeID(manager.allocator) - callbacks = API.MlirExternalPassCallbacks(@cfunction(_pass_construct, Cvoid, - (Any,)), - @cfunction(_pass_destruct, Cvoid, (Any,)), - @cfunction(_pass_initialize, - API.MlirLogicalResult, - (API.MlirContext, Any)), - @cfunction(_pass_clone, Any, (Any,)), - @cfunction(_pass_run, Cvoid, - (API.MlirOperation, - API.MlirExternalPass, Any))) + callbacks = API.MlirExternalPassCallbacks( + @cfunction(_pass_construct, Cvoid, (Any,)), + @cfunction(_pass_destruct, Cvoid, (Any,)), + @cfunction(_pass_initialize, API.MlirLogicalResult, (API.MlirContext, Any)), + @cfunction(_pass_clone, Any, (Any,)), + @cfunction(_pass_run, Cvoid, (API.MlirOperation, API.MlirExternalPass, Any)) + ) pass_handle = manager.passes[passid] = ExternalPassHandle(nothing, pass) userdata = Base.pointer_from_objref(pass_handle) - mlir_pass = API.mlirCreateExternalPass(passid, name, argument, description, opname, - length(dependent_dialects), - dependent_dialects, - callbacks, userdata) + mlir_pass = API.mlirCreateExternalPass( + passid, + name, + argument, + description, + opname, + length(dependent_dialects), + dependent_dialects, + callbacks, + userdata, + ) return mlir_pass end end diff --git a/src/mlir/IR/Region.jl b/src/mlir/IR/Region.jl index 394666d37..cd1a8c0d6 100644 --- a/src/mlir/IR/Region.jl +++ b/src/mlir/IR/Region.jl @@ -59,18 +59,16 @@ end Takes a block owned by the caller and inserts it after the (non-owned) reference block in the given region. The reference block must belong to the region. If the reference block is null, prepends the block to the region. """ -insert_after!(region::Region, reference::Block, block::Block) = API.mlirRegionInsertOwnedBlockAfter(region, - reference, - lose_ownership!(block)) +insert_after!(region::Region, reference::Block, block::Block) = + API.mlirRegionInsertOwnedBlockAfter(region, reference, lose_ownership!(block)) """ insert_before!(region, reference, block) Takes a block owned by the caller and inserts it before the (non-owned) reference block in the given region. The reference block must belong to the region. If the reference block is null, appends the block to the region. """ -insert_before!(region::Region, reference::Block, block::Block) = API.mlirRegionInsertOwnedBlockBefore(region, - reference, - lose_ownership!(block)) +insert_before!(region::Region, reference::Block, block::Block) = + API.mlirRegionInsertOwnedBlockBefore(region, reference, lose_ownership!(block)) """ first_block(region) diff --git a/src/mlir/IR/SymbolTable.jl b/src/mlir/IR/SymbolTable.jl index 34eb09936..4bda62c5c 100644 --- a/src/mlir/IR/SymbolTable.jl +++ b/src/mlir/IR/SymbolTable.jl @@ -25,8 +25,8 @@ Base.convert(::Core.Type{API.MlirSymbolTable}, st::SymbolTable) = st.st Looks up a symbol with the given name in the given symbol table and returns the operation that corresponds to the symbol. If the symbol cannot be found, returns a null operation. """ -lookup(st::SymbolTable, name::AbstractString) = Operation(API.mlirSymbolTableLookup(st, - name)) +lookup(st::SymbolTable, name::AbstractString) = + Operation(API.mlirSymbolTableLookup(st, name)) Base.getindex(st::SymbolTable, name::AbstractString) = lookup(st, name) """ diff --git a/src/mlir/IR/Type.jl b/src/mlir/IR/Type.jl index 848fb770a..0d3129ce0 100644 --- a/src/mlir/IR/Type.jl +++ b/src/mlir/IR/Type.jl @@ -14,8 +14,8 @@ Base.convert(::Core.Type{API.MlirType}, type::Type) = type.type Parses a type. The type is owned by the context. """ -Base.parse(::Core.Type{Type}, s; context::Context=context()) = Type(API.mlirTypeParseGet(context, - s)) +Base.parse(::Core.Type{Type}, s; context::Context=context()) = + Type(API.mlirTypeParseGet(context, s)) """ ==(t1, t2) @@ -73,8 +73,8 @@ isindex(type::Type) = API.mlirTypeIsAIndex(type) Creates a 1-bit signless integer type in the context. The type is owned by the context. """ -Type(::Core.Type{Bool}; context::Context=context()) = Type(API.mlirIntegerTypeGet(context, - 1)) +Type(::Core.Type{Bool}; context::Context=context()) = + Type(API.mlirIntegerTypeGet(context, 1)) # Integer types """ @@ -82,27 +82,24 @@ Type(::Core.Type{Bool}; context::Context=context()) = Type(API.mlirIntegerTypeGe Creates a signless integer type of the given bitwidth in the context. The type is owned by the context. """ -Type(T::Core.Type{<:Integer}; context::Context=context()) = Type(API.mlirIntegerTypeGet(context, - sizeof(T) * - 8)) +Type(T::Core.Type{<:Integer}; context::Context=context()) = + Type(API.mlirIntegerTypeGet(context, sizeof(T) * 8)) """ Type(T::Core.Type{<:Signed}; context=context() Creates a signed integer type of the given bitwidth in the context. The type is owned by the context. """ -Type(T::Core.Type{<:Signed}; context::Context=context()) = Type(API.mlirIntegerTypeGet(context, - sizeof(T) * - 8)) +Type(T::Core.Type{<:Signed}; context::Context=context()) = + Type(API.mlirIntegerTypeGet(context, sizeof(T) * 8)) """ Type(T::Core.Type{<:Unsigned}; context=context() Creates an unsigned integer type of the given bitwidth in the context. The type is owned by the context. """ -Type(T::Core.Type{<:Unsigned}; context::Context=context()) = Type(API.mlirIntegerTypeUnsignedGet(context, - sizeof(T) * - 8)) +Type(T::Core.Type{<:Unsigned}; context::Context=context()) = + Type(API.mlirIntegerTypeUnsignedGet(context, sizeof(T) * 8)) """ isinteger(type) @@ -328,10 +325,16 @@ dynstrideoroffset() = API.mlirShapedTypeGetDynamicStrideOrOffset() Creates a vector type of the shape identified by its rank and dimensions, with the given element type in the same context as the element type. The type is owned by the context. If `check=true`, emits appropriate diagnostics on illegal arguments. """ -function VectorType(rank, shape, elem_type; location::Location=Location(), - check::Bool=false) - return Type(check ? API.mlirVectorTypeGetChecked(location, rank, shape, elem_type) : - API.mlirVectorTypeGet(rank, shape, elem_type)) +function VectorType( + rank, shape, elem_type; location::Location=Location(), check::Bool=false +) + return Type( + if check + API.mlirVectorTypeGetChecked(location, rank, shape, elem_type) + else + API.mlirVectorTypeGet(rank, shape, elem_type) + end, + ) end """ @@ -349,14 +352,18 @@ Creates a tensor type of a fixed rank with the given shape, element type, and op The type is owned by the context. Tensor types without any specific encoding field should assign [`mlirAttributeGetNull`](@ref) to this parameter. If `check=true`, emits appropriate diagnostics on illegal arguments. """ -function TensorType(shape, elem_type, encoding=Attribute(); location::Location=Location(), - check::Bool=false) +function TensorType( + shape, elem_type, encoding=Attribute(); location::Location=Location(), check::Bool=false +) rank = length(shape) shape = shape isa AbstractVector ? shape : collect(shape) - return Type(check ? - API.mlirRankedTensorTypeGetChecked(location, rank, shape, elem_type, - encoding) : - API.mlirRankedTensorTypeGet(rank, shape, elem_type, encoding)) + return Type( + if check + API.mlirRankedTensorTypeGetChecked(location, rank, shape, elem_type, encoding) + else + API.mlirRankedTensorTypeGet(rank, shape, elem_type, encoding) + end, + ) end """ @@ -366,8 +373,13 @@ Creates an unranked tensor type with the given element type in the same context If `check=true`, emits appropriate diagnostics on illegal arguments. """ function TensorType(elem_type; location::Location=Location(), check::Bool=false) - return Type(check ? API.mlirUnrankedTensorTypeGetChecked(location, elem_type) : - API.mlirUnrankedTensorTypeGet(elem_type)) + return Type( + if check + API.mlirUnrankedTensorTypeGetChecked(location, elem_type) + else + API.mlirUnrankedTensorTypeGet(elem_type) + end, + ) end # TODO maybe add these helper methods? @@ -426,14 +438,26 @@ end Creates a MemRef type with the given rank and shape, a potentially empty list of affine layout maps, the given memory space and element type, in the same context as element type. The type is owned by the context. If `check=true`, emits appropriate diagnostics on illegal arguments. """ -function MemRefType(elem_type::Type, shape, layout, memspace; location::Location=Location(), - check::Bool=false) +function MemRefType( + elem_type::Type, + shape, + layout, + memspace; + location::Location=Location(), + check::Bool=false, +) if check - Type(API.mlirMemRefTypeGetChecked(location, elem_type, length(shape), - pointer(shape), layout, memspace)) + Type( + API.mlirMemRefTypeGetChecked( + location, elem_type, length(shape), pointer(shape), layout, memspace + ), + ) else - Type(API.mlirMemRefTypeGet(elem_type, length(shape), pointer(shape), layout, - memspace)) + Type( + API.mlirMemRefTypeGet( + elem_type, length(shape), pointer(shape), layout, memspace + ), + ) end end @@ -444,14 +468,21 @@ Creates a MemRef type with the given rank, shape, memory space and element type The type has no affine maps, i.e. represents a default row-major contiguous memref. The type is owned by the context. If `check=true`, emits appropriate diagnostics on illegal arguments. """ -function MemRefType(elem_type::Type, shape, memspace; location::Location=Location(), - check::Bool=false) +function MemRefType( + elem_type::Type, shape, memspace; location::Location=Location(), check::Bool=false +) if check - Type(API.mlirMemRefTypeContiguousGetChecked(location, elem_type, length(shape), - pointer(shape), memspace)) + Type( + API.mlirMemRefTypeContiguousGetChecked( + location, elem_type, length(shape), pointer(shape), memspace + ), + ) else - Type(API.mlirMemRefTypeContiguousGet(elem_type, length(shape), pointer(shape), - memspace)) + Type( + API.mlirMemRefTypeContiguousGet( + elem_type, length(shape), pointer(shape), memspace + ), + ) end end @@ -461,8 +492,9 @@ end Creates an Unranked MemRef type with the given element type and in the given memory space. The type is owned by the context of element type. If `check=true`, emits appropriate diagnostics on illegal arguments. """ -function MemRefType(elem_type::Type, memspace; location::Location=Location(), - check::Bool=false) +function MemRefType( + elem_type::Type, memspace; location::Location=Location(), check::Bool=false +) if check Type(API.mlirUnrankedMemRefTypeGetChecked(location, elem_type, memspace)) else @@ -527,9 +559,8 @@ end Creates a tuple type that consists of the given list of elemental types. The type is owned by the context. """ -Type(elements::Vector{Type}; context::Context=context()) = Type(API.mlirTupleTypeGet(context, - length(elements), - pointer(elements))) +Type(elements::Vector{Type}; context::Context=context()) = + Type(API.mlirTupleTypeGet(context, length(elements), pointer(elements))) function Type(@nospecialize(elements::NTuple{N,Type}); context::Context=context()) where {N} return Type(collect(elements); context) end @@ -558,8 +589,11 @@ isfunction(type::Type) = API.mlirTypeIsAFunction(type) Creates a function type, mapping a list of input types to result types. """ function FunctionType(inputs, results; context::Context=context()) - return Type(API.mlirFunctionTypeGet(context, length(inputs), pointer(inputs), - length(results), pointer(results))) + return Type( + API.mlirFunctionTypeGet( + context, length(inputs), pointer(inputs), length(results), pointer(results) + ), + ) end # TODO maybe add this helper method? @@ -613,9 +647,8 @@ end Creates an opaque type in the given context associated with the dialect identified by its namespace. The type contains opaque byte data of the specified length (data need not be null-terminated). """ -OpaqueType(namespace, data; context::Context=context()) = Type(API.mlirOpaqueTypeGet(context, - namespace, - data)) +OpaqueType(namespace, data; context::Context=context()) = + Type(API.mlirOpaqueTypeGet(context, namespace, data)) """ isopaque(type) diff --git a/src/mlir/MLIR.jl b/src/mlir/MLIR.jl index 28bb74ce9..71d11eecd 100644 --- a/src/mlir/MLIR.jl +++ b/src/mlir/MLIR.jl @@ -1,16 +1,16 @@ module MLIR module API -using CEnum -using Preferences -using Reactant_jll + using CEnum + using Preferences + using Reactant_jll -const mlir_c = Reactant_jll.libReactantExtra + const mlir_c = Reactant_jll.libReactantExtra -# MLIR C API -let - include("libMLIR_h.jl") -end + # MLIR C API + let + include("libMLIR_h.jl") + end end # module API include("IR/IR.jl") diff --git a/src/mlir/libMLIR_h.jl b/src/mlir/libMLIR_h.jl index ea3f9c25e..398566859 100644 --- a/src/mlir/libMLIR_h.jl +++ b/src/mlir/libMLIR_h.jl @@ -235,7 +235,9 @@ end Allocates a type id that is valid for the lifetime of the allocator """ function mlirTypeIDAllocatorAllocateTypeID(allocator) - @ccall mlir_c.mlirTypeIDAllocatorAllocateTypeID(allocator::MlirTypeIDAllocator)::MlirTypeID + @ccall mlir_c.mlirTypeIDAllocatorAllocateTypeID( + allocator::MlirTypeIDAllocator + )::MlirTypeID end struct MlirAsmState @@ -342,8 +344,9 @@ end Creates an MLIR context, setting the multithreading setting explicitly and pre-loading the dialects from the provided DialectRegistry. """ function mlirContextCreateWithRegistry(registry, threadingEnabled) - @ccall mlir_c.mlirContextCreateWithRegistry(registry::MlirDialectRegistry, - threadingEnabled::Bool)::MlirContext + @ccall mlir_c.mlirContextCreateWithRegistry( + registry::MlirDialectRegistry, threadingEnabled::Bool + )::MlirContext end """ @@ -379,8 +382,9 @@ end Sets whether unregistered dialects are allowed in this context. """ function mlirContextSetAllowUnregisteredDialects(context, allow) - @ccall mlir_c.mlirContextSetAllowUnregisteredDialects(context::MlirContext, - allow::Bool)::Cvoid + @ccall mlir_c.mlirContextSetAllowUnregisteredDialects( + context::MlirContext, allow::Bool + )::Cvoid end """ @@ -407,8 +411,9 @@ end Append the contents of the given dialect registry to the registry associated with the context. """ function mlirContextAppendDialectRegistry(ctx, registry) - @ccall mlir_c.mlirContextAppendDialectRegistry(ctx::MlirContext, - registry::MlirDialectRegistry)::Cvoid + @ccall mlir_c.mlirContextAppendDialectRegistry( + ctx::MlirContext, registry::MlirDialectRegistry + )::Cvoid end """ @@ -426,8 +431,9 @@ end Gets the dialect instance owned by the given context using the dialect namespace to identify it, loads (i.e., constructs the instance of) the dialect if necessary. If the dialect is not registered with the context, returns null. Use mlirContextLoadDialect to load an unregistered dialect. """ function mlirContextGetOrLoadDialect(context, name) - @ccall mlir_c.mlirContextGetOrLoadDialect(context::MlirContext, - name::MlirStringRef)::MlirDialect + @ccall mlir_c.mlirContextGetOrLoadDialect( + context::MlirContext, name::MlirStringRef + )::MlirDialect end """ @@ -454,8 +460,9 @@ end Returns whether the given fully-qualified operation (i.e. 'dialect.operation') is registered with the context. This will return true if the dialect is loaded and the operation is registered within the dialect. """ function mlirContextIsRegisteredOperation(context, name) - @ccall mlir_c.mlirContextIsRegisteredOperation(context::MlirContext, - name::MlirStringRef)::Bool + @ccall mlir_c.mlirContextIsRegisteredOperation( + context::MlirContext, name::MlirStringRef + )::Bool end """ @@ -464,8 +471,9 @@ end Sets the thread pool of the context explicitly, enabling multithreading in the process. This API should be used to avoid re-creating thread pools in long-running applications that perform multiple compilations, see the C++ documentation for MLIRContext for details. """ function mlirContextSetThreadPool(context, threadPool) - @ccall mlir_c.mlirContextSetThreadPool(context::MlirContext, - threadPool::MlirLlvmThreadPool)::Cvoid + @ccall mlir_c.mlirContextSetThreadPool( + context::MlirContext, threadPool::MlirLlvmThreadPool + )::Cvoid end """ @@ -519,8 +527,9 @@ end Inserts the dialect associated with the provided dialect handle into the provided dialect registry """ function mlirDialectHandleInsertDialect(arg1, arg2) - @ccall mlir_c.mlirDialectHandleInsertDialect(arg1::MlirDialectHandle, - arg2::MlirDialectRegistry)::Cvoid + @ccall mlir_c.mlirDialectHandleInsertDialect( + arg1::MlirDialectHandle, arg2::MlirDialectRegistry + )::Cvoid end """ @@ -529,8 +538,9 @@ end Registers the dialect associated with the provided dialect handle. """ function mlirDialectHandleRegisterDialect(arg1, arg2) - @ccall mlir_c.mlirDialectHandleRegisterDialect(arg1::MlirDialectHandle, - arg2::MlirContext)::Cvoid + @ccall mlir_c.mlirDialectHandleRegisterDialect( + arg1::MlirDialectHandle, arg2::MlirContext + )::Cvoid end """ @@ -539,8 +549,9 @@ end Loads the dialect associated with the provided dialect handle. """ function mlirDialectHandleLoadDialect(arg1, arg2) - @ccall mlir_c.mlirDialectHandleLoadDialect(arg1::MlirDialectHandle, - arg2::MlirContext)::MlirDialect + @ccall mlir_c.mlirDialectHandleLoadDialect( + arg1::MlirDialectHandle, arg2::MlirContext + )::MlirDialect end """ @@ -594,8 +605,9 @@ end Creates an File/Line/Column location owned by the given context. """ function mlirLocationFileLineColGet(context, filename, line, col) - @ccall mlir_c.mlirLocationFileLineColGet(context::MlirContext, filename::MlirStringRef, - line::Cuint, col::Cuint)::MlirLocation + @ccall mlir_c.mlirLocationFileLineColGet( + context::MlirContext, filename::MlirStringRef, line::Cuint, col::Cuint + )::MlirLocation end """ @@ -604,8 +616,9 @@ end Creates a call site location with a callee and a caller. """ function mlirLocationCallSiteGet(callee, caller) - @ccall mlir_c.mlirLocationCallSiteGet(callee::MlirLocation, - caller::MlirLocation)::MlirLocation + @ccall mlir_c.mlirLocationCallSiteGet( + callee::MlirLocation, caller::MlirLocation + )::MlirLocation end """ @@ -614,9 +627,12 @@ end Creates a fused location with an array of locations and metadata. """ function mlirLocationFusedGet(ctx, nLocations, locations, metadata) - @ccall mlir_c.mlirLocationFusedGet(ctx::MlirContext, nLocations::intptr_t, - locations::Ptr{MlirLocation}, - metadata::MlirAttribute)::MlirLocation + @ccall mlir_c.mlirLocationFusedGet( + ctx::MlirContext, + nLocations::intptr_t, + locations::Ptr{MlirLocation}, + metadata::MlirAttribute, + )::MlirLocation end """ @@ -625,8 +641,9 @@ end Creates a name location owned by the given context. Providing null location for childLoc is allowed and if childLoc is null location, then the behavior is the same as having unknown child location. """ function mlirLocationNameGet(context, name, childLoc) - @ccall mlir_c.mlirLocationNameGet(context::MlirContext, name::MlirStringRef, - childLoc::MlirLocation)::MlirLocation + @ccall mlir_c.mlirLocationNameGet( + context::MlirContext, name::MlirStringRef, childLoc::MlirLocation + )::MlirLocation end """ @@ -671,8 +688,9 @@ end Prints a location by sending chunks of the string representation and forwarding `userData to `callback`. Note that the callback may be called several times with consecutive chunks of the string. """ function mlirLocationPrint(location, callback, userData) - @ccall mlir_c.mlirLocationPrint(location::MlirLocation, callback::MlirStringCallback, - userData::Ptr{Cvoid})::Cvoid + @ccall mlir_c.mlirLocationPrint( + location::MlirLocation, callback::MlirStringCallback, userData::Ptr{Cvoid} + )::Cvoid end """ @@ -690,8 +708,9 @@ end Parses a module from the string and transfers ownership to the caller. """ function mlirModuleCreateParse(context, _module) - @ccall mlir_c.mlirModuleCreateParse(context::MlirContext, - _module::MlirStringRef)::MlirModule + @ccall mlir_c.mlirModuleCreateParse( + context::MlirContext, _module::MlirStringRef + )::MlirModule end """ @@ -777,8 +796,9 @@ end Constructs an operation state from a name and a location. """ function mlirOperationStateGet(name, loc) - @ccall mlir_c.mlirOperationStateGet(name::MlirStringRef, - loc::MlirLocation)::MlirOperationState + @ccall mlir_c.mlirOperationStateGet( + name::MlirStringRef, loc::MlirLocation + )::MlirOperationState end """ @@ -787,31 +807,33 @@ end Adds a list of components to the operation state. """ function mlirOperationStateAddResults(state, n, results) - @ccall mlir_c.mlirOperationStateAddResults(state::Ptr{MlirOperationState}, n::intptr_t, - results::Ptr{MlirType})::Cvoid + @ccall mlir_c.mlirOperationStateAddResults( + state::Ptr{MlirOperationState}, n::intptr_t, results::Ptr{MlirType} + )::Cvoid end function mlirOperationStateAddOperands(state, n, operands) - @ccall mlir_c.mlirOperationStateAddOperands(state::Ptr{MlirOperationState}, n::intptr_t, - operands::Ptr{MlirValue})::Cvoid + @ccall mlir_c.mlirOperationStateAddOperands( + state::Ptr{MlirOperationState}, n::intptr_t, operands::Ptr{MlirValue} + )::Cvoid end function mlirOperationStateAddOwnedRegions(state, n, regions) - @ccall mlir_c.mlirOperationStateAddOwnedRegions(state::Ptr{MlirOperationState}, - n::intptr_t, - regions::Ptr{MlirRegion})::Cvoid + @ccall mlir_c.mlirOperationStateAddOwnedRegions( + state::Ptr{MlirOperationState}, n::intptr_t, regions::Ptr{MlirRegion} + )::Cvoid end function mlirOperationStateAddSuccessors(state, n, successors) - @ccall mlir_c.mlirOperationStateAddSuccessors(state::Ptr{MlirOperationState}, - n::intptr_t, - successors::Ptr{MlirBlock})::Cvoid + @ccall mlir_c.mlirOperationStateAddSuccessors( + state::Ptr{MlirOperationState}, n::intptr_t, successors::Ptr{MlirBlock} + )::Cvoid end function mlirOperationStateAddAttributes(state, n, attributes) - @ccall mlir_c.mlirOperationStateAddAttributes(state::Ptr{MlirOperationState}, - n::intptr_t, - attributes::Ptr{MlirNamedAttribute})::Cvoid + @ccall mlir_c.mlirOperationStateAddAttributes( + state::Ptr{MlirOperationState}, n::intptr_t, attributes::Ptr{MlirNamedAttribute} + )::Cvoid end """ @@ -820,7 +842,9 @@ end Enables result type inference for the operation under construction. If enabled, then the caller must not have called [`mlirOperationStateAddResults`](@ref)(). Note that if enabled, the [`mlirOperationCreate`](@ref)() call is failable: it will return a null operation on inference failure and will emit diagnostics. """ function mlirOperationStateEnableResultTypeInference(state) - @ccall mlir_c.mlirOperationStateEnableResultTypeInference(state::Ptr{MlirOperationState})::Cvoid + @ccall mlir_c.mlirOperationStateEnableResultTypeInference( + state::Ptr{MlirOperationState} + )::Cvoid end """ @@ -829,8 +853,9 @@ end Creates new AsmState, as with AsmState the IR should not be mutated in-between using this state. Must be freed with a call to [`mlirAsmStateDestroy`](@ref)(). """ function mlirAsmStateCreateForOperation(op, flags) - @ccall mlir_c.mlirAsmStateCreateForOperation(op::MlirOperation, - flags::MlirOpPrintingFlags)::MlirAsmState + @ccall mlir_c.mlirAsmStateCreateForOperation( + op::MlirOperation, flags::MlirOpPrintingFlags + )::MlirAsmState end """ @@ -839,8 +864,9 @@ end Creates new AsmState from value. Must be freed with a call to [`mlirAsmStateDestroy`](@ref)(). """ function mlirAsmStateCreateForValue(value, flags) - @ccall mlir_c.mlirAsmStateCreateForValue(value::MlirValue, - flags::MlirOpPrintingFlags)::MlirAsmState + @ccall mlir_c.mlirAsmStateCreateForValue( + value::MlirValue, flags::MlirOpPrintingFlags + )::MlirAsmState end """ @@ -876,8 +902,9 @@ end Enables the elision of large elements attributes by printing a lexically valid but otherwise meaningless form instead of the element data. The `largeElementLimit` is used to configure what is considered to be a "large" ElementsAttr by providing an upper limit to the number of elements. """ function mlirOpPrintingFlagsElideLargeElementsAttrs(flags, largeElementLimit) - @ccall mlir_c.mlirOpPrintingFlagsElideLargeElementsAttrs(flags::MlirOpPrintingFlags, - largeElementLimit::intptr_t)::Cvoid + @ccall mlir_c.mlirOpPrintingFlagsElideLargeElementsAttrs( + flags::MlirOpPrintingFlags, largeElementLimit::intptr_t + )::Cvoid end """ @@ -886,8 +913,9 @@ end Enable or disable printing of debug information (based on `enable`). If 'prettyForm' is set to true, debug information is printed in a more readable 'pretty' form. Note: The IR generated with 'prettyForm' is not parsable. """ function mlirOpPrintingFlagsEnableDebugInfo(flags, enable, prettyForm) - @ccall mlir_c.mlirOpPrintingFlagsEnableDebugInfo(flags::MlirOpPrintingFlags, - enable::Bool, prettyForm::Bool)::Cvoid + @ccall mlir_c.mlirOpPrintingFlagsEnableDebugInfo( + flags::MlirOpPrintingFlags, enable::Bool, prettyForm::Bool + )::Cvoid end """ @@ -941,8 +969,9 @@ end Sets the version to emit in the writer config. """ function mlirBytecodeWriterConfigDesiredEmitVersion(flags, version) - @ccall mlir_c.mlirBytecodeWriterConfigDesiredEmitVersion(flags::MlirBytecodeWriterConfig, - version::Int64)::Cvoid + @ccall mlir_c.mlirBytecodeWriterConfigDesiredEmitVersion( + flags::MlirBytecodeWriterConfig, version::Int64 + )::Cvoid end """ @@ -964,8 +993,9 @@ Parses an operation, giving ownership to the caller. If parsing fails a null ope `sourceStr` may be either the text assembly format, or binary bytecode format. `sourceName` is used as the file name of the source; any IR without locations will get a `FileLineColLoc` location with `sourceName` as the file name. """ function mlirOperationCreateParse(context, sourceStr, sourceName) - @ccall mlir_c.mlirOperationCreateParse(context::MlirContext, sourceStr::MlirStringRef, - sourceName::MlirStringRef)::MlirOperation + @ccall mlir_c.mlirOperationCreateParse( + context::MlirContext, sourceStr::MlirStringRef, sourceName::MlirStringRef + )::MlirOperation end """ @@ -1118,8 +1148,9 @@ end Sets the `pos`-th operand of the operation. """ function mlirOperationSetOperand(op, pos, newValue) - @ccall mlir_c.mlirOperationSetOperand(op::MlirOperation, pos::intptr_t, - newValue::MlirValue)::Cvoid + @ccall mlir_c.mlirOperationSetOperand( + op::MlirOperation, pos::intptr_t, newValue::MlirValue + )::Cvoid end """ @@ -1128,8 +1159,9 @@ end Replaces the operands of the operation. """ function mlirOperationSetOperands(op, nOperands, operands) - @ccall mlir_c.mlirOperationSetOperands(op::MlirOperation, nOperands::intptr_t, - operands::Ptr{MlirValue})::Cvoid + @ccall mlir_c.mlirOperationSetOperands( + op::MlirOperation, nOperands::intptr_t, operands::Ptr{MlirValue} + )::Cvoid end """ @@ -1174,8 +1206,9 @@ end Set `pos`-th successor of the operation. """ function mlirOperationSetSuccessor(op, pos, block) - @ccall mlir_c.mlirOperationSetSuccessor(op::MlirOperation, pos::intptr_t, - block::MlirBlock)::Cvoid + @ccall mlir_c.mlirOperationSetSuccessor( + op::MlirOperation, pos::intptr_t, block::MlirBlock + )::Cvoid end """ @@ -1184,8 +1217,9 @@ end Returns true if this operation defines an inherent attribute with this name. Note: the attribute can be optional, so [`mlirOperationGetInherentAttributeByName`](@ref) can still return a null attribute. """ function mlirOperationHasInherentAttributeByName(op, name) - @ccall mlir_c.mlirOperationHasInherentAttributeByName(op::MlirOperation, - name::MlirStringRef)::Bool + @ccall mlir_c.mlirOperationHasInherentAttributeByName( + op::MlirOperation, name::MlirStringRef + )::Bool end """ @@ -1194,8 +1228,9 @@ end Returns an inherent attribute attached to the operation given its name. """ function mlirOperationGetInherentAttributeByName(op, name) - @ccall mlir_c.mlirOperationGetInherentAttributeByName(op::MlirOperation, - name::MlirStringRef)::MlirAttribute + @ccall mlir_c.mlirOperationGetInherentAttributeByName( + op::MlirOperation, name::MlirStringRef + )::MlirAttribute end """ @@ -1204,9 +1239,9 @@ end Sets an inherent attribute by name, replacing the existing if it exists. This has no effect if "name" does not match an inherent attribute. """ function mlirOperationSetInherentAttributeByName(op, name, attr) - @ccall mlir_c.mlirOperationSetInherentAttributeByName(op::MlirOperation, - name::MlirStringRef, - attr::MlirAttribute)::Cvoid + @ccall mlir_c.mlirOperationSetInherentAttributeByName( + op::MlirOperation, name::MlirStringRef, attr::MlirAttribute + )::Cvoid end """ @@ -1224,8 +1259,9 @@ end Return `pos`-th discardable attribute of the operation. """ function mlirOperationGetDiscardableAttribute(op, pos) - @ccall mlir_c.mlirOperationGetDiscardableAttribute(op::MlirOperation, - pos::intptr_t)::MlirNamedAttribute + @ccall mlir_c.mlirOperationGetDiscardableAttribute( + op::MlirOperation, pos::intptr_t + )::MlirNamedAttribute end """ @@ -1234,8 +1270,9 @@ end Returns a discardable attribute attached to the operation given its name. """ function mlirOperationGetDiscardableAttributeByName(op, name) - @ccall mlir_c.mlirOperationGetDiscardableAttributeByName(op::MlirOperation, - name::MlirStringRef)::MlirAttribute + @ccall mlir_c.mlirOperationGetDiscardableAttributeByName( + op::MlirOperation, name::MlirStringRef + )::MlirAttribute end """ @@ -1244,9 +1281,9 @@ end Sets a discardable attribute by name, replacing the existing if it exists or adding a new one otherwise. The new `attr` Attribute is not allowed to be null, use [`mlirOperationRemoveDiscardableAttributeByName`](@ref) to remove an Attribute instead. """ function mlirOperationSetDiscardableAttributeByName(op, name, attr) - @ccall mlir_c.mlirOperationSetDiscardableAttributeByName(op::MlirOperation, - name::MlirStringRef, - attr::MlirAttribute)::Cvoid + @ccall mlir_c.mlirOperationSetDiscardableAttributeByName( + op::MlirOperation, name::MlirStringRef, attr::MlirAttribute + )::Cvoid end """ @@ -1255,8 +1292,9 @@ end Removes a discardable attribute by name. Returns false if the attribute was not found and true if removed. """ function mlirOperationRemoveDiscardableAttributeByName(op, name) - @ccall mlir_c.mlirOperationRemoveDiscardableAttributeByName(op::MlirOperation, - name::MlirStringRef)::Bool + @ccall mlir_c.mlirOperationRemoveDiscardableAttributeByName( + op::MlirOperation, name::MlirStringRef + )::Bool end """ @@ -1274,8 +1312,9 @@ end Return `pos`-th attribute of the operation. Deprecated, please use `mlirOperationGetInherentAttribute` or [`mlirOperationGetDiscardableAttribute`](@ref). """ function mlirOperationGetAttribute(op, pos) - @ccall mlir_c.mlirOperationGetAttribute(op::MlirOperation, - pos::intptr_t)::MlirNamedAttribute + @ccall mlir_c.mlirOperationGetAttribute( + op::MlirOperation, pos::intptr_t + )::MlirNamedAttribute end """ @@ -1284,8 +1323,9 @@ end Returns an attribute attached to the operation given its name. Deprecated, please use [`mlirOperationGetInherentAttributeByName`](@ref) or [`mlirOperationGetDiscardableAttributeByName`](@ref). """ function mlirOperationGetAttributeByName(op, name) - @ccall mlir_c.mlirOperationGetAttributeByName(op::MlirOperation, - name::MlirStringRef)::MlirAttribute + @ccall mlir_c.mlirOperationGetAttributeByName( + op::MlirOperation, name::MlirStringRef + )::MlirAttribute end """ @@ -1294,8 +1334,9 @@ end Sets an attribute by name, replacing the existing if it exists or adding a new one otherwise. Deprecated, please use [`mlirOperationSetInherentAttributeByName`](@ref) or [`mlirOperationSetDiscardableAttributeByName`](@ref). """ function mlirOperationSetAttributeByName(op, name, attr) - @ccall mlir_c.mlirOperationSetAttributeByName(op::MlirOperation, name::MlirStringRef, - attr::MlirAttribute)::Cvoid + @ccall mlir_c.mlirOperationSetAttributeByName( + op::MlirOperation, name::MlirStringRef, attr::MlirAttribute + )::Cvoid end """ @@ -1304,8 +1345,9 @@ end Removes an attribute by name. Returns false if the attribute was not found and true if removed. Deprecated, please use `mlirOperationRemoveInherentAttributeByName` or [`mlirOperationRemoveDiscardableAttributeByName`](@ref). """ function mlirOperationRemoveAttributeByName(op, name) - @ccall mlir_c.mlirOperationRemoveAttributeByName(op::MlirOperation, - name::MlirStringRef)::Bool + @ccall mlir_c.mlirOperationRemoveAttributeByName( + op::MlirOperation, name::MlirStringRef + )::Bool end """ @@ -1314,8 +1356,9 @@ end Prints an operation by sending chunks of the string representation and forwarding `userData to `callback`. Note that the callback may be called several times with consecutive chunks of the string. """ function mlirOperationPrint(op, callback, userData) - @ccall mlir_c.mlirOperationPrint(op::MlirOperation, callback::MlirStringCallback, - userData::Ptr{Cvoid})::Cvoid + @ccall mlir_c.mlirOperationPrint( + op::MlirOperation, callback::MlirStringCallback, userData::Ptr{Cvoid} + )::Cvoid end """ @@ -1324,9 +1367,12 @@ end Same as [`mlirOperationPrint`](@ref) but accepts flags controlling the printing behavior. """ function mlirOperationPrintWithFlags(op, flags, callback, userData) - @ccall mlir_c.mlirOperationPrintWithFlags(op::MlirOperation, flags::MlirOpPrintingFlags, - callback::MlirStringCallback, - userData::Ptr{Cvoid})::Cvoid + @ccall mlir_c.mlirOperationPrintWithFlags( + op::MlirOperation, + flags::MlirOpPrintingFlags, + callback::MlirStringCallback, + userData::Ptr{Cvoid}, + )::Cvoid end """ @@ -1335,9 +1381,12 @@ end Same as [`mlirOperationPrint`](@ref) but accepts AsmState controlling the printing behavior as well as caching computed names. """ function mlirOperationPrintWithState(op, state, callback, userData) - @ccall mlir_c.mlirOperationPrintWithState(op::MlirOperation, state::MlirAsmState, - callback::MlirStringCallback, - userData::Ptr{Cvoid})::Cvoid + @ccall mlir_c.mlirOperationPrintWithState( + op::MlirOperation, + state::MlirAsmState, + callback::MlirStringCallback, + userData::Ptr{Cvoid}, + )::Cvoid end """ @@ -1346,9 +1395,9 @@ end Same as [`mlirOperationPrint`](@ref) but writing the bytecode format. """ function mlirOperationWriteBytecode(op, callback, userData) - @ccall mlir_c.mlirOperationWriteBytecode(op::MlirOperation, - callback::MlirStringCallback, - userData::Ptr{Cvoid})::Cvoid + @ccall mlir_c.mlirOperationWriteBytecode( + op::MlirOperation, callback::MlirStringCallback, userData::Ptr{Cvoid} + )::Cvoid end """ @@ -1357,10 +1406,12 @@ end Same as [`mlirOperationWriteBytecode`](@ref) but with writer config and returns failure only if desired bytecode could not be honored. """ function mlirOperationWriteBytecodeWithConfig(op, config, callback, userData) - @ccall mlir_c.mlirOperationWriteBytecodeWithConfig(op::MlirOperation, - config::MlirBytecodeWriterConfig, - callback::MlirStringCallback, - userData::Ptr{Cvoid})::MlirLogicalResult + @ccall mlir_c.mlirOperationWriteBytecodeWithConfig( + op::MlirOperation, + config::MlirBytecodeWriterConfig, + callback::MlirStringCallback, + userData::Ptr{Cvoid}, + )::MlirLogicalResult end """ @@ -1421,8 +1472,12 @@ const MlirOperationWalkCallback = Ptr{Cvoid} Walks operation `op` in `walkOrder` and calls `callback` on that operation. `*userData` is passed to the callback as well and can be used to tunnel some context or other data into the callback. """ function mlirOperationWalk(op, callback, userData, walkOrder) - @ccall mlir_c.mlirOperationWalk(op::MlirOperation, callback::MlirOperationWalkCallback, - userData::Ptr{Cvoid}, walkOrder::MlirWalkOrder)::Cvoid + @ccall mlir_c.mlirOperationWalk( + op::MlirOperation, + callback::MlirOperationWalkCallback, + userData::Ptr{Cvoid}, + walkOrder::MlirWalkOrder, + )::Cvoid end """ @@ -1485,8 +1540,9 @@ end Takes a block owned by the caller and inserts it at `pos` to the given region. This is an expensive operation that linearly scans the region, prefer insertAfter/Before instead. """ function mlirRegionInsertOwnedBlock(region, pos, block) - @ccall mlir_c.mlirRegionInsertOwnedBlock(region::MlirRegion, pos::intptr_t, - block::MlirBlock)::Cvoid + @ccall mlir_c.mlirRegionInsertOwnedBlock( + region::MlirRegion, pos::intptr_t, block::MlirBlock + )::Cvoid end """ @@ -1495,8 +1551,9 @@ end Takes a block owned by the caller and inserts it after the (non-owned) reference block in the given region. The reference block must belong to the region. If the reference block is null, prepends the block to the region. """ function mlirRegionInsertOwnedBlockAfter(region, reference, block) - @ccall mlir_c.mlirRegionInsertOwnedBlockAfter(region::MlirRegion, reference::MlirBlock, - block::MlirBlock)::Cvoid + @ccall mlir_c.mlirRegionInsertOwnedBlockAfter( + region::MlirRegion, reference::MlirBlock, block::MlirBlock + )::Cvoid end """ @@ -1505,8 +1562,9 @@ end Takes a block owned by the caller and inserts it before the (non-owned) reference block in the given region. The reference block must belong to the region. If the reference block is null, appends the block to the region. """ function mlirRegionInsertOwnedBlockBefore(region, reference, block) - @ccall mlir_c.mlirRegionInsertOwnedBlockBefore(region::MlirRegion, reference::MlirBlock, - block::MlirBlock)::Cvoid + @ccall mlir_c.mlirRegionInsertOwnedBlockBefore( + region::MlirRegion, reference::MlirBlock, block::MlirBlock + )::Cvoid end """ @@ -1542,8 +1600,9 @@ end Creates a new empty block with the given argument types and transfers ownership to the caller. """ function mlirBlockCreate(nArgs, args, locs) - @ccall mlir_c.mlirBlockCreate(nArgs::intptr_t, args::Ptr{MlirType}, - locs::Ptr{MlirLocation})::MlirBlock + @ccall mlir_c.mlirBlockCreate( + nArgs::intptr_t, args::Ptr{MlirType}, locs::Ptr{MlirLocation} + )::MlirBlock end """ @@ -1633,8 +1692,9 @@ end Takes an operation owned by the caller and appends it to the block. """ function mlirBlockAppendOwnedOperation(block, operation) - @ccall mlir_c.mlirBlockAppendOwnedOperation(block::MlirBlock, - operation::MlirOperation)::Cvoid + @ccall mlir_c.mlirBlockAppendOwnedOperation( + block::MlirBlock, operation::MlirOperation + )::Cvoid end """ @@ -1643,8 +1703,9 @@ end Takes an operation owned by the caller and inserts it as `pos` to the block. This is an expensive operation that scans the block linearly, prefer insertBefore/After instead. """ function mlirBlockInsertOwnedOperation(block, pos, operation) - @ccall mlir_c.mlirBlockInsertOwnedOperation(block::MlirBlock, pos::intptr_t, - operation::MlirOperation)::Cvoid + @ccall mlir_c.mlirBlockInsertOwnedOperation( + block::MlirBlock, pos::intptr_t, operation::MlirOperation + )::Cvoid end """ @@ -1653,9 +1714,9 @@ end Takes an operation owned by the caller and inserts it after the (non-owned) reference operation in the given block. If the reference is null, prepends the operation. Otherwise, the reference must belong to the block. """ function mlirBlockInsertOwnedOperationAfter(block, reference, operation) - @ccall mlir_c.mlirBlockInsertOwnedOperationAfter(block::MlirBlock, - reference::MlirOperation, - operation::MlirOperation)::Cvoid + @ccall mlir_c.mlirBlockInsertOwnedOperationAfter( + block::MlirBlock, reference::MlirOperation, operation::MlirOperation + )::Cvoid end """ @@ -1664,9 +1725,9 @@ end Takes an operation owned by the caller and inserts it before the (non-owned) reference operation in the given block. If the reference is null, appends the operation. Otherwise, the reference must belong to the block. """ function mlirBlockInsertOwnedOperationBefore(block, reference, operation) - @ccall mlir_c.mlirBlockInsertOwnedOperationBefore(block::MlirBlock, - reference::MlirOperation, - operation::MlirOperation)::Cvoid + @ccall mlir_c.mlirBlockInsertOwnedOperationBefore( + block::MlirBlock, reference::MlirOperation, operation::MlirOperation + )::Cvoid end """ @@ -1684,8 +1745,9 @@ end Appends an argument of the specified type to the block. Returns the newly added argument. """ function mlirBlockAddArgument(block, type, loc) - @ccall mlir_c.mlirBlockAddArgument(block::MlirBlock, type::MlirType, - loc::MlirLocation)::MlirValue + @ccall mlir_c.mlirBlockAddArgument( + block::MlirBlock, type::MlirType, loc::MlirLocation + )::MlirValue end """ @@ -1694,8 +1756,9 @@ end Inserts an argument of the specified type at a specified index to the block. Returns the newly added argument. """ function mlirBlockInsertArgument(block, pos, type, loc) - @ccall mlir_c.mlirBlockInsertArgument(block::MlirBlock, pos::intptr_t, type::MlirType, - loc::MlirLocation)::MlirValue + @ccall mlir_c.mlirBlockInsertArgument( + block::MlirBlock, pos::intptr_t, type::MlirType, loc::MlirLocation + )::MlirValue end """ @@ -1713,8 +1776,9 @@ end Prints a block by sending chunks of the string representation and forwarding `userData to `callback`. Note that the callback may be called several times with consecutive chunks of the string. """ function mlirBlockPrint(block, callback, userData) - @ccall mlir_c.mlirBlockPrint(block::MlirBlock, callback::MlirStringCallback, - userData::Ptr{Cvoid})::Cvoid + @ccall mlir_c.mlirBlockPrint( + block::MlirBlock, callback::MlirStringCallback, userData::Ptr{Cvoid} + )::Cvoid end """ @@ -1831,8 +1895,9 @@ end Prints a value by sending chunks of the string representation and forwarding `userData to `callback`. Note that the callback may be called several times with consecutive chunks of the string. """ function mlirValuePrint(value, callback, userData) - @ccall mlir_c.mlirValuePrint(value::MlirValue, callback::MlirStringCallback, - userData::Ptr{Cvoid})::Cvoid + @ccall mlir_c.mlirValuePrint( + value::MlirValue, callback::MlirStringCallback, userData::Ptr{Cvoid} + )::Cvoid end """ @@ -1841,9 +1906,12 @@ end Prints a value as an operand (i.e., the ValueID). """ function mlirValuePrintAsOperand(value, state, callback, userData) - @ccall mlir_c.mlirValuePrintAsOperand(value::MlirValue, state::MlirAsmState, - callback::MlirStringCallback, - userData::Ptr{Cvoid})::Cvoid + @ccall mlir_c.mlirValuePrintAsOperand( + value::MlirValue, + state::MlirAsmState, + callback::MlirStringCallback, + userData::Ptr{Cvoid}, + )::Cvoid end """ @@ -1969,8 +2037,9 @@ end Prints a location by sending chunks of the string representation and forwarding `userData to `callback`. Note that the callback may be called several times with consecutive chunks of the string. """ function mlirTypePrint(type, callback, userData) - @ccall mlir_c.mlirTypePrint(type::MlirType, callback::MlirStringCallback, - userData::Ptr{Cvoid})::Cvoid + @ccall mlir_c.mlirTypePrint( + type::MlirType, callback::MlirStringCallback, userData::Ptr{Cvoid} + )::Cvoid end """ @@ -1988,8 +2057,9 @@ end Parses an attribute. The attribute is owned by the context. """ function mlirAttributeParseGet(context, attr) - @ccall mlir_c.mlirAttributeParseGet(context::MlirContext, - attr::MlirStringRef)::MlirAttribute + @ccall mlir_c.mlirAttributeParseGet( + context::MlirContext, attr::MlirStringRef + )::MlirAttribute end """ @@ -2052,8 +2122,9 @@ end Prints an attribute by sending chunks of the string representation and forwarding `userData to `callback`. Note that the callback may be called several times with consecutive chunks of the string. """ function mlirAttributePrint(attr, callback, userData) - @ccall mlir_c.mlirAttributePrint(attr::MlirAttribute, callback::MlirStringCallback, - userData::Ptr{Cvoid})::Cvoid + @ccall mlir_c.mlirAttributePrint( + attr::MlirAttribute, callback::MlirStringCallback, userData::Ptr{Cvoid} + )::Cvoid end """ @@ -2071,8 +2142,9 @@ end Associates an attribute with the name. Takes ownership of neither. """ function mlirNamedAttributeGet(name, attr) - @ccall mlir_c.mlirNamedAttributeGet(name::MlirIdentifier, - attr::MlirAttribute)::MlirNamedAttribute + @ccall mlir_c.mlirNamedAttributeGet( + name::MlirIdentifier, attr::MlirAttribute + )::MlirNamedAttribute end """ @@ -2081,8 +2153,9 @@ end Gets an identifier with the given string value. """ function mlirIdentifierGet(context, str) - @ccall mlir_c.mlirIdentifierGet(context::MlirContext, - str::MlirStringRef)::MlirIdentifier + @ccall mlir_c.mlirIdentifierGet( + context::MlirContext, str::MlirStringRef + )::MlirIdentifier end """ @@ -2163,8 +2236,9 @@ end Looks up a symbol with the given name in the given symbol table and returns the operation that corresponds to the symbol. If the symbol cannot be found, returns a null operation. """ function mlirSymbolTableLookup(symbolTable, name) - @ccall mlir_c.mlirSymbolTableLookup(symbolTable::MlirSymbolTable, - name::MlirStringRef)::MlirOperation + @ccall mlir_c.mlirSymbolTableLookup( + symbolTable::MlirSymbolTable, name::MlirStringRef + )::MlirOperation end """ @@ -2173,8 +2247,9 @@ end Inserts the given operation into the given symbol table. The operation must have the symbol trait. If the symbol table already has a symbol with the same name, renames the symbol being inserted to ensure name uniqueness. Note that this does not move the operation itself into the block of the symbol table operation, this should be done separately. Returns the name of the symbol after insertion. """ function mlirSymbolTableInsert(symbolTable, operation) - @ccall mlir_c.mlirSymbolTableInsert(symbolTable::MlirSymbolTable, - operation::MlirOperation)::MlirAttribute + @ccall mlir_c.mlirSymbolTableInsert( + symbolTable::MlirSymbolTable, operation::MlirOperation + )::MlirAttribute end """ @@ -2183,8 +2258,9 @@ end Removes the given operation from the symbol table and erases it. """ function mlirSymbolTableErase(symbolTable, operation) - @ccall mlir_c.mlirSymbolTableErase(symbolTable::MlirSymbolTable, - operation::MlirOperation)::Cvoid + @ccall mlir_c.mlirSymbolTableErase( + symbolTable::MlirSymbolTable, operation::MlirOperation + )::Cvoid end """ @@ -2193,9 +2269,9 @@ end Attempt to replace all uses that are nested within the given operation of the given symbol 'oldSymbol' with the provided 'newSymbol'. This does not traverse into nested symbol tables. Will fail atomically if there are any unknown operations that may be potential symbol tables. """ function mlirSymbolTableReplaceAllSymbolUses(oldSymbol, newSymbol, from) - @ccall mlir_c.mlirSymbolTableReplaceAllSymbolUses(oldSymbol::MlirStringRef, - newSymbol::MlirStringRef, - from::MlirOperation)::MlirLogicalResult + @ccall mlir_c.mlirSymbolTableReplaceAllSymbolUses( + oldSymbol::MlirStringRef, newSymbol::MlirStringRef, from::MlirOperation + )::MlirLogicalResult end """ @@ -2204,10 +2280,12 @@ end Walks all symbol table operations nested within, and including, `op`. For each symbol table operation, the provided callback is invoked with the op and a boolean signifying if the symbols within that symbol table can be treated as if all uses within the IR are visible to the caller. `allSymUsesVisible` identifies whether all of the symbol uses of symbols within `op` are visible. """ function mlirSymbolTableWalkSymbolTables(from, allSymUsesVisible, callback, userData) - @ccall mlir_c.mlirSymbolTableWalkSymbolTables(from::MlirOperation, - allSymUsesVisible::Bool, - callback::Ptr{Cvoid}, - userData::Ptr{Cvoid})::Cvoid + @ccall mlir_c.mlirSymbolTableWalkSymbolTables( + from::MlirOperation, + allSymUsesVisible::Bool, + callback::Ptr{Cvoid}, + userData::Ptr{Cvoid}, + )::Cvoid end struct MlirAffineExpr @@ -2247,9 +2325,9 @@ end Prints an affine expression by sending chunks of the string representation and forwarding `userData to `callback`. Note that the callback may be called several times with consecutive chunks of the string. """ function mlirAffineExprPrint(affineExpr, callback, userData) - @ccall mlir_c.mlirAffineExprPrint(affineExpr::MlirAffineExpr, - callback::MlirStringCallback, - userData::Ptr{Cvoid})::Cvoid + @ccall mlir_c.mlirAffineExprPrint( + affineExpr::MlirAffineExpr, callback::MlirStringCallback, userData::Ptr{Cvoid} + )::Cvoid end """ @@ -2294,8 +2372,9 @@ end Checks whether the given affine expression is a multiple of 'factor'. """ function mlirAffineExprIsMultipleOf(affineExpr, factor) - @ccall mlir_c.mlirAffineExprIsMultipleOf(affineExpr::MlirAffineExpr, - factor::Int64)::Bool + @ccall mlir_c.mlirAffineExprIsMultipleOf( + affineExpr::MlirAffineExpr, factor::Int64 + )::Bool end """ @@ -2304,8 +2383,9 @@ end Checks whether the given affine expression involves AffineDimExpr 'position'. """ function mlirAffineExprIsFunctionOfDim(affineExpr, position) - @ccall mlir_c.mlirAffineExprIsFunctionOfDim(affineExpr::MlirAffineExpr, - position::intptr_t)::Bool + @ccall mlir_c.mlirAffineExprIsFunctionOfDim( + affineExpr::MlirAffineExpr, position::intptr_t + )::Bool end struct MlirAffineMap @@ -2318,8 +2398,9 @@ end Composes the given map with the given expression. """ function mlirAffineExprCompose(affineExpr, affineMap) - @ccall mlir_c.mlirAffineExprCompose(affineExpr::MlirAffineExpr, - affineMap::MlirAffineMap)::MlirAffineExpr + @ccall mlir_c.mlirAffineExprCompose( + affineExpr::MlirAffineExpr, affineMap::MlirAffineMap + )::MlirAffineExpr end """ @@ -2364,8 +2445,9 @@ end Creates an affine symbol expression with 'position' in the context. """ function mlirAffineSymbolExprGet(ctx, position) - @ccall mlir_c.mlirAffineSymbolExprGet(ctx::MlirContext, - position::intptr_t)::MlirAffineExpr + @ccall mlir_c.mlirAffineSymbolExprGet( + ctx::MlirContext, position::intptr_t + )::MlirAffineExpr end """ @@ -2392,8 +2474,9 @@ end Creates an affine constant expression with 'constant' in the context. """ function mlirAffineConstantExprGet(ctx, constant) - @ccall mlir_c.mlirAffineConstantExprGet(ctx::MlirContext, - constant::Int64)::MlirAffineExpr + @ccall mlir_c.mlirAffineConstantExprGet( + ctx::MlirContext, constant::Int64 + )::MlirAffineExpr end """ @@ -2420,8 +2503,9 @@ end Creates an affine add expression with 'lhs' and 'rhs'. """ function mlirAffineAddExprGet(lhs, rhs) - @ccall mlir_c.mlirAffineAddExprGet(lhs::MlirAffineExpr, - rhs::MlirAffineExpr)::MlirAffineExpr + @ccall mlir_c.mlirAffineAddExprGet( + lhs::MlirAffineExpr, rhs::MlirAffineExpr + )::MlirAffineExpr end """ @@ -2439,8 +2523,9 @@ end Creates an affine mul expression with 'lhs' and 'rhs'. """ function mlirAffineMulExprGet(lhs, rhs) - @ccall mlir_c.mlirAffineMulExprGet(lhs::MlirAffineExpr, - rhs::MlirAffineExpr)::MlirAffineExpr + @ccall mlir_c.mlirAffineMulExprGet( + lhs::MlirAffineExpr, rhs::MlirAffineExpr + )::MlirAffineExpr end """ @@ -2458,8 +2543,9 @@ end Creates an affine mod expression with 'lhs' and 'rhs'. """ function mlirAffineModExprGet(lhs, rhs) - @ccall mlir_c.mlirAffineModExprGet(lhs::MlirAffineExpr, - rhs::MlirAffineExpr)::MlirAffineExpr + @ccall mlir_c.mlirAffineModExprGet( + lhs::MlirAffineExpr, rhs::MlirAffineExpr + )::MlirAffineExpr end """ @@ -2477,8 +2563,9 @@ end Creates an affine floordiv expression with 'lhs' and 'rhs'. """ function mlirAffineFloorDivExprGet(lhs, rhs) - @ccall mlir_c.mlirAffineFloorDivExprGet(lhs::MlirAffineExpr, - rhs::MlirAffineExpr)::MlirAffineExpr + @ccall mlir_c.mlirAffineFloorDivExprGet( + lhs::MlirAffineExpr, rhs::MlirAffineExpr + )::MlirAffineExpr end """ @@ -2496,8 +2583,9 @@ end Creates an affine ceildiv expression with 'lhs' and 'rhs'. """ function mlirAffineCeilDivExprGet(lhs, rhs) - @ccall mlir_c.mlirAffineCeilDivExprGet(lhs::MlirAffineExpr, - rhs::MlirAffineExpr)::MlirAffineExpr + @ccall mlir_c.mlirAffineCeilDivExprGet( + lhs::MlirAffineExpr, rhs::MlirAffineExpr + )::MlirAffineExpr end """ @@ -2560,8 +2648,9 @@ end Prints an affine map by sending chunks of the string representation and forwarding `userData to `callback`. Note that the callback may be called several times with consecutive chunks of the string. """ function mlirAffineMapPrint(affineMap, callback, userData) - @ccall mlir_c.mlirAffineMapPrint(affineMap::MlirAffineMap, callback::MlirStringCallback, - userData::Ptr{Cvoid})::Cvoid + @ccall mlir_c.mlirAffineMapPrint( + affineMap::MlirAffineMap, callback::MlirStringCallback, userData::Ptr{Cvoid} + )::Cvoid end """ @@ -2588,8 +2677,9 @@ end Creates a zero result affine map of the given dimensions and symbols in the context. The affine map is owned by the context. """ function mlirAffineMapZeroResultGet(ctx, dimCount, symbolCount) - @ccall mlir_c.mlirAffineMapZeroResultGet(ctx::MlirContext, dimCount::intptr_t, - symbolCount::intptr_t)::MlirAffineMap + @ccall mlir_c.mlirAffineMapZeroResultGet( + ctx::MlirContext, dimCount::intptr_t, symbolCount::intptr_t + )::MlirAffineMap end """ @@ -2598,9 +2688,13 @@ end Creates an affine map with results defined by the given list of affine expressions. The map resulting map also has the requested number of input dimensions and symbols, regardless of them being used in the results. """ function mlirAffineMapGet(ctx, dimCount, symbolCount, nAffineExprs, affineExprs) - @ccall mlir_c.mlirAffineMapGet(ctx::MlirContext, dimCount::intptr_t, - symbolCount::intptr_t, nAffineExprs::intptr_t, - affineExprs::Ptr{MlirAffineExpr})::MlirAffineMap + @ccall mlir_c.mlirAffineMapGet( + ctx::MlirContext, + dimCount::intptr_t, + symbolCount::intptr_t, + nAffineExprs::intptr_t, + affineExprs::Ptr{MlirAffineExpr}, + )::MlirAffineMap end """ @@ -2618,8 +2712,9 @@ end Creates an affine map with 'numDims' identity in the context. The affine map is owned by the context. """ function mlirAffineMapMultiDimIdentityGet(ctx, numDims) - @ccall mlir_c.mlirAffineMapMultiDimIdentityGet(ctx::MlirContext, - numDims::intptr_t)::MlirAffineMap + @ccall mlir_c.mlirAffineMapMultiDimIdentityGet( + ctx::MlirContext, numDims::intptr_t + )::MlirAffineMap end """ @@ -2628,8 +2723,9 @@ end Creates an identity affine map on the most minor dimensions in the context. The affine map is owned by the context. The function asserts that the number of dimensions is greater or equal to the number of results. """ function mlirAffineMapMinorIdentityGet(ctx, dims, results) - @ccall mlir_c.mlirAffineMapMinorIdentityGet(ctx::MlirContext, dims::intptr_t, - results::intptr_t)::MlirAffineMap + @ccall mlir_c.mlirAffineMapMinorIdentityGet( + ctx::MlirContext, dims::intptr_t, results::intptr_t + )::MlirAffineMap end """ @@ -2638,8 +2734,9 @@ end Creates an affine map with a permutation expression and its size in the context. The permutation expression is a non-empty vector of integers. The elements of the permutation vector must be continuous from 0 and cannot be repeated (i.e. `[1,2,0]` is a valid permutation. `[2,0]` or `[1,1,2]` is an invalid permutation.) The affine map is owned by the context. """ function mlirAffineMapPermutationGet(ctx, size, permutation) - @ccall mlir_c.mlirAffineMapPermutationGet(ctx::MlirContext, size::intptr_t, - permutation::Ptr{Cuint})::MlirAffineMap + @ccall mlir_c.mlirAffineMapPermutationGet( + ctx::MlirContext, size::intptr_t, permutation::Ptr{Cuint} + )::MlirAffineMap end """ @@ -2720,8 +2817,9 @@ end Returns the result at the given position. """ function mlirAffineMapGetResult(affineMap, pos) - @ccall mlir_c.mlirAffineMapGetResult(affineMap::MlirAffineMap, - pos::intptr_t)::MlirAffineExpr + @ccall mlir_c.mlirAffineMapGetResult( + affineMap::MlirAffineMap, pos::intptr_t + )::MlirAffineExpr end """ @@ -2757,8 +2855,9 @@ end Returns the affine map consisting of the `resultPos` subset. """ function mlirAffineMapGetSubMap(affineMap, size, resultPos) - @ccall mlir_c.mlirAffineMapGetSubMap(affineMap::MlirAffineMap, size::intptr_t, - resultPos::Ptr{intptr_t})::MlirAffineMap + @ccall mlir_c.mlirAffineMapGetSubMap( + affineMap::MlirAffineMap, size::intptr_t, resultPos::Ptr{intptr_t} + )::MlirAffineMap end """ @@ -2767,8 +2866,9 @@ end Returns the affine map consisting of the most major `numResults` results. Returns the null AffineMap if the `numResults` is equal to zero. Returns the `affineMap` if `numResults` is greater or equals to number of results of the given affine map. """ function mlirAffineMapGetMajorSubMap(affineMap, numResults) - @ccall mlir_c.mlirAffineMapGetMajorSubMap(affineMap::MlirAffineMap, - numResults::intptr_t)::MlirAffineMap + @ccall mlir_c.mlirAffineMapGetMajorSubMap( + affineMap::MlirAffineMap, numResults::intptr_t + )::MlirAffineMap end """ @@ -2777,8 +2877,9 @@ end Returns the affine map consisting of the most minor `numResults` results. Returns the null AffineMap if the `numResults` is equal to zero. Returns the `affineMap` if `numResults` is greater or equals to number of results of the given affine map. """ function mlirAffineMapGetMinorSubMap(affineMap, numResults) - @ccall mlir_c.mlirAffineMapGetMinorSubMap(affineMap::MlirAffineMap, - numResults::intptr_t)::MlirAffineMap + @ccall mlir_c.mlirAffineMapGetMinorSubMap( + affineMap::MlirAffineMap, numResults::intptr_t + )::MlirAffineMap end """ @@ -2786,11 +2887,16 @@ end Apply AffineExpr::replace(`map`) to each of the results and return a new new AffineMap with the new results and the specified number of dims and symbols. """ -function mlirAffineMapReplace(affineMap, expression, replacement, numResultDims, - numResultSyms) - @ccall mlir_c.mlirAffineMapReplace(affineMap::MlirAffineMap, expression::MlirAffineExpr, - replacement::MlirAffineExpr, numResultDims::intptr_t, - numResultSyms::intptr_t)::MlirAffineMap +function mlirAffineMapReplace( + affineMap, expression, replacement, numResultDims, numResultSyms +) + @ccall mlir_c.mlirAffineMapReplace( + affineMap::MlirAffineMap, + expression::MlirAffineExpr, + replacement::MlirAffineExpr, + numResultDims::intptr_t, + numResultSyms::intptr_t, + )::MlirAffineMap end """ @@ -2799,9 +2905,12 @@ end Returns the simplified affine map resulting from dropping the symbols that do not appear in any of the individual maps in `affineMaps`. Asserts that all maps in `affineMaps` are normalized to the same number of dims and symbols. Takes a callback `populateResult` to fill the `res` container with value `m` at entry `idx`. This allows returning without worrying about ownership considerations. """ function mlirAffineMapCompressUnusedSymbols(affineMaps, size, result, populateResult) - @ccall mlir_c.mlirAffineMapCompressUnusedSymbols(affineMaps::Ptr{MlirAffineMap}, - size::intptr_t, result::Ptr{Cvoid}, - populateResult::Ptr{Cvoid})::Cvoid + @ccall mlir_c.mlirAffineMapCompressUnusedSymbols( + affineMaps::Ptr{MlirAffineMap}, + size::intptr_t, + result::Ptr{Cvoid}, + populateResult::Ptr{Cvoid}, + )::Cvoid end """ @@ -2868,8 +2977,9 @@ end Creates an array element containing the given list of elements in the given context. """ function mlirArrayAttrGet(ctx, numElements, elements) - @ccall mlir_c.mlirArrayAttrGet(ctx::MlirContext, numElements::intptr_t, - elements::Ptr{MlirAttribute})::MlirAttribute + @ccall mlir_c.mlirArrayAttrGet( + ctx::MlirContext, numElements::intptr_t, elements::Ptr{MlirAttribute} + )::MlirAttribute end """ @@ -2914,8 +3024,9 @@ end Creates a dictionary attribute containing the given list of elements in the provided context. """ function mlirDictionaryAttrGet(ctx, numElements, elements) - @ccall mlir_c.mlirDictionaryAttrGet(ctx::MlirContext, numElements::intptr_t, - elements::Ptr{MlirNamedAttribute})::MlirAttribute + @ccall mlir_c.mlirDictionaryAttrGet( + ctx::MlirContext, numElements::intptr_t, elements::Ptr{MlirNamedAttribute} + )::MlirAttribute end """ @@ -2933,8 +3044,9 @@ end Returns pos-th element of the given dictionary attribute. """ function mlirDictionaryAttrGetElement(attr, pos) - @ccall mlir_c.mlirDictionaryAttrGetElement(attr::MlirAttribute, - pos::intptr_t)::MlirNamedAttribute + @ccall mlir_c.mlirDictionaryAttrGetElement( + attr::MlirAttribute, pos::intptr_t + )::MlirNamedAttribute end """ @@ -2943,8 +3055,9 @@ end Returns the dictionary attribute element with the given name or NULL if the given name does not exist in the dictionary. """ function mlirDictionaryAttrGetElementByName(attr, name) - @ccall mlir_c.mlirDictionaryAttrGetElementByName(attr::MlirAttribute, - name::MlirStringRef)::MlirAttribute + @ccall mlir_c.mlirDictionaryAttrGetElementByName( + attr::MlirAttribute, name::MlirStringRef + )::MlirAttribute end """ @@ -2971,8 +3084,9 @@ end Creates a floating point attribute in the given context with the given double value and double-precision FP semantics. """ function mlirFloatAttrDoubleGet(ctx, type, value) - @ccall mlir_c.mlirFloatAttrDoubleGet(ctx::MlirContext, type::MlirType, - value::Cdouble)::MlirAttribute + @ccall mlir_c.mlirFloatAttrDoubleGet( + ctx::MlirContext, type::MlirType, value::Cdouble + )::MlirAttribute end """ @@ -2981,8 +3095,9 @@ end Same as "[`mlirFloatAttrDoubleGet`](@ref)", but if the type is not valid for a construction of a FloatAttr, returns a null [`MlirAttribute`](@ref). """ function mlirFloatAttrDoubleGetChecked(loc, type, value) - @ccall mlir_c.mlirFloatAttrDoubleGetChecked(loc::MlirLocation, type::MlirType, - value::Cdouble)::MlirAttribute + @ccall mlir_c.mlirFloatAttrDoubleGetChecked( + loc::MlirLocation, type::MlirType, value::Cdouble + )::MlirAttribute end """ @@ -3117,9 +3232,13 @@ end Creates an opaque attribute in the given context associated with the dialect identified by its namespace. The attribute contains opaque byte data of the specified length (data need not be null-terminated). """ function mlirOpaqueAttrGet(ctx, dialectNamespace, dataLength, data, type) - @ccall mlir_c.mlirOpaqueAttrGet(ctx::MlirContext, dialectNamespace::MlirStringRef, - dataLength::intptr_t, data::Cstring, - type::MlirType)::MlirAttribute + @ccall mlir_c.mlirOpaqueAttrGet( + ctx::MlirContext, + dialectNamespace::MlirStringRef, + dataLength::intptr_t, + data::Cstring, + type::MlirType, + )::MlirAttribute end """ @@ -3209,9 +3328,12 @@ end Creates a symbol reference attribute in the given context referencing a symbol identified by the given string inside a list of nested references. Each of the references in the list must not be nested. """ function mlirSymbolRefAttrGet(ctx, symbol, numReferences, references) - @ccall mlir_c.mlirSymbolRefAttrGet(ctx::MlirContext, symbol::MlirStringRef, - numReferences::intptr_t, - references::Ptr{MlirAttribute})::MlirAttribute + @ccall mlir_c.mlirSymbolRefAttrGet( + ctx::MlirContext, + symbol::MlirStringRef, + numReferences::intptr_t, + references::Ptr{MlirAttribute}, + )::MlirAttribute end """ @@ -3247,8 +3369,9 @@ end Returns pos-th reference nested in the given symbol reference attribute. """ function mlirSymbolRefAttrGetNestedReference(attr, pos) - @ccall mlir_c.mlirSymbolRefAttrGetNestedReference(attr::MlirAttribute, - pos::intptr_t)::MlirAttribute + @ccall mlir_c.mlirSymbolRefAttrGetNestedReference( + attr::MlirAttribute, pos::intptr_t + )::MlirAttribute end """ @@ -3284,8 +3407,9 @@ end Creates a flat symbol reference attribute in the given context referencing a symbol identified by the given string. """ function mlirFlatSymbolRefAttrGet(ctx, symbol) - @ccall mlir_c.mlirFlatSymbolRefAttrGet(ctx::MlirContext, - symbol::MlirStringRef)::MlirAttribute + @ccall mlir_c.mlirFlatSymbolRefAttrGet( + ctx::MlirContext, symbol::MlirStringRef + )::MlirAttribute end """ @@ -3375,8 +3499,9 @@ end Returns the element at the given rank-dimensional index. """ function mlirElementsAttrGetValue(attr, rank, idxs) - @ccall mlir_c.mlirElementsAttrGetValue(attr::MlirAttribute, rank::intptr_t, - idxs::Ptr{UInt64})::MlirAttribute + @ccall mlir_c.mlirElementsAttrGetValue( + attr::MlirAttribute, rank::intptr_t, idxs::Ptr{UInt64} + )::MlirAttribute end """ @@ -3385,8 +3510,9 @@ end Checks whether the given rank-dimensional index is valid in the given elements attribute. """ function mlirElementsAttrIsValidIndex(attr, rank, idxs) - @ccall mlir_c.mlirElementsAttrIsValidIndex(attr::MlirAttribute, rank::intptr_t, - idxs::Ptr{UInt64})::Bool + @ccall mlir_c.mlirElementsAttrIsValidIndex( + attr::MlirAttribute, rank::intptr_t, idxs::Ptr{UInt64} + )::Bool end """ @@ -3441,38 +3567,45 @@ end Create a dense array attribute with the given elements. """ function mlirDenseBoolArrayGet(ctx, size, values) - @ccall mlir_c.mlirDenseBoolArrayGet(ctx::MlirContext, size::intptr_t, - values::Ptr{Cint})::MlirAttribute + @ccall mlir_c.mlirDenseBoolArrayGet( + ctx::MlirContext, size::intptr_t, values::Ptr{Cint} + )::MlirAttribute end function mlirDenseI8ArrayGet(ctx, size, values) - @ccall mlir_c.mlirDenseI8ArrayGet(ctx::MlirContext, size::intptr_t, - values::Ptr{Int8})::MlirAttribute + @ccall mlir_c.mlirDenseI8ArrayGet( + ctx::MlirContext, size::intptr_t, values::Ptr{Int8} + )::MlirAttribute end function mlirDenseI16ArrayGet(ctx, size, values) - @ccall mlir_c.mlirDenseI16ArrayGet(ctx::MlirContext, size::intptr_t, - values::Ptr{Int16})::MlirAttribute + @ccall mlir_c.mlirDenseI16ArrayGet( + ctx::MlirContext, size::intptr_t, values::Ptr{Int16} + )::MlirAttribute end function mlirDenseI32ArrayGet(ctx, size, values) - @ccall mlir_c.mlirDenseI32ArrayGet(ctx::MlirContext, size::intptr_t, - values::Ptr{Int32})::MlirAttribute + @ccall mlir_c.mlirDenseI32ArrayGet( + ctx::MlirContext, size::intptr_t, values::Ptr{Int32} + )::MlirAttribute end function mlirDenseI64ArrayGet(ctx, size, values) - @ccall mlir_c.mlirDenseI64ArrayGet(ctx::MlirContext, size::intptr_t, - values::Ptr{Int64})::MlirAttribute + @ccall mlir_c.mlirDenseI64ArrayGet( + ctx::MlirContext, size::intptr_t, values::Ptr{Int64} + )::MlirAttribute end function mlirDenseF32ArrayGet(ctx, size, values) - @ccall mlir_c.mlirDenseF32ArrayGet(ctx::MlirContext, size::intptr_t, - values::Ptr{Cfloat})::MlirAttribute + @ccall mlir_c.mlirDenseF32ArrayGet( + ctx::MlirContext, size::intptr_t, values::Ptr{Cfloat} + )::MlirAttribute end function mlirDenseF64ArrayGet(ctx, size, values) - @ccall mlir_c.mlirDenseF64ArrayGet(ctx::MlirContext, size::intptr_t, - values::Ptr{Cdouble})::MlirAttribute + @ccall mlir_c.mlirDenseF64ArrayGet( + ctx::MlirContext, size::intptr_t, values::Ptr{Cdouble} + )::MlirAttribute end """ @@ -3549,8 +3682,9 @@ end Creates a dense elements attribute with the given Shaped type and elements in the same context as the type. """ function mlirDenseElementsAttrGet(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrGet(shapedType::MlirType, numElements::intptr_t, - elements::Ptr{MlirAttribute})::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrGet( + shapedType::MlirType, numElements::intptr_t, elements::Ptr{MlirAttribute} + )::MlirAttribute end """ @@ -3563,9 +3697,9 @@ The format of the raw buffer is a densely packed array of values that can be bit A raw buffer of a single element (or for 1-bit, a byte of value 0 or 255) will be interpreted as a splat. User code should be prepared for additional, conformant patterns to be identified as splats in the future. """ function mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, rawBuffer) - @ccall mlir_c.mlirDenseElementsAttrRawBufferGet(shapedType::MlirType, - rawBufferSize::Csize_t, - rawBuffer::Ptr{Cvoid})::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrRawBufferGet( + shapedType::MlirType, rawBufferSize::Csize_t, rawBuffer::Ptr{Cvoid} + )::MlirAttribute end """ @@ -3574,53 +3708,63 @@ end Creates a dense elements attribute with the given Shaped type containing a single replicated element (splat). """ function mlirDenseElementsAttrSplatGet(shapedType, element) - @ccall mlir_c.mlirDenseElementsAttrSplatGet(shapedType::MlirType, - element::MlirAttribute)::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrSplatGet( + shapedType::MlirType, element::MlirAttribute + )::MlirAttribute end function mlirDenseElementsAttrBoolSplatGet(shapedType, element) - @ccall mlir_c.mlirDenseElementsAttrBoolSplatGet(shapedType::MlirType, - element::Bool)::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrBoolSplatGet( + shapedType::MlirType, element::Bool + )::MlirAttribute end function mlirDenseElementsAttrUInt8SplatGet(shapedType, element) - @ccall mlir_c.mlirDenseElementsAttrUInt8SplatGet(shapedType::MlirType, - element::UInt8)::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrUInt8SplatGet( + shapedType::MlirType, element::UInt8 + )::MlirAttribute end function mlirDenseElementsAttrInt8SplatGet(shapedType, element) - @ccall mlir_c.mlirDenseElementsAttrInt8SplatGet(shapedType::MlirType, - element::Int8)::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrInt8SplatGet( + shapedType::MlirType, element::Int8 + )::MlirAttribute end function mlirDenseElementsAttrUInt32SplatGet(shapedType, element) - @ccall mlir_c.mlirDenseElementsAttrUInt32SplatGet(shapedType::MlirType, - element::UInt32)::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrUInt32SplatGet( + shapedType::MlirType, element::UInt32 + )::MlirAttribute end function mlirDenseElementsAttrInt32SplatGet(shapedType, element) - @ccall mlir_c.mlirDenseElementsAttrInt32SplatGet(shapedType::MlirType, - element::Int32)::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrInt32SplatGet( + shapedType::MlirType, element::Int32 + )::MlirAttribute end function mlirDenseElementsAttrUInt64SplatGet(shapedType, element) - @ccall mlir_c.mlirDenseElementsAttrUInt64SplatGet(shapedType::MlirType, - element::UInt64)::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrUInt64SplatGet( + shapedType::MlirType, element::UInt64 + )::MlirAttribute end function mlirDenseElementsAttrInt64SplatGet(shapedType, element) - @ccall mlir_c.mlirDenseElementsAttrInt64SplatGet(shapedType::MlirType, - element::Int64)::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrInt64SplatGet( + shapedType::MlirType, element::Int64 + )::MlirAttribute end function mlirDenseElementsAttrFloatSplatGet(shapedType, element) - @ccall mlir_c.mlirDenseElementsAttrFloatSplatGet(shapedType::MlirType, - element::Cfloat)::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrFloatSplatGet( + shapedType::MlirType, element::Cfloat + )::MlirAttribute end function mlirDenseElementsAttrDoubleSplatGet(shapedType, element) - @ccall mlir_c.mlirDenseElementsAttrDoubleSplatGet(shapedType::MlirType, - element::Cdouble)::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrDoubleSplatGet( + shapedType::MlirType, element::Cdouble + )::MlirAttribute end """ @@ -3629,74 +3773,81 @@ end Creates a dense elements attribute with the given shaped type from elements of a specific type. Expects the element type of the shaped type to match the data element type. """ function mlirDenseElementsAttrBoolGet(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrBoolGet(shapedType::MlirType, numElements::intptr_t, - elements::Ptr{Cint})::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrBoolGet( + shapedType::MlirType, numElements::intptr_t, elements::Ptr{Cint} + )::MlirAttribute end function mlirDenseElementsAttrUInt8Get(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrUInt8Get(shapedType::MlirType, numElements::intptr_t, - elements::Ptr{UInt8})::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrUInt8Get( + shapedType::MlirType, numElements::intptr_t, elements::Ptr{UInt8} + )::MlirAttribute end function mlirDenseElementsAttrInt8Get(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrInt8Get(shapedType::MlirType, numElements::intptr_t, - elements::Ptr{Int8})::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrInt8Get( + shapedType::MlirType, numElements::intptr_t, elements::Ptr{Int8} + )::MlirAttribute end function mlirDenseElementsAttrUInt16Get(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrUInt16Get(shapedType::MlirType, - numElements::intptr_t, - elements::Ptr{UInt16})::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrUInt16Get( + shapedType::MlirType, numElements::intptr_t, elements::Ptr{UInt16} + )::MlirAttribute end function mlirDenseElementsAttrInt16Get(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrInt16Get(shapedType::MlirType, numElements::intptr_t, - elements::Ptr{Int16})::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrInt16Get( + shapedType::MlirType, numElements::intptr_t, elements::Ptr{Int16} + )::MlirAttribute end function mlirDenseElementsAttrUInt32Get(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrUInt32Get(shapedType::MlirType, - numElements::intptr_t, - elements::Ptr{UInt32})::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrUInt32Get( + shapedType::MlirType, numElements::intptr_t, elements::Ptr{UInt32} + )::MlirAttribute end function mlirDenseElementsAttrInt32Get(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrInt32Get(shapedType::MlirType, numElements::intptr_t, - elements::Ptr{Int32})::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrInt32Get( + shapedType::MlirType, numElements::intptr_t, elements::Ptr{Int32} + )::MlirAttribute end function mlirDenseElementsAttrUInt64Get(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrUInt64Get(shapedType::MlirType, - numElements::intptr_t, - elements::Ptr{UInt64})::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrUInt64Get( + shapedType::MlirType, numElements::intptr_t, elements::Ptr{UInt64} + )::MlirAttribute end function mlirDenseElementsAttrInt64Get(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrInt64Get(shapedType::MlirType, numElements::intptr_t, - elements::Ptr{Int64})::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrInt64Get( + shapedType::MlirType, numElements::intptr_t, elements::Ptr{Int64} + )::MlirAttribute end function mlirDenseElementsAttrFloatGet(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrFloatGet(shapedType::MlirType, numElements::intptr_t, - elements::Ptr{Cfloat})::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrFloatGet( + shapedType::MlirType, numElements::intptr_t, elements::Ptr{Cfloat} + )::MlirAttribute end function mlirDenseElementsAttrDoubleGet(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrDoubleGet(shapedType::MlirType, - numElements::intptr_t, - elements::Ptr{Cdouble})::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrDoubleGet( + shapedType::MlirType, numElements::intptr_t, elements::Ptr{Cdouble} + )::MlirAttribute end function mlirDenseElementsAttrBFloat16Get(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrBFloat16Get(shapedType::MlirType, - numElements::intptr_t, - elements::Ptr{UInt16})::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrBFloat16Get( + shapedType::MlirType, numElements::intptr_t, elements::Ptr{UInt16} + )::MlirAttribute end function mlirDenseElementsAttrFloat16Get(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrFloat16Get(shapedType::MlirType, - numElements::intptr_t, - elements::Ptr{UInt16})::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrFloat16Get( + shapedType::MlirType, numElements::intptr_t, elements::Ptr{UInt16} + )::MlirAttribute end """ @@ -3705,9 +3856,9 @@ end Creates a dense elements attribute with the given shaped type from string elements. """ function mlirDenseElementsAttrStringGet(shapedType, numElements, strs) - @ccall mlir_c.mlirDenseElementsAttrStringGet(shapedType::MlirType, - numElements::intptr_t, - strs::Ptr{MlirStringRef})::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrStringGet( + shapedType::MlirType, numElements::intptr_t, strs::Ptr{MlirStringRef} + )::MlirAttribute end """ @@ -3716,8 +3867,9 @@ end Creates a dense elements attribute that has the same data as the given dense elements attribute and a different shaped type. The new type must have the same total number of elements. """ function mlirDenseElementsAttrReshapeGet(attr, shapedType) - @ccall mlir_c.mlirDenseElementsAttrReshapeGet(attr::MlirAttribute, - shapedType::MlirType)::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrReshapeGet( + attr::MlirAttribute, shapedType::MlirType + )::MlirAttribute end """ @@ -3775,7 +3927,9 @@ function mlirDenseElementsAttrGetDoubleSplatValue(attr) end function mlirDenseElementsAttrGetStringSplatValue(attr) - @ccall mlir_c.mlirDenseElementsAttrGetStringSplatValue(attr::MlirAttribute)::MlirStringRef + @ccall mlir_c.mlirDenseElementsAttrGetStringSplatValue( + attr::MlirAttribute + )::MlirStringRef end """ @@ -3784,63 +3938,75 @@ end Returns the pos-th value (flat contiguous indexing) of a specific type contained by the given dense elements attribute. """ function mlirDenseElementsAttrGetBoolValue(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetBoolValue(attr::MlirAttribute, - pos::intptr_t)::Bool + @ccall mlir_c.mlirDenseElementsAttrGetBoolValue( + attr::MlirAttribute, pos::intptr_t + )::Bool end function mlirDenseElementsAttrGetInt8Value(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetInt8Value(attr::MlirAttribute, - pos::intptr_t)::Int8 + @ccall mlir_c.mlirDenseElementsAttrGetInt8Value( + attr::MlirAttribute, pos::intptr_t + )::Int8 end function mlirDenseElementsAttrGetUInt8Value(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetUInt8Value(attr::MlirAttribute, - pos::intptr_t)::UInt8 + @ccall mlir_c.mlirDenseElementsAttrGetUInt8Value( + attr::MlirAttribute, pos::intptr_t + )::UInt8 end function mlirDenseElementsAttrGetInt16Value(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetInt16Value(attr::MlirAttribute, - pos::intptr_t)::Int16 + @ccall mlir_c.mlirDenseElementsAttrGetInt16Value( + attr::MlirAttribute, pos::intptr_t + )::Int16 end function mlirDenseElementsAttrGetUInt16Value(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetUInt16Value(attr::MlirAttribute, - pos::intptr_t)::UInt16 + @ccall mlir_c.mlirDenseElementsAttrGetUInt16Value( + attr::MlirAttribute, pos::intptr_t + )::UInt16 end function mlirDenseElementsAttrGetInt32Value(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetInt32Value(attr::MlirAttribute, - pos::intptr_t)::Int32 + @ccall mlir_c.mlirDenseElementsAttrGetInt32Value( + attr::MlirAttribute, pos::intptr_t + )::Int32 end function mlirDenseElementsAttrGetUInt32Value(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetUInt32Value(attr::MlirAttribute, - pos::intptr_t)::UInt32 + @ccall mlir_c.mlirDenseElementsAttrGetUInt32Value( + attr::MlirAttribute, pos::intptr_t + )::UInt32 end function mlirDenseElementsAttrGetInt64Value(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetInt64Value(attr::MlirAttribute, - pos::intptr_t)::Int64 + @ccall mlir_c.mlirDenseElementsAttrGetInt64Value( + attr::MlirAttribute, pos::intptr_t + )::Int64 end function mlirDenseElementsAttrGetUInt64Value(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetUInt64Value(attr::MlirAttribute, - pos::intptr_t)::UInt64 + @ccall mlir_c.mlirDenseElementsAttrGetUInt64Value( + attr::MlirAttribute, pos::intptr_t + )::UInt64 end function mlirDenseElementsAttrGetFloatValue(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetFloatValue(attr::MlirAttribute, - pos::intptr_t)::Cfloat + @ccall mlir_c.mlirDenseElementsAttrGetFloatValue( + attr::MlirAttribute, pos::intptr_t + )::Cfloat end function mlirDenseElementsAttrGetDoubleValue(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetDoubleValue(attr::MlirAttribute, - pos::intptr_t)::Cdouble + @ccall mlir_c.mlirDenseElementsAttrGetDoubleValue( + attr::MlirAttribute, pos::intptr_t + )::Cdouble end function mlirDenseElementsAttrGetStringValue(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetStringValue(attr::MlirAttribute, - pos::intptr_t)::MlirStringRef + @ccall mlir_c.mlirDenseElementsAttrGetStringValue( + attr::MlirAttribute, pos::intptr_t + )::MlirStringRef end """ @@ -3861,105 +4027,140 @@ end Unlike the typed accessors below, constructs the attribute with a raw data buffer and no type/alignment checking. Use a more strongly typed accessor if possible. If dataIsMutable is false, then an immutable AsmResourceBlob will be created and that passed data contents will be treated as const. If the deleter is non NULL, then it will be called when the data buffer can no longer be accessed (passing userData to it). """ -function mlirUnmanagedDenseResourceElementsAttrGet(shapedType, name, data, dataLength, - dataAlignment, dataIsMutable, deleter, - userData) - @ccall mlir_c.mlirUnmanagedDenseResourceElementsAttrGet(shapedType::MlirType, - name::MlirStringRef, - data::Ptr{Cvoid}, - dataLength::Csize_t, - dataAlignment::Csize_t, - dataIsMutable::Bool, - deleter::Ptr{Cvoid}, - userData::Ptr{Cvoid})::MlirAttribute -end - -function mlirUnmanagedDenseBoolResourceElementsAttrGet(shapedType, name, numElements, - elements) - @ccall mlir_c.mlirUnmanagedDenseBoolResourceElementsAttrGet(shapedType::MlirType, - name::MlirStringRef, - numElements::intptr_t, - elements::Ptr{Cint})::MlirAttribute -end - -function mlirUnmanagedDenseUInt8ResourceElementsAttrGet(shapedType, name, numElements, - elements) - @ccall mlir_c.mlirUnmanagedDenseUInt8ResourceElementsAttrGet(shapedType::MlirType, - name::MlirStringRef, - numElements::intptr_t, - elements::Ptr{UInt8})::MlirAttribute -end - -function mlirUnmanagedDenseInt8ResourceElementsAttrGet(shapedType, name, numElements, - elements) - @ccall mlir_c.mlirUnmanagedDenseInt8ResourceElementsAttrGet(shapedType::MlirType, - name::MlirStringRef, - numElements::intptr_t, - elements::Ptr{Int8})::MlirAttribute -end - -function mlirUnmanagedDenseUInt16ResourceElementsAttrGet(shapedType, name, numElements, - elements) - @ccall mlir_c.mlirUnmanagedDenseUInt16ResourceElementsAttrGet(shapedType::MlirType, - name::MlirStringRef, - numElements::intptr_t, - elements::Ptr{UInt16})::MlirAttribute -end - -function mlirUnmanagedDenseInt16ResourceElementsAttrGet(shapedType, name, numElements, - elements) - @ccall mlir_c.mlirUnmanagedDenseInt16ResourceElementsAttrGet(shapedType::MlirType, - name::MlirStringRef, - numElements::intptr_t, - elements::Ptr{Int16})::MlirAttribute -end - -function mlirUnmanagedDenseUInt32ResourceElementsAttrGet(shapedType, name, numElements, - elements) - @ccall mlir_c.mlirUnmanagedDenseUInt32ResourceElementsAttrGet(shapedType::MlirType, - name::MlirStringRef, - numElements::intptr_t, - elements::Ptr{UInt32})::MlirAttribute -end - -function mlirUnmanagedDenseInt32ResourceElementsAttrGet(shapedType, name, numElements, - elements) - @ccall mlir_c.mlirUnmanagedDenseInt32ResourceElementsAttrGet(shapedType::MlirType, - name::MlirStringRef, - numElements::intptr_t, - elements::Ptr{Int32})::MlirAttribute -end - -function mlirUnmanagedDenseUInt64ResourceElementsAttrGet(shapedType, name, numElements, - elements) - @ccall mlir_c.mlirUnmanagedDenseUInt64ResourceElementsAttrGet(shapedType::MlirType, - name::MlirStringRef, - numElements::intptr_t, - elements::Ptr{UInt64})::MlirAttribute -end - -function mlirUnmanagedDenseInt64ResourceElementsAttrGet(shapedType, name, numElements, - elements) - @ccall mlir_c.mlirUnmanagedDenseInt64ResourceElementsAttrGet(shapedType::MlirType, - name::MlirStringRef, - numElements::intptr_t, - elements::Ptr{Int64})::MlirAttribute -end - -function mlirUnmanagedDenseFloatResourceElementsAttrGet(shapedType, name, numElements, - elements) - @ccall mlir_c.mlirUnmanagedDenseFloatResourceElementsAttrGet(shapedType::MlirType, - name::MlirStringRef, - numElements::intptr_t, - elements::Ptr{Cfloat})::MlirAttribute -end - -function mlirUnmanagedDenseDoubleResourceElementsAttrGet(shapedType, name, numElements, - elements) - @ccall mlir_c.mlirUnmanagedDenseDoubleResourceElementsAttrGet(shapedType::MlirType, - name::MlirStringRef, - numElements::intptr_t, - elements::Ptr{Cdouble})::MlirAttribute +function mlirUnmanagedDenseResourceElementsAttrGet( + shapedType, name, data, dataLength, dataAlignment, dataIsMutable, deleter, userData +) + @ccall mlir_c.mlirUnmanagedDenseResourceElementsAttrGet( + shapedType::MlirType, + name::MlirStringRef, + data::Ptr{Cvoid}, + dataLength::Csize_t, + dataAlignment::Csize_t, + dataIsMutable::Bool, + deleter::Ptr{Cvoid}, + userData::Ptr{Cvoid}, + )::MlirAttribute +end + +function mlirUnmanagedDenseBoolResourceElementsAttrGet( + shapedType, name, numElements, elements +) + @ccall mlir_c.mlirUnmanagedDenseBoolResourceElementsAttrGet( + shapedType::MlirType, + name::MlirStringRef, + numElements::intptr_t, + elements::Ptr{Cint}, + )::MlirAttribute +end + +function mlirUnmanagedDenseUInt8ResourceElementsAttrGet( + shapedType, name, numElements, elements +) + @ccall mlir_c.mlirUnmanagedDenseUInt8ResourceElementsAttrGet( + shapedType::MlirType, + name::MlirStringRef, + numElements::intptr_t, + elements::Ptr{UInt8}, + )::MlirAttribute +end + +function mlirUnmanagedDenseInt8ResourceElementsAttrGet( + shapedType, name, numElements, elements +) + @ccall mlir_c.mlirUnmanagedDenseInt8ResourceElementsAttrGet( + shapedType::MlirType, + name::MlirStringRef, + numElements::intptr_t, + elements::Ptr{Int8}, + )::MlirAttribute +end + +function mlirUnmanagedDenseUInt16ResourceElementsAttrGet( + shapedType, name, numElements, elements +) + @ccall mlir_c.mlirUnmanagedDenseUInt16ResourceElementsAttrGet( + shapedType::MlirType, + name::MlirStringRef, + numElements::intptr_t, + elements::Ptr{UInt16}, + )::MlirAttribute +end + +function mlirUnmanagedDenseInt16ResourceElementsAttrGet( + shapedType, name, numElements, elements +) + @ccall mlir_c.mlirUnmanagedDenseInt16ResourceElementsAttrGet( + shapedType::MlirType, + name::MlirStringRef, + numElements::intptr_t, + elements::Ptr{Int16}, + )::MlirAttribute +end + +function mlirUnmanagedDenseUInt32ResourceElementsAttrGet( + shapedType, name, numElements, elements +) + @ccall mlir_c.mlirUnmanagedDenseUInt32ResourceElementsAttrGet( + shapedType::MlirType, + name::MlirStringRef, + numElements::intptr_t, + elements::Ptr{UInt32}, + )::MlirAttribute +end + +function mlirUnmanagedDenseInt32ResourceElementsAttrGet( + shapedType, name, numElements, elements +) + @ccall mlir_c.mlirUnmanagedDenseInt32ResourceElementsAttrGet( + shapedType::MlirType, + name::MlirStringRef, + numElements::intptr_t, + elements::Ptr{Int32}, + )::MlirAttribute +end + +function mlirUnmanagedDenseUInt64ResourceElementsAttrGet( + shapedType, name, numElements, elements +) + @ccall mlir_c.mlirUnmanagedDenseUInt64ResourceElementsAttrGet( + shapedType::MlirType, + name::MlirStringRef, + numElements::intptr_t, + elements::Ptr{UInt64}, + )::MlirAttribute +end + +function mlirUnmanagedDenseInt64ResourceElementsAttrGet( + shapedType, name, numElements, elements +) + @ccall mlir_c.mlirUnmanagedDenseInt64ResourceElementsAttrGet( + shapedType::MlirType, + name::MlirStringRef, + numElements::intptr_t, + elements::Ptr{Int64}, + )::MlirAttribute +end + +function mlirUnmanagedDenseFloatResourceElementsAttrGet( + shapedType, name, numElements, elements +) + @ccall mlir_c.mlirUnmanagedDenseFloatResourceElementsAttrGet( + shapedType::MlirType, + name::MlirStringRef, + numElements::intptr_t, + elements::Ptr{Cfloat}, + )::MlirAttribute +end + +function mlirUnmanagedDenseDoubleResourceElementsAttrGet( + shapedType, name, numElements, elements +) + @ccall mlir_c.mlirUnmanagedDenseDoubleResourceElementsAttrGet( + shapedType::MlirType, + name::MlirStringRef, + numElements::intptr_t, + elements::Ptr{Cdouble}, + )::MlirAttribute end """ @@ -3968,58 +4169,69 @@ end Returns the pos-th value (flat contiguous indexing) of a specific type contained by the given dense resource elements attribute. """ function mlirDenseBoolResourceElementsAttrGetValue(attr, pos) - @ccall mlir_c.mlirDenseBoolResourceElementsAttrGetValue(attr::MlirAttribute, - pos::intptr_t)::Bool + @ccall mlir_c.mlirDenseBoolResourceElementsAttrGetValue( + attr::MlirAttribute, pos::intptr_t + )::Bool end function mlirDenseInt8ResourceElementsAttrGetValue(attr, pos) - @ccall mlir_c.mlirDenseInt8ResourceElementsAttrGetValue(attr::MlirAttribute, - pos::intptr_t)::Int8 + @ccall mlir_c.mlirDenseInt8ResourceElementsAttrGetValue( + attr::MlirAttribute, pos::intptr_t + )::Int8 end function mlirDenseUInt8ResourceElementsAttrGetValue(attr, pos) - @ccall mlir_c.mlirDenseUInt8ResourceElementsAttrGetValue(attr::MlirAttribute, - pos::intptr_t)::UInt8 + @ccall mlir_c.mlirDenseUInt8ResourceElementsAttrGetValue( + attr::MlirAttribute, pos::intptr_t + )::UInt8 end function mlirDenseInt16ResourceElementsAttrGetValue(attr, pos) - @ccall mlir_c.mlirDenseInt16ResourceElementsAttrGetValue(attr::MlirAttribute, - pos::intptr_t)::Int16 + @ccall mlir_c.mlirDenseInt16ResourceElementsAttrGetValue( + attr::MlirAttribute, pos::intptr_t + )::Int16 end function mlirDenseUInt16ResourceElementsAttrGetValue(attr, pos) - @ccall mlir_c.mlirDenseUInt16ResourceElementsAttrGetValue(attr::MlirAttribute, - pos::intptr_t)::UInt16 + @ccall mlir_c.mlirDenseUInt16ResourceElementsAttrGetValue( + attr::MlirAttribute, pos::intptr_t + )::UInt16 end function mlirDenseInt32ResourceElementsAttrGetValue(attr, pos) - @ccall mlir_c.mlirDenseInt32ResourceElementsAttrGetValue(attr::MlirAttribute, - pos::intptr_t)::Int32 + @ccall mlir_c.mlirDenseInt32ResourceElementsAttrGetValue( + attr::MlirAttribute, pos::intptr_t + )::Int32 end function mlirDenseUInt32ResourceElementsAttrGetValue(attr, pos) - @ccall mlir_c.mlirDenseUInt32ResourceElementsAttrGetValue(attr::MlirAttribute, - pos::intptr_t)::UInt32 + @ccall mlir_c.mlirDenseUInt32ResourceElementsAttrGetValue( + attr::MlirAttribute, pos::intptr_t + )::UInt32 end function mlirDenseInt64ResourceElementsAttrGetValue(attr, pos) - @ccall mlir_c.mlirDenseInt64ResourceElementsAttrGetValue(attr::MlirAttribute, - pos::intptr_t)::Int64 + @ccall mlir_c.mlirDenseInt64ResourceElementsAttrGetValue( + attr::MlirAttribute, pos::intptr_t + )::Int64 end function mlirDenseUInt64ResourceElementsAttrGetValue(attr, pos) - @ccall mlir_c.mlirDenseUInt64ResourceElementsAttrGetValue(attr::MlirAttribute, - pos::intptr_t)::UInt64 + @ccall mlir_c.mlirDenseUInt64ResourceElementsAttrGetValue( + attr::MlirAttribute, pos::intptr_t + )::UInt64 end function mlirDenseFloatResourceElementsAttrGetValue(attr, pos) - @ccall mlir_c.mlirDenseFloatResourceElementsAttrGetValue(attr::MlirAttribute, - pos::intptr_t)::Cfloat + @ccall mlir_c.mlirDenseFloatResourceElementsAttrGetValue( + attr::MlirAttribute, pos::intptr_t + )::Cfloat end function mlirDenseDoubleResourceElementsAttrGetValue(attr, pos) - @ccall mlir_c.mlirDenseDoubleResourceElementsAttrGetValue(attr::MlirAttribute, - pos::intptr_t)::Cdouble + @ccall mlir_c.mlirDenseDoubleResourceElementsAttrGetValue( + attr::MlirAttribute, pos::intptr_t + )::Cdouble end """ @@ -4037,9 +4249,9 @@ end Creates a sparse elements attribute of the given shape from a list of indices and a list of associated values. Both lists are expected to be dense elements attributes with the same number of elements. The list of indices is expected to contain 64-bit integers. The attribute is created in the same context as the type. """ function mlirSparseElementsAttribute(shapedType, denseIndices, denseValues) - @ccall mlir_c.mlirSparseElementsAttribute(shapedType::MlirType, - denseIndices::MlirAttribute, - denseValues::MlirAttribute)::MlirAttribute + @ccall mlir_c.mlirSparseElementsAttribute( + shapedType::MlirType, denseIndices::MlirAttribute, denseValues::MlirAttribute + )::MlirAttribute end """ @@ -4074,9 +4286,9 @@ function mlirAttributeIsAStridedLayout(attr) end function mlirStridedLayoutAttrGet(ctx, offset, numStrides, strides) - @ccall mlir_c.mlirStridedLayoutAttrGet(ctx::MlirContext, offset::Int64, - numStrides::intptr_t, - strides::Ptr{Int64})::MlirAttribute + @ccall mlir_c.mlirStridedLayoutAttrGet( + ctx::MlirContext, offset::Int64, numStrides::intptr_t, strides::Ptr{Int64} + )::MlirAttribute end function mlirStridedLayoutAttrGetOffset(attr) @@ -4682,8 +4894,9 @@ end Creates a vector type of the shape identified by its rank and dimensions, with the given element type in the same context as the element type. The type is owned by the context. """ function mlirVectorTypeGet(rank, shape, elementType) - @ccall mlir_c.mlirVectorTypeGet(rank::intptr_t, shape::Ptr{Int64}, - elementType::MlirType)::MlirType + @ccall mlir_c.mlirVectorTypeGet( + rank::intptr_t, shape::Ptr{Int64}, elementType::MlirType + )::MlirType end """ @@ -4692,9 +4905,9 @@ end Same as "[`mlirVectorTypeGet`](@ref)" but returns a nullptr wrapping [`MlirType`](@ref) on illegal arguments, emitting appropriate diagnostics. """ function mlirVectorTypeGetChecked(loc, rank, shape, elementType) - @ccall mlir_c.mlirVectorTypeGetChecked(loc::MlirLocation, rank::intptr_t, - shape::Ptr{Int64}, - elementType::MlirType)::MlirType + @ccall mlir_c.mlirVectorTypeGetChecked( + loc::MlirLocation, rank::intptr_t, shape::Ptr{Int64}, elementType::MlirType + )::MlirType end """ @@ -4703,9 +4916,9 @@ end Creates a scalable vector type with the shape identified by its rank and dimensions. A subset of dimensions may be marked as scalable via the corresponding flag list, which is expected to have as many entries as the rank of the vector. The vector is created in the same context as the element type. """ function mlirVectorTypeGetScalable(rank, shape, scalable, elementType) - @ccall mlir_c.mlirVectorTypeGetScalable(rank::intptr_t, shape::Ptr{Int64}, - scalable::Ptr{Bool}, - elementType::MlirType)::MlirType + @ccall mlir_c.mlirVectorTypeGetScalable( + rank::intptr_t, shape::Ptr{Int64}, scalable::Ptr{Bool}, elementType::MlirType + )::MlirType end """ @@ -4714,9 +4927,13 @@ end Same as "[`mlirVectorTypeGetScalable`](@ref)" but returns a nullptr wrapping [`MlirType`](@ref) on illegal arguments, emitting appropriate diagnostics. """ function mlirVectorTypeGetScalableChecked(loc, rank, shape, scalable, elementType) - @ccall mlir_c.mlirVectorTypeGetScalableChecked(loc::MlirLocation, rank::intptr_t, - shape::Ptr{Int64}, scalable::Ptr{Bool}, - elementType::MlirType)::MlirType + @ccall mlir_c.mlirVectorTypeGetScalableChecked( + loc::MlirLocation, + rank::intptr_t, + shape::Ptr{Int64}, + scalable::Ptr{Bool}, + elementType::MlirType, + )::MlirType end """ @@ -4788,9 +5005,9 @@ end Creates a tensor type of a fixed rank with the given shape, element type, and optional encoding in the same context as the element type. The type is owned by the context. Tensor types without any specific encoding field should assign [`mlirAttributeGetNull`](@ref)() to this parameter. """ function mlirRankedTensorTypeGet(rank, shape, elementType, encoding) - @ccall mlir_c.mlirRankedTensorTypeGet(rank::intptr_t, shape::Ptr{Int64}, - elementType::MlirType, - encoding::MlirAttribute)::MlirType + @ccall mlir_c.mlirRankedTensorTypeGet( + rank::intptr_t, shape::Ptr{Int64}, elementType::MlirType, encoding::MlirAttribute + )::MlirType end """ @@ -4799,9 +5016,13 @@ end Same as "[`mlirRankedTensorTypeGet`](@ref)" but returns a nullptr wrapping [`MlirType`](@ref) on illegal arguments, emitting appropriate diagnostics. """ function mlirRankedTensorTypeGetChecked(loc, rank, shape, elementType, encoding) - @ccall mlir_c.mlirRankedTensorTypeGetChecked(loc::MlirLocation, rank::intptr_t, - shape::Ptr{Int64}, elementType::MlirType, - encoding::MlirAttribute)::MlirType + @ccall mlir_c.mlirRankedTensorTypeGetChecked( + loc::MlirLocation, + rank::intptr_t, + shape::Ptr{Int64}, + elementType::MlirType, + encoding::MlirAttribute, + )::MlirType end """ @@ -4828,8 +5049,9 @@ end Same as "[`mlirUnrankedTensorTypeGet`](@ref)" but returns a nullptr wrapping [`MlirType`](@ref) on illegal arguments, emitting appropriate diagnostics. """ function mlirUnrankedTensorTypeGetChecked(loc, elementType) - @ccall mlir_c.mlirUnrankedTensorTypeGetChecked(loc::MlirLocation, - elementType::MlirType)::MlirType + @ccall mlir_c.mlirUnrankedTensorTypeGetChecked( + loc::MlirLocation, elementType::MlirType + )::MlirType end """ @@ -4874,9 +5096,13 @@ end Creates a MemRef type with the given rank and shape, a potentially empty list of affine layout maps, the given memory space and element type, in the same context as element type. The type is owned by the context. """ function mlirMemRefTypeGet(elementType, rank, shape, layout, memorySpace) - @ccall mlir_c.mlirMemRefTypeGet(elementType::MlirType, rank::intptr_t, - shape::Ptr{Int64}, layout::MlirAttribute, - memorySpace::MlirAttribute)::MlirType + @ccall mlir_c.mlirMemRefTypeGet( + elementType::MlirType, + rank::intptr_t, + shape::Ptr{Int64}, + layout::MlirAttribute, + memorySpace::MlirAttribute, + )::MlirType end """ @@ -4885,10 +5111,14 @@ end Same as "[`mlirMemRefTypeGet`](@ref)" but returns a nullptr-wrapping [`MlirType`](@ref) o illegal arguments, emitting appropriate diagnostics. """ function mlirMemRefTypeGetChecked(loc, elementType, rank, shape, layout, memorySpace) - @ccall mlir_c.mlirMemRefTypeGetChecked(loc::MlirLocation, elementType::MlirType, - rank::intptr_t, shape::Ptr{Int64}, - layout::MlirAttribute, - memorySpace::MlirAttribute)::MlirType + @ccall mlir_c.mlirMemRefTypeGetChecked( + loc::MlirLocation, + elementType::MlirType, + rank::intptr_t, + shape::Ptr{Int64}, + layout::MlirAttribute, + memorySpace::MlirAttribute, + )::MlirType end """ @@ -4897,9 +5127,9 @@ end Creates a MemRef type with the given rank, shape, memory space and element type in the same context as the element type. The type has no affine maps, i.e. represents a default row-major contiguous memref. The type is owned by the context. """ function mlirMemRefTypeContiguousGet(elementType, rank, shape, memorySpace) - @ccall mlir_c.mlirMemRefTypeContiguousGet(elementType::MlirType, rank::intptr_t, - shape::Ptr{Int64}, - memorySpace::MlirAttribute)::MlirType + @ccall mlir_c.mlirMemRefTypeContiguousGet( + elementType::MlirType, rank::intptr_t, shape::Ptr{Int64}, memorySpace::MlirAttribute + )::MlirType end """ @@ -4908,10 +5138,13 @@ end Same as "[`mlirMemRefTypeContiguousGet`](@ref)" but returns a nullptr wrapping [`MlirType`](@ref) on illegal arguments, emitting appropriate diagnostics. """ function mlirMemRefTypeContiguousGetChecked(loc, elementType, rank, shape, memorySpace) - @ccall mlir_c.mlirMemRefTypeContiguousGetChecked(loc::MlirLocation, - elementType::MlirType, rank::intptr_t, - shape::Ptr{Int64}, - memorySpace::MlirAttribute)::MlirType + @ccall mlir_c.mlirMemRefTypeContiguousGetChecked( + loc::MlirLocation, + elementType::MlirType, + rank::intptr_t, + shape::Ptr{Int64}, + memorySpace::MlirAttribute, + )::MlirType end """ @@ -4920,8 +5153,9 @@ end Creates an Unranked MemRef type with the given element type and in the given memory space. The type is owned by the context of element type. """ function mlirUnrankedMemRefTypeGet(elementType, memorySpace) - @ccall mlir_c.mlirUnrankedMemRefTypeGet(elementType::MlirType, - memorySpace::MlirAttribute)::MlirType + @ccall mlir_c.mlirUnrankedMemRefTypeGet( + elementType::MlirType, memorySpace::MlirAttribute + )::MlirType end """ @@ -4930,8 +5164,9 @@ end Same as "[`mlirUnrankedMemRefTypeGet`](@ref)" but returns a nullptr wrapping [`MlirType`](@ref) on illegal arguments, emitting appropriate diagnostics. """ function mlirUnrankedMemRefTypeGetChecked(loc, elementType, memorySpace) - @ccall mlir_c.mlirUnrankedMemRefTypeGetChecked(loc::MlirLocation, elementType::MlirType, - memorySpace::MlirAttribute)::MlirType + @ccall mlir_c.mlirUnrankedMemRefTypeGetChecked( + loc::MlirLocation, elementType::MlirType, memorySpace::MlirAttribute + )::MlirType end """ @@ -4967,8 +5202,9 @@ end Returns the strides of the MemRef if the layout map is in strided form. Both strides and offset are out params. strides must point to pre-allocated memory of length equal to the rank of the memref. """ function mlirMemRefTypeGetStridesAndOffset(type, strides, offset) - @ccall mlir_c.mlirMemRefTypeGetStridesAndOffset(type::MlirType, strides::Ptr{Int64}, - offset::Ptr{Int64})::MlirLogicalResult + @ccall mlir_c.mlirMemRefTypeGetStridesAndOffset( + type::MlirType, strides::Ptr{Int64}, offset::Ptr{Int64} + )::MlirLogicalResult end """ @@ -5004,8 +5240,9 @@ end Creates a tuple type that consists of the given list of elemental types. The type is owned by the context. """ function mlirTupleTypeGet(ctx, numElements, elements) - @ccall mlir_c.mlirTupleTypeGet(ctx::MlirContext, numElements::intptr_t, - elements::Ptr{MlirType})::MlirType + @ccall mlir_c.mlirTupleTypeGet( + ctx::MlirContext, numElements::intptr_t, elements::Ptr{MlirType} + )::MlirType end """ @@ -5050,9 +5287,13 @@ end Creates a function type, mapping a list of input types to result types. """ function mlirFunctionTypeGet(ctx, numInputs, inputs, numResults, results) - @ccall mlir_c.mlirFunctionTypeGet(ctx::MlirContext, numInputs::intptr_t, - inputs::Ptr{MlirType}, numResults::intptr_t, - results::Ptr{MlirType})::MlirType + @ccall mlir_c.mlirFunctionTypeGet( + ctx::MlirContext, + numInputs::intptr_t, + inputs::Ptr{MlirType}, + numResults::intptr_t, + results::Ptr{MlirType}, + )::MlirType end """ @@ -5115,8 +5356,9 @@ end Creates an opaque type in the given context associated with the dialect identified by its namespace. The type contains opaque byte data of the specified length (data need not be null-terminated). """ function mlirOpaqueTypeGet(ctx, dialectNamespace, typeData) - @ccall mlir_c.mlirOpaqueTypeGet(ctx::MlirContext, dialectNamespace::MlirStringRef, - typeData::MlirStringRef)::MlirType + @ccall mlir_c.mlirOpaqueTypeGet( + ctx::MlirContext, dialectNamespace::MlirStringRef, typeData::MlirStringRef + )::MlirType end """ @@ -5168,8 +5410,9 @@ end Create a new top-level PassManager anchored on `anchorOp`. """ function mlirPassManagerCreateOnOperation(ctx, anchorOp) - @ccall mlir_c.mlirPassManagerCreateOnOperation(ctx::MlirContext, - anchorOp::MlirStringRef)::MlirPassManager + @ccall mlir_c.mlirPassManagerCreateOnOperation( + ctx::MlirContext, anchorOp::MlirStringRef + )::MlirPassManager end """ @@ -5196,7 +5439,9 @@ end Cast a top-level PassManager to a generic OpPassManager. """ function mlirPassManagerGetAsOpPassManager(passManager) - @ccall mlir_c.mlirPassManagerGetAsOpPassManager(passManager::MlirPassManager)::MlirOpPassManager + @ccall mlir_c.mlirPassManagerGetAsOpPassManager( + passManager::MlirPassManager + )::MlirOpPassManager end """ @@ -5205,8 +5450,9 @@ end Run the provided `passManager` on the given `op`. """ function mlirPassManagerRunOnOp(passManager, op) - @ccall mlir_c.mlirPassManagerRunOnOp(passManager::MlirPassManager, - op::MlirOperation)::MlirLogicalResult + @ccall mlir_c.mlirPassManagerRunOnOp( + passManager::MlirPassManager, op::MlirOperation + )::MlirLogicalResult end """ @@ -5224,8 +5470,9 @@ end Enable / disable verify-each. """ function mlirPassManagerEnableVerifier(passManager, enable) - @ccall mlir_c.mlirPassManagerEnableVerifier(passManager::MlirPassManager, - enable::Bool)::Cvoid + @ccall mlir_c.mlirPassManagerEnableVerifier( + passManager::MlirPassManager, enable::Bool + )::Cvoid end """ @@ -5234,8 +5481,9 @@ end Nest an OpPassManager under the top-level PassManager, the nested passmanager will only run on operations matching the provided name. The returned OpPassManager will be destroyed when the parent is destroyed. To further nest more OpPassManager under the newly returned one, see `mlirOpPassManagerNest` below. """ function mlirPassManagerGetNestedUnder(passManager, operationName) - @ccall mlir_c.mlirPassManagerGetNestedUnder(passManager::MlirPassManager, - operationName::MlirStringRef)::MlirOpPassManager + @ccall mlir_c.mlirPassManagerGetNestedUnder( + passManager::MlirPassManager, operationName::MlirStringRef + )::MlirOpPassManager end """ @@ -5244,8 +5492,9 @@ end Nest an OpPassManager under the provided OpPassManager, the nested passmanager will only run on operations matching the provided name. The returned OpPassManager will be destroyed when the parent is destroyed. """ function mlirOpPassManagerGetNestedUnder(passManager, operationName) - @ccall mlir_c.mlirOpPassManagerGetNestedUnder(passManager::MlirOpPassManager, - operationName::MlirStringRef)::MlirOpPassManager + @ccall mlir_c.mlirOpPassManagerGetNestedUnder( + passManager::MlirOpPassManager, operationName::MlirStringRef + )::MlirOpPassManager end """ @@ -5254,8 +5503,9 @@ end Add a pass and transfer ownership to the provided top-level mlirPassManager. If the pass is not a generic operation pass or a ModulePass, a new OpPassManager is implicitly nested under the provided PassManager. """ function mlirPassManagerAddOwnedPass(passManager, pass) - @ccall mlir_c.mlirPassManagerAddOwnedPass(passManager::MlirPassManager, - pass::MlirPass)::Cvoid + @ccall mlir_c.mlirPassManagerAddOwnedPass( + passManager::MlirPassManager, pass::MlirPass + )::Cvoid end """ @@ -5264,8 +5514,9 @@ end Add a pass and transfer ownership to the provided mlirOpPassManager. If the pass is not a generic operation pass or matching the type of the provided PassManager, a new OpPassManager is implicitly nested under the provided PassManager. """ function mlirOpPassManagerAddOwnedPass(passManager, pass) - @ccall mlir_c.mlirOpPassManagerAddOwnedPass(passManager::MlirOpPassManager, - pass::MlirPass)::Cvoid + @ccall mlir_c.mlirOpPassManagerAddOwnedPass( + passManager::MlirOpPassManager, pass::MlirPass + )::Cvoid end """ @@ -5274,10 +5525,12 @@ end Parse a sequence of textual MLIR pass pipeline elements and add them to the provided OpPassManager. If parsing fails an error message is reported using the provided callback. """ function mlirOpPassManagerAddPipeline(passManager, pipelineElements, callback, userData) - @ccall mlir_c.mlirOpPassManagerAddPipeline(passManager::MlirOpPassManager, - pipelineElements::MlirStringRef, - callback::MlirStringCallback, - userData::Ptr{Cvoid})::MlirLogicalResult + @ccall mlir_c.mlirOpPassManagerAddPipeline( + passManager::MlirOpPassManager, + pipelineElements::MlirStringRef, + callback::MlirStringCallback, + userData::Ptr{Cvoid}, + )::MlirLogicalResult end """ @@ -5286,9 +5539,9 @@ end Print a textual MLIR pass pipeline by sending chunks of the string representation and forwarding `userData to `callback`. Note that the callback may be called several times with consecutive chunks of the string. """ function mlirPrintPassPipeline(passManager, callback, userData) - @ccall mlir_c.mlirPrintPassPipeline(passManager::MlirOpPassManager, - callback::MlirStringCallback, - userData::Ptr{Cvoid})::Cvoid + @ccall mlir_c.mlirPrintPassPipeline( + passManager::MlirOpPassManager, callback::MlirStringCallback, userData::Ptr{Cvoid} + )::Cvoid end """ @@ -5297,10 +5550,12 @@ end Parse a textual MLIR pass pipeline and assign it to the provided OpPassManager. If parsing fails an error message is reported using the provided callback. """ function mlirParsePassPipeline(passManager, pipeline, callback, userData) - @ccall mlir_c.mlirParsePassPipeline(passManager::MlirOpPassManager, - pipeline::MlirStringRef, - callback::MlirStringCallback, - userData::Ptr{Cvoid})::MlirLogicalResult + @ccall mlir_c.mlirParsePassPipeline( + passManager::MlirOpPassManager, + pipeline::MlirStringRef, + callback::MlirStringCallback, + userData::Ptr{Cvoid}, + )::MlirLogicalResult end """ @@ -5329,15 +5584,28 @@ end Creates an external [`MlirPass`](@ref) that calls the supplied `callbacks` using the supplied `userData`. If `opName` is empty, the pass is a generic operation pass. Otherwise it is an operation pass specific to the specified pass name. """ -function mlirCreateExternalPass(passID, name, argument, description, opName, - nDependentDialects, dependentDialects, callbacks, userData) - @ccall mlir_c.mlirCreateExternalPass(passID::MlirTypeID, name::MlirStringRef, - argument::MlirStringRef, - description::MlirStringRef, opName::MlirStringRef, - nDependentDialects::intptr_t, - dependentDialects::Ptr{MlirDialectHandle}, - callbacks::MlirExternalPassCallbacks, - userData::Ptr{Cvoid})::MlirPass +function mlirCreateExternalPass( + passID, + name, + argument, + description, + opName, + nDependentDialects, + dependentDialects, + callbacks, + userData, +) + @ccall mlir_c.mlirCreateExternalPass( + passID::MlirTypeID, + name::MlirStringRef, + argument::MlirStringRef, + description::MlirStringRef, + opName::MlirStringRef, + nDependentDialects::intptr_t, + dependentDialects::Ptr{MlirDialectHandle}, + callbacks::MlirExternalPassCallbacks, + userData::Ptr{Cvoid}, + )::MlirPass end """ @@ -5985,9 +6253,9 @@ const MlirDiagnosticHandler = Ptr{Cvoid} Prints a diagnostic using the provided callback. """ function mlirDiagnosticPrint(diagnostic, callback, userData) - @ccall mlir_c.mlirDiagnosticPrint(diagnostic::MlirDiagnostic, - callback::MlirStringCallback, - userData::Ptr{Cvoid})::Cvoid + @ccall mlir_c.mlirDiagnosticPrint( + diagnostic::MlirDiagnostic, callback::MlirStringCallback, userData::Ptr{Cvoid} + )::Cvoid end """ @@ -6005,7 +6273,9 @@ end Returns the severity of the diagnostic. """ function mlirDiagnosticGetSeverity(diagnostic) - @ccall mlir_c.mlirDiagnosticGetSeverity(diagnostic::MlirDiagnostic)::MlirDiagnosticSeverity + @ccall mlir_c.mlirDiagnosticGetSeverity( + diagnostic::MlirDiagnostic + )::MlirDiagnosticSeverity end """ @@ -6023,8 +6293,9 @@ end Returns `pos`-th note attached to the diagnostic. Expects `pos` to be a valid zero-based index into the list of notes. """ function mlirDiagnosticGetNote(diagnostic, pos) - @ccall mlir_c.mlirDiagnosticGetNote(diagnostic::MlirDiagnostic, - pos::intptr_t)::MlirDiagnostic + @ccall mlir_c.mlirDiagnosticGetNote( + diagnostic::MlirDiagnostic, pos::intptr_t + )::MlirDiagnostic end """ @@ -6033,10 +6304,12 @@ end Attaches the diagnostic handler to the context. Handlers are invoked in the reverse order of attachment until one of them processes the diagnostic completely. When a handler is invoked it is passed the `userData` that was provided when it was attached. If non-NULL, `deleteUserData` is called once the system no longer needs to call the handler (for instance after the handler is detached or the context is destroyed). Returns an identifier that can be used to detach the handler. """ function mlirContextAttachDiagnosticHandler(context, handler, userData, deleteUserData) - @ccall mlir_c.mlirContextAttachDiagnosticHandler(context::MlirContext, - handler::MlirDiagnosticHandler, - userData::Ptr{Cvoid}, - deleteUserData::Ptr{Cvoid})::MlirDiagnosticHandlerID + @ccall mlir_c.mlirContextAttachDiagnosticHandler( + context::MlirContext, + handler::MlirDiagnosticHandler, + userData::Ptr{Cvoid}, + deleteUserData::Ptr{Cvoid}, + )::MlirDiagnosticHandlerID end """ @@ -6045,8 +6318,9 @@ end Detaches an attached diagnostic handler from the context given its identifier. """ function mlirContextDetachDiagnosticHandler(context, id) - @ccall mlir_c.mlirContextDetachDiagnosticHandler(context::MlirContext, - id::MlirDiagnosticHandlerID)::Cvoid + @ccall mlir_c.mlirContextDetachDiagnosticHandler( + context::MlirContext, id::MlirDiagnosticHandlerID + )::Cvoid end """ @@ -6136,8 +6410,9 @@ end Sets the argument attribute 'name' of an argument at index 'pos'. Asserts that the operation is a FuncOp. """ function mlirFuncSetArgAttr(op, pos, name, attr) - @ccall mlir_c.mlirFuncSetArgAttr(op::MlirOperation, pos::intptr_t, name::MlirStringRef, - attr::MlirAttribute)::Cvoid + @ccall mlir_c.mlirFuncSetArgAttr( + op::MlirOperation, pos::intptr_t, name::MlirStringRef, attr::MlirAttribute + )::Cvoid end function mlirGetDialectHandle__gpu__() @@ -6265,9 +6540,12 @@ end Creates an llvm.func type. """ function mlirLLVMFunctionTypeGet(resultType, nArgumentTypes, argumentTypes, isVarArg) - @ccall mlir_c.mlirLLVMFunctionTypeGet(resultType::MlirType, nArgumentTypes::intptr_t, - argumentTypes::Ptr{MlirType}, - isVarArg::Bool)::MlirType + @ccall mlir_c.mlirLLVMFunctionTypeGet( + resultType::MlirType, + nArgumentTypes::intptr_t, + argumentTypes::Ptr{MlirType}, + isVarArg::Bool, + )::MlirType end """ @@ -6303,8 +6581,9 @@ end Returns the `positions`-th field of the struct. Asserts if the struct is opaque, not yet initialized or if the position is out of range. """ function mlirLLVMStructTypeGetElementType(type, position) - @ccall mlir_c.mlirLLVMStructTypeGetElementType(type::MlirType, - position::intptr_t)::MlirType + @ccall mlir_c.mlirLLVMStructTypeGetElementType( + type::MlirType, position::intptr_t + )::MlirType end """ @@ -6340,9 +6619,9 @@ end Creates an LLVM literal (unnamed) struct type. This may assert if the fields have types not compatible with the LLVM dialect. For a graceful failure, use the checked version. """ function mlirLLVMStructTypeLiteralGet(ctx, nFieldTypes, fieldTypes, isPacked) - @ccall mlir_c.mlirLLVMStructTypeLiteralGet(ctx::MlirContext, nFieldTypes::intptr_t, - fieldTypes::Ptr{MlirType}, - isPacked::Bool)::MlirType + @ccall mlir_c.mlirLLVMStructTypeLiteralGet( + ctx::MlirContext, nFieldTypes::intptr_t, fieldTypes::Ptr{MlirType}, isPacked::Bool + )::MlirType end """ @@ -6351,10 +6630,9 @@ end Creates an LLVM literal (unnamed) struct type if possible. Emits a diagnostic at the given location and returns null otherwise. """ function mlirLLVMStructTypeLiteralGetChecked(loc, nFieldTypes, fieldTypes, isPacked) - @ccall mlir_c.mlirLLVMStructTypeLiteralGetChecked(loc::MlirLocation, - nFieldTypes::intptr_t, - fieldTypes::Ptr{MlirType}, - isPacked::Bool)::MlirType + @ccall mlir_c.mlirLLVMStructTypeLiteralGetChecked( + loc::MlirLocation, nFieldTypes::intptr_t, fieldTypes::Ptr{MlirType}, isPacked::Bool + )::MlirType end """ @@ -6363,8 +6641,9 @@ end Creates an LLVM identified struct type with no body. If a struct type with this name already exists in the context, returns that type. Use [`mlirLLVMStructTypeIdentifiedNewGet`](@ref) to create a fresh struct type, potentially renaming it. The body should be set separatelty by calling [`mlirLLVMStructTypeSetBody`](@ref), if it isn't set already. """ function mlirLLVMStructTypeIdentifiedGet(ctx, name) - @ccall mlir_c.mlirLLVMStructTypeIdentifiedGet(ctx::MlirContext, - name::MlirStringRef)::MlirType + @ccall mlir_c.mlirLLVMStructTypeIdentifiedGet( + ctx::MlirContext, name::MlirStringRef + )::MlirType end """ @@ -6373,15 +6652,19 @@ end Creates an LLVM identified struct type with no body and a name starting with the given prefix. If a struct with the exact name as the given prefix already exists, appends an unspecified suffix to the name so that the name is unique in context. """ function mlirLLVMStructTypeIdentifiedNewGet(ctx, name, nFieldTypes, fieldTypes, isPacked) - @ccall mlir_c.mlirLLVMStructTypeIdentifiedNewGet(ctx::MlirContext, name::MlirStringRef, - nFieldTypes::intptr_t, - fieldTypes::Ptr{MlirType}, - isPacked::Bool)::MlirType + @ccall mlir_c.mlirLLVMStructTypeIdentifiedNewGet( + ctx::MlirContext, + name::MlirStringRef, + nFieldTypes::intptr_t, + fieldTypes::Ptr{MlirType}, + isPacked::Bool, + )::MlirType end function mlirLLVMStructTypeOpaqueGet(ctx, name) - @ccall mlir_c.mlirLLVMStructTypeOpaqueGet(ctx::MlirContext, - name::MlirStringRef)::MlirType + @ccall mlir_c.mlirLLVMStructTypeOpaqueGet( + ctx::MlirContext, name::MlirStringRef + )::MlirType end """ @@ -6390,9 +6673,12 @@ end Sets the body of the identified struct if it hasn't been set yet. Returns whether the operation was successful. """ function mlirLLVMStructTypeSetBody(structType, nFieldTypes, fieldTypes, isPacked) - @ccall mlir_c.mlirLLVMStructTypeSetBody(structType::MlirType, nFieldTypes::intptr_t, - fieldTypes::Ptr{MlirType}, - isPacked::Bool)::MlirLogicalResult + @ccall mlir_c.mlirLLVMStructTypeSetBody( + structType::MlirType, + nFieldTypes::intptr_t, + fieldTypes::Ptr{MlirType}, + isPacked::Bool, + )::MlirLogicalResult end @cenum MlirLLVMCConv::UInt32 begin @@ -6451,8 +6737,9 @@ end Creates a LLVM CConv attribute. """ function mlirLLVMCConvAttrGet(ctx, cconv) - @ccall mlir_c.mlirLLVMCConvAttrGet(ctx::MlirContext, - cconv::MlirLLVMCConv)::MlirAttribute + @ccall mlir_c.mlirLLVMCConvAttrGet( + ctx::MlirContext, cconv::MlirLLVMCConv + )::MlirAttribute end @cenum MlirLLVMComdat::UInt32 begin @@ -6469,8 +6756,9 @@ end Creates a LLVM Comdat attribute. """ function mlirLLVMComdatAttrGet(ctx, comdat) - @ccall mlir_c.mlirLLVMComdatAttrGet(ctx::MlirContext, - comdat::MlirLLVMComdat)::MlirAttribute + @ccall mlir_c.mlirLLVMComdatAttrGet( + ctx::MlirContext, comdat::MlirLLVMComdat + )::MlirAttribute end @cenum MlirLLVMLinkage::UInt32 begin @@ -6493,8 +6781,9 @@ end Creates a LLVM Linkage attribute. """ function mlirLLVMLinkageAttrGet(ctx, linkage) - @ccall mlir_c.mlirLLVMLinkageAttrGet(ctx::MlirContext, - linkage::MlirLLVMLinkage)::MlirAttribute + @ccall mlir_c.mlirLLVMLinkageAttrGet( + ctx::MlirContext, linkage::MlirLLVMLinkage + )::MlirAttribute end """ @@ -6512,9 +6801,9 @@ end Creates a LLVM DIExpressionElem attribute. """ function mlirLLVMDIExpressionElemAttrGet(ctx, opcode, nArguments, arguments) - @ccall mlir_c.mlirLLVMDIExpressionElemAttrGet(ctx::MlirContext, opcode::Cuint, - nArguments::intptr_t, - arguments::Ptr{UInt64})::MlirAttribute + @ccall mlir_c.mlirLLVMDIExpressionElemAttrGet( + ctx::MlirContext, opcode::Cuint, nArguments::intptr_t, arguments::Ptr{UInt64} + )::MlirAttribute end """ @@ -6523,8 +6812,9 @@ end Creates a LLVM DIExpression attribute. """ function mlirLLVMDIExpressionAttrGet(ctx, nOperations, operations) - @ccall mlir_c.mlirLLVMDIExpressionAttrGet(ctx::MlirContext, nOperations::intptr_t, - operations::Ptr{MlirAttribute})::MlirAttribute + @ccall mlir_c.mlirLLVMDIExpressionAttrGet( + ctx::MlirContext, nOperations::intptr_t, operations::Ptr{MlirAttribute} + )::MlirAttribute end @cenum MlirLLVMTypeEncoding::UInt32 begin @@ -6556,9 +6846,13 @@ end Creates a LLVM DIBasicType attribute. """ function mlirLLVMDIBasicTypeAttrGet(ctx, tag, name, sizeInBits, encoding) - @ccall mlir_c.mlirLLVMDIBasicTypeAttrGet(ctx::MlirContext, tag::Cuint, - name::MlirAttribute, sizeInBits::UInt64, - encoding::MlirLLVMTypeEncoding)::MlirAttribute + @ccall mlir_c.mlirLLVMDIBasicTypeAttrGet( + ctx::MlirContext, + tag::Cuint, + name::MlirAttribute, + sizeInBits::UInt64, + encoding::MlirLLVMTypeEncoding, + )::MlirAttribute end """ @@ -6566,16 +6860,36 @@ end Creates a LLVM DICompositeType attribute. """ -function mlirLLVMDICompositeTypeAttrGet(ctx, tag, recId, name, file, line, scope, baseType, - flags, sizeInBits, alignInBits, nElements, elements) - @ccall mlir_c.mlirLLVMDICompositeTypeAttrGet(ctx::MlirContext, tag::Cuint, - recId::MlirAttribute, name::MlirAttribute, - file::MlirAttribute, line::UInt32, - scope::MlirAttribute, - baseType::MlirAttribute, flags::Int64, - sizeInBits::UInt64, alignInBits::UInt64, - nElements::intptr_t, - elements::Ptr{MlirAttribute})::MlirAttribute +function mlirLLVMDICompositeTypeAttrGet( + ctx, + tag, + recId, + name, + file, + line, + scope, + baseType, + flags, + sizeInBits, + alignInBits, + nElements, + elements, +) + @ccall mlir_c.mlirLLVMDICompositeTypeAttrGet( + ctx::MlirContext, + tag::Cuint, + recId::MlirAttribute, + name::MlirAttribute, + file::MlirAttribute, + line::UInt32, + scope::MlirAttribute, + baseType::MlirAttribute, + flags::Int64, + sizeInBits::UInt64, + alignInBits::UInt64, + nElements::intptr_t, + elements::Ptr{MlirAttribute}, + )::MlirAttribute end """ @@ -6583,13 +6897,19 @@ end Creates a LLVM DIDerivedType attribute. """ -function mlirLLVMDIDerivedTypeAttrGet(ctx, tag, name, baseType, sizeInBits, alignInBits, - offsetInBits, extraData) - @ccall mlir_c.mlirLLVMDIDerivedTypeAttrGet(ctx::MlirContext, tag::Cuint, - name::MlirAttribute, baseType::MlirAttribute, - sizeInBits::UInt64, alignInBits::UInt32, - offsetInBits::UInt64, - extraData::MlirAttribute)::MlirAttribute +function mlirLLVMDIDerivedTypeAttrGet( + ctx, tag, name, baseType, sizeInBits, alignInBits, offsetInBits, extraData +) + @ccall mlir_c.mlirLLVMDIDerivedTypeAttrGet( + ctx::MlirContext, + tag::Cuint, + name::MlirAttribute, + baseType::MlirAttribute, + sizeInBits::UInt64, + alignInBits::UInt32, + offsetInBits::UInt64, + extraData::MlirAttribute, + )::MlirAttribute end """ @@ -6598,7 +6918,9 @@ end Gets the base type from a LLVM DIDerivedType attribute. """ function mlirLLVMDIDerivedTypeAttrGetBaseType(diDerivedType) - @ccall mlir_c.mlirLLVMDIDerivedTypeAttrGetBaseType(diDerivedType::MlirAttribute)::MlirAttribute + @ccall mlir_c.mlirLLVMDIDerivedTypeAttrGetBaseType( + diDerivedType::MlirAttribute + )::MlirAttribute end """ @@ -6607,8 +6929,9 @@ end Creates a LLVM DIFileAttr attribute. """ function mlirLLVMDIFileAttrGet(ctx, name, directory) - @ccall mlir_c.mlirLLVMDIFileAttrGet(ctx::MlirContext, name::MlirAttribute, - directory::MlirAttribute)::MlirAttribute + @ccall mlir_c.mlirLLVMDIFileAttrGet( + ctx::MlirContext, name::MlirAttribute, directory::MlirAttribute + )::MlirAttribute end @cenum MlirLLVMDIEmissionKind::UInt32 begin @@ -6623,12 +6946,18 @@ end Creates a LLVM DICompileUnit attribute. """ -function mlirLLVMDICompileUnitAttrGet(ctx, id, sourceLanguage, file, producer, isOptimized, - emissionKind) - @ccall mlir_c.mlirLLVMDICompileUnitAttrGet(ctx::MlirContext, id::MlirAttribute, - sourceLanguage::Cuint, file::MlirAttribute, - producer::MlirAttribute, isOptimized::Bool, - emissionKind::MlirLLVMDIEmissionKind)::MlirAttribute +function mlirLLVMDICompileUnitAttrGet( + ctx, id, sourceLanguage, file, producer, isOptimized, emissionKind +) + @ccall mlir_c.mlirLLVMDICompileUnitAttrGet( + ctx::MlirContext, + id::MlirAttribute, + sourceLanguage::Cuint, + file::MlirAttribute, + producer::MlirAttribute, + isOptimized::Bool, + emissionKind::MlirLLVMDIEmissionKind, + )::MlirAttribute end """ @@ -6646,9 +6975,13 @@ end Creates a LLVM DILexicalBlock attribute. """ function mlirLLVMDILexicalBlockAttrGet(ctx, scope, file, line, column) - @ccall mlir_c.mlirLLVMDILexicalBlockAttrGet(ctx::MlirContext, scope::MlirAttribute, - file::MlirAttribute, line::Cuint, - column::Cuint)::MlirAttribute + @ccall mlir_c.mlirLLVMDILexicalBlockAttrGet( + ctx::MlirContext, + scope::MlirAttribute, + file::MlirAttribute, + line::Cuint, + column::Cuint, + )::MlirAttribute end """ @@ -6657,9 +6990,9 @@ end Creates a LLVM DILexicalBlockFile attribute. """ function mlirLLVMDILexicalBlockFileAttrGet(ctx, scope, file, discriminator) - @ccall mlir_c.mlirLLVMDILexicalBlockFileAttrGet(ctx::MlirContext, scope::MlirAttribute, - file::MlirAttribute, - discriminator::Cuint)::MlirAttribute + @ccall mlir_c.mlirLLVMDILexicalBlockFileAttrGet( + ctx::MlirContext, scope::MlirAttribute, file::MlirAttribute, discriminator::Cuint + )::MlirAttribute end """ @@ -6667,13 +7000,19 @@ end Creates a LLVM DILocalVariableAttr attribute. """ -function mlirLLVMDILocalVariableAttrGet(ctx, scope, name, diFile, line, arg, alignInBits, - diType) - @ccall mlir_c.mlirLLVMDILocalVariableAttrGet(ctx::MlirContext, scope::MlirAttribute, - name::MlirAttribute, diFile::MlirAttribute, - line::Cuint, arg::Cuint, - alignInBits::Cuint, - diType::MlirAttribute)::MlirAttribute +function mlirLLVMDILocalVariableAttrGet( + ctx, scope, name, diFile, line, arg, alignInBits, diType +) + @ccall mlir_c.mlirLLVMDILocalVariableAttrGet( + ctx::MlirContext, + scope::MlirAttribute, + name::MlirAttribute, + diFile::MlirAttribute, + line::Cuint, + arg::Cuint, + alignInBits::Cuint, + diType::MlirAttribute, + )::MlirAttribute end """ @@ -6681,15 +7020,32 @@ end Creates a LLVM DISubprogramAttr attribute. """ -function mlirLLVMDISubprogramAttrGet(ctx, id, compileUnit, scope, name, linkageName, file, - line, scopeLine, subprogramFlags, type) - @ccall mlir_c.mlirLLVMDISubprogramAttrGet(ctx::MlirContext, id::MlirAttribute, - compileUnit::MlirAttribute, - scope::MlirAttribute, name::MlirAttribute, - linkageName::MlirAttribute, - file::MlirAttribute, line::Cuint, - scopeLine::Cuint, subprogramFlags::UInt64, - type::MlirAttribute)::MlirAttribute +function mlirLLVMDISubprogramAttrGet( + ctx, + id, + compileUnit, + scope, + name, + linkageName, + file, + line, + scopeLine, + subprogramFlags, + type, +) + @ccall mlir_c.mlirLLVMDISubprogramAttrGet( + ctx::MlirContext, + id::MlirAttribute, + compileUnit::MlirAttribute, + scope::MlirAttribute, + name::MlirAttribute, + linkageName::MlirAttribute, + file::MlirAttribute, + line::Cuint, + scopeLine::Cuint, + subprogramFlags::UInt64, + type::MlirAttribute, + )::MlirAttribute end """ @@ -6698,7 +7054,9 @@ end Gets the scope from this DISubprogramAttr. """ function mlirLLVMDISubprogramAttrGetScope(diSubprogram) - @ccall mlir_c.mlirLLVMDISubprogramAttrGetScope(diSubprogram::MlirAttribute)::MlirAttribute + @ccall mlir_c.mlirLLVMDISubprogramAttrGetScope( + diSubprogram::MlirAttribute + )::MlirAttribute end """ @@ -6725,7 +7083,9 @@ end Gets the compile unit from this DISubprogram. """ function mlirLLVMDISubprogramAttrGetCompileUnit(diSubprogram) - @ccall mlir_c.mlirLLVMDISubprogramAttrGetCompileUnit(diSubprogram::MlirAttribute)::MlirAttribute + @ccall mlir_c.mlirLLVMDISubprogramAttrGetCompileUnit( + diSubprogram::MlirAttribute + )::MlirAttribute end """ @@ -6734,7 +7094,9 @@ end Gets the file from this DISubprogramAttr. """ function mlirLLVMDISubprogramAttrGetFile(diSubprogram) - @ccall mlir_c.mlirLLVMDISubprogramAttrGetFile(diSubprogram::MlirAttribute)::MlirAttribute + @ccall mlir_c.mlirLLVMDISubprogramAttrGetFile( + diSubprogram::MlirAttribute + )::MlirAttribute end """ @@ -6743,7 +7105,9 @@ end Gets the type from this DISubprogramAttr. """ function mlirLLVMDISubprogramAttrGetType(diSubprogram) - @ccall mlir_c.mlirLLVMDISubprogramAttrGetType(diSubprogram::MlirAttribute)::MlirAttribute + @ccall mlir_c.mlirLLVMDISubprogramAttrGetType( + diSubprogram::MlirAttribute + )::MlirAttribute end """ @@ -6752,10 +7116,12 @@ end Creates a LLVM DISubroutineTypeAttr attribute. """ function mlirLLVMDISubroutineTypeAttrGet(ctx, callingConvention, nTypes, types) - @ccall mlir_c.mlirLLVMDISubroutineTypeAttrGet(ctx::MlirContext, - callingConvention::Cuint, - nTypes::intptr_t, - types::Ptr{MlirAttribute})::MlirAttribute + @ccall mlir_c.mlirLLVMDISubroutineTypeAttrGet( + ctx::MlirContext, + callingConvention::Cuint, + nTypes::intptr_t, + types::Ptr{MlirAttribute}, + )::MlirAttribute end """ @@ -6763,14 +7129,20 @@ end Creates a LLVM DIModuleAttr attribute. """ -function mlirLLVMDIModuleAttrGet(ctx, file, scope, name, configMacros, includePath, - apinotes, line, isDecl) - @ccall mlir_c.mlirLLVMDIModuleAttrGet(ctx::MlirContext, file::MlirAttribute, - scope::MlirAttribute, name::MlirAttribute, - configMacros::MlirAttribute, - includePath::MlirAttribute, - apinotes::MlirAttribute, line::Cuint, - isDecl::Bool)::MlirAttribute +function mlirLLVMDIModuleAttrGet( + ctx, file, scope, name, configMacros, includePath, apinotes, line, isDecl +) + @ccall mlir_c.mlirLLVMDIModuleAttrGet( + ctx::MlirContext, + file::MlirAttribute, + scope::MlirAttribute, + name::MlirAttribute, + configMacros::MlirAttribute, + includePath::MlirAttribute, + apinotes::MlirAttribute, + line::Cuint, + isDecl::Bool, + )::MlirAttribute end """ @@ -6991,8 +7363,9 @@ end Returns the minimum possible value stored by a quantized type. """ function mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned, integralWidth) - @ccall mlir_c.mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned::Bool, - integralWidth::Cuint)::Int64 + @ccall mlir_c.mlirQuantizedTypeGetDefaultMinimumForInteger( + isSigned::Bool, integralWidth::Cuint + )::Int64 end """ @@ -7001,8 +7374,9 @@ end Returns the maximum possible value stored by a quantized type. """ function mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned, integralWidth) - @ccall mlir_c.mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned::Bool, - integralWidth::Cuint)::Int64 + @ccall mlir_c.mlirQuantizedTypeGetDefaultMaximumForInteger( + isSigned::Bool, integralWidth::Cuint + )::Int64 end """ @@ -7074,8 +7448,9 @@ end Returns `true` if the `candidate` type is compatible with the given quantized `type`. """ function mlirQuantizedTypeIsCompatibleExpressedType(type, candidate) - @ccall mlir_c.mlirQuantizedTypeIsCompatibleExpressedType(type::MlirType, - candidate::MlirType)::Bool + @ccall mlir_c.mlirQuantizedTypeIsCompatibleExpressedType( + type::MlirType, candidate::MlirType + )::Bool end """ @@ -7093,8 +7468,9 @@ end Casts from a type based on the storage type of the given type to a corresponding type based on the given type. Returns a null type if the cast is not valid. """ function mlirQuantizedTypeCastFromStorageType(type, candidate) - @ccall mlir_c.mlirQuantizedTypeCastFromStorageType(type::MlirType, - candidate::MlirType)::MlirType + @ccall mlir_c.mlirQuantizedTypeCastFromStorageType( + type::MlirType, candidate::MlirType + )::MlirType end """ @@ -7112,8 +7488,9 @@ end Casts from a type based on the expressed type of the given type to a corresponding type based on the given type. Returns a null type if the cast is not valid. """ function mlirQuantizedTypeCastFromExpressedType(type, candidate) - @ccall mlir_c.mlirQuantizedTypeCastFromExpressedType(type::MlirType, - candidate::MlirType)::MlirType + @ccall mlir_c.mlirQuantizedTypeCastFromExpressedType( + type::MlirType, candidate::MlirType + )::MlirType end """ @@ -7131,8 +7508,9 @@ end Casts from a type based on the expressed type of the given quantized type to equivalent type based on storage type of the same quantized type. """ function mlirQuantizedTypeCastExpressedToStorageType(type, candidate) - @ccall mlir_c.mlirQuantizedTypeCastExpressedToStorageType(type::MlirType, - candidate::MlirType)::MlirType + @ccall mlir_c.mlirQuantizedTypeCastExpressedToStorageType( + type::MlirType, candidate::MlirType + )::MlirType end """ @@ -7149,11 +7527,16 @@ end Creates an instance of AnyQuantizedType with the given parameters in the same context as `storageType` and returns it. The instance is owned by the context. """ -function mlirAnyQuantizedTypeGet(flags, storageType, expressedType, storageTypeMin, - storageTypeMax) - @ccall mlir_c.mlirAnyQuantizedTypeGet(flags::Cuint, storageType::MlirType, - expressedType::MlirType, storageTypeMin::Int64, - storageTypeMax::Int64)::MlirType +function mlirAnyQuantizedTypeGet( + flags, storageType, expressedType, storageTypeMin, storageTypeMax +) + @ccall mlir_c.mlirAnyQuantizedTypeGet( + flags::Cuint, + storageType::MlirType, + expressedType::MlirType, + storageTypeMin::Int64, + storageTypeMax::Int64, + )::MlirType end """ @@ -7170,12 +7553,18 @@ end Creates an instance of UniformQuantizedType with the given parameters in the same context as `storageType` and returns it. The instance is owned by the context. """ -function mlirUniformQuantizedTypeGet(flags, storageType, expressedType, scale, zeroPoint, - storageTypeMin, storageTypeMax) - @ccall mlir_c.mlirUniformQuantizedTypeGet(flags::Cuint, storageType::MlirType, - expressedType::MlirType, scale::Cdouble, - zeroPoint::Int64, storageTypeMin::Int64, - storageTypeMax::Int64)::MlirType +function mlirUniformQuantizedTypeGet( + flags, storageType, expressedType, scale, zeroPoint, storageTypeMin, storageTypeMax +) + @ccall mlir_c.mlirUniformQuantizedTypeGet( + flags::Cuint, + storageType::MlirType, + expressedType::MlirType, + scale::Cdouble, + zeroPoint::Int64, + storageTypeMin::Int64, + storageTypeMax::Int64, + )::MlirType end """ @@ -7219,16 +7608,28 @@ end Creates an instance of UniformQuantizedPerAxisType with the given parameters in the same context as `storageType` and returns it. `scales` and `zeroPoints` point to `nDims` number of elements. The instance is owned by the context. """ -function mlirUniformQuantizedPerAxisTypeGet(flags, storageType, expressedType, nDims, - scales, zeroPoints, quantizedDimension, - storageTypeMin, storageTypeMax) - @ccall mlir_c.mlirUniformQuantizedPerAxisTypeGet(flags::Cuint, storageType::MlirType, - expressedType::MlirType, - nDims::intptr_t, scales::Ptr{Cdouble}, - zeroPoints::Ptr{Int64}, - quantizedDimension::Int32, - storageTypeMin::Int64, - storageTypeMax::Int64)::MlirType +function mlirUniformQuantizedPerAxisTypeGet( + flags, + storageType, + expressedType, + nDims, + scales, + zeroPoints, + quantizedDimension, + storageTypeMin, + storageTypeMax, +) + @ccall mlir_c.mlirUniformQuantizedPerAxisTypeGet( + flags::Cuint, + storageType::MlirType, + expressedType::MlirType, + nDims::intptr_t, + scales::Ptr{Cdouble}, + zeroPoints::Ptr{Int64}, + quantizedDimension::Int32, + storageTypeMin::Int64, + storageTypeMax::Int64, + )::MlirType end """ @@ -7246,8 +7647,9 @@ end Returns `pos`-th scale of the given quantized per-axis type. """ function mlirUniformQuantizedPerAxisTypeGetScale(type, pos) - @ccall mlir_c.mlirUniformQuantizedPerAxisTypeGetScale(type::MlirType, - pos::intptr_t)::Cdouble + @ccall mlir_c.mlirUniformQuantizedPerAxisTypeGetScale( + type::MlirType, pos::intptr_t + )::Cdouble end """ @@ -7256,8 +7658,9 @@ end Returns `pos`-th zero point of the given quantized per-axis type. """ function mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, pos) - @ccall mlir_c.mlirUniformQuantizedPerAxisTypeGetZeroPoint(type::MlirType, - pos::intptr_t)::Int64 + @ccall mlir_c.mlirUniformQuantizedPerAxisTypeGetZeroPoint( + type::MlirType, pos::intptr_t + )::Int64 end """ @@ -7266,7 +7669,9 @@ end Returns the index of the quantized dimension in the given quantized per-axis type. """ function mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type) - @ccall mlir_c.mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type::MlirType)::Int32 + @ccall mlir_c.mlirUniformQuantizedPerAxisTypeGetQuantizedDimension( + type::MlirType + )::Int32 end """ @@ -7293,8 +7698,9 @@ end Creates an instance of CalibratedQuantizedType with the given parameters in the same context as `expressedType` and returns it. The instance is owned by the context. """ function mlirCalibratedQuantizedTypeGet(expressedType, min, max) - @ccall mlir_c.mlirCalibratedQuantizedTypeGet(expressedType::MlirType, min::Cdouble, - max::Cdouble)::MlirType + @ccall mlir_c.mlirCalibratedQuantizedTypeGet( + expressedType::MlirType, min::Cdouble, max::Cdouble + )::MlirType end """ @@ -7370,13 +7776,18 @@ end Creates a `sparse\\_tensor.encoding` attribute with the given parameters. """ -function mlirSparseTensorEncodingAttrGet(ctx, lvlRank, lvlTypes, dimToLvl, lvlTodim, - posWidth, crdWidth) - @ccall mlir_c.mlirSparseTensorEncodingAttrGet(ctx::MlirContext, lvlRank::intptr_t, - lvlTypes::Ptr{MlirSparseTensorLevelType}, - dimToLvl::MlirAffineMap, - lvlTodim::MlirAffineMap, posWidth::Cint, - crdWidth::Cint)::MlirAttribute +function mlirSparseTensorEncodingAttrGet( + ctx, lvlRank, lvlTypes, dimToLvl, lvlTodim, posWidth, crdWidth +) + @ccall mlir_c.mlirSparseTensorEncodingAttrGet( + ctx::MlirContext, + lvlRank::intptr_t, + lvlTypes::Ptr{MlirSparseTensorLevelType}, + dimToLvl::MlirAffineMap, + lvlTodim::MlirAffineMap, + posWidth::Cint, + crdWidth::Cint, + )::MlirAttribute end """ @@ -7394,8 +7805,9 @@ end Returns a specified level-type of the `sparse\\_tensor.encoding` attribute. """ function mlirSparseTensorEncodingAttrGetLvlType(attr, lvl) - @ccall mlir_c.mlirSparseTensorEncodingAttrGetLvlType(attr::MlirAttribute, - lvl::intptr_t)::MlirSparseTensorLevelType + @ccall mlir_c.mlirSparseTensorEncodingAttrGetLvlType( + attr::MlirAttribute, lvl::intptr_t + )::MlirSparseTensorLevelType end """ @@ -7404,8 +7816,9 @@ end Returns a specified level-format of the `sparse\\_tensor.encoding` attribute. """ function mlirSparseTensorEncodingAttrGetLvlFmt(attr, lvl) - @ccall mlir_c.mlirSparseTensorEncodingAttrGetLvlFmt(attr::MlirAttribute, - lvl::intptr_t)::MlirSparseTensorLevelFormat + @ccall mlir_c.mlirSparseTensorEncodingAttrGetLvlFmt( + attr::MlirAttribute, lvl::intptr_t + )::MlirSparseTensorLevelFormat end """ @@ -7414,7 +7827,9 @@ end Returns the dimension-to-level mapping of the `sparse\\_tensor.encoding` attribute. """ function mlirSparseTensorEncodingAttrGetDimToLvl(attr) - @ccall mlir_c.mlirSparseTensorEncodingAttrGetDimToLvl(attr::MlirAttribute)::MlirAffineMap + @ccall mlir_c.mlirSparseTensorEncodingAttrGetDimToLvl( + attr::MlirAttribute + )::MlirAffineMap end """ @@ -7423,7 +7838,9 @@ end Returns the level-to-dimension mapping of the `sparse\\_tensor.encoding` attribute. """ function mlirSparseTensorEncodingAttrGetLvlToDim(attr) - @ccall mlir_c.mlirSparseTensorEncodingAttrGetLvlToDim(attr::MlirAttribute)::MlirAffineMap + @ccall mlir_c.mlirSparseTensorEncodingAttrGetLvlToDim( + attr::MlirAttribute + )::MlirAffineMap end """ @@ -7445,18 +7862,25 @@ function mlirSparseTensorEncodingAttrGetCrdWidth(attr) end function mlirSparseTensorEncodingAttrGetStructuredN(lvlType) - @ccall mlir_c.mlirSparseTensorEncodingAttrGetStructuredN(lvlType::MlirSparseTensorLevelType)::Cuint + @ccall mlir_c.mlirSparseTensorEncodingAttrGetStructuredN( + lvlType::MlirSparseTensorLevelType + )::Cuint end function mlirSparseTensorEncodingAttrGetStructuredM(lvlType) - @ccall mlir_c.mlirSparseTensorEncodingAttrGetStructuredM(lvlType::MlirSparseTensorLevelType)::Cuint + @ccall mlir_c.mlirSparseTensorEncodingAttrGetStructuredM( + lvlType::MlirSparseTensorLevelType + )::Cuint end function mlirSparseTensorEncodingAttrBuildLvlType(lvlFmt, properties, propSize, n, m) - @ccall mlir_c.mlirSparseTensorEncodingAttrBuildLvlType(lvlFmt::MlirSparseTensorLevelFormat, - properties::Ptr{MlirSparseTensorLevelPropertyNondefault}, - propSize::Cuint, n::Cuint, - m::Cuint)::MlirSparseTensorLevelType + @ccall mlir_c.mlirSparseTensorEncodingAttrBuildLvlType( + lvlFmt::MlirSparseTensorLevelFormat, + properties::Ptr{MlirSparseTensorLevelPropertyNondefault}, + propSize::Cuint, + n::Cuint, + m::Cuint, + )::MlirSparseTensorLevelType end function mlirRegisterSparseTensorPasses() @@ -7628,8 +8052,9 @@ function mlirTransformOperationTypeGetTypeID() end function mlirTransformOperationTypeGet(ctx, operationName) - @ccall mlir_c.mlirTransformOperationTypeGet(ctx::MlirContext, - operationName::MlirStringRef)::MlirType + @ccall mlir_c.mlirTransformOperationTypeGet( + ctx::MlirContext, operationName::MlirStringRef + )::MlirType end function mlirTransformOperationTypeGetOperationName(type) @@ -7671,8 +8096,9 @@ end Enables or disables expensive checks in transform options. """ function mlirTransformOptionsEnableExpensiveChecks(transformOptions, enable) - @ccall mlir_c.mlirTransformOptionsEnableExpensiveChecks(transformOptions::MlirTransformOptions, - enable::Bool)::Cvoid + @ccall mlir_c.mlirTransformOptionsEnableExpensiveChecks( + transformOptions::MlirTransformOptions, enable::Bool + )::Cvoid end """ @@ -7681,7 +8107,9 @@ end Returns true if expensive checks are enabled in transform options. """ function mlirTransformOptionsGetExpensiveChecksEnabled(transformOptions) - @ccall mlir_c.mlirTransformOptionsGetExpensiveChecksEnabled(transformOptions::MlirTransformOptions)::Bool + @ccall mlir_c.mlirTransformOptionsGetExpensiveChecksEnabled( + transformOptions::MlirTransformOptions + )::Bool end """ @@ -7690,8 +8118,9 @@ end Enables or disables the enforcement of the top-level transform op being single in transform options. """ function mlirTransformOptionsEnforceSingleTopLevelTransformOp(transformOptions, enable) - @ccall mlir_c.mlirTransformOptionsEnforceSingleTopLevelTransformOp(transformOptions::MlirTransformOptions, - enable::Bool)::Cvoid + @ccall mlir_c.mlirTransformOptionsEnforceSingleTopLevelTransformOp( + transformOptions::MlirTransformOptions, enable::Bool + )::Cvoid end """ @@ -7700,7 +8129,9 @@ end Returns true if the enforcement of the top-level transform op being single is enabled in transform options. """ function mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(transformOptions) - @ccall mlir_c.mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(transformOptions::MlirTransformOptions)::Bool + @ccall mlir_c.mlirTransformOptionsGetEnforceSingleTopLevelTransformOp( + transformOptions::MlirTransformOptions + )::Bool end """ @@ -7717,12 +8148,15 @@ end Applies the transformation script starting at the given transform root operation to the given payload operation. The module containing the transform root as well as the transform options should be provided. The transform operation must implement TransformOpInterface and the module must be a ModuleOp. Returns the status of the application. """ -function mlirTransformApplyNamedSequence(payload, transformRoot, transformModule, - transformOptions) - @ccall mlir_c.mlirTransformApplyNamedSequence(payload::MlirOperation, - transformRoot::MlirOperation, - transformModule::MlirOperation, - transformOptions::MlirTransformOptions)::MlirLogicalResult +function mlirTransformApplyNamedSequence( + payload, transformRoot, transformModule, transformOptions +) + @ccall mlir_c.mlirTransformApplyNamedSequence( + payload::MlirOperation, + transformRoot::MlirOperation, + transformModule::MlirOperation, + transformOptions::MlirTransformOptions, + )::MlirLogicalResult end function mlirGetDialectHandle__vector__() @@ -7739,9 +8173,13 @@ end Creates an ExecutionEngine for the provided ModuleOp. The ModuleOp is expected to be "translatable" to LLVM IR (only contains operations in dialects that implement the `LLVMTranslationDialectInterface`). The module ownership stays with the client and can be destroyed as soon as the call returns. `optLevel` is the optimization level to be used for transformation and code generation. LLVM passes at `optLevel` are run before code generation. The number and array of paths corresponding to shared libraries that will be loaded are specified via `numPaths` and `sharedLibPaths` respectively. TODO: figure out other options. """ function mlirExecutionEngineCreate(op, optLevel, numPaths, sharedLibPaths, enableObjectDump) - @ccall mlir_c.mlirExecutionEngineCreate(op::MlirModule, optLevel::Cint, numPaths::Cint, - sharedLibPaths::Ptr{MlirStringRef}, - enableObjectDump::Bool)::MlirExecutionEngine + @ccall mlir_c.mlirExecutionEngineCreate( + op::MlirModule, + optLevel::Cint, + numPaths::Cint, + sharedLibPaths::Ptr{MlirStringRef}, + enableObjectDump::Bool, + )::MlirExecutionEngine end """ @@ -7768,9 +8206,9 @@ end Invoke a native function in the execution engine by name with the arguments and result of the invoked function passed as an array of pointers. The function must have been tagged with the `llvm.emit\\_c\\_interface` attribute. Returns a failure if the execution fails for any reason (the function name can't be resolved for instance). """ function mlirExecutionEngineInvokePacked(jit, name, arguments) - @ccall mlir_c.mlirExecutionEngineInvokePacked(jit::MlirExecutionEngine, - name::MlirStringRef, - arguments::Ptr{Ptr{Cvoid}})::MlirLogicalResult + @ccall mlir_c.mlirExecutionEngineInvokePacked( + jit::MlirExecutionEngine, name::MlirStringRef, arguments::Ptr{Ptr{Cvoid}} + )::MlirLogicalResult end """ @@ -7779,8 +8217,9 @@ end Lookup the wrapper of the native function in the execution engine with the given name, returns nullptr if the function can't be looked-up. """ function mlirExecutionEngineLookupPacked(jit, name) - @ccall mlir_c.mlirExecutionEngineLookupPacked(jit::MlirExecutionEngine, - name::MlirStringRef)::Ptr{Cvoid} + @ccall mlir_c.mlirExecutionEngineLookupPacked( + jit::MlirExecutionEngine, name::MlirStringRef + )::Ptr{Cvoid} end """ @@ -7789,8 +8228,9 @@ end Lookup a native function in the execution engine by name, returns nullptr if the name can't be looked-up. """ function mlirExecutionEngineLookup(jit, name) - @ccall mlir_c.mlirExecutionEngineLookup(jit::MlirExecutionEngine, - name::MlirStringRef)::Ptr{Cvoid} + @ccall mlir_c.mlirExecutionEngineLookup( + jit::MlirExecutionEngine, name::MlirStringRef + )::Ptr{Cvoid} end """ @@ -7799,9 +8239,9 @@ end Register a symbol with the jit: this symbol will be accessible to the jitted code. """ function mlirExecutionEngineRegisterSymbol(jit, name, sym) - @ccall mlir_c.mlirExecutionEngineRegisterSymbol(jit::MlirExecutionEngine, - name::MlirStringRef, - sym::Ptr{Cvoid})::Cvoid + @ccall mlir_c.mlirExecutionEngineRegisterSymbol( + jit::MlirExecutionEngine, name::MlirStringRef, sym::Ptr{Cvoid} + )::Cvoid end """ @@ -7810,8 +8250,9 @@ end Dump as an object in `fileName`. """ function mlirExecutionEngineDumpToObjectFile(jit, fileName) - @ccall mlir_c.mlirExecutionEngineDumpToObjectFile(jit::MlirExecutionEngine, - fileName::MlirStringRef)::Cvoid + @ccall mlir_c.mlirExecutionEngineDumpToObjectFile( + jit::MlirExecutionEngine, fileName::MlirStringRef + )::Cvoid end struct MlirIntegerSet @@ -7851,8 +8292,9 @@ end Prints an integer set by sending chunks of the string representation and forwarding `userData to `callback`. Note that the callback may be called several times with consecutive chunks of the string. """ function mlirIntegerSetPrint(set, callback, userData) - @ccall mlir_c.mlirIntegerSetPrint(set::MlirIntegerSet, callback::MlirStringCallback, - userData::Ptr{Cvoid})::Cvoid + @ccall mlir_c.mlirIntegerSetPrint( + set::MlirIntegerSet, callback::MlirStringCallback, userData::Ptr{Cvoid} + )::Cvoid end """ @@ -7870,8 +8312,9 @@ end Gets or creates a new canonically empty integer set with the give number of dimensions and symbols in the given context. """ function mlirIntegerSetEmptyGet(context, numDims, numSymbols) - @ccall mlir_c.mlirIntegerSetEmptyGet(context::MlirContext, numDims::intptr_t, - numSymbols::intptr_t)::MlirIntegerSet + @ccall mlir_c.mlirIntegerSetEmptyGet( + context::MlirContext, numDims::intptr_t, numSymbols::intptr_t + )::MlirIntegerSet end """ @@ -7879,12 +8322,17 @@ end Gets or creates a new integer set in the given context. The set is defined by a list of affine constraints, with the given number of input dimensions and symbols, which are treated as either equalities (eqFlags is 1) or inequalities (eqFlags is 0). Both `constraints` and `eqFlags` are expected to point to at least `numConstraint` consecutive values. """ -function mlirIntegerSetGet(context, numDims, numSymbols, numConstraints, constraints, - eqFlags) - @ccall mlir_c.mlirIntegerSetGet(context::MlirContext, numDims::intptr_t, - numSymbols::intptr_t, numConstraints::intptr_t, - constraints::Ptr{MlirAffineExpr}, - eqFlags::Ptr{Bool})::MlirIntegerSet +function mlirIntegerSetGet( + context, numDims, numSymbols, numConstraints, constraints, eqFlags +) + @ccall mlir_c.mlirIntegerSetGet( + context::MlirContext, + numDims::intptr_t, + numSymbols::intptr_t, + numConstraints::intptr_t, + constraints::Ptr{MlirAffineExpr}, + eqFlags::Ptr{Bool}, + )::MlirIntegerSet end """ @@ -7892,13 +8340,16 @@ end Gets or creates a new integer set in which the values and dimensions of the given set are replaced with the given affine expressions. `dimReplacements` and `symbolReplacements` are expected to point to at least as many consecutive expressions as the given set has dimensions and symbols, respectively. The new set will have `numResultDims` and `numResultSymbols` dimensions and symbols, respectively. """ -function mlirIntegerSetReplaceGet(set, dimReplacements, symbolReplacements, numResultDims, - numResultSymbols) - @ccall mlir_c.mlirIntegerSetReplaceGet(set::MlirIntegerSet, - dimReplacements::Ptr{MlirAffineExpr}, - symbolReplacements::Ptr{MlirAffineExpr}, - numResultDims::intptr_t, - numResultSymbols::intptr_t)::MlirIntegerSet +function mlirIntegerSetReplaceGet( + set, dimReplacements, symbolReplacements, numResultDims, numResultSymbols +) + @ccall mlir_c.mlirIntegerSetReplaceGet( + set::MlirIntegerSet, + dimReplacements::Ptr{MlirAffineExpr}, + symbolReplacements::Ptr{MlirAffineExpr}, + numResultDims::intptr_t, + numResultSymbols::intptr_t, + )::MlirIntegerSet end """ @@ -7970,8 +8421,9 @@ end Returns `pos`-th constraint of the set. """ function mlirIntegerSetGetConstraint(set, pos) - @ccall mlir_c.mlirIntegerSetGetConstraint(set::MlirIntegerSet, - pos::intptr_t)::MlirAffineExpr + @ccall mlir_c.mlirIntegerSetGetConstraint( + set::MlirIntegerSet, pos::intptr_t + )::MlirAffineExpr end """ @@ -7989,8 +8441,9 @@ end Returns `true` if the given operation implements an interface identified by its TypeID. """ function mlirOperationImplementsInterface(operation, interfaceTypeID) - @ccall mlir_c.mlirOperationImplementsInterface(operation::MlirOperation, - interfaceTypeID::MlirTypeID)::Bool + @ccall mlir_c.mlirOperationImplementsInterface( + operation::MlirOperation, interfaceTypeID::MlirTypeID + )::Bool end """ @@ -7999,9 +8452,9 @@ end Returns `true` if the operation identified by its canonical string name implements the interface identified by its TypeID in the given context. Note that interfaces may be attached to operations in some contexts and not others. """ function mlirOperationImplementsInterfaceStatic(operationName, context, interfaceTypeID) - @ccall mlir_c.mlirOperationImplementsInterfaceStatic(operationName::MlirStringRef, - context::MlirContext, - interfaceTypeID::MlirTypeID)::Bool + @ccall mlir_c.mlirOperationImplementsInterfaceStatic( + operationName::MlirStringRef, context::MlirContext, interfaceTypeID::MlirTypeID + )::Bool end """ @@ -8024,20 +8477,32 @@ const MlirTypesCallback = Ptr{Cvoid} Infers the return types of the operation identified by its canonical given the arguments that will be supplied to its generic builder. Calls `callback` with the types of inferred arguments, potentially several times, on success. Returns failure otherwise. """ -function mlirInferTypeOpInterfaceInferReturnTypes(opName, context, location, nOperands, - operands, attributes, properties, - nRegions, regions, callback, userData) - @ccall mlir_c.mlirInferTypeOpInterfaceInferReturnTypes(opName::MlirStringRef, - context::MlirContext, - location::MlirLocation, - nOperands::intptr_t, - operands::Ptr{MlirValue}, - attributes::MlirAttribute, - properties::Ptr{Cvoid}, - nRegions::intptr_t, - regions::Ptr{MlirRegion}, - callback::MlirTypesCallback, - userData::Ptr{Cvoid})::MlirLogicalResult +function mlirInferTypeOpInterfaceInferReturnTypes( + opName, + context, + location, + nOperands, + operands, + attributes, + properties, + nRegions, + regions, + callback, + userData, +) + @ccall mlir_c.mlirInferTypeOpInterfaceInferReturnTypes( + opName::MlirStringRef, + context::MlirContext, + location::MlirLocation, + nOperands::intptr_t, + operands::Ptr{MlirValue}, + attributes::MlirAttribute, + properties::Ptr{Cvoid}, + nRegions::intptr_t, + regions::Ptr{MlirRegion}, + callback::MlirTypesCallback, + userData::Ptr{Cvoid}, + )::MlirLogicalResult end """ @@ -8060,21 +8525,32 @@ const MlirShapedTypeComponentsCallback = Ptr{Cvoid} Infers the return shaped type components of the operation. Calls `callback` with the types of inferred arguments on success. Returns failure otherwise. """ -function mlirInferShapedTypeOpInterfaceInferReturnTypes(opName, context, location, - nOperands, operands, attributes, - properties, nRegions, regions, - callback, userData) - @ccall mlir_c.mlirInferShapedTypeOpInterfaceInferReturnTypes(opName::MlirStringRef, - context::MlirContext, - location::MlirLocation, - nOperands::intptr_t, - operands::Ptr{MlirValue}, - attributes::MlirAttribute, - properties::Ptr{Cvoid}, - nRegions::intptr_t, - regions::Ptr{MlirRegion}, - callback::MlirShapedTypeComponentsCallback, - userData::Ptr{Cvoid})::MlirLogicalResult +function mlirInferShapedTypeOpInterfaceInferReturnTypes( + opName, + context, + location, + nOperands, + operands, + attributes, + properties, + nRegions, + regions, + callback, + userData, +) + @ccall mlir_c.mlirInferShapedTypeOpInterfaceInferReturnTypes( + opName::MlirStringRef, + context::MlirContext, + location::MlirLocation, + nOperands::intptr_t, + operands::Ptr{MlirValue}, + attributes::MlirAttribute, + properties::Ptr{Cvoid}, + nRegions::intptr_t, + regions::Ptr{MlirRegion}, + callback::MlirShapedTypeComponentsCallback, + userData::Ptr{Cvoid}, + )::MlirLogicalResult end """ @@ -8323,8 +8799,9 @@ This function parses the given arguments using the LLVM command line parser. Not llvm::cl::ParseCommandLineOptions() """ function LLVMParseCommandLineOptions(argc, argv, Overview) - @ccall mlir_c.LLVMParseCommandLineOptions(argc::Cint, argv::Ptr{Cstring}, - Overview::Cstring)::Cvoid + @ccall mlir_c.LLVMParseCommandLineOptions( + argc::Cint, argv::Ptr{Cstring}, Overview::Cstring + )::Cvoid end """ @@ -8360,8 +8837,9 @@ Translate operation that satisfies LLVM dialect module requirements into an LLVM the generated LLVM IR Module from the translated MLIR module, it is owned by the caller. """ function mlirTranslateModuleToLLVMIR(_module, context) - @ccall mlir_c.mlirTranslateModuleToLLVMIR(_module::MlirOperation, - context::LLVMContextRef)::LLVMModuleRef + @ccall mlir_c.mlirTranslateModuleToLLVMIR( + _module::MlirOperation, context::LLVMContextRef + )::LLVMModuleRef end function mlirRegisterTransformsPasses() @@ -8520,18 +8998,26 @@ function mlirRegisterTransformsViewOpGraph() @ccall mlir_c.mlirRegisterTransformsViewOpGraph()::Cvoid end -function stablehloScatterDimensionNumbersGet(ctx, nUpdateWindowDims, updateWindowDims, - nInsertedWindowDims, insertedWindowDims, - nScatteredDimsToOperandDims, - scatteredDimsToOperandDims, indexVectorDim) - @ccall mlir_c.stablehloScatterDimensionNumbersGet(ctx::MlirContext, - nUpdateWindowDims::intptr_t, - updateWindowDims::Ptr{Int64}, - nInsertedWindowDims::intptr_t, - insertedWindowDims::Ptr{Int64}, - nScatteredDimsToOperandDims::intptr_t, - scatteredDimsToOperandDims::Ptr{Int64}, - indexVectorDim::Int64)::MlirAttribute +function stablehloScatterDimensionNumbersGet( + ctx, + nUpdateWindowDims, + updateWindowDims, + nInsertedWindowDims, + insertedWindowDims, + nScatteredDimsToOperandDims, + scatteredDimsToOperandDims, + indexVectorDim, +) + @ccall mlir_c.stablehloScatterDimensionNumbersGet( + ctx::MlirContext, + nUpdateWindowDims::intptr_t, + updateWindowDims::Ptr{Int64}, + nInsertedWindowDims::intptr_t, + insertedWindowDims::Ptr{Int64}, + nScatteredDimsToOperandDims::intptr_t, + scatteredDimsToOperandDims::Ptr{Int64}, + indexVectorDim::Int64, + )::MlirAttribute end function stablehloAttributeIsAScatterDimensionNumbers(attr) @@ -8539,47 +9025,65 @@ function stablehloAttributeIsAScatterDimensionNumbers(attr) end function stablehloScatterDimensionNumbersGetUpdateWindowDimsSize(attr) - @ccall mlir_c.stablehloScatterDimensionNumbersGetUpdateWindowDimsSize(attr::MlirAttribute)::intptr_t + @ccall mlir_c.stablehloScatterDimensionNumbersGetUpdateWindowDimsSize( + attr::MlirAttribute + )::intptr_t end function stablehloScatterDimensionNumbersGetUpdateWindowDimsElem(attr, pos) - @ccall mlir_c.stablehloScatterDimensionNumbersGetUpdateWindowDimsElem(attr::MlirAttribute, - pos::intptr_t)::Int64 + @ccall mlir_c.stablehloScatterDimensionNumbersGetUpdateWindowDimsElem( + attr::MlirAttribute, pos::intptr_t + )::Int64 end function stablehloScatterDimensionNumbersGetInsertedWindowDimsSize(attr) - @ccall mlir_c.stablehloScatterDimensionNumbersGetInsertedWindowDimsSize(attr::MlirAttribute)::intptr_t + @ccall mlir_c.stablehloScatterDimensionNumbersGetInsertedWindowDimsSize( + attr::MlirAttribute + )::intptr_t end function stablehloScatterDimensionNumbersGetInsertedWindowDimsElem(attr, pos) - @ccall mlir_c.stablehloScatterDimensionNumbersGetInsertedWindowDimsElem(attr::MlirAttribute, - pos::intptr_t)::Int64 + @ccall mlir_c.stablehloScatterDimensionNumbersGetInsertedWindowDimsElem( + attr::MlirAttribute, pos::intptr_t + )::Int64 end function stablehloScatterDimensionNumbersGetScatteredDimsToOperandDimsSize(attr) - @ccall mlir_c.stablehloScatterDimensionNumbersGetScatteredDimsToOperandDimsSize(attr::MlirAttribute)::intptr_t + @ccall mlir_c.stablehloScatterDimensionNumbersGetScatteredDimsToOperandDimsSize( + attr::MlirAttribute + )::intptr_t end function stablehloScatterDimensionNumbersGetScatteredDimsToOperandDimsElem(attr, pos) - @ccall mlir_c.stablehloScatterDimensionNumbersGetScatteredDimsToOperandDimsElem(attr::MlirAttribute, - pos::intptr_t)::Int64 + @ccall mlir_c.stablehloScatterDimensionNumbersGetScatteredDimsToOperandDimsElem( + attr::MlirAttribute, pos::intptr_t + )::Int64 end function stablehloDimensionNumbersGetIndexVectorDim(attr) @ccall mlir_c.stablehloDimensionNumbersGetIndexVectorDim(attr::MlirAttribute)::Int64 end -function stablehloGatherDimensionNumbersGet(ctx, nOffsetDims, offsetDims, - nCollapsedSliceDims, collapsedSliceDims, - nStartIndexMap, startIndexMap, indexVectorDim) - @ccall mlir_c.stablehloGatherDimensionNumbersGet(ctx::MlirContext, - nOffsetDims::intptr_t, - offsetDims::Ptr{Int64}, - nCollapsedSliceDims::intptr_t, - collapsedSliceDims::Ptr{Int64}, - nStartIndexMap::intptr_t, - startIndexMap::Ptr{Int64}, - indexVectorDim::Int64)::MlirAttribute +function stablehloGatherDimensionNumbersGet( + ctx, + nOffsetDims, + offsetDims, + nCollapsedSliceDims, + collapsedSliceDims, + nStartIndexMap, + startIndexMap, + indexVectorDim, +) + @ccall mlir_c.stablehloGatherDimensionNumbersGet( + ctx::MlirContext, + nOffsetDims::intptr_t, + offsetDims::Ptr{Int64}, + nCollapsedSliceDims::intptr_t, + collapsedSliceDims::Ptr{Int64}, + nStartIndexMap::intptr_t, + startIndexMap::Ptr{Int64}, + indexVectorDim::Int64, + )::MlirAttribute end function stablehloAttributeIsAGatherDimensionNumbers(attr) @@ -8587,51 +9091,69 @@ function stablehloAttributeIsAGatherDimensionNumbers(attr) end function stablehloGatherDimensionNumbersGetOffsetDimsSize(attr) - @ccall mlir_c.stablehloGatherDimensionNumbersGetOffsetDimsSize(attr::MlirAttribute)::intptr_t + @ccall mlir_c.stablehloGatherDimensionNumbersGetOffsetDimsSize( + attr::MlirAttribute + )::intptr_t end function stablehloGatherDimensionNumbersGetOffsetDimsElem(attr, pos) - @ccall mlir_c.stablehloGatherDimensionNumbersGetOffsetDimsElem(attr::MlirAttribute, - pos::intptr_t)::Int64 + @ccall mlir_c.stablehloGatherDimensionNumbersGetOffsetDimsElem( + attr::MlirAttribute, pos::intptr_t + )::Int64 end function stablehloGatherDimensionNumbersGetCollapsedSliceDimsSize(attr) - @ccall mlir_c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsSize(attr::MlirAttribute)::intptr_t + @ccall mlir_c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsSize( + attr::MlirAttribute + )::intptr_t end function stablehloGatherDimensionNumbersGetCollapsedSliceDimsElem(attr, pos) - @ccall mlir_c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsElem(attr::MlirAttribute, - pos::intptr_t)::Int64 + @ccall mlir_c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsElem( + attr::MlirAttribute, pos::intptr_t + )::Int64 end function stablehloGatherDimensionNumbersGetStartIndexMapSize(attr) - @ccall mlir_c.stablehloGatherDimensionNumbersGetStartIndexMapSize(attr::MlirAttribute)::intptr_t + @ccall mlir_c.stablehloGatherDimensionNumbersGetStartIndexMapSize( + attr::MlirAttribute + )::intptr_t end function stablehloGatherDimensionNumbersGetStartIndexMapElem(attr, pos) - @ccall mlir_c.stablehloGatherDimensionNumbersGetStartIndexMapElem(attr::MlirAttribute, - pos::intptr_t)::Int64 + @ccall mlir_c.stablehloGatherDimensionNumbersGetStartIndexMapElem( + attr::MlirAttribute, pos::intptr_t + )::Int64 end function stablehloGatherDimensionNumbersGetIndexVectorDim(attr) - @ccall mlir_c.stablehloGatherDimensionNumbersGetIndexVectorDim(attr::MlirAttribute)::Int64 -end - -function stablehloDotDimensionNumbersGet(ctx, nLhsBatchingDimensions, lhsBatchingDimensions, - nRhsBatchingDimensions, rhsBatchingDimensions, - nLhsContractingDimensions, - lhsContractingDimensions, - nRhsContractingDimensions, - rhsContractingDimensions) - @ccall mlir_c.stablehloDotDimensionNumbersGet(ctx::MlirContext, - nLhsBatchingDimensions::intptr_t, - lhsBatchingDimensions::Ptr{Int64}, - nRhsBatchingDimensions::intptr_t, - rhsBatchingDimensions::Ptr{Int64}, - nLhsContractingDimensions::intptr_t, - lhsContractingDimensions::Ptr{Int64}, - nRhsContractingDimensions::intptr_t, - rhsContractingDimensions::Ptr{Int64})::MlirAttribute + @ccall mlir_c.stablehloGatherDimensionNumbersGetIndexVectorDim( + attr::MlirAttribute + )::Int64 +end + +function stablehloDotDimensionNumbersGet( + ctx, + nLhsBatchingDimensions, + lhsBatchingDimensions, + nRhsBatchingDimensions, + rhsBatchingDimensions, + nLhsContractingDimensions, + lhsContractingDimensions, + nRhsContractingDimensions, + rhsContractingDimensions, +) + @ccall mlir_c.stablehloDotDimensionNumbersGet( + ctx::MlirContext, + nLhsBatchingDimensions::intptr_t, + lhsBatchingDimensions::Ptr{Int64}, + nRhsBatchingDimensions::intptr_t, + rhsBatchingDimensions::Ptr{Int64}, + nLhsContractingDimensions::intptr_t, + lhsContractingDimensions::Ptr{Int64}, + nRhsContractingDimensions::intptr_t, + rhsContractingDimensions::Ptr{Int64}, + )::MlirAttribute end function stablehloAttributeIsADotDimensionNumbers(attr) @@ -8639,61 +9161,83 @@ function stablehloAttributeIsADotDimensionNumbers(attr) end function stablehloDotDimensionNumbersGetLhsBatchingDimensionsSize(attr) - @ccall mlir_c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsSize(attr::MlirAttribute)::intptr_t + @ccall mlir_c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsSize( + attr::MlirAttribute + )::intptr_t end function stablehloDotDimensionNumbersGetLhsBatchingDimensionsElem(attr, pos) - @ccall mlir_c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsElem(attr::MlirAttribute, - pos::intptr_t)::Int64 + @ccall mlir_c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsElem( + attr::MlirAttribute, pos::intptr_t + )::Int64 end function stablehloDotDimensionNumbersGetRhsBatchingDimensionsSize(attr) - @ccall mlir_c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsSize(attr::MlirAttribute)::intptr_t + @ccall mlir_c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsSize( + attr::MlirAttribute + )::intptr_t end function stablehloDotDimensionNumbersGetRhsBatchingDimensionsElem(attr, pos) - @ccall mlir_c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsElem(attr::MlirAttribute, - pos::intptr_t)::Int64 + @ccall mlir_c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsElem( + attr::MlirAttribute, pos::intptr_t + )::Int64 end function stablehloDotDimensionNumbersGetLhsContractingDimensionsSize(attr) - @ccall mlir_c.stablehloDotDimensionNumbersGetLhsContractingDimensionsSize(attr::MlirAttribute)::intptr_t + @ccall mlir_c.stablehloDotDimensionNumbersGetLhsContractingDimensionsSize( + attr::MlirAttribute + )::intptr_t end function stablehloDotDimensionNumbersGetLhsContractingDimensionsElem(attr, pos) - @ccall mlir_c.stablehloDotDimensionNumbersGetLhsContractingDimensionsElem(attr::MlirAttribute, - pos::intptr_t)::Int64 + @ccall mlir_c.stablehloDotDimensionNumbersGetLhsContractingDimensionsElem( + attr::MlirAttribute, pos::intptr_t + )::Int64 end function stablehloDotDimensionNumbersGetRhsContractingDimensionsSize(attr) - @ccall mlir_c.stablehloDotDimensionNumbersGetRhsContractingDimensionsSize(attr::MlirAttribute)::intptr_t + @ccall mlir_c.stablehloDotDimensionNumbersGetRhsContractingDimensionsSize( + attr::MlirAttribute + )::intptr_t end function stablehloDotDimensionNumbersGetRhsContractingDimensionsElem(attr, pos) - @ccall mlir_c.stablehloDotDimensionNumbersGetRhsContractingDimensionsElem(attr::MlirAttribute, - pos::intptr_t)::Int64 -end - -function stablehloConvDimensionNumbersGet(ctx, inputBatchDimension, inputFeatureDimension, - nInputSpatialDimensions, inputSpatialDimensions, - kernelInputFeatureDimension, - kernelOutputFeatureDimension, - nKernelSpatialDimensions, kernelSpatialDimensions, - outputBatchDimension, outputFeatureDimension, - nOutputSpatialDimensions, outputSpatialDimensions) - @ccall mlir_c.stablehloConvDimensionNumbersGet(ctx::MlirContext, - inputBatchDimension::Int64, - inputFeatureDimension::Int64, - nInputSpatialDimensions::intptr_t, - inputSpatialDimensions::Ptr{Int64}, - kernelInputFeatureDimension::Int64, - kernelOutputFeatureDimension::Int64, - nKernelSpatialDimensions::intptr_t, - kernelSpatialDimensions::Ptr{Int64}, - outputBatchDimension::Int64, - outputFeatureDimension::Int64, - nOutputSpatialDimensions::intptr_t, - outputSpatialDimensions::Ptr{Int64})::MlirAttribute + @ccall mlir_c.stablehloDotDimensionNumbersGetRhsContractingDimensionsElem( + attr::MlirAttribute, pos::intptr_t + )::Int64 +end + +function stablehloConvDimensionNumbersGet( + ctx, + inputBatchDimension, + inputFeatureDimension, + nInputSpatialDimensions, + inputSpatialDimensions, + kernelInputFeatureDimension, + kernelOutputFeatureDimension, + nKernelSpatialDimensions, + kernelSpatialDimensions, + outputBatchDimension, + outputFeatureDimension, + nOutputSpatialDimensions, + outputSpatialDimensions, +) + @ccall mlir_c.stablehloConvDimensionNumbersGet( + ctx::MlirContext, + inputBatchDimension::Int64, + inputFeatureDimension::Int64, + nInputSpatialDimensions::intptr_t, + inputSpatialDimensions::Ptr{Int64}, + kernelInputFeatureDimension::Int64, + kernelOutputFeatureDimension::Int64, + nKernelSpatialDimensions::intptr_t, + kernelSpatialDimensions::Ptr{Int64}, + outputBatchDimension::Int64, + outputFeatureDimension::Int64, + nOutputSpatialDimensions::intptr_t, + outputSpatialDimensions::Ptr{Int64}, + )::MlirAttribute end function stablehloAttributeIsAConvDimensionNumbers(attr) @@ -8701,65 +9245,93 @@ function stablehloAttributeIsAConvDimensionNumbers(attr) end function stablehloConvDimensionNumbersGetInputBatchDimension(attr) - @ccall mlir_c.stablehloConvDimensionNumbersGetInputBatchDimension(attr::MlirAttribute)::Int64 + @ccall mlir_c.stablehloConvDimensionNumbersGetInputBatchDimension( + attr::MlirAttribute + )::Int64 end function stablehloConvDimensionNumbersGetInputFeatureDimension(attr) - @ccall mlir_c.stablehloConvDimensionNumbersGetInputFeatureDimension(attr::MlirAttribute)::Int64 + @ccall mlir_c.stablehloConvDimensionNumbersGetInputFeatureDimension( + attr::MlirAttribute + )::Int64 end function stablehloConvDimensionNumbersGetInputSpatialDimensionsSize(attr) - @ccall mlir_c.stablehloConvDimensionNumbersGetInputSpatialDimensionsSize(attr::MlirAttribute)::intptr_t + @ccall mlir_c.stablehloConvDimensionNumbersGetInputSpatialDimensionsSize( + attr::MlirAttribute + )::intptr_t end function stablehloConvDimensionNumbersGetInputSpatialDimensionsElem(attr, pos) - @ccall mlir_c.stablehloConvDimensionNumbersGetInputSpatialDimensionsElem(attr::MlirAttribute, - pos::intptr_t)::Int64 + @ccall mlir_c.stablehloConvDimensionNumbersGetInputSpatialDimensionsElem( + attr::MlirAttribute, pos::intptr_t + )::Int64 end function stablehloConvDimensionNumbersGetKernelInputFeatureDimension(attr) - @ccall mlir_c.stablehloConvDimensionNumbersGetKernelInputFeatureDimension(attr::MlirAttribute)::Int64 + @ccall mlir_c.stablehloConvDimensionNumbersGetKernelInputFeatureDimension( + attr::MlirAttribute + )::Int64 end function stablehloConvDimensionNumbersGetKernelOutputFeatureDimension(attr) - @ccall mlir_c.stablehloConvDimensionNumbersGetKernelOutputFeatureDimension(attr::MlirAttribute)::Int64 + @ccall mlir_c.stablehloConvDimensionNumbersGetKernelOutputFeatureDimension( + attr::MlirAttribute + )::Int64 end function stablehloConvDimensionNumbersGetKernelSpatialDimensionsSize(attr) - @ccall mlir_c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsSize(attr::MlirAttribute)::intptr_t + @ccall mlir_c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsSize( + attr::MlirAttribute + )::intptr_t end function stablehloConvDimensionNumbersGetKernelSpatialDimensionsElem(attr, pos) - @ccall mlir_c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsElem(attr::MlirAttribute, - pos::intptr_t)::Int64 + @ccall mlir_c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsElem( + attr::MlirAttribute, pos::intptr_t + )::Int64 end function stablehloConvDimensionNumbersGetOutputBatchDimension(attr) - @ccall mlir_c.stablehloConvDimensionNumbersGetOutputBatchDimension(attr::MlirAttribute)::Int64 + @ccall mlir_c.stablehloConvDimensionNumbersGetOutputBatchDimension( + attr::MlirAttribute + )::Int64 end function stablehloConvDimensionNumbersGetOutputFeatureDimension(attr) - @ccall mlir_c.stablehloConvDimensionNumbersGetOutputFeatureDimension(attr::MlirAttribute)::Int64 + @ccall mlir_c.stablehloConvDimensionNumbersGetOutputFeatureDimension( + attr::MlirAttribute + )::Int64 end function stablehloConvDimensionNumbersGetOutputSpatialDimensionsSize(attr) - @ccall mlir_c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsSize(attr::MlirAttribute)::intptr_t + @ccall mlir_c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsSize( + attr::MlirAttribute + )::intptr_t end function stablehloConvDimensionNumbersGetOutputSpatialDimensionsElem(attr, pos) - @ccall mlir_c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsElem(attr::MlirAttribute, - pos::intptr_t)::Int64 -end - -function stablehloOutputOperandAliasGet(ctx, nOutputTupleIndices, outputTupleIndices, - operandIndex, nOperandTupleIndices, - operandTupleIndices) - @ccall mlir_c.stablehloOutputOperandAliasGet(ctx::MlirContext, - nOutputTupleIndices::intptr_t, - outputTupleIndices::Ptr{Int64}, - operandIndex::Int64, - nOperandTupleIndices::intptr_t, - operandTupleIndices::Ptr{Int64})::MlirAttribute + @ccall mlir_c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsElem( + attr::MlirAttribute, pos::intptr_t + )::Int64 +end + +function stablehloOutputOperandAliasGet( + ctx, + nOutputTupleIndices, + outputTupleIndices, + operandIndex, + nOperandTupleIndices, + operandTupleIndices, +) + @ccall mlir_c.stablehloOutputOperandAliasGet( + ctx::MlirContext, + nOutputTupleIndices::intptr_t, + outputTupleIndices::Ptr{Int64}, + operandIndex::Int64, + nOperandTupleIndices::intptr_t, + operandTupleIndices::Ptr{Int64}, + )::MlirAttribute end function stablehloAttributeIsAOutputOperandAlias(attr) @@ -8767,12 +9339,15 @@ function stablehloAttributeIsAOutputOperandAlias(attr) end function stablehloOutputOperandAliasGetOutputTupleIndicesSize(attr) - @ccall mlir_c.stablehloOutputOperandAliasGetOutputTupleIndicesSize(attr::MlirAttribute)::intptr_t + @ccall mlir_c.stablehloOutputOperandAliasGetOutputTupleIndicesSize( + attr::MlirAttribute + )::intptr_t end function stablehloOutputOperandAliasGetOutputTupleIndicesElem(attr, pos) - @ccall mlir_c.stablehloOutputOperandAliasGetOutputTupleIndicesElem(attr::MlirAttribute, - pos::intptr_t)::Int64 + @ccall mlir_c.stablehloOutputOperandAliasGetOutputTupleIndicesElem( + attr::MlirAttribute, pos::intptr_t + )::Int64 end function stablehloOutputOperandAliasGetOperandIndex(attr) @@ -8780,17 +9355,21 @@ function stablehloOutputOperandAliasGetOperandIndex(attr) end function stablehloOutputOperandAliasGetOperandTupleIndicesSize(attr) - @ccall mlir_c.stablehloOutputOperandAliasGetOperandTupleIndicesSize(attr::MlirAttribute)::intptr_t + @ccall mlir_c.stablehloOutputOperandAliasGetOperandTupleIndicesSize( + attr::MlirAttribute + )::intptr_t end function stablehloOutputOperandAliasGetOperandTupleIndicesElem(attr, pos) - @ccall mlir_c.stablehloOutputOperandAliasGetOperandTupleIndicesElem(attr::MlirAttribute, - pos::intptr_t)::Int64 + @ccall mlir_c.stablehloOutputOperandAliasGetOperandTupleIndicesElem( + attr::MlirAttribute, pos::intptr_t + )::Int64 end function stablehloComparisonDirectionAttrGet(ctx, value) - @ccall mlir_c.stablehloComparisonDirectionAttrGet(ctx::MlirContext, - value::MlirStringRef)::MlirAttribute + @ccall mlir_c.stablehloComparisonDirectionAttrGet( + ctx::MlirContext, value::MlirStringRef + )::MlirAttribute end function stablehloAttributeIsAComparisonDirectionAttr(attr) @@ -8798,12 +9377,15 @@ function stablehloAttributeIsAComparisonDirectionAttr(attr) end function stablehloComparisonDirectionAttrGetValue(attr) - @ccall mlir_c.stablehloComparisonDirectionAttrGetValue(attr::MlirAttribute)::MlirStringRef + @ccall mlir_c.stablehloComparisonDirectionAttrGetValue( + attr::MlirAttribute + )::MlirStringRef end function stablehloComparisonTypeAttrGet(ctx, value) - @ccall mlir_c.stablehloComparisonTypeAttrGet(ctx::MlirContext, - value::MlirStringRef)::MlirAttribute + @ccall mlir_c.stablehloComparisonTypeAttrGet( + ctx::MlirContext, value::MlirStringRef + )::MlirAttribute end function stablehloAttributeIsAComparisonTypeAttr(attr) @@ -8815,8 +9397,9 @@ function stablehloComparisonTypeAttrGetValue(attr) end function stablehloPrecisionAttrGet(ctx, value) - @ccall mlir_c.stablehloPrecisionAttrGet(ctx::MlirContext, - value::MlirStringRef)::MlirAttribute + @ccall mlir_c.stablehloPrecisionAttrGet( + ctx::MlirContext, value::MlirStringRef + )::MlirAttribute end function stablehloAttributeIsAPrecisionAttr(attr) @@ -8828,8 +9411,9 @@ function stablehloPrecisionAttrGetValue(attr) end function stablehloFftTypeAttrGet(ctx, value) - @ccall mlir_c.stablehloFftTypeAttrGet(ctx::MlirContext, - value::MlirStringRef)::MlirAttribute + @ccall mlir_c.stablehloFftTypeAttrGet( + ctx::MlirContext, value::MlirStringRef + )::MlirAttribute end function stablehloAttributeIsAFftTypeAttr(attr) @@ -8841,8 +9425,9 @@ function stablehloFftTypeAttrGetValue(attr) end function stablehloTransposeAttrGet(ctx, value) - @ccall mlir_c.stablehloTransposeAttrGet(ctx::MlirContext, - value::MlirStringRef)::MlirAttribute + @ccall mlir_c.stablehloTransposeAttrGet( + ctx::MlirContext, value::MlirStringRef + )::MlirAttribute end function stablehloAttributeIsATransposeAttr(attr) @@ -8854,8 +9439,9 @@ function stablehloTransposeAttrGetValue(attr) end function stablehloRngDistributionAttrGet(ctx, value) - @ccall mlir_c.stablehloRngDistributionAttrGet(ctx::MlirContext, - value::MlirStringRef)::MlirAttribute + @ccall mlir_c.stablehloRngDistributionAttrGet( + ctx::MlirContext, value::MlirStringRef + )::MlirAttribute end function stablehloAttributeIsARngDistributionAttr(attr) @@ -8867,8 +9453,9 @@ function stablehloRngDistributionAttrGetValue(attr) end function stablehloRngAlgorithmAttrGet(ctx, value) - @ccall mlir_c.stablehloRngAlgorithmAttrGet(ctx::MlirContext, - value::MlirStringRef)::MlirAttribute + @ccall mlir_c.stablehloRngAlgorithmAttrGet( + ctx::MlirContext, value::MlirStringRef + )::MlirAttribute end function stablehloAttributeIsARngAlgorithmAttr(attr) @@ -8880,8 +9467,9 @@ function stablehloRngAlgorithmAttrGetValue(attr) end function stablehloChannelHandleGet(ctx, handle, type) - @ccall mlir_c.stablehloChannelHandleGet(ctx::MlirContext, handle::Int64, - type::Int64)::MlirAttribute + @ccall mlir_c.stablehloChannelHandleGet( + ctx::MlirContext, handle::Int64, type::Int64 + )::MlirAttribute end function stablehloAttributeIsChannelHandle(attr) @@ -8897,8 +9485,9 @@ function stablehloChannelHandleGetType(attr) end function stablehloTypeExtensionsGet(ctx, nBounds, bounds) - @ccall mlir_c.stablehloTypeExtensionsGet(ctx::MlirContext, nBounds::intptr_t, - bounds::Ptr{Int64})::MlirAttribute + @ccall mlir_c.stablehloTypeExtensionsGet( + ctx::MlirContext, nBounds::intptr_t, bounds::Ptr{Int64} + )::MlirAttribute end function stablehloAttributeIsTypeExtensions(attr) @@ -8910,6 +9499,7 @@ function stablehloTypeExtensionsGetBoundsSize(attr) end function stablehloTypeExtensionsGetBoundsElem(attr, pos) - @ccall mlir_c.stablehloTypeExtensionsGetBoundsElem(attr::MlirAttribute, - pos::intptr_t)::Int64 + @ccall mlir_c.stablehloTypeExtensionsGetBoundsElem( + attr::MlirAttribute, pos::intptr_t + )::Int64 end diff --git a/src/overloads.jl b/src/overloads.jl index 2c08c66ac..643c69a2d 100644 --- a/src/overloads.jl +++ b/src/overloads.jl @@ -47,19 +47,16 @@ function has_residx(x) return false end -@inline act_from_type(x, reverse, needs_primal=true) = throw(AssertionError("Unhandled activity $(typeof(x))")) -@inline act_from_type(::Enzyme.Const, reverse, needs_primal=true) = act_from_type(Enzyme.Const, - reverse, - needs_primal) -@inline act_from_type(::Enzyme.Duplicated, reverse, needs_primal=true) = act_from_type(Enzyme.Duplicated, - reverse, - needs_primal) -@inline act_from_type(::Enzyme.DuplicatedNoNeed, reverse, needs_primal=true) = reverse ? - enzyme_out : - enzyme_dupnoneed -@inline act_from_type(::Enzyme.Active, reverse, needs_primal=true) = act_from_tuple(Enzyme.Active, - reverse, - needs_primal) +@inline act_from_type(x, reverse, needs_primal=true) = + throw(AssertionError("Unhandled activity $(typeof(x))")) +@inline act_from_type(::Enzyme.Const, reverse, needs_primal=true) = + act_from_type(Enzyme.Const, reverse, needs_primal) +@inline act_from_type(::Enzyme.Duplicated, reverse, needs_primal=true) = + act_from_type(Enzyme.Duplicated, reverse, needs_primal) +@inline act_from_type(::Enzyme.DuplicatedNoNeed, reverse, needs_primal=true) = + reverse ? enzyme_out : enzyme_dupnoneed +@inline act_from_type(::Enzyme.Active, reverse, needs_primal=true) = + act_from_tuple(Enzyme.Active, reverse, needs_primal) @inline act_from_type(::Type{<:Enzyme.Const}, reverse, needs_primal) = if needs_primal @@ -152,11 +149,14 @@ function set!(x, path, tostore; emptypath=false) end end -function Cassette.overdub(::TraceCtx, ::typeof(Enzyme.autodiff), ::CMode, f::FA, ::Type{A}, - args::Vararg{Enzyme.Annotation,Nargs}) where {CMode<:Enzyme.Mode, - FA<:Enzyme.Annotation, - A<:Enzyme.Annotation, - Nargs} +function Cassette.overdub( + ::TraceCtx, + ::typeof(Enzyme.autodiff), + ::CMode, + f::FA, + ::Type{A}, + args::Vararg{Enzyme.Annotation,Nargs}, +) where {CMode<:Enzyme.Mode,FA<:Enzyme.Annotation,A<:Enzyme.Annotation,Nargs} reverse = CMode <: Enzyme.ReverseMode primf = f.val @@ -164,13 +164,9 @@ function Cassette.overdub(::TraceCtx, ::typeof(Enzyme.autodiff), ::CMode, f::FA, mod = MLIR.IR.mmodule() - fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn(mod, - primf, - primargs, - (), - string(f) * - "_autodiff", - false) + fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn( + mod, primf, primargs, (), string(f) * "_autodiff", false + ) activity = Int32[] ad_inputs = MLIR.IR.Value[] @@ -190,7 +186,8 @@ function Cassette.overdub(::TraceCtx, ::typeof(Enzyme.autodiff), ::CMode, f::FA, end outtys = MLIR.IR.Type[] - @inline needs_primal(::Type{<:Enzyme.ReverseMode{ReturnPrimal}}) where {ReturnPrimal} = ReturnPrimal + @inline needs_primal(::Type{<:Enzyme.ReverseMode{ReturnPrimal}}) where {ReturnPrimal} = + ReturnPrimal for a in linear_results if has_residx(a) if needs_primal(CMode) @@ -244,23 +241,20 @@ function Cassette.overdub(::TraceCtx, ::typeof(Enzyme.autodiff), ::CMode, f::FA, end function act_attr(val) - val = @ccall MLIR.API.mlir_c.enzymeActivityAttrGet(MLIR.IR.context()::MLIR.API.MlirContext, - val::Int32)::MLIR.API.MlirAttribute + val = @ccall MLIR.API.mlir_c.enzymeActivityAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, val::Int32 + )::MLIR.API.MlirAttribute return MLIR.IR.Attribute(val) end fname = get_attribute_by_name(func2, "sym_name") fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - res = (reverse ? MLIR.Dialects.enzyme.autodiff : MLIR.Dialects.enzyme.fwddiff)([transpose_val(v) - for v in - ad_inputs]; - outputs=outtys, - fn=fname, - activity=MLIR.IR.Attribute([act_attr(a) - for a in - activity]), - ret_activity=MLIR.IR.Attribute([act_attr(a) - for a in - ret_activity])) + res = (reverse ? MLIR.Dialects.enzyme.autodiff : MLIR.Dialects.enzyme.fwddiff)( + [transpose_val(v) for v in ad_inputs]; + outputs=outtys, + fn=fname, + activity=MLIR.IR.Attribute([act_attr(a) for a in activity]), + ret_activity=MLIR.IR.Attribute([act_attr(a) for a in ret_activity]), + ) residx = 1 @@ -307,13 +301,19 @@ function Cassette.overdub(::TraceCtx, ::typeof(Enzyme.autodiff), ::CMode, f::FA, continue end if args[idx] isa Enzyme.Active - set_act!(args[idx], path[3:end], false, - transpose_val(MLIR.IR.result(res, residx)); emptypaths=true) #=reverse=# + set_act!( + args[idx], + path[3:end], + false, + transpose_val(MLIR.IR.result(res, residx)); + emptypaths=true, + ) #=reverse=# residx += 1 continue end - set_act!(args[idx], path[3:end], reverse, - transpose_val(MLIR.IR.result(res, residx))) + set_act!( + args[idx], path[3:end], reverse, transpose_val(MLIR.IR.result(res, residx)) + ) end residx += 1 end @@ -341,7 +341,11 @@ end function promote_to(lhs::TracedRArray{ElType,Shape,N}, rhs) where {ElType,Shape,N} if !(rhs <: Number) if ElType != eltype(rhs) - throw(ArgumentError("Cannot promote $(typeof(rhs)) to $(TracedRArray{ElType,Shape,N}) with different element types")) + throw( + ArgumentError( + "Cannot promote $(typeof(rhs)) to $(TracedRArray{ElType,Shape,N}) with different element types", + ), + ) end if Shape != size(rhs) throw(ArgumentError("Cannot promote to TracedRArray with different shapes")) @@ -362,143 +366,211 @@ function promote_to(lhs::TracedRArray{ElType,Shape,N}, rhs) where {ElType,Shape, return TracedRArray{ElType,Shape,N}(nothing, MLIR.Dialects.stablehlo.constant(attr)) end -for (jlop, hloop, RT) in - ((:(Base.min), :minimum, :ElType), (:(Base.max), :maximum, :ElType), - (:(Base.:+), :add, :ElType), (:(Base.:-), :subtract, :ElType)) +for (jlop, hloop, RT) in ( + (:(Base.min), :minimum, :ElType), + (:(Base.max), :maximum, :ElType), + (:(Base.:+), :add, :ElType), + (:(Base.:-), :subtract, :ElType), +) @eval begin - function $jlop(lhs::TracedRArray{ElType,Shape,N}, - rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N} - return TracedRArray{$RT,Shape,N}((), - MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, - rhs.mlir_data), - 1)) + function $jlop( + lhs::TracedRArray{ElType,Shape,N}, rhs::TracedRArray{ElType,Shape,N} + ) where {ElType,Shape,N} + return TracedRArray{$RT,Shape,N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 + ), + ) end function $jlop(lhs::TracedRArray{ElType,Shape,N}, rhs) where {ElType,Shape,N} rhs = promote_to(lhs, rhs) - return TracedRArray{$RT,Shape,N}((), - MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, - rhs.mlir_data), - 1)) + return TracedRArray{$RT,Shape,N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 + ), + ) end function $jlop(lhs, rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N} lhs = promote_to(rhs, lhs) - return TracedRArray{$RT,Shape,N}((), - MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, - rhs.mlir_data), - 1)) + return TracedRArray{$RT,Shape,N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 + ), + ) end end end Cassette.overdub(context::TraceCtx, f::typeof(Enzyme.make_zero), args...) = f(args...) -function Base.:*(lhs::TracedRArray{ElType,Shape,2}, - rhs::TracedRArray{ElType,Shape2,2}) where {ElType,Shape,Shape2} +function Base.:*( + lhs::TracedRArray{ElType,Shape,2}, rhs::TracedRArray{ElType,Shape2,2} +) where {ElType,Shape,Shape2} lhsty = MLIR.IR.type(lhs.mlir_data) rhsty = MLIR.IR.type(rhs.mlir_data) resty = MLIR.IR.TensorType((Base.size(lhsty)[1], Base.size(rhsty)[2]), eltype(lhsty)) - dot_dimension_numbers = MLIR.API.stablehloDotDimensionNumbersGet(MLIR.IR.context(), 0, - [], 0, [], 1, [1], 1, - [0]) - prec = MLIR.IR.Attribute(MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), - "DEFAULT")) + dot_dimension_numbers = MLIR.API.stablehloDotDimensionNumbersGet( + MLIR.IR.context(), 0, [], 0, [], 1, [1], 1, [0] + ) + prec = MLIR.IR.Attribute( + MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT") + ) precar = MLIR.IR.Attribute([prec, prec]) - res = MLIR.IR.result(MLIR.Dialects.stablehlo.dot_general(lhs.mlir_data, rhs.mlir_data; - result_0=resty, - dot_dimension_numbers=dot_dimension_numbers, - precision_config=precar), 1) + res = MLIR.IR.result( + MLIR.Dialects.stablehlo.dot_general( + lhs.mlir_data, + rhs.mlir_data; + result_0=resty, + dot_dimension_numbers=dot_dimension_numbers, + precision_config=precar, + ), + 1, + ) return TracedRArray{ElType,(Base.size(lhsty)[1], Base.size(rhsty)[2]),2}((), res) end Cassette.overdub(context::TraceCtx, f::typeof(Base.:*), args...) = f(args...) -for (jlop, hloop) in ((:(Base.:-), :negate), (:(Base.sin), :sine), (:(Base.cos), :cosine), - (:(Base.tanh), :tanh), (:(Base.FastMath.tanh_fast), :tanh), - (:(Base.exp), :exponential), (:(Base.FastMath.exp_fast), :exponential), - (:(Base.log), :log), (:(Base.sqrt), :sqrt)) +for (jlop, hloop) in ( + (:(Base.:-), :negate), + (:(Base.sin), :sine), + (:(Base.cos), :cosine), + (:(Base.tanh), :tanh), + (:(Base.FastMath.tanh_fast), :tanh), + (:(Base.exp), :exponential), + (:(Base.FastMath.exp_fast), :exponential), + (:(Base.log), :log), + (:(Base.sqrt), :sqrt), +) @eval begin function $jlop(lhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N} - return TracedRArray{ElType,Shape,N}((), - MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), - 1)) + return TracedRArray{ElType,Shape,N}( + (), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1) + ) end Cassette.overdub(context::TraceCtx, f::typeof($jlop), args...) = f(args...) end end -for (jlop, hloop, RT) in - ((:(Base.min), :minimum, :ElType), (:(Base.max), :maximum, :ElType), - (:(Base.:+), :add, :ElType), (:(Base.add_sum), :add, :ElType), - (:(Base.:-), :subtract, :ElType), (:(Base.:*), :multiply, :ElType), - (:(Base.:/), :divide, :ElType)) +for (jlop, hloop, RT) in ( + (:(Base.min), :minimum, :ElType), + (:(Base.max), :maximum, :ElType), + (:(Base.:+), :add, :ElType), + (:(Base.add_sum), :add, :ElType), + (:(Base.:-), :subtract, :ElType), + (:(Base.:*), :multiply, :ElType), + (:(Base.:/), :divide, :ElType), +) @eval begin - function elem_apply(::typeof($jlop), lhs::TracedRArray{ElType,Shape,N}, - rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N} - return TracedRArray{$RT,Shape,N}((), - MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, - rhs.mlir_data), - 1)) + function elem_apply( + ::typeof($jlop), + lhs::TracedRArray{ElType,Shape,N}, + rhs::TracedRArray{ElType,Shape,N}, + ) where {ElType,Shape,N} + return TracedRArray{$RT,Shape,N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 + ), + ) end - function elem_apply(::typeof($jlop), lhs::TracedRArray{ElType,Shape,N}, - rhs) where {ElType,Shape,N} + function elem_apply( + ::typeof($jlop), lhs::TracedRArray{ElType,Shape,N}, rhs + ) where {ElType,Shape,N} rhs = promote_to(lhs, rhs) - return TracedRArray{$RT,Shape,N}((), - MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, - rhs.mlir_data), - 1)) + return TracedRArray{$RT,Shape,N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 + ), + ) end - function elem_apply(::typeof($jlop), lhs, - rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N} + function elem_apply( + ::typeof($jlop), lhs, rhs::TracedRArray{ElType,Shape,N} + ) where {ElType,Shape,N} lhs = promote_to(rhs, lhs) - return TracedRArray{$RT,Shape,N}((), - MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, - rhs.mlir_data), - 1)) + return TracedRArray{$RT,Shape,N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 + ), + ) end end end -for (jlop, hloop, hlocomp, RT) in ((:(Base.:(==)), :compare, "EQ", :ElType), - (:(Base.:(!=)), :compare, "NE", :ElType), - (:(Base.:(>=)), :compare, "GE", :ElType), - (:(Base.:(>)), :compare, "GT", :ElType), - (:(Base.:(<=)), :compare, "LE", :ElType), - (:(Base.:(<)), :compare, "LT", :ElType)) +for (jlop, hloop, hlocomp, RT) in ( + (:(Base.:(==)), :compare, "EQ", :ElType), + (:(Base.:(!=)), :compare, "NE", :ElType), + (:(Base.:(>=)), :compare, "GE", :ElType), + (:(Base.:(>)), :compare, "GT", :ElType), + (:(Base.:(<=)), :compare, "LE", :ElType), + (:(Base.:(<)), :compare, "LT", :ElType), +) @eval begin - function elem_apply(::typeof($jlop), lhs::TracedRArray{ElType,Shape,N}, - rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N} - return TracedRArray{$RT,Shape,N}((), - MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, - rhs.mlir_data; - comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(MLIR.IR.context(), - $hlocomp)), - 1)) + function elem_apply( + ::typeof($jlop), + lhs::TracedRArray{ElType,Shape,N}, + rhs::TracedRArray{ElType,Shape,N}, + ) where {ElType,Shape,N} + return TracedRArray{$RT,Shape,N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop( + lhs.mlir_data, + rhs.mlir_data; + comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet( + MLIR.IR.context(), $hlocomp + ), + ), + 1, + ), + ) end - function elem_apply(::typeof($jlop), lhs::TracedRArray{ElType,Shape,N}, - rhs) where {ElType,Shape,N} + function elem_apply( + ::typeof($jlop), lhs::TracedRArray{ElType,Shape,N}, rhs + ) where {ElType,Shape,N} rhs = promote_to(lhs, rhs) - return TracedRArray{$RT,Shape,N}((), - MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, - rhs.mlir_data; - comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(MLIR.IR.context(), - $hlocomp)), - 1)) + return TracedRArray{$RT,Shape,N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop( + lhs.mlir_data, + rhs.mlir_data; + comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet( + MLIR.IR.context(), $hlocomp + ), + ), + 1, + ), + ) end - function elem_apply(::typeof($jlop), lhs, - rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N} + function elem_apply( + ::typeof($jlop), lhs, rhs::TracedRArray{ElType,Shape,N} + ) where {ElType,Shape,N} lhs = promote_to(rhs, lhs) - return TracedRArray{$RT,Shape,N}((), - MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, - rhs.mlir_data; - comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(MLIR.IR.context(), - $hlocomp)), - 1)) + return TracedRArray{$RT,Shape,N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop( + lhs.mlir_data, + rhs.mlir_data; + comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet( + MLIR.IR.context(), $hlocomp + ), + ), + 1, + ), + ) end end end @@ -506,16 +578,24 @@ end function elem_apply(::typeof(identity), lhs) return lhs end -for (jlop, hloop) in ((:(Base.:-), :negate), (:(Base.sin), :sine), (:(Base.cos), :cosine), - (:(Base.tanh), :tanh), (:(Base.FastMath.tanh_fast), :tanh), - (:(Base.exp), :exponential), (:(Base.FastMath.exp_fast), :exponential), - (:(Base.log), :log), (:(Base.sqrt), :sqrt)) +for (jlop, hloop) in ( + (:(Base.:-), :negate), + (:(Base.sin), :sine), + (:(Base.cos), :cosine), + (:(Base.tanh), :tanh), + (:(Base.FastMath.tanh_fast), :tanh), + (:(Base.exp), :exponential), + (:(Base.FastMath.exp_fast), :exponential), + (:(Base.log), :log), + (:(Base.sqrt), :sqrt), +) @eval begin - function elem_apply(::typeof($jlop), - lhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N} - return TracedRArray{ElType,Shape,N}((), - MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), - 1)) + function elem_apply( + ::typeof($jlop), lhs::TracedRArray{ElType,Shape,N} + ) where {ElType,Shape,N} + return TracedRArray{ElType,Shape,N}( + (), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1) + ) end end end @@ -527,8 +607,9 @@ Cassette.overdub(context::TraceCtx, f::typeof(elem_apply), args...) = f(args...) end Cassette.overdub(context::TraceCtx, f::typeof(Base.reshape), args...) = f(args...) -@inline function Base.reshape(A::ConcreteRArray{T,Shape,N}, - dims::NTuple{NT,Int}) where {T,Shape,N,NT} +@inline function Base.reshape( + A::ConcreteRArray{T,Shape,N}, dims::NTuple{NT,Int} +) where {T,Shape,N,NT} prod(dims) == prod(size(A)) || Base._throw_dmrsa(dims, prod(size(A))) host = convert(Array{T,N}, A) # HLO reshape semantics collapse the opposite so enforce on Julia Side @@ -536,9 +617,9 @@ Cassette.overdub(context::TraceCtx, f::typeof(Base.reshape), args...) = f(args.. host = reshape(host, dims) client = XLA.client(A.data) device = XLA.device(A.data) - return ConcreteRArray{T,dims,NT}(XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, host, - device), - nothing)) + return ConcreteRArray{T,dims,NT}( + XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, host, device), nothing) + ) # ConcreteRArray{T, dims, NT}(XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, XLA.to_row_major(host), device), nothing)) end @@ -546,42 +627,50 @@ Base.copy(A::TracedRArray{T,Shape,N}) where {T,Shape,N} = TracedRArray((), A.mli Cassette.overdub(context::TraceCtx, f::typeof(Base.copy), args...) = f(args...) @inline function Base.permutedims(A::TracedRArray{T,Shape,N}, perm) where {T,Shape,N} - return TracedArray{T,tuple(Shape[i] for i in perm),N}((), - MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(A.mlir_data, - DenseArrayAttribute([Int64(i - - 1) - for i in - perm])), - 1)) + return TracedArray{T,tuple(Shape[i] for i in perm),N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.transpose( + A.mlir_data, DenseArrayAttribute([Int64(i - 1) for i in perm]) + ), + 1, + ), + ) end Cassette.overdub(context::TraceCtx, f::typeof(Base.permutedims), args...) = f(args...) -@inline function Base.reshape(A::TracedRArray{T,Shape,N}, - dims::NTuple{NT,Int}) where {T,Shape,N,NT} +@inline function Base.reshape( + A::TracedRArray{T,Shape,N}, dims::NTuple{NT,Int} +) where {T,Shape,N,NT} prod(dims) == prod(size(A)) || Base._throw_dmrsa(dims, prod(size(A))) # HLO reshape semantics collapse the opposite way - res1 = MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(A.mlir_data; - permutation=MLIR.IR.DenseArrayAttribute([Int64(N - - 1 - - i) - for i in - 0:(N - 1)])), - 1) - - res2 = MLIR.IR.result(MLIR.Dialects.stablehlo.reshape(res1; - result_0=MLIR.IR.TensorType([Int64(i) - for i in - reverse(dims)], - eltype(MLIR.IR.type(res1))))) - - res3 = MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(res2; - permutation=MLIR.IR.DenseArrayAttribute([Int64(NT - - 1 - - i) - for i in - 0:(NT - 1)])), - 1) + res1 = MLIR.IR.result( + MLIR.Dialects.stablehlo.transpose( + A.mlir_data; + permutation=MLIR.IR.DenseArrayAttribute([Int64(N - 1 - i) for i in 0:(N - 1)]), + ), + 1, + ) + + res2 = MLIR.IR.result( + MLIR.Dialects.stablehlo.reshape( + res1; + result_0=MLIR.IR.TensorType( + [Int64(i) for i in reverse(dims)], eltype(MLIR.IR.type(res1)) + ), + ), + ) + + res3 = MLIR.IR.result( + MLIR.Dialects.stablehlo.transpose( + res2; + permutation=MLIR.IR.DenseArrayAttribute([ + Int64(NT - 1 - i) for i in 0:(NT - 1) + ]), + ), + 1, + ) return TracedRArray{T,dims,NT}((), res3) end @@ -608,8 +697,9 @@ function Base.similar(x::TracedRArray{T,Shape,N}, ::Type{T2}) where {T,Shape,N,T end Cassette.overdub(context::TraceCtx, f::typeof(Base.similar), args...) = f(args...) -@inline function Base.similar(bc::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, - dims) where {T,N} +@inline function Base.similar( + bc::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims +) where {T,N} @assert N isa Int return TracedRArray{T,map(length, dims),N}((), nothing) end @@ -650,8 +740,9 @@ end return copyto!(sim, bc) end -@inline function Base.materialize!(::Style, dest, - bc::Broadcasted) where {Style<:AbstractReactantArrayStyle} +@inline function Base.materialize!( + ::Style, dest, bc::Broadcasted +) where {Style<:AbstractReactantArrayStyle} return _copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest)))) end Cassette.overdub(context::TraceCtx, f::typeof(Base.materialize!), args...) = f(args...) @@ -662,10 +753,9 @@ Cassette.overdub(context::TraceCtx, f::typeof(Base.materialize!), args...) = f(a attr = MLIR.IR.DenseElementsAttribute(arg) len = ndims(arg) @assert typeof(len) == Int - arg = TracedRArray{eltype(arg),size(arg),len}((), - MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; - value=attr), - 1)) + arg = TracedRArray{eltype(arg),size(arg),len}( + (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) + ) return arg end @@ -683,10 +773,9 @@ Cassette.overdub(context::TraceCtx, f::typeof(Base.fill!), args...) = f(args...) @inline function broadcast_to_size(arg::T, rsize) where {T<:Number} TT = MLIR.IR.TensorType([Int64(s) for s in rsize], MLIR.IR.Type(typeof(arg))) attr = Base.fill(arg, TT) - return arg = TracedRArray{T,rsize,length(rsize)}((), - MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; - value=attr), - 1)) + return arg = TracedRArray{T,rsize,length(rsize)}( + (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) + ) end @inline function broadcast_to_size(arg::Broadcast.Extruded, rsize) @@ -712,14 +801,17 @@ end len = length(rsize) @assert typeof(len) == Int - return TracedRArray{eltype(x),rsize,len}((), - MLIR.IR.result(MLIR.Dialects.stablehlo.broadcast_in_dim(x.mlir_data; - result_0=MLIR.IR.TensorType([t - for t in - rsize], - eltype(mlirty)), - broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims)), - 1)) + return TracedRArray{eltype(x),rsize,len}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.broadcast_in_dim( + x.mlir_data; + result_0=MLIR.IR.TensorType([t for t in rsize], eltype(mlirty)), + broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims), + ), + 1, + ), + ) end @inline function _copyto!(dest::TracedRArray, bc::Broadcasted) @@ -735,15 +827,21 @@ end return dest end -function Cassette.overdub(context::Cassette.Context, ::Core.kwftype(typeof(Base.mapreduce)), - kwargs::Any, ::typeof(Base.mapreduce), args...) +function Cassette.overdub( + context::Cassette.Context, + ::Core.kwftype(typeof(Base.mapreduce)), + kwargs::Any, + ::typeof(Base.mapreduce), + args..., +) return Base.mapreduce(args...; kwargs...) end Cassette.overdub(context::Cassette.Context, f::typeof(Base.mapreduce), args...) = f(args...) -function Base.mapreduce(f, op, A::TracedRArray{ElType,Shape,N}; dims=:, - init=nothing) where {ElType,Shape,N} +function Base.mapreduce( + f, op, A::TracedRArray{ElType,Shape,N}; dims=:, init=nothing +) where {ElType,Shape,N} if dims isa Int dims = [dims] end @@ -764,13 +862,16 @@ function Base.mapreduce(f, op, A::TracedRArray{ElType,Shape,N}; dims=:, Int64[i - 1 for i in dims] end - in_tys = [MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(arg))) - for arg in (inp[1], init[1])] + in_tys = [ + MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(arg))) for arg in (inp[1], init[1]) + ] fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in in_tys]) - args = (TracedRArray{ElType,(),0}((), MLIR.IR.argument(fnbody, i)) for (i, ty) in - enumerate(in_tys)) + args = ( + TracedRArray{ElType,(),0}((), MLIR.IR.argument(fnbody, i)) for + (i, ty) in enumerate(in_tys) + ) res = MLIR.IR.block!(fnbody) do tmp = broadcast_to_size(op(args...), (1,)).mlir_data @@ -781,22 +882,26 @@ function Base.mapreduce(f, op, A::TracedRArray{ElType,Shape,N}; dims=:, toonedims = [(in(i - 1, rdims) ? 1 : Shape[i]) for i in 1:N] outdims = [Shape[i] for i in 1:N if (i - 1) ∉ rdims] - TT = [MLIR.IR.TensorType(outdims, eltype(MLIR.IR.type(inp0))) - for (inp0, res0) in zip(inp, (res,))] + TT = [ + MLIR.IR.TensorType(outdims, eltype(MLIR.IR.type(inp0))) for + (inp0, res0) in zip(inp, (res,)) + ] body = MLIR.IR.Region() push!(body, fnbody) - red = MLIR.Dialects.stablehlo.reduce(inp, init; result_0=TT, - dimensions=MLIR.IR.DenseArrayAttribute(rdims), - body) + red = MLIR.Dialects.stablehlo.reduce( + inp, init; result_0=TT, dimensions=MLIR.IR.DenseArrayAttribute(rdims), body + ) red = MLIR.IR.result(red, 1) if dims != (:) - red = MLIR.IR.result(MLIR.Dialects.stablehlo.reshape(red; - result_0=MLIR.IR.TensorType(toonedims, - eltype(MLIR.IR.type(red)))), - 1) + red = MLIR.IR.result( + MLIR.Dialects.stablehlo.reshape( + red; result_0=MLIR.IR.TensorType(toonedims, eltype(MLIR.IR.type(red))) + ), + 1, + ) red = TracedRArray{ElType,(toonedims...,),length(toonedims)}((), red) else red = TracedRArray{ElType,(outdims...,),length(outdims)}((), red) diff --git a/src/utils.jl b/src/utils.jl index 61e13a14c..2dfd97567 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -3,7 +3,9 @@ function transpose_ty(mlirty) return MLIR.IR.TensorType([reverse(size(mlirty))...], eltype(mlirty)) end function transpose_val(val) - attr = MLIR.IR.DenseArrayAttribute(Int64[reverse(0:(length(size(MLIR.IR.type(val))) - 1))...]) + attr = MLIR.IR.DenseArrayAttribute( + Int64[reverse(0:(length(size(MLIR.IR.type(val))) - 1))...] + ) return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1) end @@ -13,16 +15,22 @@ end function make_mlir_fn(mod, f, args, kwargs, name="main", concretein=true) if sizeof(typeof(f)) != 0 - return (true, - make_mlir_fn(mod, apply, (f, args...), kwargs, name, concretein)[2:end]...) + return ( + true, make_mlir_fn(mod, apply, (f, args...), kwargs, name, concretein)[2:end]... + ) end N = length(args) seen_args = IdDict() traced_args = ntuple(Val(N)) do i Base.@_inline_meta - return make_tracer(seen_args, args[i], ("args", i), - concretein ? ConcreteToTraced : TracedSetPath, nothing) #=data=# + return make_tracer( + seen_args, + args[i], + ("args", i), + concretein ? ConcreteToTraced : TracedSetPath, + nothing, + ) #=data=# end linear_args = TracedRArray[] @@ -40,9 +48,11 @@ function make_mlir_fn(mod, f, args, kwargs, name="main", concretein=true) sym_visibility = MLIR.IR.Attribute("private") end - func = MLIR.Dialects.func.func_(; sym_name=name * "_tmp", - function_type=MLIR.IR.FunctionType(in_tys, []), - body=MLIR.IR.Region()) + func = MLIR.Dialects.func.func_(; + sym_name=name * "_tmp", + function_type=MLIR.IR.FunctionType(in_tys, []), + body=MLIR.IR.Region(), + ) fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in linear_args]) push!(MLIR.IR.region(func, 1), fnbody) @@ -59,13 +69,19 @@ function make_mlir_fn(mod, f, args, kwargs, name="main", concretein=true) seen_results = IdDict() - traced_result = make_tracer(seen_results, result, ("result",), - concretein ? TracedTrack : TracedSetPath, nothing) #=data=# + traced_result = make_tracer( + seen_results, result, ("result",), concretein ? TracedTrack : TracedSetPath, nothing + ) #=data=# retraced_args = ntuple(Val(N)) do i Base.@_inline_meta - return make_tracer(seen_results, traced_args[i], concretein ? ("resargs", i) : (), - TracedTrack, nothing) #=data=# + return make_tracer( + seen_results, + traced_args[i], + concretein ? ("resargs", i) : (), + TracedTrack, + nothing, + ) #=data=# end linear_results = TracedRArray[] @@ -91,9 +107,12 @@ function make_mlir_fn(mod, f, args, kwargs, name="main", concretein=true) end func2 = MLIR.IR.block!(MLIR.IR.body(mod)) do - return MLIR.Dialects.func.func_(; sym_name=name, - function_type=MLIR.IR.FunctionType(in_tys, out_tys), - body=MLIR.IR.Region(), sym_visibility) + return MLIR.Dialects.func.func_(; + sym_name=name, + function_type=MLIR.IR.FunctionType(in_tys, out_tys), + body=MLIR.IR.Region(), + sym_visibility, + ) end MLIR.API.mlirRegionTakeBody(MLIR.IR.region(func2, 1), MLIR.IR.region(func, 1)) @@ -101,6 +120,7 @@ function make_mlir_fn(mod, f, args, kwargs, name="main", concretein=true) MLIR.API.mlirOperationDestroy(func.operation) func.operation = MLIR.API.MlirOperation(C_NULL) end - return false, func2, traced_result, result, seen_args, ret, linear_args, in_tys, - linear_results + return false, + func2, traced_result, result, seen_args, ret, linear_args, in_tys, + linear_results end diff --git a/test/bcast.jl b/test/bcast.jl index c978a3371..9d05200ad 100644 --- a/test/bcast.jl +++ b/test/bcast.jl @@ -31,9 +31,11 @@ function test() in_tys = [MLIR.IR.TensorType([4], MLIR.IR.Type(Float64))] - func = MLIR.Dialects.func.func_(; sym_name="main_tmp", - function_type=MLIR.IR.FunctionType(in_tys, []), - body=MLIR.IR.Region()) + func = MLIR.Dialects.func.func_(; + sym_name="main_tmp", + function_type=MLIR.IR.FunctionType(in_tys, []), + body=MLIR.IR.Region(), + ) fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for _ in in_tys]) push!(MLIR.IR.region(func, 1), fnbody) @@ -42,8 +44,9 @@ function test() MLIR.IR.block!(fnbody) do a = ones(4) b = ones(4) - d = Data(Reactant.TracedRArray{Float64,(4,),1}((), - MLIR.IR.argument(fnbody, 1))) + d = Data( + Reactant.TracedRArray{Float64,(4,),1}((), MLIR.IR.argument(fnbody, 1)) + ) return tmp(a, b, d) end diff --git a/test/nn.jl b/test/nn.jl index ed2d04e14..e0aec9364 100644 --- a/test/nn.jl +++ b/test/nn.jl @@ -8,10 +8,12 @@ noisy = rand(Float32, 2, 1000) # 2×1000 Matr truth = [xor(col[1] > 0.5, col[2] > 0.5) for col in eachcol(noisy)] # 1000-element Vector{Bool} # Define our model, a multi-layer perceptron with one hidden layer of size 3: -model = Chain(Dense(2 => 3, tanh), # activation function inside layer - BatchNorm(3), - Dense(3 => 2), - softmax) +model = Chain( + Dense(2 => 3, tanh), # activation function inside layer + BatchNorm(3), + Dense(3 => 2), + softmax, +) using BenchmarkTools diff --git a/test/struct.jl b/test/struct.jl index 07a2cedf8..c73943a3d 100644 --- a/test/struct.jl +++ b/test/struct.jl @@ -65,7 +65,7 @@ end y = f(x2) @test y isa - MutableMockTensor{Float64,2,Reactant.ConcreteRArray{Float64,(4, 4),2}} + MutableMockTensor{Float64,2,Reactant.ConcreteRArray{Float64,(4, 4),2}} @test isapprox(parent(y), cos.(parent(x))) @test x.inds == [:i, :j] end