diff --git a/src/Reactant.jl b/src/Reactant.jl index e76820962..9b0757e40 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -8,6 +8,7 @@ abstract type RArray{ElType,Shape,N} <: AbstractArray{ElType, N} end @inline Base.eltype(::RArray{ElType,Shape}) where {ElType, Shape} = ElType @inline Base.size(::RArray{ElType,Shape}) where {ElType, Shape} = Shape +@inline Base.size(::Type{<:RArray{ElType,Shape}}) where {ElType, Shape} = Shape @inline Base.ndims(::RArray{ElType,Shape, N}) where {ElType, Shape, N} = N @inline Base.ndims(::Type{<:RArray{ElType,Shape, N}}) where {ElType, Shape, N} = N @@ -140,8 +141,7 @@ end end if RT <: TracedRArray - @assert typeof(prev) <: RT # prev has concrete shape while RT might not have that - res = typeof(prev)(prev.paths, prev.mlir_data) + res = broadcast_to_size(eltype(RT)(0), size(prev)) seen[prev] = res return res end