Skip to content

Commit

Permalink
implement @trace for (#255)
Browse files Browse the repository at this point in the history
* implement `@trace` for

* Apply suggestions from code review

Co-authored-by: Sergio Sánchez Ramírez <[email protected]>

* Allow using induction variable

* floating point ranges

* lightspeed

* import LinearAlgebra

* generate 0 to N loop

* ir test

* clean iter

* Revert "Apply suggestions from code review"

This reverts commit 079ed4a.

* remove precompilation warning

* format and fix

* loop ranges as traced numbers

* fmt and add non unit step test

* fmt2

* integers

---------

Co-authored-by: Sergio Sánchez Ramírez <[email protected]>
  • Loading branch information
Pangoraw and mofeing authored Nov 12, 2024
1 parent f2a91bf commit 9e8eec0
Show file tree
Hide file tree
Showing 6 changed files with 299 additions and 18 deletions.
104 changes: 97 additions & 7 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ if no traced value is found inside the expression, then there is no overhead.
- `if` conditions (with `elseif` and other niceties) (`@trace if ...`)
- `if` statements with a preceeding assignment (`@trace a = if ...`) (note the positioning
of the macro needs to be before the assignment and not before the `if`)
- `for` statements with a single induction variable iterating over a syntactic `StepRange` of integers.
## Special Considerations
Expand Down Expand Up @@ -81,6 +82,15 @@ end
This will not compile since `y` is a `Float32` in one branch and a `Float64` in the other.
You need to ensure that all branches have the same type.
Another example is the following for loop which changes the type of `x` between iterations.
```julia
x = ... # ConcreteRArray{Int64, 1}
for i in 1f0:0.5f0:10f0
x = x .+ i # ConcreteRArray{Float32, 1}
end
```
### Certain Symbols are Reserved
Symbols like $(SPECIAL_SYMBOLS) are not allowed as variables in `@trace` expressions. While certain cases might work but these are not guaranteed to work. For
Expand All @@ -100,15 +110,84 @@ end
"""
macro trace(expr)
expr = macroexpand(__module__, expr)
if expr.head == :(=)
if expr.args[2] isa Expr && expr.args[2].head == :if
if Meta.isexpr(expr, :(=))
if Meta.isexpr(expr.args[2], :if)
return esc(trace_if_with_returns(__module__, expr))
end
end
expr.head == :if && return esc(trace_if(__module__, expr))
Meta.isexpr(expr, :if) && return esc(trace_if(__module__, expr))
Meta.isexpr(expr, :for) && return (esc(trace_for(__module__, expr)))
return error("Only `if-elseif-else` blocks are currently supported by `@trace`")
end

function trace_for(mod, expr)
Meta.isexpr(expr, :for, 2) || error("expected for expr")
assign, body = expr.args

error_if_any_control_flow(body)
if !Meta.isexpr(assign, :(=)) ||
!(assign.args[1] isa Symbol) ||
!Meta.isexpr(assign.args[2], :call) ||
assign.args[2].args[1] !== :(:)
error("malformed for loop assignment")
end

induction, range = assign.args

counter = gensym(:i)
num_iters = gensym(:num_iters)

start = range.args[2]
step = length(range.args) == 3 ? 1 : range.args[3]
limit = range.args[end]

body_symbols = ExpressionExplorer.compute_symbols_state(
quote
$(Expr(:local, assign))
$body
end,
)

external_syms = body_symbols.assignments body_symbols.references
filter!((SPECIAL_SYMBOLS), external_syms)

all_syms = Expr(:tuple, counter, external_syms...)
args_init = Expr(
:tuple, :(Reactant.promote_to(Reactant.TracedRNumber{Int}, 0)), external_syms...
)

reactant_code_block = quote
let args = $(args_init)
cond_fn =
$(all_syms) -> begin
local num_iters = div($limit - $start, $step, RoundDown)
local num_iters = Reactant.promote_to(
Reactant.TracedRNumber{Int64}, num_iters
)
$counter < num_iters + 1
end
body_fn =
$(all_syms) -> begin
local step_ = $step
local start_ = $start
local $induction = start_ + $counter * step_
$body
($counter + 1, $(all_syms.args[(begin + 1):end]...))
end

$(ReactantCore).traced_while(cond_fn, body_fn, args)
end
end

return quote
if any($(is_traced), $(Expr(:tuple, all_syms.args[(begin + 1):end]...)))
$(reactant_code_block)
else
$(expr)
end
end
end

# ... = if ... style expressions
function trace_if_with_returns(mod, expr)
new_expr, _, all_check_vars = trace_if(
Expand All @@ -128,7 +207,7 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0)
original_expr = expr

if depth == 0
error_if_return(expr)
error_if_any_control_flow(expr)

counter = 0
expr = MacroTools.prewalk(expr) do x
Expand Down Expand Up @@ -285,6 +364,13 @@ function traced_if(cond, true_fn::TFn, false_fn::FFn, args) where {TFn,FFn}
return cond ? true_fn(args) : false_fn(args)
end

function traced_while(cond_fn, body_fn, args) where {CFn,BFn}
while cond_fn(args...)
args = body_fn(args...)
end
return args
end

function cleanup_expr_to_avoid_boxing(expr, prepend::Symbol, all_vars)
return MacroTools.postwalk(expr) do x
if x isa Symbol && x all_vars
Expand All @@ -294,10 +380,14 @@ function cleanup_expr_to_avoid_boxing(expr, prepend::Symbol, all_vars)
end
end

function error_if_return(expr)
const CONTROL_FLOW_EXPRS = [:return, :break, :continue, :symbolicgoto]

function error_if_any_control_flow(expr)
return MacroTools.postwalk(expr) do x
if x isa Expr && x.head == :return
error("Cannot use @trace on a block that contains a return statement")
for head in CONTROL_FLOW_EXPRS
if Meta.isexpr(x, head)
error("Cannot use @trace on a block that contains a $head statement")
end
end
return x
end
Expand Down
4 changes: 2 additions & 2 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -417,15 +417,15 @@ macro code_hlo(options, maybe_call=nothing)
f = $(fname)
args = $(Expr(:vect, call.args[2:end]...))
mode = first($(compile_mlir)(f, args; optimize=options.optimize))
return mode
mode
end
elseif Meta.isexpr(call, :(.), 2) && Meta.isexpr(call.args[2], :tuple)
quote
options = $(options)
f = Base.Broadcast.BroadcastFunction($(call.args[1]))
args = $(call.args[2:end]...)
mode = first($(compile_mlir)(f, args; optimize=options.optimize))
return mode
mode
end
else
error("Invalid function call: $(call)")
Expand Down
63 changes: 62 additions & 1 deletion src/ControlFlow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,70 @@ function ReactantCore.traced_if(
end
end

function get_region_removing_missing_values(compiled_fn, insertions)
function ReactantCore.traced_while(
cond_fn::CFn, body_fn::BFn, args
) where {CFn<:Function,BFn<:Function}
# TODO: detect and prevent mutation within the condition

# We promote all incoming args (is there a better way to do this?)
traced_args = [
if v isa Number && !(v isa TracedType)
Reactant.promote_to(TracedRNumber{typeof(v)}, v)
else
v
end for v in args
]

(_, cond_fn_compiled, cond_fn_results, _, _, _, _, in_tys, cond_fn_linear_results) = Reactant.make_mlir_fn(
cond_fn,
traced_args,
(),
string(gensym("cond_fn")),
false;
no_args_in_result=true,
return_dialect=:stablehlo,
do_transpose=false,
)

(_, body_fn_compiled, body_fn_results, _, _, _, _, _, body_fn_linear_results) = Reactant.make_mlir_fn(
body_fn,
traced_args,
(),
string(gensym("body_fn")),
false;
no_args_in_result=true,
return_dialect=:stablehlo,
do_transpose=false,
)

cond_reg = take_region(cond_fn_compiled)
body_reg = take_region(body_fn_compiled)

MLIR.IR.rmfromparent!(cond_fn_compiled)
MLIR.IR.rmfromparent!(body_fn_compiled)

result_0 = in_tys

operands = MLIR.IR.Value[v.mlir_data for v in traced_args]

while_compiled = MLIR.Dialects.stablehlo.while_(
operands; result_0, cond=cond_reg, body=body_reg
)

return map(enumerate(traced_args)) do (i, res)
res.mlir_data = MLIR.IR.result(while_compiled, i)
return res
end
end

function take_region(compiled_fn)
region = MLIR.IR.Region()
MLIR.API.mlirRegionTakeBody(region, MLIR.API.mlirOperationGetRegion(compiled_fn, 0))
return region
end

function get_region_removing_missing_values(compiled_fn, insertions)
region = take_region(compiled_fn)
block = MLIR.IR.Block(MLIR.API.mlirRegionGetFirstBlock(region), false)
return_op = MLIR.IR.terminator(block)
for (i, rt) in insertions
Expand Down
14 changes: 14 additions & 0 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,20 @@ for (jlop, hloop) in (
end
end

function Base.div(
@nospecialize(lhs::TracedRNumber{T}), rhs, ::typeof(RoundDown)
) where {T<:Integer}
return TracedRNumber{T}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.divide(
lhs.mlir_data, promote_to(TracedRNumber{T}, rhs).mlir_data
),
1,
),
)
end

for (jlop, hloop, hlocomp) in (
(:(Base.:(==)), :compare, "EQ"),
(:(Base.:(!=)), :compare, "NE"),
Expand Down
20 changes: 12 additions & 8 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ function make_mlir_fn(
return_dialect=:func,
no_args_in_result::Bool=false,
construct_function_without_args::Bool=false,
do_transpose=true,
)
if sizeof(typeof(f)) != 0 || f isa BroadcastFunction
return (
Expand All @@ -57,6 +58,7 @@ function make_mlir_fn(
return_dialect,
no_args_in_result,
construct_function_without_args,
do_transpose,
)[2:end]...,
)
end
Expand All @@ -82,8 +84,10 @@ function make_mlir_fn(

in_tys = if toscalar
[MLIR.IR.TensorType((), MLIR.IR.Type(eltype(arg))) for arg in linear_args]
else
elseif do_transpose
[transpose_ty(mlir_type(arg)) for arg in linear_args]
else
[mlir_type(arg) for arg in linear_args]
end

sym_visibility = nothing
Expand Down Expand Up @@ -115,7 +119,7 @@ function make_mlir_fn(
arg.mlir_data = args[i].mlir_data
else
raw_arg = MLIR.IR.argument(fnbody, i)
row_maj_arg = transpose_val(raw_arg)
row_maj_arg = do_transpose ? transpose_val(raw_arg) : raw_arg
arg.mlir_data = row_maj_arg
end
end
Expand Down Expand Up @@ -180,12 +184,12 @@ function make_mlir_fn(
ret = MLIR.IR.block!(fnbody) do
vals = MLIR.IR.Value[]
for res in linear_results
if res isa MissingTracedValue
col_maj = broadcast_to_size(false, ()).mlir_data
elseif construct_function_without_args
col_maj = res.mlir_data
else
col_maj = transpose_val(res.mlir_data)
col_maj = if res isa MissingTracedValue
broadcast_to_size(false, ()).mlir_data
elseif construct_function_without_args || !do_transpose
res.mlir_data
elseif do_transpose
transpose_val(res.mlir_data)
end
push!(vals, col_maj)
end
Expand Down
Loading

0 comments on commit 9e8eec0

Please sign in to comment.