From b92a7d88403a4be6d7a6f8ef1478f3fa60f535a5 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Fri, 15 Nov 2024 22:00:41 +0000 Subject: [PATCH] fix attribute for arrray of bools (#279) * fix attribute for arrray of bools * test * remove uses of unsafe pointer calls --- src/mlir/IR/AffineMap.jl | 8 ++--- src/mlir/IR/Attribute.jl | 66 +++++++++++++++------------------------ src/mlir/IR/IntegerSet.jl | 7 +---- src/mlir/IR/Location.jl | 2 +- src/mlir/IR/Type.jl | 22 ++++--------- test/compile.jl | 7 +++++ 6 files changed, 44 insertions(+), 68 deletions(-) diff --git a/src/mlir/IR/AffineMap.jl b/src/mlir/IR/AffineMap.jl index a70e6f722..39044e804 100644 --- a/src/mlir/IR/AffineMap.jl +++ b/src/mlir/IR/AffineMap.jl @@ -54,7 +54,7 @@ Creates an affine map with results defined by the given list of affine expressio The map resulting map also has the requested number of input dimensions and symbols, regardless of them being used in the results. """ AffineMap(ndims, nsymbols, exprs::Vector{AffineExpr}; context::Context=context()) = - AffineMap(API.mlirAffineMapGet(context, ndims, nsymbols, length(exprs), pointer(exprs))) + AffineMap(API.mlirAffineMapGet(context, ndims, nsymbols, length(exprs), exprs)) """ ConstantAffineMap(val; context=context()) @@ -94,9 +94,7 @@ The affine map is owned by the context. function PermutationAffineMap(permutation; context::Context=context()) @assert Base.isperm(permutation) "$permutation must be a valid permutation" zero_perm = permutation .- 1 - return AffineMap( - API.mlirAffineMapPermutationGet(context, length(zero_perm), pointer(zero_perm)) - ) + return AffineMap(API.mlirAffineMapPermutationGet(context, length(zero_perm), zero_perm)) end """ @@ -192,7 +190,7 @@ Base.isperm(map::AffineMap) = API.mlirAffineMapIsPermutation(map) Returns the affine map consisting of the `positions` subset. """ submap(map::AffineMap, pos::Vector{Int}) = - AffineMap(API.mlirAffineMapGetSubMap(map, length(pos), pointer(pos))) + AffineMap(API.mlirAffineMapGetSubMap(map, length(pos), pos)) """ majorsubmap(affineMap, nresults) diff --git a/src/mlir/IR/Attribute.jl b/src/mlir/IR/Attribute.jl index 6380d75ed..69bd9b99a 100644 --- a/src/mlir/IR/Attribute.jl +++ b/src/mlir/IR/Attribute.jl @@ -81,7 +81,7 @@ isarray(attr::Attribute) = API.mlirAttributeIsAArray(attr) Creates an array element containing the given list of elements in the given context. """ Attribute(attrs::Vector{Attribute}; context::Context=context()) = - Attribute(API.mlirArrayAttrGet(context, length(attrs), pointer(attrs))) + Attribute(API.mlirArrayAttrGet(context, length(attrs), attrs)) """ isdict(attr) @@ -97,7 +97,7 @@ Creates a dictionary attribute containing the given list of elements in the prov """ function Attribute(attrs::Dict; context::Context=context()) attrs = map(splat(NamedAttribute), attrs) - return Attribute(API.mlirDictionaryAttrGet(context, length(attrs), pointer(attrs))) + return Attribute(API.mlirDictionaryAttrGet(context, length(attrs), attrs)) end """ @@ -309,9 +309,7 @@ Each of the references in the list must not be nested. """ SymbolRefAttribute( symbol::String, references::Vector{Attribute}; context::Context=context() -) = Attribute( - API.mlirSymbolRefAttrGet(context, symbol, length(references), pointer(references)) -) +) = Attribute(API.mlirSymbolRefAttrGet(context, symbol, length(references), references)) """ rootref(attr) @@ -429,9 +427,7 @@ Creates a dense elements attribute with the given Shaped type and elements in th """ function DenseElementsAttribute(shaped_type::Type, elements::AbstractArray) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return Attribute( - API.mlirDenseElementsAttrGet(shaped_type, length(elements), pointer(elements)) - ) + return Attribute(API.mlirDenseElementsAttrGet(shaped_type, length(elements), elements)) end # TODO mlirDenseElementsAttrRawBufferGet @@ -504,77 +500,67 @@ Creates a dense elements attribute with the given shaped type from elements of a function DenseElementsAttribute(values::AbstractArray{Bool}) shaped_type = TensorType(size(values), Type(Bool)) return Attribute( - API.mlirDenseElementsAttrBoolGet(shaped_type, length(values), pointer(values)) + API.mlirDenseElementsAttrBoolGet( + shaped_type, length(values), AbstractArray{Cint}(values) + ), ) end function DenseElementsAttribute(values::AbstractArray{UInt8}) shaped_type = TensorType(size(values), Type(UInt8)) - return Attribute( - API.mlirDenseElementsAttrUInt8Get(shaped_type, length(values), pointer(values)) - ) + return Attribute(API.mlirDenseElementsAttrUInt8Get(shaped_type, length(values), values)) end function DenseElementsAttribute(values::AbstractArray{Int8}) shaped_type = TensorType(size(values), Type(Int8)) - return Attribute( - API.mlirDenseElementsAttrInt8Get(shaped_type, length(values), pointer(values)) - ) + return Attribute(API.mlirDenseElementsAttrInt8Get(shaped_type, length(values), values)) end function DenseElementsAttribute(values::AbstractArray{UInt16}) shaped_type = TensorType(size(values), Type(UInt16)) return Attribute( - API.mlirDenseElementsAttrUInt16Get(shaped_type, length(values), pointer(values)) + API.mlirDenseElementsAttrUInt16Get(shaped_type, length(values), values) ) end function DenseElementsAttribute(values::AbstractArray{Int16}) shaped_type = TensorType(size(values), Type(Int16)) - return Attribute( - API.mlirDenseElementsAttrInt16Get(shaped_type, length(values), pointer(values)) - ) + return Attribute(API.mlirDenseElementsAttrInt16Get(shaped_type, length(values), values)) end function DenseElementsAttribute(values::AbstractArray{UInt32}) shaped_type = TensorType(size(values), Type(UInt32)) return Attribute( - API.mlirDenseElementsAttrUInt32Get(shaped_type, length(values), pointer(values)) + API.mlirDenseElementsAttrUInt32Get(shaped_type, length(values), values) ) end function DenseElementsAttribute(values::AbstractArray{Int32}) shaped_type = TensorType(size(values), Type(Int32)) - return Attribute( - API.mlirDenseElementsAttrInt32Get(shaped_type, length(values), pointer(values)) - ) + return Attribute(API.mlirDenseElementsAttrInt32Get(shaped_type, length(values), values)) end function DenseElementsAttribute(values::AbstractArray{UInt64}) shaped_type = TensorType(size(values), Type(UInt64)) return Attribute( - API.mlirDenseElementsAttrUInt64Get(shaped_type, length(values), pointer(values)) + API.mlirDenseElementsAttrUInt64Get(shaped_type, length(values), values) ) end function DenseElementsAttribute(values::AbstractArray{Int64}) shaped_type = TensorType(size(values), Type(Int64)) - return Attribute( - API.mlirDenseElementsAttrInt64Get(shaped_type, length(values), pointer(values)) - ) + return Attribute(API.mlirDenseElementsAttrInt64Get(shaped_type, length(values), values)) end function DenseElementsAttribute(values::AbstractArray{Float32}) shaped_type = TensorType(size(values), Type(Float32)) - return Attribute( - API.mlirDenseElementsAttrFloatGet(shaped_type, length(values), pointer(values)) - ) + return Attribute(API.mlirDenseElementsAttrFloatGet(shaped_type, length(values), values)) end function DenseElementsAttribute(values::AbstractArray{Float64}) shaped_type = TensorType(size(values), Type(Float64)) return Attribute( - API.mlirDenseElementsAttrDoubleGet(shaped_type, length(values), pointer(values)) + API.mlirDenseElementsAttrDoubleGet(shaped_type, length(values), values) ) end @@ -583,7 +569,7 @@ end function DenseElementsAttribute(values::AbstractArray{Float16}) shaped_type = TensorType(size(values), Type(Float16)) return Attribute( - API.mlirDenseElementsAttrFloat16Get(shaped_type, length(values), pointer(values)) + API.mlirDenseElementsAttrFloat16Get(shaped_type, length(values), values) ) end @@ -606,7 +592,7 @@ function DenseElementsAttribute(values::AbstractArray{String}) # TODO may fail because `Type(String)` is not defined shaped_type = TensorType(size(values), Type(String)) return Attribute( - API.mlirDenseElementsAttrStringGet(shaped_type, length(values), pointer(values)) + API.mlirDenseElementsAttrStringGet(shaped_type, length(values), values) ) end @@ -677,25 +663,25 @@ function DenseArrayAttribute end @llvmversioned min = v"16" DenseArrayAttribute( values::AbstractArray{Bool}; context::Context=context() -) = Attribute(API.mlirDenseBoolArrayGet(context, length(values), pointer(values))) +) = Attribute(API.mlirDenseBoolArrayGet(context, length(values), values)) @llvmversioned min = v"16" DenseArrayAttribute( values::AbstractArray{Int8}; context::Context=context() -) = Attribute(API.mlirDenseI8ArrayGet(context, length(values), pointer(values))) +) = Attribute(API.mlirDenseI8ArrayGet(context, length(values), values)) @llvmversioned min = v"16" DenseArrayAttribute( values::AbstractArray{Int16}; context::Context=context() -) = Attribute(API.mlirDenseI16ArrayGet(context, length(values), pointer(values))) +) = Attribute(API.mlirDenseI16ArrayGet(context, length(values), values)) @llvmversioned min = v"16" DenseArrayAttribute( values::AbstractArray{Int32}; context::Context=context() -) = Attribute(API.mlirDenseI32ArrayGet(context, length(values), pointer(values))) +) = Attribute(API.mlirDenseI32ArrayGet(context, length(values), values)) @llvmversioned min = v"16" DenseArrayAttribute( values::AbstractArray{Int64}; context::Context=context() -) = Attribute(API.mlirDenseI64ArrayGet(context, length(values), pointer(values))) +) = Attribute(API.mlirDenseI64ArrayGet(context, length(values), values)) @llvmversioned min = v"16" DenseArrayAttribute( values::AbstractArray{Float32}; context::Context=context() -) = Attribute(API.mlirDenseF32ArrayGet(context, length(values), pointer(values))) +) = Attribute(API.mlirDenseF32ArrayGet(context, length(values), values)) @llvmversioned min = v"16" DenseArrayAttribute( values::AbstractArray{Float64}; context::Context=context() -) = Attribute(API.mlirDenseF64ArrayGet(context, length(values), pointer(values))) +) = Attribute(API.mlirDenseF64ArrayGet(context, length(values), values)) @llvmversioned min = v"16" Attribute(values::AbstractArray) = DenseArrayAttribute(values) diff --git a/src/mlir/IR/IntegerSet.jl b/src/mlir/IR/IntegerSet.jl index e7f89cbc1..eb57939cc 100644 --- a/src/mlir/IR/IntegerSet.jl +++ b/src/mlir/IR/IntegerSet.jl @@ -24,12 +24,7 @@ Both `constraints` and `eqflags` need to be arrays of the same length. """ IntegerSet(ndims, nsymbols, constraints, eqflags; context::Context=context()) = IntegerSet( API.mlirIntegerSetGet( - context, - ndims, - nsymbols, - length(constraints), - pointer(constraints), - pointer(eqflags), + context, ndims, nsymbols, length(constraints), constraints, eqflags ), ) diff --git a/src/mlir/IR/Location.jl b/src/mlir/IR/Location.jl index 88f403006..d98e4fa1c 100644 --- a/src/mlir/IR/Location.jl +++ b/src/mlir/IR/Location.jl @@ -24,7 +24,7 @@ end # TODO rename to merge? function fuse(locations::Vector{Location}, metadata; context::Context=context()) return Location( - API.mlirLocationFusedGet(context, length(locations), pointer(locations), metadata) + API.mlirLocationFusedGet(context, length(locations), locations, metadata) ) end diff --git a/src/mlir/IR/Type.jl b/src/mlir/IR/Type.jl index 0d3129ce0..c77224d68 100644 --- a/src/mlir/IR/Type.jl +++ b/src/mlir/IR/Type.jl @@ -449,15 +449,11 @@ function MemRefType( if check Type( API.mlirMemRefTypeGetChecked( - location, elem_type, length(shape), pointer(shape), layout, memspace + location, elem_type, length(shape), shape, layout, memspace ), ) else - Type( - API.mlirMemRefTypeGet( - elem_type, length(shape), pointer(shape), layout, memspace - ), - ) + Type(API.mlirMemRefTypeGet(elem_type, length(shape), shape, layout, memspace)) end end @@ -474,15 +470,11 @@ function MemRefType( if check Type( API.mlirMemRefTypeContiguousGetChecked( - location, elem_type, length(shape), pointer(shape), memspace + location, elem_type, length(shape), shape, memspace ), ) else - Type( - API.mlirMemRefTypeContiguousGet( - elem_type, length(shape), pointer(shape), memspace - ), - ) + Type(API.mlirMemRefTypeContiguousGet(elem_type, length(shape), shape, memspace)) end end @@ -560,7 +552,7 @@ end Creates a tuple type that consists of the given list of elemental types. The type is owned by the context. """ Type(elements::Vector{Type}; context::Context=context()) = - Type(API.mlirTupleTypeGet(context, length(elements), pointer(elements))) + Type(API.mlirTupleTypeGet(context, length(elements), elements)) function Type(@nospecialize(elements::NTuple{N,Type}); context::Context=context()) where {N} return Type(collect(elements); context) end @@ -590,9 +582,7 @@ Creates a function type, mapping a list of input types to result types. """ function FunctionType(inputs, results; context::Context=context()) return Type( - API.mlirFunctionTypeGet( - context, length(inputs), pointer(inputs), length(results), pointer(results) - ), + API.mlirFunctionTypeGet(context, length(inputs), inputs, length(results), results) ) end diff --git a/test/compile.jl b/test/compile.jl index 682c8d02b..0a65ecbd0 100644 --- a/test/compile.jl +++ b/test/compile.jl @@ -65,3 +65,10 @@ Base.sum(x::NamedTuple{(:a,),Tuple{T}}) where {T<:Reactant.TracedRArray} = (; a= # end # end end + +@testset "Bool attributes" begin + x_ra = Reactant.to_rarray(false; track_numbers=(Number,)) + @test @jit(iszero(x_ra)) == true + x_ra = Reactant.to_rarray(true; track_numbers=(Number,)) + @test @jit(iszero(x_ra)) == false +end