Skip to content

Commit

Permalink
feat: support custom user types
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 24, 2025
1 parent 2cd100b commit 3a493bf
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 31 deletions.
19 changes: 19 additions & 0 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,25 @@ end
return TracedRArray{T,N}((), res, size(x))
end

## This is somewhat a hack because I can't seem to find the corresponding mlir
## DenseElementsAttribute functions (also our optimizations will run a pass converting this
## to a single operation)
for T in (:F8E5M2, :F8E4M3FN, :F8E4M3B11FNUZ, :F8E5M2FNUZ, :F8E4M3FNUZ)
@eval @noinline function constant(
x::DenseArray{<:Reactant.$(T),N};
location=mlir_stacktrace("constant", @__FILE__, @__LINE__),
) where {N}
value = MLIR.IR.DenseElementsAttribute(
map(Float16 Base.Fix2(getproperty, :val), x)
)
output = mlir_type(TracedRArray{Float16,N}, size(x))
res = MLIR.IR.result(stablehlo.constant(; output, value, location))
return convert(
TracedRArray{eltype(x),N}, TracedRArray{Float16,N}((), res, size(x)); location
)
end
end

@noinline function constant(
x::T; location=mlir_stacktrace("constant", @__FILE__, @__LINE__)
) where {T<:Number}
Expand Down
32 changes: 18 additions & 14 deletions src/PrimitiveTypes.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
# The types listed in this file are the ones present in StableHLO specification.

# These only exist for the purpose of lowering. Since `ReactantPrimitive` is a fixed set of
# types, users can use these to convert their types to the primitive types supported by
# Reactant.
for T in (:F8E5M2, :F8E4M3FN, :F8E4M3B11FNUZ, :F8E5M2FNUZ, :F8E4M3FNUZ)
@eval begin
primitive type $(T) <: AbstractFloat 8 end
struct $(T){inT} <: AbstractFloat
val::inT
end

Base.promote_rule(::Type{$(T)}, ::Type{Float16}) = Float16
Base.promote_rule(::Type{Float16}, ::Type{$(T)}) = Float16
Base.promote_rule(::Type{<:$(T)}, ::Type{Float16}) = Float16
Base.promote_rule(::Type{Float16}, ::Type{<:$(T)}) = Float16

Base.promote_rule(::Type{$(T)}, ::Type{Float32}) = Float32
Base.promote_rule(::Type{Float32}, ::Type{$(T)}) = Float32
Base.promote_rule(::Type{<:$(T)}, ::Type{Float32}) = Float32
Base.promote_rule(::Type{Float32}, ::Type{<:$(T)}) = Float32

Base.promote_rule(::Type{$(T)}, ::Type{Float64}) = Float64
Base.promote_rule(::Type{Float64}, ::Type{$(T)}) = Float64
Base.promote_rule(::Type{<:$(T)}, ::Type{Float64}) = Float64
Base.promote_rule(::Type{Float64}, ::Type{<:$(T)}) = Float64

Base.promote_rule(::Type{$(T)}, ::Type{<:Integer}) = $(T)
Base.promote_rule(::Type{<:Integer}, ::Type{$(T)}) = $(T)
Base.promote_rule(::Type{<:$(T){inT}}, ::Type{<:Integer}) where {inT} = $(T){inT}
Base.promote_rule(::Type{<:Integer}, ::Type{<:$(T){inT}}) where {inT} = $(T){inT}

@static if isdefined(Core, :BFloat16)
Base.promote_rule(::Type{$(T)}, ::Type{Core.BFloat16}) = Core.BFloat16
Base.promote_rule(::Type{Core.BFloat16}, ::Type{$(T)}) = Core.BFloat16
Base.promote_rule(::Type{<:$(T)}, ::Type{Core.BFloat16}) = Core.BFloat16
Base.promote_rule(::Type{Core.BFloat16}, ::Type{<:$(T)}) = Core.BFloat16
end
end
end
Expand All @@ -36,9 +40,9 @@ else
const ReactantFloat = Union{Float16,Float32,Float64,Base.uniontypes(ReactantFloat8)...}
end

const ReactantComplexFloat = Union{[Complex{T} for T in Base.uniontypes(ReactantFloat)]...}
const ReactantComplexFloat = Union{Complex{Float32}, Complex{Float64}}

const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Int128,UInt128}
const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64}

const ReactantComplexInt = Union{[Complex{T} for T in Base.uniontypes(ReactantInt)]...}

Expand All @@ -53,7 +57,7 @@ const ReactantPrimitive = Union{
Base.uniontypes(ReactantComplexFloat)...,
}

to_reactant_primitive(v::T) where {T} = reinterpret(reactant_primitive(T), v)
to_reactant_primitive(v::T) where {T} = reactant_primitive(T)(v)
reactant_primitive(::Type{T}) where {T} = nothing

for T in Base.uniontypes(ReactantPrimitive)
Expand Down
16 changes: 9 additions & 7 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Base.zero(::TracedRNumber{T}) where {T} = TracedUtils.promote_to(TracedRNumber{T
Base.one(::TracedRNumber{T}) where {T} = TracedUtils.promote_to(TracedRNumber{T}, one(T))
Base.collect(x::TracedRNumber{T}) where {T} = TracedRArray{T,0}((), x.mlir_data, ())

function Base.eps(::Type{TracedRNumber{T}}) where {T}
function Base.eps(::Type{<:TracedRNumber{T}}) where {T}
return TracedUtils.promote_to(TracedRNumber{T}, eps(T))
end

Expand All @@ -36,24 +36,26 @@ end

Base.only(A::TracedRNumber{T}) where {T} = A

function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{TracedRNumber{S}}) where {T,S}
function Base.promote_rule(
::Type{<:TracedRNumber{T}}, ::Type{<:TracedRNumber{S}}
) where {T,S}
return TracedRNumber{Base.promote_type(T, S)}
end

# Bool has special promotion rules in Base
function Base.promote_rule(::Type{Bool}, ::Type{TracedRNumber{T}}) where {T}
function Base.promote_rule(::Type{Bool}, ::Type{<:TracedRNumber{T}}) where {T}
return TracedRNumber{T}
end

function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{Bool}) where {T}
function Base.promote_rule(::Type{<:TracedRNumber{T}}, ::Type{Bool}) where {T}
return TracedRNumber{T}
end

function Base.promote_rule(::Type{T}, ::Type{TracedRNumber{S}}) where {T,S}
function Base.promote_rule(::Type{T}, ::Type{<:TracedRNumber{S}}) where {T,S}
return TracedRNumber{Base.promote_type(T, S)}
end

function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{S}) where {T,S}
function Base.promote_rule(::Type{<:TracedRNumber{T}}, ::Type{S}) where {T,S}
return TracedRNumber{Base.promote_type(T, S)}
end

Expand All @@ -67,7 +69,7 @@ function TracedRNumber{T}(x::Number) where {T}
return TracedUtils.promote_to(TracedRNumber{unwrapped_eltype(T)}, x)
end

function TracedUtils.promote_to(::Type{TracedRNumber{T}}, rhs) where {T}
function TracedUtils.promote_to(::Type{<:TracedRNumber{T}}, rhs) where {T}
if rhs isa TracedRNumber
rhs isa TracedRNumber{T} && return rhs
return Ops.convert(TracedRNumber{T}, rhs)
Expand Down
10 changes: 5 additions & 5 deletions src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -357,11 +357,11 @@ end
@inline primitive_type(::Type{Float16}) = 10
@inline primitive_type(::Type{Float32}) = 11

@inline primitive_type(::Type{Reactant.F8E5M2}) = 19
@inline primitive_type(::Type{Reactant.F8E4M3FN}) = 20
@inline primitive_type(::Type{Reactant.F8E4M3B11FNUZ}) = 23
@inline primitive_type(::Type{Reactant.F8E5M2FNUZ}) = 24
@inline primitive_type(::Type{Reactant.F8E4M3FNUZ}) = 25
@inline primitive_type(::Type{<:Reactant.F8E5M2}) = 19
@inline primitive_type(::Type{<:Reactant.F8E4M3FN}) = 20
@inline primitive_type(::Type{<:Reactant.F8E4M3B11FNUZ}) = 23
@inline primitive_type(::Type{<:Reactant.F8E5M2FNUZ}) = 24
@inline primitive_type(::Type{<:Reactant.F8E4M3FNUZ}) = 25

@static if isdefined(Core, :BFloat16)
@inline primitive_type(::Type{Core.BFloat16}) = 16
Expand Down
10 changes: 5 additions & 5 deletions src/mlir/IR/Type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ Type(::Core.Type{Float64}; context::Context=context()) = Type(API.mlirF64TypeGet
Creates a f8e5m2 type in the given context. The type is owned by the context.
"""
function Type(::Core.Type{Reactant.F8E5M2}; context::Context=context())
function Type(::Core.Type{<:Reactant.F8E5M2}; context::Context=context())
return Type(API.mlirFloat8E5M2TypeGet(context))
end

Expand All @@ -205,7 +205,7 @@ end
Creates a f8e4m3fn type in the given context. The type is owned by the context.
"""
function Type(::Core.Type{Reactant.F8E4M3FN}; context::Context=context())
function Type(::Core.Type{<:Reactant.F8E4M3FN}; context::Context=context())
return Type(API.mlirFloat8E4M3FNTypeGet(context))
end

Expand All @@ -214,7 +214,7 @@ end
Creates a f8e4m3b11fnuz type in the given context. The type is owned by the context.
"""
function Type(::Core.Type{Reactant.F8E4M3B11FNUZ}; context::Context=context())
function Type(::Core.Type{<:Reactant.F8E4M3B11FNUZ}; context::Context=context())
return Type(API.mlirFloat8E4M3B11FNUZTypeGet(context))
end

Expand All @@ -223,7 +223,7 @@ end
Creates a f8e5m2fnuz type in the given context. The type is owned by the context.
"""
function Type(::Core.Type{Reactant.F8E5M2FNUZ}; context::Context=context())
function Type(::Core.Type{<:Reactant.F8E5M2FNUZ}; context::Context=context())
return Type(API.mlirFloat8E5M2FNUZTypeGet(context))
end

Expand All @@ -232,7 +232,7 @@ end
Creates a f8e4m3fnuz type in the given context. The type is owned by the context.
"""
function Type(::Core.Type{Reactant.F8E4M3FNUZ}; context::Context=context())
function Type(::Core.Type{<:Reactant.F8E4M3FNUZ}; context::Context=context())
return Type(API.mlirFloat8E4M3FNTypeGet(context))
end

Expand Down

0 comments on commit 3a493bf

Please sign in to comment.