From f23abb596793a993806e2c10c988a7e45631d56b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 27 May 2024 12:01:01 -0700 Subject: [PATCH] Add a make_zero path for TracedRArray --- src/Reactant.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/Reactant.jl b/src/Reactant.jl index 908161389..e76820962 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -138,6 +138,13 @@ end seen[prev] = res return res end + + if RT <: TracedRArray + @assert typeof(prev) <: RT # prev has concrete shape while RT might not have that + res = typeof(prev)(prev.paths, prev.mlir_data) + seen[prev] = res + return res + end attr = fill(MLIR.IR.Attribute(eltype(RT)(0)), mlir_type(prev)) cst = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(value=attr), 1) @@ -146,8 +153,6 @@ end return res end - - function Base.promote_rule(A::Type{TracedRArray{T, Shape, N}}, B::Type{TracedRArray{S, Shape, N}}) where {T, S, Shape, N} TracedRArray{Base.promote_type(T, S), Shape, N} end