diff --git a/src/ControlFlow.jl b/src/ControlFlow.jl index 3b30c4cb6..7c3657fde 100644 --- a/src/ControlFlow.jl +++ b/src/ControlFlow.jl @@ -142,10 +142,12 @@ function get_region_removing_missing_values(compiled_fn, insertions) return_op = MLIR.IR.terminator(block) for (i, rt) in insertions if rt isa TracedRNumber - attr = MLIR.IR.DenseElementsAttribute(Array{eltype(rt)}(undef, ())) + attr = MLIR.IR.DenseElementsAttribute(Array{concrete_eltype(rt)}(undef, ())) op = MLIR.Dialects.stablehlo.constant(; value=attr) elseif rt isa TracedRArray - attr = MLIR.IR.DenseElementsAttribute(Array{eltype(rt)}(undef, size(rt))) + attr = MLIR.IR.DenseElementsAttribute( + Array{concrete_eltype(rt)}(undef, size(rt)) + ) op = MLIR.Dialects.stablehlo.constant(; value=attr) else error("Unknown type $(typeof(rt))") diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 8377f9900..83913a79e 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -326,7 +326,7 @@ end function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} if all(iszero ∘ ndims, args) scalar_args = map(args) do arg - return promote_to(TracedRNumber{eltype(arg)}, arg) + return promote_to(TracedRNumber{concrete_eltype(arg)}, arg) end return f(scalar_args...) end @@ -722,7 +722,7 @@ end function broadcast_to_size(arg::TracedRNumber, rsize) length(rsize) == 0 && return arg return broadcast_to_size_internal( - TracedRArray{eltype(arg),0}((), arg.mlir_data, ()), rsize + TracedRArray{concrete_eltype(arg),0}((), arg.mlir_data, ()), rsize ) end @@ -757,7 +757,7 @@ function broadcast_to_size_internal(x::TracedRArray, rsize) @assert length(size(MLIR.IR.type(x.mlir_data))) == length(dims) mlirty = MLIR.IR.type(x.mlir_data) - return TracedRArray{eltype(x),Int(length(rsize))}( + return TracedRArray{concrete_eltype(x),Int(length(rsize))}( (), MLIR.IR.result( MLIR.Dialects.stablehlo.broadcast_in_dim( diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 9a1c1725e..16e7c5b77 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -16,7 +16,10 @@ ReactantCore.is_traced(::TracedRNumber) = true new_traced_value(::TracedRNumber{T}) where {T} = TracedRNumber{T}((), nothing) -Base.eltype(::Type{TracedRNumber{T}}) where {T} = T +concrete_eltype(x) = eltype(x) +concrete_eltype(::TracedRNumber{T}) where {T} = T +concrete_eltype(::Type{TracedRNumber{T}}) where {T} = T +Base.eltype(::Type{T}) where {T<:TracedRNumber} = T Base.getindex(a::TracedRNumber{T}) where {T} = a @@ -210,17 +213,17 @@ end for (T1, T2) in zip((Bool, Integer), (Bool, Integer)) @eval begin function Base.:&(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)}) - return TracedRNumber{promote_type(eltype(x), eltype(y))}( + return TracedRNumber{promote_type(concrete_eltype(x), concrete_eltype(y))}( (), MLIR.IR.result(MLIR.Dialects.stablehlo.and(x.mlir_data, y.mlir_data), 1) ) end function Base.:|(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)}) - return TracedRNumber{promote_type(eltype(x), eltype(y))}( + return TracedRNumber{promote_type(concrete_eltype(x), concrete_eltype(y))}( (), MLIR.IR.result(MLIR.Dialects.stablehlo.or(x.mlir_data, y.mlir_data), 1) ) end function Base.:!(x::TracedRNumber{<:$(T1)}) - return TracedRNumber{eltype(x)}( + return TracedRNumber{concrete_eltype(x)}( (), MLIR.IR.result(MLIR.Dialects.stablehlo.not(x.mlir_data), 1) ) end