From 6f99e332d9a991fe00272f9ac14e1a7e2ba75622 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 27 May 2024 16:49:11 -0700 Subject: [PATCH] Add copyto! --- Project.toml | 2 +- src/overloads.jl | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ddc27a0b6..d47c765a0 100644 --- a/Project.toml +++ b/Project.toml @@ -24,7 +24,7 @@ ArrayInterface = "7.10" CEnum = "0.4, 0.5" Cassette = "0.3" Enzyme = "0.11, 0.12" -NNlib = "0.9.17" +NNlib = "0.9" Preferences = "1.4" Reactant_jll = "0.0.6" julia = "1" diff --git a/src/overloads.jl b/src/overloads.jl index bd7a3a8ea..ca948f43c 100644 --- a/src/overloads.jl +++ b/src/overloads.jl @@ -549,6 +549,12 @@ Cassette.overdub(context::TraceCtx, f::typeof(Base.materialize!), args...) = f(a @inline Base.copyto!(dest::TracedRArray, bc::Broadcasted{Nothing}) = _copyto!(dest, bc) # Keep it for ArrayConflict +@inline function Base.copyto!(dest::TracedRArray{ElType, Shape, N}, + src::TracedRArray{ElType, Shape, N}) where {ElType, Shape, N} + dest.mlir_data = src.mlir_data + return dest +end + @inline function broadcast_to_size(arg::AbstractArray, rsize) attr = MLIR.IR.DenseElementsAttribute(arg) len = ndims(arg)