From 9e8eec051c61c4c122c694ac2fb68b1598968cc0 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Tue, 12 Nov 2024 17:52:20 +0100 Subject: [PATCH] implement `@trace` for (#255) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * implement `@trace` for * Apply suggestions from code review Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com> * Allow using induction variable * floating point ranges * lightspeed * import LinearAlgebra * generate 0 to N loop * ir test * clean iter * Revert "Apply suggestions from code review" This reverts commit 079ed4af9b25427cf486746343334bd5ce31a99e. * remove precompilation warning * format and fix * loop ranges as traced numbers * fmt and add non unit step test * fmt2 * integers --------- Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com> --- lib/ReactantCore/src/ReactantCore.jl | 104 +++++++++++++++++++++++-- src/Compiler.jl | 4 +- src/ControlFlow.jl | 63 ++++++++++++++- src/TracedRNumber.jl | 14 ++++ src/utils.jl | 20 +++-- test/control_flow.jl | 112 +++++++++++++++++++++++++++ 6 files changed, 299 insertions(+), 18 deletions(-) diff --git a/lib/ReactantCore/src/ReactantCore.jl b/lib/ReactantCore/src/ReactantCore.jl index 22ec9f3b9..9f826536d 100644 --- a/lib/ReactantCore/src/ReactantCore.jl +++ b/lib/ReactantCore/src/ReactantCore.jl @@ -31,6 +31,7 @@ if no traced value is found inside the expression, then there is no overhead. - `if` conditions (with `elseif` and other niceties) (`@trace if ...`) - `if` statements with a preceeding assignment (`@trace a = if ...`) (note the positioning of the macro needs to be before the assignment and not before the `if`) +- `for` statements with a single induction variable iterating over a syntactic `StepRange` of integers. ## Special Considerations @@ -81,6 +82,15 @@ end This will not compile since `y` is a `Float32` in one branch and a `Float64` in the other. You need to ensure that all branches have the same type. +Another example is the following for loop which changes the type of `x` between iterations. + +```julia +x = ... # ConcreteRArray{Int64, 1} +for i in 1f0:0.5f0:10f0 + x = x .+ i # ConcreteRArray{Float32, 1} +end +``` + ### Certain Symbols are Reserved Symbols like $(SPECIAL_SYMBOLS) are not allowed as variables in `@trace` expressions. While certain cases might work but these are not guaranteed to work. For @@ -100,15 +110,84 @@ end """ macro trace(expr) expr = macroexpand(__module__, expr) - if expr.head == :(=) - if expr.args[2] isa Expr && expr.args[2].head == :if + if Meta.isexpr(expr, :(=)) + if Meta.isexpr(expr.args[2], :if) return esc(trace_if_with_returns(__module__, expr)) end end - expr.head == :if && return esc(trace_if(__module__, expr)) + Meta.isexpr(expr, :if) && return esc(trace_if(__module__, expr)) + Meta.isexpr(expr, :for) && return (esc(trace_for(__module__, expr))) return error("Only `if-elseif-else` blocks are currently supported by `@trace`") end +function trace_for(mod, expr) + Meta.isexpr(expr, :for, 2) || error("expected for expr") + assign, body = expr.args + + error_if_any_control_flow(body) + if !Meta.isexpr(assign, :(=)) || + !(assign.args[1] isa Symbol) || + !Meta.isexpr(assign.args[2], :call) || + assign.args[2].args[1] !== :(:) + error("malformed for loop assignment") + end + + induction, range = assign.args + + counter = gensym(:i) + num_iters = gensym(:num_iters) + + start = range.args[2] + step = length(range.args) == 3 ? 1 : range.args[3] + limit = range.args[end] + + body_symbols = ExpressionExplorer.compute_symbols_state( + quote + $(Expr(:local, assign)) + $body + end, + ) + + external_syms = body_symbols.assignments ∪ body_symbols.references + filter!(∉(SPECIAL_SYMBOLS), external_syms) + + all_syms = Expr(:tuple, counter, external_syms...) + args_init = Expr( + :tuple, :(Reactant.promote_to(Reactant.TracedRNumber{Int}, 0)), external_syms... + ) + + reactant_code_block = quote + let args = $(args_init) + cond_fn = + $(all_syms) -> begin + local num_iters = div($limit - $start, $step, RoundDown) + local num_iters = Reactant.promote_to( + Reactant.TracedRNumber{Int64}, num_iters + ) + $counter < num_iters + 1 + end + body_fn = + $(all_syms) -> begin + local step_ = $step + local start_ = $start + local $induction = start_ + $counter * step_ + $body + ($counter + 1, $(all_syms.args[(begin + 1):end]...)) + end + + $(ReactantCore).traced_while(cond_fn, body_fn, args) + end + end + + return quote + if any($(is_traced), $(Expr(:tuple, all_syms.args[(begin + 1):end]...))) + $(reactant_code_block) + else + $(expr) + end + end +end + # ... = if ... style expressions function trace_if_with_returns(mod, expr) new_expr, _, all_check_vars = trace_if( @@ -128,7 +207,7 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0) original_expr = expr if depth == 0 - error_if_return(expr) + error_if_any_control_flow(expr) counter = 0 expr = MacroTools.prewalk(expr) do x @@ -285,6 +364,13 @@ function traced_if(cond, true_fn::TFn, false_fn::FFn, args) where {TFn,FFn} return cond ? true_fn(args) : false_fn(args) end +function traced_while(cond_fn, body_fn, args) where {CFn,BFn} + while cond_fn(args...) + args = body_fn(args...) + end + return args +end + function cleanup_expr_to_avoid_boxing(expr, prepend::Symbol, all_vars) return MacroTools.postwalk(expr) do x if x isa Symbol && x ∈ all_vars @@ -294,10 +380,14 @@ function cleanup_expr_to_avoid_boxing(expr, prepend::Symbol, all_vars) end end -function error_if_return(expr) +const CONTROL_FLOW_EXPRS = [:return, :break, :continue, :symbolicgoto] + +function error_if_any_control_flow(expr) return MacroTools.postwalk(expr) do x - if x isa Expr && x.head == :return - error("Cannot use @trace on a block that contains a return statement") + for head in CONTROL_FLOW_EXPRS + if Meta.isexpr(x, head) + error("Cannot use @trace on a block that contains a $head statement") + end end return x end diff --git a/src/Compiler.jl b/src/Compiler.jl index 38677e478..2b119db92 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -417,7 +417,7 @@ macro code_hlo(options, maybe_call=nothing) f = $(fname) args = $(Expr(:vect, call.args[2:end]...)) mode = first($(compile_mlir)(f, args; optimize=options.optimize)) - return mode + mode end elseif Meta.isexpr(call, :(.), 2) && Meta.isexpr(call.args[2], :tuple) quote @@ -425,7 +425,7 @@ macro code_hlo(options, maybe_call=nothing) f = Base.Broadcast.BroadcastFunction($(call.args[1])) args = $(call.args[2:end]...) mode = first($(compile_mlir)(f, args; optimize=options.optimize)) - return mode + mode end else error("Invalid function call: $(call)") diff --git a/src/ControlFlow.jl b/src/ControlFlow.jl index 98c9d9469..3b30c4cb6 100644 --- a/src/ControlFlow.jl +++ b/src/ControlFlow.jl @@ -74,9 +74,70 @@ function ReactantCore.traced_if( end end -function get_region_removing_missing_values(compiled_fn, insertions) +function ReactantCore.traced_while( + cond_fn::CFn, body_fn::BFn, args +) where {CFn<:Function,BFn<:Function} + # TODO: detect and prevent mutation within the condition + + # We promote all incoming args (is there a better way to do this?) + traced_args = [ + if v isa Number && !(v isa TracedType) + Reactant.promote_to(TracedRNumber{typeof(v)}, v) + else + v + end for v in args + ] + + (_, cond_fn_compiled, cond_fn_results, _, _, _, _, in_tys, cond_fn_linear_results) = Reactant.make_mlir_fn( + cond_fn, + traced_args, + (), + string(gensym("cond_fn")), + false; + no_args_in_result=true, + return_dialect=:stablehlo, + do_transpose=false, + ) + + (_, body_fn_compiled, body_fn_results, _, _, _, _, _, body_fn_linear_results) = Reactant.make_mlir_fn( + body_fn, + traced_args, + (), + string(gensym("body_fn")), + false; + no_args_in_result=true, + return_dialect=:stablehlo, + do_transpose=false, + ) + + cond_reg = take_region(cond_fn_compiled) + body_reg = take_region(body_fn_compiled) + + MLIR.IR.rmfromparent!(cond_fn_compiled) + MLIR.IR.rmfromparent!(body_fn_compiled) + + result_0 = in_tys + + operands = MLIR.IR.Value[v.mlir_data for v in traced_args] + + while_compiled = MLIR.Dialects.stablehlo.while_( + operands; result_0, cond=cond_reg, body=body_reg + ) + + return map(enumerate(traced_args)) do (i, res) + res.mlir_data = MLIR.IR.result(while_compiled, i) + return res + end +end + +function take_region(compiled_fn) region = MLIR.IR.Region() MLIR.API.mlirRegionTakeBody(region, MLIR.API.mlirOperationGetRegion(compiled_fn, 0)) + return region +end + +function get_region_removing_missing_values(compiled_fn, insertions) + region = take_region(compiled_fn) block = MLIR.IR.Block(MLIR.API.mlirRegionGetFirstBlock(region), false) return_op = MLIR.IR.terminator(block) for (i, rt) in insertions diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 4ddb02131..9a1c1725e 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -124,6 +124,20 @@ for (jlop, hloop) in ( end end +function Base.div( + @nospecialize(lhs::TracedRNumber{T}), rhs, ::typeof(RoundDown) +) where {T<:Integer} + return TracedRNumber{T}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.divide( + lhs.mlir_data, promote_to(TracedRNumber{T}, rhs).mlir_data + ), + 1, + ), + ) +end + for (jlop, hloop, hlocomp) in ( (:(Base.:(==)), :compare, "EQ"), (:(Base.:(!=)), :compare, "NE"), diff --git a/src/utils.jl b/src/utils.jl index 46d94e75b..deefb3503 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -43,6 +43,7 @@ function make_mlir_fn( return_dialect=:func, no_args_in_result::Bool=false, construct_function_without_args::Bool=false, + do_transpose=true, ) if sizeof(typeof(f)) != 0 || f isa BroadcastFunction return ( @@ -57,6 +58,7 @@ function make_mlir_fn( return_dialect, no_args_in_result, construct_function_without_args, + do_transpose, )[2:end]..., ) end @@ -82,8 +84,10 @@ function make_mlir_fn( in_tys = if toscalar [MLIR.IR.TensorType((), MLIR.IR.Type(eltype(arg))) for arg in linear_args] - else + elseif do_transpose [transpose_ty(mlir_type(arg)) for arg in linear_args] + else + [mlir_type(arg) for arg in linear_args] end sym_visibility = nothing @@ -115,7 +119,7 @@ function make_mlir_fn( arg.mlir_data = args[i].mlir_data else raw_arg = MLIR.IR.argument(fnbody, i) - row_maj_arg = transpose_val(raw_arg) + row_maj_arg = do_transpose ? transpose_val(raw_arg) : raw_arg arg.mlir_data = row_maj_arg end end @@ -180,12 +184,12 @@ function make_mlir_fn( ret = MLIR.IR.block!(fnbody) do vals = MLIR.IR.Value[] for res in linear_results - if res isa MissingTracedValue - col_maj = broadcast_to_size(false, ()).mlir_data - elseif construct_function_without_args - col_maj = res.mlir_data - else - col_maj = transpose_val(res.mlir_data) + col_maj = if res isa MissingTracedValue + broadcast_to_size(false, ()).mlir_data + elseif construct_function_without_args || !do_transpose + res.mlir_data + elseif do_transpose + transpose_val(res.mlir_data) end push!(vals, col_maj) end diff --git a/test/control_flow.jl b/test/control_flow.jl index f17623dd9..1254e0c00 100644 --- a/test/control_flow.jl +++ b/test/control_flow.jl @@ -1,4 +1,5 @@ using Reactant, Test +using LinearAlgebra function condition1(x) y = sum(x) @@ -453,3 +454,114 @@ end @test @jit(condition12_compile_test(x_ra, y_ra, z_ra)) ≈ condition12_compile_test(x, y, z) end + +function for_with_step(x) + @trace for i in 10:3:22 + x[i] = i * i + end + return x +end + +@testset "for: for with step" begin + x = rand(1:100, 22) + x_ra = Reactant.to_rarray(x) + + @test @jit(for_with_step(x_ra)) == for_with_step(x) +end + +function nnorm(x, n) + @trace for i in 1:n + x = x * i ./ sum(x) + end + return x +end + +@testset "for: induction" begin + x = randn(Float32, 10) + x_ra = Reactant.to_rarray(x) + + n = 10 + + @test @jit(nnorm(x_ra, n)) ≈ nnorm(x, n) +end + +function sinkhorn(μ, ν, C) + λ = eltype(C)(0.8) + K = @. exp(-C / λ) + + u = fill!(similar(μ), one(eltype(μ))) + v = similar(ν) + + @trace for _ in 1:10 + v = ν ./ (K' * u) + u = μ ./ (K * v) + end + + return Diagonal(u) * K * Diagonal(v) +end + +@testset "for: sinkhorn" begin + Nμ = 10 + Nν = 5 + + μ = ones(Float32, Nμ) ./ Nμ + ν = ones(Float32, Nν) ./ Nν + C = randn(Float32, Nμ, Nν) + + μ_ra = Reactant.to_rarray(μ) + ν_ra = Reactant.to_rarray(ν) + C_ra = Reactant.to_rarray(C) + + @test @jit(sinkhorn(μ_ra, ν_ra, C_ra)) ≈ sinkhorn(μ, ν, C) +end + +@testset "for: forbidden syntax" begin + @test_throws "break" @eval function f_with_break() + @trace for i in 1:10 + break + end + end + + @test_throws "continue" @eval function f_with_continue() + @trace for i in 1:10 + continue + end + end + + @test_throws "return" @eval function f_with_return() + @trace for i in 1:10 + return nothing + end + end +end + +function cumsum!(x) + v = zero(eltype(x)) + @trace for i in 1:length(x) + v += x[i] + x[i] = v + end + return x +end + +@testset "for: mutation within loop" begin + x = rand(1:100, 10) + x_ra = Reactant.to_rarray(x) + + @test @jit(cumsum!(x_ra)) == cumsum!(x) +end + +function for_ref_outer(x) + i = sum(x) + @trace for i in 1:length(x) + x .+= i + end + return x / i +end + +@testset "for: outer reference" begin + x = randn(Float64, 10) + x_ra = Reactant.to_rarray(x) + + @test @jit(for_ref_outer(x_ra)) ≈ for_ref_outer(x) +end