From 885264eaf59e957f48937e831adda0408ce9907a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 31 Dec 2024 15:55:19 -0500 Subject: [PATCH] feat: missing BF16 dispatches (#443) --- src/mlir/IR/Attribute.jl | 11 ++++++++++- src/mlir/IR/Type.jl | 11 +++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/mlir/IR/Attribute.jl b/src/mlir/IR/Attribute.jl index d37e7c986..d4354198a 100644 --- a/src/mlir/IR/Attribute.jl +++ b/src/mlir/IR/Attribute.jl @@ -588,7 +588,16 @@ function DenseElementsAttribute(values::AbstractArray{Float64}) ) end -# TODO mlirDenseElementsAttrBFloat16Get +if isdefined(Core, :BFloat16) + function DenseElementsAttribute(values::AbstractArray{Core.BFloat16}) + shaped_type = TensorType(size(values), Type(Core.BFloat16)) + return Attribute( + API.mlirDenseElementsAttrBFloat16Get( + shaped_type, length(values), to_row_major(values) + ), + ) + end +end function DenseElementsAttribute(values::AbstractArray{Float16}) shaped_type = TensorType(size(values), Type(Float16)) diff --git a/src/mlir/IR/Type.jl b/src/mlir/IR/Type.jl index c77224d68..bd44b4d6f 100644 --- a/src/mlir/IR/Type.jl +++ b/src/mlir/IR/Type.jl @@ -168,6 +168,15 @@ Creates an f16 type in the given context. The type is owned by the context. """ Type(::Core.Type{Float16}; context::Context=context()) = Type(API.mlirF16TypeGet(context)) +if isdefined(Core, :BFloat16) + """ + Type(::Core.Type{Core.BFloat16}; context=context()) + + Creates an bf16 type in the given context. The type is owned by the context. + """ + Type(::Core.Type{Core.BFloat16}; context::Context=context()) = BFloat16Type(; context) +end + """ Type(Core.Type{Float32}; context=context()) @@ -721,6 +730,8 @@ function julia_type(type::Type) throw("could not convert unsigned $width-bit integer type to julia") end end + elseif isbf16(type) + Core.BFloat16 elseif isf16(type) Float16 elseif isf32(type)