From 38f1c9c0f6b6d2cf4d13c58f98c259e534655190 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 30 Apr 2024 15:56:19 -0400 Subject: [PATCH] better overdub --- src/Reactant.jl | 5 ++++- src/overloads.jl | 21 ++++++++++++++++++--- src/utils.jl | 2 +- test/basic.jl | 32 ++++++++++++++++++++++++++++++++ test/runtests.jl | 4 ++-- 5 files changed, 57 insertions(+), 7 deletions(-) diff --git a/src/Reactant.jl b/src/Reactant.jl index ba943e7c4..e0e5e5721 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -100,7 +100,8 @@ function Base.promote_rule(A::Type{T}, B::Type{TracedRArray{S, Shape, N}}) where end function Base.show(io::IO, X::TracedRArray{ElType, Shape, N}) where {ElType, Shape, N} - print(io, "TracedRArray{", ElType, ",", Shape, ",", N, "N}(", X.paths, ", ", X.mlir_data, ")") + print(io, "TracedRArray{", ElType, ",", Shape, ",", N, "N}(", X.paths, ", ") + print(io, X.mlir_data, ")") end include("overloads.jl") @@ -636,6 +637,8 @@ function generate_jlfunc(concrete_result, client, mod, Nargs, linear_args, linea return result end end + @show func + flush(stdout) return eval(func) end diff --git a/src/overloads.jl b/src/overloads.jl index fe934f583..81ef7ac93 100644 --- a/src/overloads.jl +++ b/src/overloads.jl @@ -116,8 +116,11 @@ function Cassette.overdub(::TraceCtx, ::CMode, f::FA, ::Type{A}, args::Vararg{En end end + @show func2 res = (reverse ? MLIR.IR.enzyme.autodiff : MLIR.IR.enzyme.fwddiff)(ad_inps; outputs=outtys, fn=func2, activity=DenseArrayAttribute(activity)) + @show res + residx = 1 restup = Any[(a isa Active) ? copy(a) : nothing for a in args] @@ -227,15 +230,17 @@ function Base.:*(lhs::TracedRArray{ElType,Shape,2}, rhs::TracedRArray{ElType,Sha return TracedRArray{ElType,(Base.size(lhsty)[1], Base.size(rhsty)[2]),2}((), res) end +Cassette.overdub(context::TraceCtx, ::typeof(Base.:*), args...) = Base.*(args...) + for (jlop, hloop) in ((:(Base.:-), :negate), (:(Base.sin), :sine), (:(Base.cos), :cosine), (:(Base.tanh), :tanh), (:(Base.FastMath.tanh_fast), :tanh), (:(Base.exp), :exponential), (:(Base.FastMath.exp_fast), :exponential), (:(Base.log), :log), (:(Base.sqrt), :sqrt)) @eval begin function $jlop(lhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N} return TracedRArray{ElType,Shape,N}((), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1)) end + Cassette.overdub(context::TraceCtx, ::typeof($jlop), args...) = $jlop(args...) end end - for (jlop, hloop, RT) in ((:(Base.min), :minimum, :ElType),(:(Base.max), :maximum, :ElType), (:(Base.:+), :add, :ElType), (:(Base.add_sum), :add, :ElType), (:(Base.:-), :subtract, :ElType), (:(Base.:*), :multiply, :ElType), (:(Base.:/), :divide, :ElType)) @eval begin function elem_apply(::typeof($jlop), lhs::TracedRArray{ElType,Shape,N}, rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N} @@ -266,11 +271,12 @@ for (jlop, hloop) in ((:(Base.:-), :negate), (:(Base.sin), :sine), (:(Base.cos), end end - +Cassette.overdub(context::TraceCtx, ::typeof(elem_apply), args...) = elem_apply(args...) @inline function Base.reshape(A::RArray, dims::Tuple{Vararg{Union{Int,Colon}}}) reshape(A, Base._reshape_uncolon(A, dims)) end +Cassette.overdub(context::TraceCtx, f::typeof(Base.reshape), args...) = f(args...) @inline function Base.reshape(A::ConcreteRArray{T, Shape, N}, dims::NTuple{NT, Int}) where {T, Shape, N, NT} prod(dims) == prod(size(A)) || Base._throw_dmrsa(dims, prod(size(A))) @@ -285,7 +291,7 @@ end end Base.copy(A::TracedRArray{T, Shape, N}) where {T, Shape, N} = TracedRArray((), A.mlir_data) - +Cassette.overdub(context::TraceCtx, f::typeof(Base.copy), args...) = f(args...) @inline function Base.permutedims(A::TracedRArray{T, Shape, N}, perm) where {T, Shape, N} TracedArray{T, tuple(Shape[i] for i in perm), N}( @@ -293,6 +299,7 @@ Base.copy(A::TracedRArray{T, Shape, N}) where {T, Shape, N} = TracedRArray((), A MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(A.mlir_data, DenseArrayAttribute([Int64(i-1) for i in perm])), 1) ) end +Cassette.overdub(context::TraceCtx, f::typeof(Base.permutedims), args...) = f(args...) @inline function Base.reshape(A::TracedRArray{T, Shape, N}, dims::NTuple{NT, Int}) where {T, Shape, N, NT} prod(dims) == prod(size(A)) || Base._throw_dmrsa(dims, prod(size(A))) @@ -328,6 +335,7 @@ BroadcastStyle(::Type{T}) where {T<:TracedRArray} = AbstractReactantArrayStyle{n Base.similar(x::TracedRArray{T, Shape, N}, ::Type{T2}) where {T, Shape, N, T2} = TracedRArray{T2, Shape, N}((), nothing) +Cassette.overdub(context::TraceCtx, f::typeof(Base.similar), args...) = f(args...) @inline function Base.similar(bc::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims) where {T,N} @assert N isa Int @@ -339,6 +347,7 @@ function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle{0}}) dest = copyto!(similar(bc, ElType), bc) return dest[CartesianIndex()] # 0D broadcast needs to unwrap results end +Cassette.overdub(context::TraceCtx, f::typeof(Broadcast.copy), args...) = f(args...) @inline Base.eltype(b::Broadcast.Extruded{T}) where T = eltype(T) @@ -373,6 +382,7 @@ end @inline function Base.materialize!(::Style, dest, bc::Broadcasted) where {Style<:AbstractReactantArrayStyle} return _copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest)))) end +Cassette.overdub(context::TraceCtx, f::typeof(Base.materialize!), args...) = f(args...) @inline Base.copyto!(dest::TracedRArray, bc::Broadcasted{Nothing}) = _copyto!(dest, bc) # Keep it for ArrayConflict @@ -394,6 +404,7 @@ function Base.fill!(A::TracedRArray{T, Shape, N}, x) where {T, Shape, N} A.mlir_data = bcast.mlir_data A end +Cassette.overdub(context::TraceCtx, f::typeof(Base.fill!), args...) = f(args...) @inline function broadcast_to_size(arg::T, rsize) where T <: Number TT = MLIR.IR.TensorType([Int64(s) for s in rsize], MLIR.IR.Type(typeof(arg))) @@ -440,6 +451,8 @@ end return dest end +Cassette.overdub(context::Cassette.Context, ::Core.kwftype(typeof(Base.mapreduce)), kwargs::Any, ::typeof(Base.mapreduce), args...) = Base.mapreduce(args...; kwargs...) + function Base.mapreduce(f, op, A::TracedRArray{ElType, Shape, N}; dims=:, init=nothing) where {ElType, Shape, N} if dims isa Int dims = [dims] @@ -491,6 +504,7 @@ function Base.mapreduce(f, op, A::TracedRArray{ElType, Shape, N}; dims=:, init=n else red = TracedRArray{ElType, (outdims...,), length(outdims)}((), red) end + @show red return red end @@ -499,3 +513,4 @@ function Base.mapreducedim!(f, op, R::TracedRArray, A::Base.AbstractArrayOrBroad R.mlir_data = elem_apply(op, R, tmp).mlir_data return R end +Cassette.overdub(context::TraceCtx, f::typeof(Base.mapreducedim!), args...) = f(args...) diff --git a/src/utils.jl b/src/utils.jl index 3e2992721..099b06712 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -45,7 +45,7 @@ function make_mlir_fn(mod, f, args, kwargs, name="main", concretein=true) arg.mlir_data = row_maj_arg end - f(traced_args...; kwargs...) + Cassette.overdub(TraceCtx(), f, traced_args...; kwargs...) end seen_results = IdDict() diff --git a/test/basic.jl b/test/basic.jl index 75148b709..d0fc00348 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -48,4 +48,36 @@ end @test r ≈ cos.(ones(3,2)) end +function sumcos(x) + return sum(cos.(x)) +end + +function grad_ip(x) + dx = Enzyme.make_zero(x) + Enzyme.autodiff(Reverse, sumcos, Active, Duplicated(x, dx)) + return dx +end + +function resgrad_ip(f, x) + dx = Enzyme.make_zero(x) + res = Enzyme.autodiff(ReverseWithPrimal, sumcos, Active, Duplicated(x, dx)) + return (res, dx) +end + +@testset "Basic grad cos" begin + c = Reactant.ConcreteRArray(ones(3,2)) + + f=Reactant.compile(grad_ip, (c,)) + r = f(c) + @show r + + @test r ≈ sin.(ones(3,2)) + + f=Reactant.compile(resgrad_ip, (c,)) + orig, r = f(c) + @show r + + @test orig[1] ≈ cos.(ones(3,2)) + @test r ≈ sin.(ones(3,2)) +end diff --git a/test/runtests.jl b/test/runtests.jl index 412de6427..a24fdb77b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ using Reactant #include("layout.jl") -#include("basic.jl") +include("basic.jl") #include("bcast.jl") -include("nn.jl") +#include("nn.jl")