Skip to content

Commit

Permalink
feat: missing BF16 dispatches (#443)
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal authored Dec 31, 2024
1 parent 6556944 commit 885264e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/mlir/IR/Attribute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,16 @@ function DenseElementsAttribute(values::AbstractArray{Float64})
)
end

# TODO mlirDenseElementsAttrBFloat16Get
if isdefined(Core, :BFloat16)
function DenseElementsAttribute(values::AbstractArray{Core.BFloat16})
shaped_type = TensorType(size(values), Type(Core.BFloat16))
return Attribute(
API.mlirDenseElementsAttrBFloat16Get(
shaped_type, length(values), to_row_major(values)
),
)
end
end

function DenseElementsAttribute(values::AbstractArray{Float16})
shaped_type = TensorType(size(values), Type(Float16))
Expand Down
11 changes: 11 additions & 0 deletions src/mlir/IR/Type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,15 @@ Creates an f16 type in the given context. The type is owned by the context.
"""
Type(::Core.Type{Float16}; context::Context=context()) = Type(API.mlirF16TypeGet(context))

if isdefined(Core, :BFloat16)
"""
Type(::Core.Type{Core.BFloat16}; context=context())
Creates an bf16 type in the given context. The type is owned by the context.
"""
Type(::Core.Type{Core.BFloat16}; context::Context=context()) = BFloat16Type(; context)
end

"""
Type(Core.Type{Float32}; context=context())
Expand Down Expand Up @@ -721,6 +730,8 @@ function julia_type(type::Type)
throw("could not convert unsigned $width-bit integer type to julia")
end
end
elseif isbf16(type)
Core.BFloat16
elseif isf16(type)
Float16
elseif isf32(type)
Expand Down

0 comments on commit 885264e

Please sign in to comment.