Skip to content

Commit

Permalink
fix: make eltype of Traced/Concrete Arrays to be respective RNumbers (#…
Browse files Browse the repository at this point in the history
…426)

* feat: overlay eltype conversion

* fix: overload the main methods

* fix: make eltype of Traced/Concrete Arrays to be respective RNumbers

* fix: handle more cases

* fix: tracing of wrapped types

* fix: arrayinterface overload

* fix: python call
  • Loading branch information
avik-pal authored Dec 29, 2024
1 parent f079a9d commit b0a58bd
Show file tree
Hide file tree
Showing 17 changed files with 249 additions and 176 deletions.
1 change: 1 addition & 0 deletions ext/ReactantArrayInterfaceExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ function ArrayInterface.aos_to_soa(x::AbstractArray{<:ConcreteRNumber{T}}) where
return x_c
end

ArrayInterface.aos_to_soa(x::TracedRArray) = x
function ArrayInterface.aos_to_soa(x::AbstractArray{<:TracedRNumber{T}}) where {T}
return Ops.reshape(vcat(x...), size(x)...)
end
Expand Down
45 changes: 20 additions & 25 deletions ext/ReactantPythonCallExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,22 @@ using PythonCall

const jaxptr = Ref{Py}()

const NUMPY_SIMPLE_TYPES = (
("bool_", Bool),
("int8", Int8),
("int16", Int16),
("int32", Int32),
("int64", Int64),
("uint8", UInt8),
("uint16", UInt16),
("uint32", UInt32),
("uint64", UInt64),
("float16", Float16),
("float32", Float32),
("float64", Float64),
("complex32", ComplexF16),
("complex64", ComplexF32),
("complex128", ComplexF64),
const NUMPY_SIMPLE_TYPES = Dict(
Bool => :bool_,
Int8 => :int8,
Int16 => :int16,
Int32 => :int32,
Int64 => :int64,
UInt8 => :uint8,
UInt16 => :uint16,
UInt32 => :uint32,
UInt64 => :uint64,
Float16 => :float16,
Float32 => :float32,
Float64 => :float64,
ComplexF16 => :complex32,
ComplexF32 => :complex64,
ComplexF64 => :complex128,
)

function PythonCall.pycall(
Expand All @@ -32,15 +32,10 @@ function PythonCall.pycall(
jax = jaxptr[]
numpy = jax.numpy
inputs = map((arg0, argNs...)) do arg
JT = eltype(arg)
PT = nothing
for (CPT, CJT) in NUMPY_SIMPLE_TYPES
if JT == CJT
PT = CPT
break
end
end
numpy.zeros(size(arg); dtype=getproperty(numpy, Symbol(PT)))
numpy.zeros(
size(arg);
dtype=getproperty(numpy, NUMPY_SIMPLE_TYPES[Reactant.unwrapped_eltype(arg)]),
)
end
lowered = jax.jit(f).lower(inputs...)
txt = pyconvert(String, lowered.as_text())
Expand Down
12 changes: 8 additions & 4 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,20 @@ import ..Reactant:
make_tracer,
TracedToConcrete,
append_path,
ancestor,
TracedType

@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field)
@inline function traced_getfield(@nospecialize(obj::AbstractArray{T}), field) where {T}
(isbitstype(T) || obj isa RArray) && return Base.getfield(obj, field)
(isbitstype(T) || ancestor(obj) isa RArray) && return Base.getfield(obj, field)
return Base.getindex(obj, field)
end

@inline traced_setfield!(@nospecialize(obj), field, val) = Base.setfield!(obj, field, val)
@inline function traced_setfield!(
@nospecialize(obj::AbstractArray{T}), field, val
) where {T}
(isbitstype(T) || obj isa RArray) && return Base.setfield!(obj, field, val)
(isbitstype(T) || ancestor(obj) isa RArray) && return Base.setfield!(obj, field, val)
return Base.setindex!(obj, val, field)
end

Expand Down Expand Up @@ -666,7 +667,8 @@ function codegen_unflatten!(
$cache_dict[$final_val]
else
$cache_dict[$final_val] = ConcreteRArray{
eltype($final_val),ndims($final_val)
$(Reactant.unwrapped_eltype)($final_val),
ndims($final_val),
}(
$concrete_res_name, size($final_val)
)
Expand All @@ -677,7 +679,9 @@ function codegen_unflatten!(
$clocal = if haskey($cache_dict, $final_val)
$cache_dict[$final_val]
else
$cache_dict[$final_val] = ConcreteRNumber{eltype($final_val)}(
$cache_dict[$final_val] = ConcreteRNumber{
$(Reactant.unwrapped_eltype)($final_val)
}(
$concrete_res_name
)
$cache_dict[$final_val]
Expand Down
23 changes: 5 additions & 18 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,3 @@
struct XLAArray{T,N} <: RArray{T,N}
# size::NTuple{N,Int}
end

mutable struct ConcreteRArray{T,N} <: RArray{T,N}
data::XLA.AsyncBuffer
# data::XLAArray{T, N}
shape::NTuple{N,Int}
end

const WrappedConcreteRArray{T,N} = WrappedArray{T,N,ConcreteRArray,ConcreteRArray{T,N}}
const AnyConcreteRArray{T,N} = Union{ConcreteRArray{T,N},WrappedConcreteRArray{T,N}}

mutable struct ConcreteRNumber{T} <: RNumber{T}
data::XLA.AsyncBuffer
end

function ConcreteRNumber{T}(
data::T2; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing
) where {T<:Number,T2<:Number}
Expand All @@ -29,7 +12,9 @@ function ConcreteRNumber(
return ConcreteRNumber{T}(crarray.data)
end

Base.collect(x::ConcreteRNumber{T}) where {T} = ConcreteRArray{T,0}(copy(x).data, ())
function Base.collect(x::ConcreteRNumber{T}) where {T}
return collect(ConcreteRArray{T,0}(copy(x).data, ()))
end

Base.size(::ConcreteRNumber) = ()
Base.real(x::ConcreteRNumber{<:Real}) = x
Expand Down Expand Up @@ -323,6 +308,8 @@ function Base.copyto!(dest::ConcreteRArray, src::ConcreteRArray)
return dest
end

Base.collect(x::AnyConcreteRArray) = convert(Array, x)

function Base.mapreduce(
@nospecialize(f),
@nospecialize(op),
Expand Down
6 changes: 4 additions & 2 deletions src/ControlFlow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,12 @@ function get_region_removing_missing_values(compiled_fn, insertions)
return_op = MLIR.IR.terminator(block)
for (i, rt) in insertions
if rt isa TracedRNumber
attr = MLIR.IR.DenseElementsAttribute(Array{eltype(rt)}(undef, ()))
attr = MLIR.IR.DenseElementsAttribute(Array{unwrapped_eltype(rt)}(undef, ()))
op = MLIR.Dialects.stablehlo.constant(; value=attr)
elseif rt isa TracedRArray
attr = MLIR.IR.DenseElementsAttribute(Array{eltype(rt)}(undef, size(rt)))
attr = MLIR.IR.DenseElementsAttribute(
Array{unwrapped_eltype(rt)}(undef, size(rt))
)
op = MLIR.Dialects.stablehlo.constant(; value=attr)
else
error("Unknown type $(typeof(rt))")
Expand Down
6 changes: 3 additions & 3 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ end
function push_acts!(ad_inputs, x::BatchDuplicated, path, reverse)
TracedUtils.push_val!(ad_inputs, x.val, path)
if !reverse
ET = eltype(x.val)
ET = unwrapped_eltype(x.val)
predims = size(x.val)
cval = MLIR.IR.result(
MLIR.Dialects.stablehlo.concatenate(
Expand All @@ -182,7 +182,7 @@ end
function push_acts!(ad_inputs, x::BatchDuplicatedNoNeed, path, reverse)
TracedUtils.push_val!(ad_inputs, x.val, path)
if !reverse
ET = eltype(x.val)
ET = unwrapped_eltype(x.val)
predims = size(x.val)
cval = MLIR.IR.result(
MLIR.Dialects.stablehlo.concatenate(
Expand Down Expand Up @@ -298,7 +298,7 @@ function overload_autodiff(
act = act_from_type(A, reverse, needs_primal(CMode))
push!(ret_activity, act)
if act == enzyme_out || act == enzyme_outnoneed
attr = fill(MLIR.IR.Attribute(eltype(a)(1)), Ops.mlir_type(a))
attr = fill(MLIR.IR.Attribute(unwrapped_eltype(a)(1)), Ops.mlir_type(a))
cst = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
push!(ad_inputs, cst)
end
Expand Down
25 changes: 15 additions & 10 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,28 @@
module Ops
using ..MLIR: MLIR
using ..MLIR.Dialects: stablehlo, chlo, enzyme
using ..Reactant: Reactant, TracedRArray, TracedRNumber, RArray, RNumber, MissingTracedValue

function mlir_type(x::RArray{T,N}) where {T,N}
return MLIR.IR.TensorType(size(x), MLIR.IR.Type(T))
using ..Reactant:
Reactant,
TracedRArray,
TracedRNumber,
RArray,
RNumber,
MissingTracedValue,
unwrapped_eltype

function mlir_type(x::Union{RNumber,RArray})
return MLIR.IR.TensorType(size(x), MLIR.IR.Type(unwrapped_eltype(x)))
end

mlir_type(::RNumber{T}) where {T} = MLIR.IR.TensorType((), MLIR.IR.Type(T))

mlir_type(::MissingTracedValue) = MLIR.IR.TensorType((), MLIR.IR.Type(Bool))

function mlir_type(::Type{<:RArray{T,N}}, shape) where {T,N}
function mlir_type(RT::Type{<:RArray{T,N}}, shape) where {T,N}
@assert length(shape) == N
return MLIR.IR.TensorType(shape, MLIR.IR.Type(T))
return MLIR.IR.TensorType(shape, MLIR.IR.Type(unwrapped_eltype(RT)))
end

function mlir_type(::Type{<:RNumber{T}}) where {T}
return MLIR.IR.TensorType((), MLIR.IR.Type(T))
function mlir_type(RT::Type{<:RNumber})
return MLIR.IR.TensorType((), MLIR.IR.Type(unwrapped_eltype(RT)))
end

function mlir_type(::Type{<:MissingTracedValue})
Expand Down
78 changes: 63 additions & 15 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,24 +54,48 @@ else
}
end

abstract type RArray{T<:ReactantPrimitive,N} <: AbstractArray{T,N} end
abstract type RNumber{T<:ReactantPrimitive} <: Number end

Base.collect(A::RArray) = copy(A)
abstract type RArray{T,N} <: AbstractArray{T,N} end

function ancestor(x::AbstractArray)
p_x = parent(x)
p_x === x && return x
return ancestor(p_x)
end

function ancestor(T::Type{<:AbstractArray})
if applicable(Adapt.parent_type, T)
p_T = Adapt.parent_type(T)
p_T == T && return T
return ancestor(p_T)
end
@warn "`Adapt.parent_type` is not implemented for $(T). Assuming $T isn't a wrapped \
array." maxlog = 1
return T
end

include("mlir/MLIR.jl")
include("XLA.jl")
include("Interpreter.jl")

include("utils.jl")

mutable struct TracedRArray{T,N} <: RArray{T,N}
mutable struct TracedRNumber{T} <: RNumber{T}
paths::Tuple
mlir_data::Union{Nothing,MLIR.IR.Value}

function TracedRNumber{T}(
paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}
) where {T}
if !isnothing(mlir_data)
@assert size(MLIR.IR.type(mlir_data)) == ()
end
return new{T}(paths, mlir_data)
end
end

mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N}
paths::Tuple
mlir_data::Union{Nothing,MLIR.IR.Value}
shape::NTuple{N,Int}
Expand All @@ -87,7 +111,11 @@ mutable struct TracedRArray{T,N} <: RArray{T,N}
end
end

const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
Adapt.parent_type(::Type{TracedRArray{T,N}}) where {T,N} = TracedRArray{T,N}

const WrappedTracedRArray{T,N} = WrappedArray{
TracedRNumber{T},N,TracedRArray,TracedRArray{T,N}
}
const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
const AnyTracedRVector{T} = AnyTracedRArray{T,1}
const AnyTracedRMatrix{T} = Union{
Expand All @@ -102,20 +130,40 @@ function TracedRArray(data::MLIR.IR.Value)
)
end

mutable struct TracedRNumber{T} <: RNumber{T}
paths::Tuple
mlir_data::Union{Nothing,MLIR.IR.Value}
struct XLAArray{T,N} <: RArray{T,N} end

function TracedRNumber{T}(
paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}
) where {T}
if !isnothing(mlir_data)
@assert size(MLIR.IR.type(mlir_data)) == ()
end
return new{T}(paths, mlir_data)
end
Adapt.parent_type(::Type{XLAArray{T,N}}) where {T,N} = XLAArray{T,N}

mutable struct ConcreteRNumber{T} <: RNumber{T}
data::XLA.AsyncBuffer
end

mutable struct ConcreteRArray{T,N} <: RArray{T,N}
data::XLA.AsyncBuffer
shape::NTuple{N,Int}
end

Adapt.parent_type(::Type{ConcreteRArray{T,N}}) where {T,N} = ConcreteRArray{T,N}

const WrappedConcreteRArray{T,N} = WrappedArray{T,N,ConcreteRArray,ConcreteRArray{T,N}}
const AnyConcreteRArray{T,N} = Union{ConcreteRArray{T,N},WrappedConcreteRArray{T,N}}

unwrapped_eltype(::Type{T}) where {T<:Number} = T
unwrapped_eltype(::Type{<:RNumber{T}}) where {T} = T
unwrapped_eltype(::Type{<:TracedRNumber{T}}) where {T} = T

unwrapped_eltype(::T) where {T<:Number} = T
unwrapped_eltype(::RNumber{T}) where {T} = T
unwrapped_eltype(::TracedRNumber{T}) where {T} = T

unwrapped_eltype(::Type{<:RArray{T,N}}) where {T,N} = T
unwrapped_eltype(::Type{<:AbstractArray{T,N}}) where {T,N} = unwrapped_eltype(T)
unwrapped_eltype(::Type{<:AnyTracedRArray{T,N}}) where {T,N} = T

unwrapped_eltype(::RArray{T,N}) where {T,N} = T
unwrapped_eltype(::AbstractArray{T,N}) where {T,N} = unwrapped_eltype(T)
unwrapped_eltype(::AnyTracedRArray{T,N}) where {T,N} = T

include("Ops.jl")
include("TracedUtils.jl")

Expand Down
Loading

0 comments on commit b0a58bd

Please sign in to comment.