Skip to content

Commit

Permalink
fix: make eltype of Traced/Concrete Arrays to be respective RNumbers
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 24, 2024
1 parent 2f53601 commit db906a4
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 64 deletions.
21 changes: 2 additions & 19 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,3 @@
struct XLAArray{T,N} <: RArray{T,N}
# size::NTuple{N,Int}
end

mutable struct ConcreteRArray{T,N} <: RArray{T,N}
data::XLA.AsyncBuffer
# data::XLAArray{T, N}
shape::NTuple{N,Int}
end

const WrappedConcreteRArray{T,N} = WrappedArray{T,N,ConcreteRArray,ConcreteRArray{T,N}}
const AnyConcreteRArray{T,N} = Union{ConcreteRArray{T,N},WrappedConcreteRArray{T,N}}

mutable struct ConcreteRNumber{T} <: RNumber{T}
data::XLA.AsyncBuffer
end

function ConcreteRNumber{T}(
data::T2; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing
) where {T<:Number,T2<:Number}
Expand Down Expand Up @@ -234,12 +217,12 @@ function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N}
# start += (args[i]-1)
end
start += 1
return unsafe_load(ptr, start)
return ConcreteRNumber(unsafe_load(ptr, start))
end
end

GPUArraysCore.assertscalar("getindex(::ConcreteRArray, ::Vararg{Int, N})")
return convert(Array, a)[args...]
return ConcreteRNumber(convert(Array, a)[args...])
end

function mysetindex!(a, v, args::Vararg{Any,N}) where {N}
Expand Down
51 changes: 36 additions & 15 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,8 @@ else
}
end

abstract type RArray{T<:ReactantPrimitive,N} <: AbstractArray{T,N} end
abstract type RNumber{T<:ReactantPrimitive} <: Number end

Base.collect(A::RArray) = copy(A)

function ancestor(x::AbstractArray)
p_x = parent(x)
p_x === x && return x
Expand All @@ -71,7 +68,21 @@ include("Interpreter.jl")

include("utils.jl")

mutable struct TracedRArray{T,N} <: RArray{T,N}
mutable struct TracedRNumber{T} <: RNumber{T}
paths::Tuple
mlir_data::Union{Nothing,MLIR.IR.Value}

function TracedRNumber{T}(
paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}
) where {T}
if !isnothing(mlir_data)
@assert size(MLIR.IR.type(mlir_data)) == ()
end
return new{T}(paths, mlir_data)
end
end

mutable struct TracedRArray{T,N} <: AbstractArray{TracedRNumber{T},N}
paths::Tuple
mlir_data::Union{Nothing,MLIR.IR.Value}
shape::NTuple{N,Int}
Expand Down Expand Up @@ -102,20 +113,30 @@ function TracedRArray(data::MLIR.IR.Value)
)
end

mutable struct TracedRNumber{T} <: RNumber{T}
paths::Tuple
mlir_data::Union{Nothing,MLIR.IR.Value}
struct XLAArray{T,N} <: AbstractArray{T,N} end

function TracedRNumber{T}(
paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}
) where {T}
if !isnothing(mlir_data)
@assert size(MLIR.IR.type(mlir_data)) == ()
end
return new{T}(paths, mlir_data)
end
mutable struct ConcreteRNumber{T} <: RNumber{T}
data::XLA.AsyncBuffer
end

mutable struct ConcreteRArray{T,N} <: AbstractArray{ConcreteRNumber{T},N}
data::XLA.AsyncBuffer
shape::NTuple{N,Int}
end

unwrapped_eltype(::Type{T}) where {T<:Number} = T
unwrapped_eltype(::Type{<:TracedRNumber{T}}) where {T} = T
unwrapped_eltype(::Type{<:TracedRArray{T,N}}) where {T,N} = T
unwrapped_eltype(::Type{<:XLAArray{T,N}}) where {T,N} = T
unwrapped_eltype(::Type{<:ConcreteRNumber{T}}) where {T} = T
unwrapped_eltype(::Type{<:ConcreteRArray{T,N}}) where {T,N} = T
unwrapped_eltype(x) = unwrapped_eltype(typeof(x))

const WrappedConcreteRArray{T,N} = WrappedArray{T,N,ConcreteRArray,ConcreteRArray{T,N}}
const AnyConcreteRArray{T,N} = Union{ConcreteRArray{T,N},WrappedConcreteRArray{T,N}}

const RArray{T,N} = Union{ConcreteRArray{T,N},TracedRArray{T,N},XLAArray{T,N}}

include("Ops.jl")
include("TracedUtils.jl")

Expand Down
2 changes: 2 additions & 0 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ Base.Tuple(x::TracedRArray) = ntuple(Base.Fix1(Base.getindex, x), length(x))

Base.size(x::TracedRArray) = x.shape

Base.collect(x::TracedRArray) = copy(x) # XXX: Is this correct?

Base.copy(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), A.mlir_data, size(A))

# TODO is there a way to create an unitialized `tensor`? does it show an advantage? maybe `fill`?
Expand Down
19 changes: 4 additions & 15 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
module TracedRNumberOverrides

import ..TracedRNumber
import ..TracedRArray
import ..ReactantPrimitive
using ..TracedUtils
import ..Ops
import ..MLIR
using ..Reactant:
Reactant, TracedRNumber, TracedRArray, ReactantPrimitive, TracedUtils, Ops, MLIR
using ReactantCore

ReactantCore.is_traced(::TracedRNumber) = true

Base.eltype(::Type{TracedRNumber{T}}) where {T} = T

Base.getindex(a::TracedRNumber{T}) where {T} = a

Base.zero(::TracedRNumber{T}) where {T} = TracedUtils.promote_to(TracedRNumber{T}, zero(T))
Expand Down Expand Up @@ -51,20 +45,15 @@ TracedRNumber{T}(x::TracedRNumber{T}) where {T} = x
TracedRNumber{T}(x::TracedRNumber) where {T} = TracedUtils.promote_to(TracedRNumber{T}, x)
TracedRNumber{T}(x::Number) where {T} = TracedUtils.promote_to(TracedRNumber{T}, x)

(T::Type{<:Number})(x::TracedRNumber) = TracedUtils.promote_to(TracedRNumber{T}, x)

function Base.convert(::Type{TracedRNumber{T}}, x::Number) where {T}
return TracedUtils.promote_to(TracedRNumber{T}, x)
end

function TracedUtils.promote_to(::Type{TracedRNumber{T}}, rhs) where {T}
if rhs isa TracedRNumber
rhs isa TracedRNumber{T} && return rhs
return Ops.convert(TracedRNumber{T}, rhs)
end
if rhs isa TracedRArray{<:Any,0}
return TracedUtils.promote_to(
TracedRNumber{T}, TracedRNumber{eltype(rhs)}((), rhs.mlir_data)
TracedRNumber{T},
TracedRNumber{Reactant.unwrapped_eltype(rhs)}((), rhs.mlir_data),
)
end
rhs isa Number &&
Expand Down
29 changes: 15 additions & 14 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,17 @@ module TracedUtils
using LinearAlgebra: LinearAlgebra
using Adapt: Adapt
using ..Reactant:
RArray,
Reactant,
MLIR,
RNumber,
TracedRArray,
TracedRNumber,
WrappedTracedRArray,
AnyTracedRArray,
MissingTracedValue,
OrderedIdDict
import ..Reactant
import ..Reactant.MLIR
import ..ReactantPrimitive
import ..Ops
OrderedIdDict,
ReactantPrimitive,
Ops

materialize_traced_array(x::TracedRArray) = x
materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...]
Expand Down Expand Up @@ -164,7 +163,10 @@ function make_mlir_fn(
end

in_tys = if toscalar
[MLIR.IR.TensorType((), MLIR.IR.Type(eltype(arg))) for arg in linear_args]
[
MLIR.IR.TensorType((), MLIR.IR.Type(Reactant.unwrapped_eltype(arg))) for
arg in linear_args
]
elseif do_transpose
[transpose_ty(Ops.mlir_type(arg)) for arg in linear_args]
else
Expand Down Expand Up @@ -416,7 +418,8 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
in_tys2 = [Ops.mlir_type(invmap[arg]) for arg in linear_args]

out_tys2 = [
MLIR.IR.TensorType(OutShape, MLIR.IR.Type(eltype(arg))) for arg in linear_results
MLIR.IR.TensorType(OutShape, MLIR.IR.Type(Reactant.unwrapped_eltype(arg))) for
arg in linear_results
]

fname = get_attribute_by_name(func2, "sym_name")
Expand Down Expand Up @@ -487,11 +490,9 @@ end

broadcast_to_size(arg::Number, rsize) = Ops.constant(Base.fill(arg, Tuple(rsize)))

function broadcast_to_size(arg::TracedRNumber, rsize)
function broadcast_to_size(arg::TracedRNumber{T}, rsize) where {T}
length(rsize) == 0 && return arg
return broadcast_to_size_internal(
TracedRArray{eltype(arg),0}((), arg.mlir_data, ()), rsize
)
return broadcast_to_size_internal(TracedRArray{T,0}((), arg.mlir_data, ()), rsize)
end

function broadcast_to_size(arg::AnyTracedRArray{T,0}, rsize) where {T}
Expand All @@ -512,7 +513,7 @@ function broadcast_to_size(arg::Broadcast.Extruded, rsize)
return broadcast_to_size_internal(x, rsize)
end

@noinline function broadcast_to_size_internal(x::TracedRArray, rsize)
@noinline function broadcast_to_size_internal(x::TracedRArray{T}, rsize) where {T}
dims = collect(Int64, 0:(length(size(x)) - 1))

if length(size(MLIR.IR.type(x.mlir_data))) != length(dims)
Expand All @@ -525,7 +526,7 @@ end
@assert length(size(MLIR.IR.type(x.mlir_data))) == length(dims)
mlirty = MLIR.IR.type(x.mlir_data)

return TracedRArray{eltype(x),Int(length(rsize))}(
return TracedRArray{T,Int(length(rsize))}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.broadcast_in_dim(
Expand Down
2 changes: 1 addition & 1 deletion src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
TracedSetPath = 5
end

for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, RArray, RNumber)
for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, RNumber)
@eval function traced_type(::Type{T}, seen, mode, track_numbers) where {T<:$T}
return T
end
Expand Down

0 comments on commit db906a4

Please sign in to comment.