From 41268bd1f4fff4b5d815438443c5324792b2f703 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 7 Jan 2025 18:20:33 -0500 Subject: [PATCH] fix: incorrect IR for traced RNGs (#494) * fix: missing scalar indexing check for setindex * fix: out of region transpose usage (#492) * fix: missing paths for missing values * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/Compiler.jl | 32 +++++++++++++++++++++------- src/Interpreter.jl | 28 ++++++++++++++++++------- src/Reactant.jl | 16 +++++++------- src/TracedRArray.jl | 3 ++- src/TracedUtils.jl | 49 ++++++++++++++++++++++++------------------- src/Tracing.jl | 6 +++--- src/stdlibs/Random.jl | 3 +++ 7 files changed, 88 insertions(+), 49 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index a5051b3fa..55a0f9b3d 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -614,11 +614,13 @@ function codegen_flatten!(linear_args, result_stores) # resarg_code = Expr[] for (i, arg) in enumerate(linear_args) - paths = ((p for p in arg.paths if p[1] == :args)...,) + paths = ((p for p in Reactant.TracedUtils.get_paths(arg) if p[1] == :args)...,) path = if length(paths) == 1 paths[1] else - throw("Invalid path duplication $(arg.paths) into $(paths)") + throw( + "Invalid path duplication $(Reactant.TracedUtils.get_paths(arg)) into $(paths)", + ) end usbuf = Symbol(:usbuf_, i) @@ -633,7 +635,7 @@ function codegen_flatten!(linear_args, result_stores) push!(flatten_code, :($sbuf = XLA.synced_buffer($usbuf))) # TODO: unused for the time being - # respaths = ((p for p in arg.paths if p[1] == :result || p[1] == :resargs)...,) + # respaths = ((p for p in Reactant.TracedUtils.get_paths(arg) if p[1] == :result || p[1] == :resargs)...,) # resarg = false # for respath in respaths @@ -688,7 +690,12 @@ function codegen_unflatten!( # mutate the result stores to point to the correct concrete results for (concrete_res_name, result) in zip(concretized_res_names, linear_results) - paths = ((p for p in result.paths if p[1] == :result || p[1] == :resargs)...,) + paths = ( + ( + p for p in Reactant.TracedUtils.get_paths(result) if + p[1] == :result || p[1] == :resargs + )..., + ) for path in paths if path[1] == :result unflatcode = :result @@ -739,7 +746,7 @@ function codegen_unflatten!( end end else - unflatcode = :($unflatcode.data = $concrete_res_name) + unflatcode = :(traced_setfield!($unflatcode, :data, $concrete_res_name)) end push!(unflatten_code, unflatcode) end @@ -753,9 +760,18 @@ function codegen_unflatten!( # if some argument is mutated, change them to point to the correct concrete results for (result, arg_idx) in preserved_args - for path in result.paths + paths = ( + ( + p for p in Reactant.TracedUtils.get_paths(result) if + p[1] == :result || p[1] == :resargs || p[1] == :args + )..., + ) + + for path in paths arg = linear_args[arg_idx + 1] - argpath = only((p for p in arg.paths if p[1] == :args)) + argpath = only(( + p for p in Reactant.TracedUtils.get_paths(arg) if p[1] == :args + )) if path[1] == :result res = :result @@ -764,7 +780,7 @@ function codegen_unflatten!( continue end else - @assert path[1] == :resargs || path[1] == :args + @assert path[1] == :resargs || path[1] == :args "Expected :resargs or :args, got $(path[1])" # We can optimize cases where we set the arg to itself if path[2:end] == argpath[2:end] continue diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 06d888345..7065319d7 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -208,14 +208,13 @@ function set_act!(inp, path, reverse, tostore; emptypath=false) end #if inp isa Enzyme.Active || !reverse - x.mlir_data = tostore + TracedUtils.set_mlir_data!(x, tostore) #else # x.mlir_data = MLIR.IR.result(MLIR.Dialects.stablehlo.add(x.mlir_data, tostore), 1) #end - if emptypath - x.paths = () - end + emptypath && TracedUtils.set_paths!(x, ()) + return nothing end function overload_autodiff( @@ -266,22 +265,35 @@ function overload_autodiff( for a in linear_results if TracedUtils.has_residx(a) if needs_primal(CMode) - push!(outtys, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data))) + push!( + outtys, + TracedUtils.transpose_ty(MLIR.IR.type(TracedUtils.get_mlir_data(a))), + ) end if CMode <: Enzyme.ForwardMode && !(A <: Enzyme.Const) if width == 1 - push!(outtys, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data))) + push!( + outtys, + TracedUtils.transpose_ty( + MLIR.IR.type(TracedUtils.get_mlir_data(a)) + ), + ) else push!( outtys, TracedUtils.batch_ty( - width, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data)) + width, + TracedUtils.transpose_ty( + MLIR.IR.type(TracedUtils.get_mlir_data(a)) + ), ), ) end end else - push!(outtys, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data))) + push!( + outtys, TracedUtils.transpose_ty(MLIR.IR.type(TracedUtils.get_mlir_data(a))) + ) end end for (i, act) in enumerate(activity) diff --git a/src/Reactant.jl b/src/Reactant.jl index ce1a86a19..6396bd65c 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -167,14 +167,6 @@ function aos_to_soa(x::AbstractArray{<:TracedRNumber{T}}) where {T} return Ops.reshape(vcat(x...), size(x)...) end -include("Ops.jl") -include("TracedUtils.jl") - -include("TracedRNumber.jl") -include("TracedRArray.jl") - -include("ConcreteRArray.jl") - mutable struct ConcreteRNG <: Random.AbstractRNG seed::ConcreteRArray{UInt64,1} const algorithm::String @@ -185,6 +177,14 @@ mutable struct TracedRNG <: Random.AbstractRNG const algorithm::String end +include("Ops.jl") +include("TracedUtils.jl") + +include("TracedRNumber.jl") +include("TracedRArray.jl") + +include("ConcreteRArray.jl") + use_overlayed_version(iter) = any(use_overlayed_version, iter) use_overlayed_version(::TracedRArray) = true diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 8244bfc24..a7d72b3ca 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -175,7 +175,8 @@ end function maybe_assert_scalar_setindexing( ::TracedRArray{T,N}, ::Vararg{Union{Int,TracedRNumber{Int}},N} ) where {T,N} - return GPUArraysCore.assertscalar("setindex!(::TracedRArray, v, ::Vararg{Int, N})") + GPUArraysCore.assertscalar("setindex!(::TracedRArray, v, ::Vararg{Int, N})") + return nothing end maybe_assert_scalar_setindexing(args...) = nothing diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 6bd29764b..ee9087557 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -16,6 +16,7 @@ using ..Reactant: OrderedIdDict, ReactantPrimitive, Ops +using ReactantCore: MissingTracedValue materialize_traced_array(x::TracedRArray) = x @@ -35,9 +36,16 @@ end get_mlir_data(x::TracedRNumber) = x.mlir_data set_mlir_data!(x::TracedRNumber, data) = (x.mlir_data = data; return x) +get_paths(x::TracedRNumber) = x.paths +set_paths!(x::TracedRNumber, paths) = (x.paths = paths; return x) get_mlir_data(x::TracedRArray) = x.mlir_data get_mlir_data(x::AnyTracedRArray) = get_mlir_data(materialize_traced_array(x)) +get_paths(x::TracedRArray) = x.paths +set_paths!(x::TracedRArray, paths) = (x.paths = paths; return x) + +get_paths(x::MissingTracedValue) = x.paths +set_paths!(x::MissingTracedValue, paths) = (x.paths = paths; return x) function set_mlir_data!(x::TracedRArray, data) x.mlir_data = data @@ -173,11 +181,11 @@ function make_mlir_fn( result = try for (i, arg) in enumerate(linear_args) if construct_function_without_args - arg.mlir_data = args[i].mlir_data + set_mlir_data!(arg, get_mlir_data(args[i])) else raw_arg = MLIR.IR.argument(fnbody, i) row_maj_arg = do_transpose ? transpose_val(raw_arg) : raw_arg - arg.mlir_data = row_maj_arg + set_mlir_data!(arg, row_maj_arg) end end @@ -210,7 +218,8 @@ function make_mlir_fn( for (k, v) in seen_results v isa Reactant.TracedType || continue - (no_args_in_result && length(v.paths) > 0 && v.paths[1][1] == :args) && continue + paths = get_paths(v) + (no_args_in_result && length(paths) > 0 && paths[1][1] == :args) && continue push!(linear_results, v) end @@ -221,11 +230,11 @@ function make_mlir_fn( vals = MLIR.IR.Value[] for res in linear_results col_maj = if res isa MissingTracedValue - broadcast_to_size(false, ()).mlir_data + get_mlir_data(broadcast_to_size(false, ())) elseif construct_function_without_args || !do_transpose - res.mlir_data + get_mlir_data(res) elseif do_transpose - transpose_val(res.mlir_data) + transpose_val(get_mlir_data(res)) end push!(vals, col_maj) end @@ -299,12 +308,12 @@ function push_val!(ad_inputs, x, path) for p in path x = Reactant.Compiler.traced_getfield(x, p) end - x = x.mlir_data + x = get_mlir_data(x) return push!(ad_inputs, x) end function get_argidx(x) - for path in x.paths + for path in get_paths(x) if length(path) == 0 continue end @@ -316,7 +325,7 @@ function get_argidx(x) end function has_argidx(x) - for path in x.paths + for path in get_paths(x) if length(path) == 0 continue end @@ -332,15 +341,13 @@ function set!(x, path, tostore; emptypath=false) x = Reactant.Compiler.traced_getfield(x, p) end - x.mlir_data = tostore + set_mlir_data!(x, tostore) - if emptypath - x.paths = () - end + return emptypath && set_paths!(x, ()) end function get_residx(x) - for path in x.paths + for path in get_paths(x) if length(path) == 0 continue end @@ -352,7 +359,7 @@ function get_residx(x) end function has_residx(x) - for path in x.paths + for path in get_paths(x) if length(path) == 0 continue end @@ -467,12 +474,12 @@ broadcast_to_size(arg::Number, rsize) = Ops.constant(Base.fill(arg, Tuple(rsize) function broadcast_to_size(arg::TracedRNumber{T}, rsize) where {T} length(rsize) == 0 && return arg - return broadcast_to_size_internal(TracedRArray{T,0}((), arg.mlir_data, ()), rsize) + return broadcast_to_size_internal(TracedRArray{T,0}((), get_mlir_data(arg), ()), rsize) end function broadcast_to_size(arg::AnyTracedRArray{T,0}, rsize) where {T} arg = materialize_traced_array(arg) - return broadcast_to_size(TracedRNumber{T}((), arg.mlir_data), rsize) + return broadcast_to_size(TracedRNumber{T}((), get_mlir_data(arg)), rsize) end function broadcast_to_size(arg::AnyTracedRArray, rsize) @@ -491,21 +498,21 @@ end @noinline function broadcast_to_size_internal(x::TracedRArray{T}, rsize) where {T} dims = collect(Int64, 0:(length(size(x)) - 1)) - if length(size(MLIR.IR.type(x.mlir_data))) != length(dims) + if length(size(MLIR.IR.type(get_mlir_data(x)))) != length(dims) @show x @show arg @show rsize @show rsize2 @show dims end - @assert length(size(MLIR.IR.type(x.mlir_data))) == length(dims) - mlirty = MLIR.IR.type(x.mlir_data) + @assert length(size(MLIR.IR.type(get_mlir_data(x)))) == length(dims) + mlirty = MLIR.IR.type(get_mlir_data(x)) return TracedRArray{T,Int(length(rsize))}( (), MLIR.IR.result( MLIR.Dialects.stablehlo.broadcast_in_dim( - x.mlir_data; + get_mlir_data(x); result_0=MLIR.IR.TensorType([t for t in rsize], eltype(mlirty)), broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims), ), diff --git a/src/Tracing.jl b/src/Tracing.jl index f1e224d61..e00fdcb00 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -454,7 +454,7 @@ function make_tracer( throw("Cannot trace existing trace type") end if mode == TracedTrack - prev.paths = (prev.paths..., path) + TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path)) if !haskey(seen, prev) return seen[prev] = prev end @@ -500,7 +500,7 @@ function make_tracer( throw("Cannot trace existing trace type") end if mode == TracedTrack - prev.paths = (prev.paths..., path) + TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path)) if !haskey(seen, prev) return seen[prev] = prev end @@ -540,7 +540,7 @@ function make_tracer( throw("Cannot trace existing trace type") end if mode == TracedTrack - prev.paths = (prev.paths..., path) + TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path)) if !haskey(seen, prev) return seen[prev] = prev end diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index 1b6686d68..617f1fac1 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -66,6 +66,9 @@ end return rng end +Base.copy(rng::ConcreteRNG) = ConcreteRNG(copy(rng.seed), rng.algorithm) +Base.copy(rng::TracedRNG) = TracedRNG(copy(rng.seed), rng.algorithm) + @noinline ConcreteRNG() = ConcreteRNG(ConcreteRArray(make_seed())) @noinline ConcreteRNG(seed::ConcreteRArray{UInt64,1}) = ConcreteRNG(seed, "DEFAULT")