Skip to content

Commit

Permalink
fix: make eltype of Traced/Concrete Arrays to be respective RNumbers
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 24, 2024
1 parent 2f53601 commit e35a7e3
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 51 deletions.
21 changes: 2 additions & 19 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,3 @@
struct XLAArray{T,N} <: RArray{T,N}
# size::NTuple{N,Int}
end

mutable struct ConcreteRArray{T,N} <: RArray{T,N}
data::XLA.AsyncBuffer
# data::XLAArray{T, N}
shape::NTuple{N,Int}
end

const WrappedConcreteRArray{T,N} = WrappedArray{T,N,ConcreteRArray,ConcreteRArray{T,N}}
const AnyConcreteRArray{T,N} = Union{ConcreteRArray{T,N},WrappedConcreteRArray{T,N}}

mutable struct ConcreteRNumber{T} <: RNumber{T}
data::XLA.AsyncBuffer
end

function ConcreteRNumber{T}(
data::T2; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing
) where {T<:Number,T2<:Number}
Expand Down Expand Up @@ -234,12 +217,12 @@ function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N}
# start += (args[i]-1)
end
start += 1
return unsafe_load(ptr, start)
return ConcreteRNumber(unsafe_load(ptr, start))
end
end

GPUArraysCore.assertscalar("getindex(::ConcreteRArray, ::Vararg{Int, N})")
return convert(Array, a)[args...]
return ConcreteRNumber(convert(Array, a)[args...])
end

function mysetindex!(a, v, args::Vararg{Any,N}) where {N}
Expand Down
52 changes: 37 additions & 15 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,8 @@ else
}
end

abstract type RArray{T<:ReactantPrimitive,N} <: AbstractArray{T,N} end
abstract type RNumber{T<:ReactantPrimitive} <: Number end

Base.collect(A::RArray) = copy(A)

function ancestor(x::AbstractArray)
p_x = parent(x)
p_x === x && return x
Expand All @@ -71,7 +68,21 @@ include("Interpreter.jl")

include("utils.jl")

mutable struct TracedRArray{T,N} <: RArray{T,N}
mutable struct TracedRNumber{T} <: RNumber{T}
paths::Tuple
mlir_data::Union{Nothing,MLIR.IR.Value}

function TracedRNumber{T}(
paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}
) where {T}
if !isnothing(mlir_data)
@assert size(MLIR.IR.type(mlir_data)) == ()
end
return new{T}(paths, mlir_data)
end
end

mutable struct TracedRArray{T,N} <: AbstractArray{TracedRNumber{T},N}
paths::Tuple
mlir_data::Union{Nothing,MLIR.IR.Value}
shape::NTuple{N,Int}
Expand Down Expand Up @@ -102,20 +113,31 @@ function TracedRArray(data::MLIR.IR.Value)
)
end

mutable struct TracedRNumber{T} <: RNumber{T}
paths::Tuple
mlir_data::Union{Nothing,MLIR.IR.Value}
struct XLAArray{T,N} <: AbstractArray{T,N} end

function TracedRNumber{T}(
paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}
) where {T}
if !isnothing(mlir_data)
@assert size(MLIR.IR.type(mlir_data)) == ()
end
return new{T}(paths, mlir_data)
end
mutable struct ConcreteRNumber{T} <: RNumber{T}
data::XLA.AsyncBuffer
end

mutable struct ConcreteRArray{T,N} <: AbstractArray{ConcreteRNumber{T},N}
data::XLA.AsyncBuffer
shape::NTuple{N,Int}
end

unwrapped_eltype(::Type{T}) where {T<:Number} = T
unwrapped_eltype(::Type{<:TracedRNumber{T}}) where {T} = T
unwrapped_eltype(::Type{<:TracedRArray{T,N}}) where {T,N} = T
unwrapped_eltype(::Type{<:XLAArray{T,N}}) where {T,N} = T
unwrapped_eltype(::Type{<:ConcreteRNumber{T}}) where {T} = T
unwrapped_eltype(::Type{<:ConcreteRArray{T,N}}) where {T,N} = T
unwrapped_eltype(::Type{<:AbstractArray{T,N}}) where {T,N} = error(1)
unwrapped_eltype(x) = (@show(x); unwrapped_eltype(typeof(x)))

const WrappedConcreteRArray{T,N} = WrappedArray{T,N,ConcreteRArray,ConcreteRArray{T,N}}
const AnyConcreteRArray{T,N} = Union{ConcreteRArray{T,N},WrappedConcreteRArray{T,N}}

const RArray{T,N} = Union{ConcreteRArray{T,N},TracedRArray{T,N},XLAArray{T,N}}

include("Ops.jl")
include("TracedUtils.jl")

Expand Down
2 changes: 2 additions & 0 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ Base.Tuple(x::TracedRArray) = ntuple(Base.Fix1(Base.getindex, x), length(x))

Base.size(x::TracedRArray) = x.shape

Base.collect(x::TracedRArray) = copy(x) # XXX: Is this correct?

Base.copy(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), A.mlir_data, size(A))

# TODO is there a way to create an unitialized `tensor`? does it show an advantage? maybe `fill`?
Expand Down
19 changes: 4 additions & 15 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
module TracedRNumberOverrides

import ..TracedRNumber
import ..TracedRArray
import ..ReactantPrimitive
using ..TracedUtils
import ..Ops
import ..MLIR
using ..Reactant:
Reactant, TracedRNumber, TracedRArray, ReactantPrimitive, TracedUtils, Ops, MLIR
using ReactantCore

ReactantCore.is_traced(::TracedRNumber) = true

Base.eltype(::Type{TracedRNumber{T}}) where {T} = T

Base.getindex(a::TracedRNumber{T}) where {T} = a

Base.zero(::TracedRNumber{T}) where {T} = TracedUtils.promote_to(TracedRNumber{T}, zero(T))
Expand Down Expand Up @@ -51,20 +45,15 @@ TracedRNumber{T}(x::TracedRNumber{T}) where {T} = x
TracedRNumber{T}(x::TracedRNumber) where {T} = TracedUtils.promote_to(TracedRNumber{T}, x)
TracedRNumber{T}(x::Number) where {T} = TracedUtils.promote_to(TracedRNumber{T}, x)

(T::Type{<:Number})(x::TracedRNumber) = TracedUtils.promote_to(TracedRNumber{T}, x)

function Base.convert(::Type{TracedRNumber{T}}, x::Number) where {T}
return TracedUtils.promote_to(TracedRNumber{T}, x)
end

function TracedUtils.promote_to(::Type{TracedRNumber{T}}, rhs) where {T}
if rhs isa TracedRNumber
rhs isa TracedRNumber{T} && return rhs
return Ops.convert(TracedRNumber{T}, rhs)
end
if rhs isa TracedRArray{<:Any,0}
return TracedUtils.promote_to(
TracedRNumber{T}, TracedRNumber{eltype(rhs)}((), rhs.mlir_data)
TracedRNumber{T},
TracedRNumber{Reactant.unwrapped_eltype(rhs)}((), rhs.mlir_data),
)
end
rhs isa Number &&
Expand Down
1 change: 0 additions & 1 deletion src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ module TracedUtils
using LinearAlgebra: LinearAlgebra
using Adapt: Adapt
using ..Reactant:
RArray,
RNumber,
TracedRArray,
TracedRNumber,
Expand Down
2 changes: 1 addition & 1 deletion src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
TracedSetPath = 5
end

for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, RArray, RNumber)
for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, RNumber)
@eval function traced_type(::Type{T}, seen, mode, track_numbers) where {T<:$T}
return T
end
Expand Down

0 comments on commit e35a7e3

Please sign in to comment.