diff --git a/Project.toml b/Project.toml index ddc27a0b6..d47c765a0 100644 --- a/Project.toml +++ b/Project.toml @@ -24,7 +24,7 @@ ArrayInterface = "7.10" CEnum = "0.4, 0.5" Cassette = "0.3" Enzyme = "0.11, 0.12" -NNlib = "0.9.17" +NNlib = "0.9" Preferences = "1.4" Reactant_jll = "0.0.6" julia = "1" diff --git a/src/overloads.jl b/src/overloads.jl index bd7a3a8ea..ca948f43c 100644 --- a/src/overloads.jl +++ b/src/overloads.jl @@ -549,6 +549,12 @@ Cassette.overdub(context::TraceCtx, f::typeof(Base.materialize!), args...) = f(a @inline Base.copyto!(dest::TracedRArray, bc::Broadcasted{Nothing}) = _copyto!(dest, bc) # Keep it for ArrayConflict +@inline function Base.copyto!(dest::TracedRArray{ElType, Shape, N}, + src::TracedRArray{ElType, Shape, N}) where {ElType, Shape, N} + dest.mlir_data = src.mlir_data + return dest +end + @inline function broadcast_to_size(arg::AbstractArray, rsize) attr = MLIR.IR.DenseElementsAttribute(arg) len = ndims(arg)