Skip to content

Commit

Permalink
better overdub
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Apr 30, 2024
1 parent deee3bf commit 38f1c9c
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 7 deletions.
5 changes: 4 additions & 1 deletion src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ function Base.promote_rule(A::Type{T}, B::Type{TracedRArray{S, Shape, N}}) where
end

function Base.show(io::IO, X::TracedRArray{ElType, Shape, N}) where {ElType, Shape, N}
print(io, "TracedRArray{", ElType, ",", Shape, ",", N, "N}(", X.paths, ", ", X.mlir_data, ")")
print(io, "TracedRArray{", ElType, ",", Shape, ",", N, "N}(", X.paths, ", ")
print(io, X.mlir_data, ")")
end

include("overloads.jl")
Expand Down Expand Up @@ -636,6 +637,8 @@ function generate_jlfunc(concrete_result, client, mod, Nargs, linear_args, linea
return result
end
end
@show func
flush(stdout)
return eval(func)
end

Expand Down
21 changes: 18 additions & 3 deletions src/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,11 @@ function Cassette.overdub(::TraceCtx, ::CMode, f::FA, ::Type{A}, args::Vararg{En
end
end

@show func2
res = (reverse ? MLIR.IR.enzyme.autodiff : MLIR.IR.enzyme.fwddiff)(ad_inps; outputs=outtys, fn=func2, activity=DenseArrayAttribute(activity))

@show res

residx = 1

restup = Any[(a isa Active) ? copy(a) : nothing for a in args]
Expand Down Expand Up @@ -227,15 +230,17 @@ function Base.:*(lhs::TracedRArray{ElType,Shape,2}, rhs::TracedRArray{ElType,Sha
return TracedRArray{ElType,(Base.size(lhsty)[1], Base.size(rhsty)[2]),2}((), res)
end

Cassette.overdub(context::TraceCtx, ::typeof(Base.:*), args...) = Base.*(args...)

for (jlop, hloop) in ((:(Base.:-), :negate), (:(Base.sin), :sine), (:(Base.cos), :cosine), (:(Base.tanh), :tanh), (:(Base.FastMath.tanh_fast), :tanh), (:(Base.exp), :exponential), (:(Base.FastMath.exp_fast), :exponential), (:(Base.log), :log), (:(Base.sqrt), :sqrt))
@eval begin
function $jlop(lhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N}
return TracedRArray{ElType,Shape,N}((), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1))
end
Cassette.overdub(context::TraceCtx, ::typeof($jlop), args...) = $jlop(args...)
end
end


for (jlop, hloop, RT) in ((:(Base.min), :minimum, :ElType),(:(Base.max), :maximum, :ElType), (:(Base.:+), :add, :ElType), (:(Base.add_sum), :add, :ElType), (:(Base.:-), :subtract, :ElType), (:(Base.:*), :multiply, :ElType), (:(Base.:/), :divide, :ElType))
@eval begin
function elem_apply(::typeof($jlop), lhs::TracedRArray{ElType,Shape,N}, rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N}
Expand Down Expand Up @@ -266,11 +271,12 @@ for (jlop, hloop) in ((:(Base.:-), :negate), (:(Base.sin), :sine), (:(Base.cos),
end
end


Cassette.overdub(context::TraceCtx, ::typeof(elem_apply), args...) = elem_apply(args...)

@inline function Base.reshape(A::RArray, dims::Tuple{Vararg{Union{Int,Colon}}})
reshape(A, Base._reshape_uncolon(A, dims))
end
Cassette.overdub(context::TraceCtx, f::typeof(Base.reshape), args...) = f(args...)

@inline function Base.reshape(A::ConcreteRArray{T, Shape, N}, dims::NTuple{NT, Int}) where {T, Shape, N, NT}
prod(dims) == prod(size(A)) || Base._throw_dmrsa(dims, prod(size(A)))
Expand All @@ -285,14 +291,15 @@ end
end

Base.copy(A::TracedRArray{T, Shape, N}) where {T, Shape, N} = TracedRArray((), A.mlir_data)

Cassette.overdub(context::TraceCtx, f::typeof(Base.copy), args...) = f(args...)

@inline function Base.permutedims(A::TracedRArray{T, Shape, N}, perm) where {T, Shape, N}
TracedArray{T, tuple(Shape[i] for i in perm), N}(
(),
MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(A.mlir_data, DenseArrayAttribute([Int64(i-1) for i in perm])), 1)
)
end
Cassette.overdub(context::TraceCtx, f::typeof(Base.permutedims), args...) = f(args...)

@inline function Base.reshape(A::TracedRArray{T, Shape, N}, dims::NTuple{NT, Int}) where {T, Shape, N, NT}
prod(dims) == prod(size(A)) || Base._throw_dmrsa(dims, prod(size(A)))
Expand Down Expand Up @@ -328,6 +335,7 @@ BroadcastStyle(::Type{T}) where {T<:TracedRArray} = AbstractReactantArrayStyle{n


Base.similar(x::TracedRArray{T, Shape, N}, ::Type{T2}) where {T, Shape, N, T2} = TracedRArray{T2, Shape, N}((), nothing)
Cassette.overdub(context::TraceCtx, f::typeof(Base.similar), args...) = f(args...)

@inline function Base.similar(bc::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims) where {T,N}
@assert N isa Int
Expand All @@ -339,6 +347,7 @@ function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle{0}})
dest = copyto!(similar(bc, ElType), bc)
return dest[CartesianIndex()] # 0D broadcast needs to unwrap results
end
Cassette.overdub(context::TraceCtx, f::typeof(Broadcast.copy), args...) = f(args...)

@inline Base.eltype(b::Broadcast.Extruded{T}) where T = eltype(T)

Expand Down Expand Up @@ -373,6 +382,7 @@ end
@inline function Base.materialize!(::Style, dest, bc::Broadcasted) where {Style<:AbstractReactantArrayStyle}
return _copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest))))
end
Cassette.overdub(context::TraceCtx, f::typeof(Base.materialize!), args...) = f(args...)

@inline Base.copyto!(dest::TracedRArray, bc::Broadcasted{Nothing}) =
_copyto!(dest, bc) # Keep it for ArrayConflict
Expand All @@ -394,6 +404,7 @@ function Base.fill!(A::TracedRArray{T, Shape, N}, x) where {T, Shape, N}
A.mlir_data = bcast.mlir_data
A
end
Cassette.overdub(context::TraceCtx, f::typeof(Base.fill!), args...) = f(args...)

@inline function broadcast_to_size(arg::T, rsize) where T <: Number
TT = MLIR.IR.TensorType([Int64(s) for s in rsize], MLIR.IR.Type(typeof(arg)))
Expand Down Expand Up @@ -440,6 +451,8 @@ end
return dest
end

Cassette.overdub(context::Cassette.Context, ::Core.kwftype(typeof(Base.mapreduce)), kwargs::Any, ::typeof(Base.mapreduce), args...) = Base.mapreduce(args...; kwargs...)

function Base.mapreduce(f, op, A::TracedRArray{ElType, Shape, N}; dims=:, init=nothing) where {ElType, Shape, N}
if dims isa Int
dims = [dims]
Expand Down Expand Up @@ -491,6 +504,7 @@ function Base.mapreduce(f, op, A::TracedRArray{ElType, Shape, N}; dims=:, init=n
else
red = TracedRArray{ElType, (outdims...,), length(outdims)}((), red)
end
@show red
return red
end

Expand All @@ -499,3 +513,4 @@ function Base.mapreducedim!(f, op, R::TracedRArray, A::Base.AbstractArrayOrBroad
R.mlir_data = elem_apply(op, R, tmp).mlir_data
return R
end
Cassette.overdub(context::TraceCtx, f::typeof(Base.mapreducedim!), args...) = f(args...)
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function make_mlir_fn(mod, f, args, kwargs, name="main", concretein=true)
arg.mlir_data = row_maj_arg
end

f(traced_args...; kwargs...)
Cassette.overdub(TraceCtx(), f, traced_args...; kwargs...)
end

seen_results = IdDict()
Expand Down
32 changes: 32 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,36 @@ end
@test r cos.(ones(3,2))
end

function sumcos(x)
return sum(cos.(x))
end

function grad_ip(x)
dx = Enzyme.make_zero(x)
Enzyme.autodiff(Reverse, sumcos, Active, Duplicated(x, dx))
return dx
end

function resgrad_ip(f, x)
dx = Enzyme.make_zero(x)
res = Enzyme.autodiff(ReverseWithPrimal, sumcos, Active, Duplicated(x, dx))
return (res, dx)
end

@testset "Basic grad cos" begin
c = Reactant.ConcreteRArray(ones(3,2))


f=Reactant.compile(grad_ip, (c,))
r = f(c)
@show r

@test r sin.(ones(3,2))

f=Reactant.compile(resgrad_ip, (c,))
orig, r = f(c)
@show r

@test orig[1] cos.(ones(3,2))
@test r sin.(ones(3,2))
end
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Reactant

#include("layout.jl")
#include("basic.jl")
include("basic.jl")
#include("bcast.jl")
include("nn.jl")
#include("nn.jl")

0 comments on commit 38f1c9c

Please sign in to comment.