Skip to content

Commit

Permalink
fix: incorrect IR for traced RNGs (#494)
Browse files Browse the repository at this point in the history
* fix: missing scalar indexing check for setindex

* fix: out of region transpose usage (#492)

* fix: missing paths for missing values

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
avik-pal and github-actions[bot] authored Jan 7, 2025
1 parent 46b8c14 commit 41268bd
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 49 deletions.
32 changes: 24 additions & 8 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -614,11 +614,13 @@ function codegen_flatten!(linear_args, result_stores)
# resarg_code = Expr[]

for (i, arg) in enumerate(linear_args)
paths = ((p for p in arg.paths if p[1] == :args)...,)
paths = ((p for p in Reactant.TracedUtils.get_paths(arg) if p[1] == :args)...,)
path = if length(paths) == 1
paths[1]
else
throw("Invalid path duplication $(arg.paths) into $(paths)")
throw(
"Invalid path duplication $(Reactant.TracedUtils.get_paths(arg)) into $(paths)",
)
end

usbuf = Symbol(:usbuf_, i)
Expand All @@ -633,7 +635,7 @@ function codegen_flatten!(linear_args, result_stores)
push!(flatten_code, :($sbuf = XLA.synced_buffer($usbuf)))

# TODO: unused for the time being
# respaths = ((p for p in arg.paths if p[1] == :result || p[1] == :resargs)...,)
# respaths = ((p for p in Reactant.TracedUtils.get_paths(arg) if p[1] == :result || p[1] == :resargs)...,)

# resarg = false
# for respath in respaths
Expand Down Expand Up @@ -688,7 +690,12 @@ function codegen_unflatten!(

# mutate the result stores to point to the correct concrete results
for (concrete_res_name, result) in zip(concretized_res_names, linear_results)
paths = ((p for p in result.paths if p[1] == :result || p[1] == :resargs)...,)
paths = (
(
p for p in Reactant.TracedUtils.get_paths(result) if
p[1] == :result || p[1] == :resargs
)...,
)
for path in paths
if path[1] == :result
unflatcode = :result
Expand Down Expand Up @@ -739,7 +746,7 @@ function codegen_unflatten!(
end
end
else
unflatcode = :($unflatcode.data = $concrete_res_name)
unflatcode = :(traced_setfield!($unflatcode, :data, $concrete_res_name))
end
push!(unflatten_code, unflatcode)
end
Expand All @@ -753,9 +760,18 @@ function codegen_unflatten!(

# if some argument is mutated, change them to point to the correct concrete results
for (result, arg_idx) in preserved_args
for path in result.paths
paths = (
(
p for p in Reactant.TracedUtils.get_paths(result) if
p[1] == :result || p[1] == :resargs || p[1] == :args
)...,
)

for path in paths
arg = linear_args[arg_idx + 1]
argpath = only((p for p in arg.paths if p[1] == :args))
argpath = only((
p for p in Reactant.TracedUtils.get_paths(arg) if p[1] == :args
))

if path[1] == :result
res = :result
Expand All @@ -764,7 +780,7 @@ function codegen_unflatten!(
continue
end
else
@assert path[1] == :resargs || path[1] == :args
@assert path[1] == :resargs || path[1] == :args "Expected :resargs or :args, got $(path[1])"
# We can optimize cases where we set the arg to itself
if path[2:end] == argpath[2:end]
continue
Expand Down
28 changes: 20 additions & 8 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,13 @@ function set_act!(inp, path, reverse, tostore; emptypath=false)
end

#if inp isa Enzyme.Active || !reverse
x.mlir_data = tostore
TracedUtils.set_mlir_data!(x, tostore)
#else
# x.mlir_data = MLIR.IR.result(MLIR.Dialects.stablehlo.add(x.mlir_data, tostore), 1)
#end

if emptypath
x.paths = ()
end
emptypath && TracedUtils.set_paths!(x, ())
return nothing
end

function overload_autodiff(
Expand Down Expand Up @@ -266,22 +265,35 @@ function overload_autodiff(
for a in linear_results
if TracedUtils.has_residx(a)
if needs_primal(CMode)
push!(outtys, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data)))
push!(
outtys,
TracedUtils.transpose_ty(MLIR.IR.type(TracedUtils.get_mlir_data(a))),
)
end
if CMode <: Enzyme.ForwardMode && !(A <: Enzyme.Const)
if width == 1
push!(outtys, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data)))
push!(
outtys,
TracedUtils.transpose_ty(
MLIR.IR.type(TracedUtils.get_mlir_data(a))
),
)
else
push!(
outtys,
TracedUtils.batch_ty(
width, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data))
width,
TracedUtils.transpose_ty(
MLIR.IR.type(TracedUtils.get_mlir_data(a))
),
),
)
end
end
else
push!(outtys, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data)))
push!(
outtys, TracedUtils.transpose_ty(MLIR.IR.type(TracedUtils.get_mlir_data(a)))
)
end
end
for (i, act) in enumerate(activity)
Expand Down
16 changes: 8 additions & 8 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,6 @@ function aos_to_soa(x::AbstractArray{<:TracedRNumber{T}}) where {T}
return Ops.reshape(vcat(x...), size(x)...)
end

include("Ops.jl")
include("TracedUtils.jl")

include("TracedRNumber.jl")
include("TracedRArray.jl")

include("ConcreteRArray.jl")

mutable struct ConcreteRNG <: Random.AbstractRNG
seed::ConcreteRArray{UInt64,1}
const algorithm::String
Expand All @@ -185,6 +177,14 @@ mutable struct TracedRNG <: Random.AbstractRNG
const algorithm::String
end

include("Ops.jl")
include("TracedUtils.jl")

include("TracedRNumber.jl")
include("TracedRArray.jl")

include("ConcreteRArray.jl")

use_overlayed_version(iter) = any(use_overlayed_version, iter)

use_overlayed_version(::TracedRArray) = true
Expand Down
3 changes: 2 additions & 1 deletion src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ end
function maybe_assert_scalar_setindexing(
::TracedRArray{T,N}, ::Vararg{Union{Int,TracedRNumber{Int}},N}
) where {T,N}
return GPUArraysCore.assertscalar("setindex!(::TracedRArray, v, ::Vararg{Int, N})")
GPUArraysCore.assertscalar("setindex!(::TracedRArray, v, ::Vararg{Int, N})")
return nothing
end

maybe_assert_scalar_setindexing(args...) = nothing
Expand Down
49 changes: 28 additions & 21 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using ..Reactant:
OrderedIdDict,
ReactantPrimitive,
Ops
using ReactantCore: MissingTracedValue

materialize_traced_array(x::TracedRArray) = x

Expand All @@ -35,9 +36,16 @@ end

get_mlir_data(x::TracedRNumber) = x.mlir_data
set_mlir_data!(x::TracedRNumber, data) = (x.mlir_data = data; return x)
get_paths(x::TracedRNumber) = x.paths
set_paths!(x::TracedRNumber, paths) = (x.paths = paths; return x)

get_mlir_data(x::TracedRArray) = x.mlir_data
get_mlir_data(x::AnyTracedRArray) = get_mlir_data(materialize_traced_array(x))
get_paths(x::TracedRArray) = x.paths
set_paths!(x::TracedRArray, paths) = (x.paths = paths; return x)

get_paths(x::MissingTracedValue) = x.paths
set_paths!(x::MissingTracedValue, paths) = (x.paths = paths; return x)

function set_mlir_data!(x::TracedRArray, data)
x.mlir_data = data
Expand Down Expand Up @@ -173,11 +181,11 @@ function make_mlir_fn(
result = try
for (i, arg) in enumerate(linear_args)
if construct_function_without_args
arg.mlir_data = args[i].mlir_data
set_mlir_data!(arg, get_mlir_data(args[i]))
else
raw_arg = MLIR.IR.argument(fnbody, i)
row_maj_arg = do_transpose ? transpose_val(raw_arg) : raw_arg
arg.mlir_data = row_maj_arg
set_mlir_data!(arg, row_maj_arg)
end
end

Expand Down Expand Up @@ -210,7 +218,8 @@ function make_mlir_fn(

for (k, v) in seen_results
v isa Reactant.TracedType || continue
(no_args_in_result && length(v.paths) > 0 && v.paths[1][1] == :args) && continue
paths = get_paths(v)
(no_args_in_result && length(paths) > 0 && paths[1][1] == :args) && continue
push!(linear_results, v)
end

Expand All @@ -221,11 +230,11 @@ function make_mlir_fn(
vals = MLIR.IR.Value[]
for res in linear_results
col_maj = if res isa MissingTracedValue
broadcast_to_size(false, ()).mlir_data
get_mlir_data(broadcast_to_size(false, ()))
elseif construct_function_without_args || !do_transpose
res.mlir_data
get_mlir_data(res)
elseif do_transpose
transpose_val(res.mlir_data)
transpose_val(get_mlir_data(res))
end
push!(vals, col_maj)
end
Expand Down Expand Up @@ -299,12 +308,12 @@ function push_val!(ad_inputs, x, path)
for p in path
x = Reactant.Compiler.traced_getfield(x, p)
end
x = x.mlir_data
x = get_mlir_data(x)
return push!(ad_inputs, x)
end

function get_argidx(x)
for path in x.paths
for path in get_paths(x)
if length(path) == 0
continue
end
Expand All @@ -316,7 +325,7 @@ function get_argidx(x)
end

function has_argidx(x)
for path in x.paths
for path in get_paths(x)
if length(path) == 0
continue
end
Expand All @@ -332,15 +341,13 @@ function set!(x, path, tostore; emptypath=false)
x = Reactant.Compiler.traced_getfield(x, p)
end

x.mlir_data = tostore
set_mlir_data!(x, tostore)

if emptypath
x.paths = ()
end
return emptypath && set_paths!(x, ())
end

function get_residx(x)
for path in x.paths
for path in get_paths(x)
if length(path) == 0
continue
end
Expand All @@ -352,7 +359,7 @@ function get_residx(x)
end

function has_residx(x)
for path in x.paths
for path in get_paths(x)
if length(path) == 0
continue
end
Expand Down Expand Up @@ -467,12 +474,12 @@ broadcast_to_size(arg::Number, rsize) = Ops.constant(Base.fill(arg, Tuple(rsize)

function broadcast_to_size(arg::TracedRNumber{T}, rsize) where {T}
length(rsize) == 0 && return arg
return broadcast_to_size_internal(TracedRArray{T,0}((), arg.mlir_data, ()), rsize)
return broadcast_to_size_internal(TracedRArray{T,0}((), get_mlir_data(arg), ()), rsize)
end

function broadcast_to_size(arg::AnyTracedRArray{T,0}, rsize) where {T}
arg = materialize_traced_array(arg)
return broadcast_to_size(TracedRNumber{T}((), arg.mlir_data), rsize)
return broadcast_to_size(TracedRNumber{T}((), get_mlir_data(arg)), rsize)
end

function broadcast_to_size(arg::AnyTracedRArray, rsize)
Expand All @@ -491,21 +498,21 @@ end
@noinline function broadcast_to_size_internal(x::TracedRArray{T}, rsize) where {T}
dims = collect(Int64, 0:(length(size(x)) - 1))

if length(size(MLIR.IR.type(x.mlir_data))) != length(dims)
if length(size(MLIR.IR.type(get_mlir_data(x)))) != length(dims)
@show x
@show arg
@show rsize
@show rsize2
@show dims
end
@assert length(size(MLIR.IR.type(x.mlir_data))) == length(dims)
mlirty = MLIR.IR.type(x.mlir_data)
@assert length(size(MLIR.IR.type(get_mlir_data(x)))) == length(dims)
mlirty = MLIR.IR.type(get_mlir_data(x))

return TracedRArray{T,Int(length(rsize))}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.broadcast_in_dim(
x.mlir_data;
get_mlir_data(x);
result_0=MLIR.IR.TensorType([t for t in rsize], eltype(mlirty)),
broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims),
),
Expand Down
6 changes: 3 additions & 3 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ function make_tracer(
throw("Cannot trace existing trace type")
end
if mode == TracedTrack
prev.paths = (prev.paths..., path)
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
if !haskey(seen, prev)
return seen[prev] = prev
end
Expand Down Expand Up @@ -500,7 +500,7 @@ function make_tracer(
throw("Cannot trace existing trace type")
end
if mode == TracedTrack
prev.paths = (prev.paths..., path)
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
if !haskey(seen, prev)
return seen[prev] = prev
end
Expand Down Expand Up @@ -540,7 +540,7 @@ function make_tracer(
throw("Cannot trace existing trace type")
end
if mode == TracedTrack
prev.paths = (prev.paths..., path)
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
if !haskey(seen, prev)
return seen[prev] = prev
end
Expand Down
3 changes: 3 additions & 0 deletions src/stdlibs/Random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ end
return rng
end

Base.copy(rng::ConcreteRNG) = ConcreteRNG(copy(rng.seed), rng.algorithm)
Base.copy(rng::TracedRNG) = TracedRNG(copy(rng.seed), rng.algorithm)

@noinline ConcreteRNG() = ConcreteRNG(ConcreteRArray(make_seed()))
@noinline ConcreteRNG(seed::ConcreteRArray{UInt64,1}) = ConcreteRNG(seed, "DEFAULT")

Expand Down

0 comments on commit 41268bd

Please sign in to comment.