Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
jumerckx committed Dec 14, 2024
1 parent 585c485 commit 0c56d35
Show file tree
Hide file tree
Showing 12 changed files with 279 additions and 182 deletions.
18 changes: 5 additions & 13 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,9 @@ module ReactantNNlibExt

using NNlib
using GPUArraysCore: @allowscalar
using Reactant:
Reactant,
Ops,
TracedRArray,
AnyTracedRArray,
MLIR,
TracedRNumber

using Reactant.TracedUtils:
materialize_traced_array,
get_mlir_data,
set_mlir_data!
using Reactant: Reactant, Ops, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber

using Reactant.TracedUtils: materialize_traced_array, get_mlir_data, set_mlir_data!

using ReactantCore: @trace
using LinearAlgebra: LinearAlgebra, triu
Expand Down Expand Up @@ -332,7 +323,8 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr
start_sizes = ntuple(i -> size(src, i), dims)
results = map(CartesianIndices(idxs)) do k
res = @allowscalar src[colons..., Tuple(idxs[k])...]
res isa TracedRNumber && (res = Reactant.TracedUtils.broadcast_to_size(res, (1,)))
res isa TracedRNumber &&
(res = Reactant.TracedUtils.broadcast_to_size(res, (1,)))
return reshape(res, start_sizes..., :)
end
res = reshape(cat(results...; dims=(dims + 1)), size(dst))
Expand Down
4 changes: 3 additions & 1 deletion lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ function trace_for(mod, expr)

all_syms = Expr(:tuple, counter, external_syms...)
args_init = Expr(
:tuple, :(Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{Int}, 0)), external_syms...
:tuple,
:(Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{Int}, 0)),
external_syms...,
)

reactant_code_block = quote
Expand Down
4 changes: 3 additions & 1 deletion src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,9 @@ function compile(f, args; client=nothing, optimize=true, sync=false)
end

# Compiling within a compile should return simply the original function
Reactant.@reactant_override function Reactant.Compiler.compile(f, args; client=nothing, optimize=true, sync=false)
Reactant.@reactant_override function Reactant.Compiler.compile(
f, args; client=nothing, optimize=true, sync=false
)
return f
end

Expand Down
67 changes: 32 additions & 35 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,8 @@ function set_reactant_abi(
# Improve inference by considering call_with_reactant as having the same results as
# the original call
if f === Reactant.call_with_reactant
arginfo2 = ArgInfo(
fargs isa Nothing ? nothing :
fargs[2:end],
argtypes[2:end],
)
return abstract_call(
interp,
arginfo2::ArgInfo,
si,
sv,
max_methods,
)
arginfo2 = ArgInfo(fargs isa Nothing ? nothing : fargs[2:end], argtypes[2:end])
return abstract_call(interp, arginfo2::ArgInfo, si, sv, max_methods)
end

return Base.@invoke abstract_call_known(
Expand Down Expand Up @@ -280,7 +270,12 @@ function overload_autodiff(
if width == 1
push!(outtys, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data)))
else
push!(outtys, TracedUtils.batch_ty(width, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data))))
push!(
outtys,
TracedUtils.batch_ty(
width, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data))
),
)
end
end
else
Expand Down Expand Up @@ -393,13 +388,21 @@ function overload_autodiff(
else
idx, path = TracedUtils.get_argidx(a)
if idx == 1 && fnwrap
TracedUtils.set!(f.val, path[3:end], TracedUtils.transpose_val(MLIR.IR.result(res, residx)))
TracedUtils.set!(
f.val,
path[3:end],
TracedUtils.transpose_val(MLIR.IR.result(res, residx)),
)
residx += 1
else
if fnwrap
idx -= 1
end
TracedUtils.set!(args[idx].val, path[3:end], TracedUtils.transpose_val(MLIR.IR.result(res, residx)))
TracedUtils.set!(
args[idx].val,
path[3:end],
TracedUtils.transpose_val(MLIR.IR.result(res, residx)),
)
residx += 1
end
end
Expand All @@ -417,7 +420,12 @@ function overload_autodiff(
residx += 1
continue
end
set_act!(f, path[3:end], reverse, TracedUtils.transpose_val(MLIR.IR.result(res, residx)))
set_act!(
f,
path[3:end],
reverse,
TracedUtils.transpose_val(MLIR.IR.result(res, residx)),
)
else
if fnwrap
idx -= 1
Expand All @@ -437,7 +445,10 @@ function overload_autodiff(
continue
end
set_act!(
args[idx], path[3:end], reverse, TracedUtils.transpose_val(MLIR.IR.result(res, residx))
args[idx],
path[3:end],
reverse,
TracedUtils.transpose_val(MLIR.IR.result(res, residx)),
)
end
residx += 1
Expand Down Expand Up @@ -470,27 +481,13 @@ function overload_autodiff(
end

@reactant_override @noinline function Enzyme.autodiff_deferred(
rmode::Enzyme.Mode,
f::FA,
rt::Type{A},
args::Vararg{Annotation,Nargs},
) where {
FA<:Annotation,
A<:Annotation,
Nargs,
}
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
) where {FA<:Annotation,A<:Annotation,Nargs}
return overload_autodiff(rmode, f, rt, args...)
end

@reactant_override @noinline function Enzyme.autodiff(
rmode::Enzyme.Mode,
f::FA,
rt::Type{A},
args::Vararg{Annotation,Nargs},
) where {
FA<:Annotation,
A<:Annotation,
Nargs,
}
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
) where {FA<:Annotation,A<:Annotation,Nargs}
return overload_autodiff(rmode, f, rt, args...)
end
26 changes: 14 additions & 12 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,7 @@
module Ops
using ..MLIR: MLIR
using ..MLIR.Dialects: stablehlo, chlo, enzyme
using ..Reactant:
Reactant,
TracedRArray,
TracedRNumber,
RArray,
RNumber,
MissingTracedValue
using ..Reactant: Reactant, TracedRArray, TracedRNumber, RArray, RNumber, MissingTracedValue

function mlir_type(x::RArray{T,N}) where {T,N}
return MLIR.IR.TensorType(size(x), MLIR.IR.Type(T))
Expand Down Expand Up @@ -569,7 +563,9 @@ end
return TracedRNumber{T}((), res)
end

@noinline function clamp(min::T, x::Union{TracedRArray{T,N},TracedRNumber{T}}, max::T) where {T,N}
@noinline function clamp(
min::T, x::Union{TracedRArray{T,N},TracedRNumber{T}}, max::T
) where {T,N}
return clamp(constant(min), x, constant(max))
end

Expand Down Expand Up @@ -818,17 +814,23 @@ end
# end

# paralell ops
@noinline function partition_id(; location=mlir_stacktrace("partition_id", @__FILE__, @__LINE__))
@noinline function partition_id(;
location=mlir_stacktrace("partition_id", @__FILE__, @__LINE__)
)
res = MLIR.IR.result(stablehlo.partition_id(; location))
return TracedRNumber{UInt32}((), res)
end

@noinline function replica_id(; location=mlir_stacktrace("replica_id", @__FILE__, @__LINE__))
@noinline function replica_id(;
location=mlir_stacktrace("replica_id", @__FILE__, @__LINE__)
)
res = MLIR.IR.result(stablehlo.replica_id(; location))
return TracedRNumber{UInt32}((), res)
end

@noinline function after_all(tokens...; location=mlir_stacktrace("after_all", @__FILE__, @__LINE__))
@noinline function after_all(
tokens...; location=mlir_stacktrace("after_all", @__FILE__, @__LINE__)
)
tokens = [token.mlir_data for token in tokens]
res = MLIR.IR.result(stablehlo.after_all(tokens; location))
return Token(res)
Expand Down Expand Up @@ -1078,7 +1080,7 @@ end
comparison_direction::String,
compare_type=nothing,
location=mlir_stacktrace("compare", @__FILE__, @__LINE__),
) where {AT <: Union{TracedRArray,TracedRNumber}}
) where {AT<:Union{TracedRArray,TracedRNumber}}
@assert comparison_direction in ("EQ", "NE", "GE", "GT", "LE", "LT")
@assert size(lhs) == size(rhs)

Expand Down
1 change: 0 additions & 1 deletion src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ include("TracedRArray.jl")

include("ConcreteRArray.jl")


include("linear_algebra.jl")

const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue}
Expand Down
32 changes: 20 additions & 12 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ using ..TracedUtils
import ..Ops
import ..MLIR
import ..ancestor
import ReactantCore
using ReactantCore: ReactantCore
import ..TracedUtils: materialize_traced_array
import GPUArraysCore
using GPUArraysCore: GPUArraysCore

ReactantCore.is_traced(::TracedRArray) = true

Expand All @@ -31,13 +31,14 @@ end

TracedRArray{T,N}(x::AbstractArray) where {T,N} = convert(TracedRArray{T,N}, x)


function Base.getindex(
a::TracedRArray{T,N}, index::Vararg{Union{Int,TracedRNumber{Int}},N}
) where {T,N}
GPUArraysCore.assertscalar("getindex(::TracedRArray, ::Vararg{Int, N})")

start_indices = [TracedUtils.promote_to(TracedRNumber{Int}, i - 1).mlir_data for i in index]
start_indices = [
TracedUtils.promote_to(TracedRNumber{Int}, i - 1).mlir_data for i in index
]
slice_sizes = [Int64(1) for _ in index]

res1 = MLIR.IR.result(
Expand Down Expand Up @@ -107,8 +108,9 @@ function Base.setindex!(
v = TracedUtils.broadcast_to_size(v, length.(indices))
v = TracedUtils.promote_to(TracedRArray{T,N}, v)
indices = [
(TracedUtils.promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1).mlir_data for
i in indices
(
TracedUtils.promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1
).mlir_data for i in indices
]
res = MLIR.IR.result(
MLIR.Dialects.stablehlo.dynamic_update_slice(
Expand Down Expand Up @@ -168,7 +170,9 @@ Base.imag(A::AnyTracedRArray) = zero(A)
Base.imag(A::AnyTracedRArray{<:Complex}) = Ops.imag(materialize_traced_array(A))

TracedUtils.promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} = TracedRArray{T,N}(rhs)
TracedUtils.promote_to(::TracedRArray{T,N}, rhs) where {T,N} = TracedUtils.promote_to(TracedRArray{T,N}, rhs)
function TracedUtils.promote_to(::TracedRArray{T,N}, rhs) where {T,N}
return TracedUtils.promote_to(TracedRArray{T,N}, rhs)
end

for (jlop, hloop, hlocomp, merge) in
((:(Base.:(==)), :compare, "EQ", :all), (:(Base.:(!=)), :compare, "NE", :any))
Expand Down Expand Up @@ -214,12 +218,12 @@ function Base.mapreduce(
rdims = Int64[]

if dims == (:)
for i in 0:(N-1)
for i in 0:(N - 1)
push!(rdims, i)
end
else
for i in dims
push!(rdims, i-1)
push!(rdims, i - 1)
end
end

Expand All @@ -244,7 +248,7 @@ function Base.mapreduce(
toonedims = Int[]
outdims = Int[]
for i in 1:N
tmp = if in(i-1, rdims)
tmp = if in(i - 1, rdims)
1
else
sz = size(A, i)
Expand Down Expand Up @@ -283,7 +287,9 @@ function Base.mapreducedim!(
@nospecialize(R::TracedRArray),
A::Base.AbstractArrayOrBroadcasted,
)
tmp = TracedUtils.broadcast_to_size(Base.mapreduce(f, op, A; dims=1), (1, size(R)[2:end]...))
tmp = TracedUtils.broadcast_to_size(
Base.mapreduce(f, op, A; dims=1), (1, size(R)[2:end]...)
)
R.mlir_data = broadcast(op, R, tmp).mlir_data
return R
end
Expand All @@ -295,7 +301,9 @@ function Base.fill!(A::TracedRArray{T,N}, x) where {T,N}
end

function Base.fill!(A::TracedRArray{T,N}, x::TracedRNumber{T2}) where {T,N,T2}
bcast = TracedUtils.broadcast_to_size(TracedUtils.promote_to(TracedRNumber{T}, x), size(A))
bcast = TracedUtils.broadcast_to_size(
TracedUtils.promote_to(TracedRNumber{T}, x), size(A)
)
A.mlir_data = bcast.mlir_data
return A
end
Expand Down
Loading

0 comments on commit 0c56d35

Please sign in to comment.