Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: tracing Random.jl functionality correctly #363

Merged
merged 22 commits into from
Dec 18, 2024
Merged

Conversation

avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Dec 11, 2024

TODOs

  • Tests
  • Overlay seems to act weird. Can't switch to non-reactant interpreter nicely Using a workaround for now
  • Fix documentation build
  • Overlay Random123 to generate specific RNGs (Threefry and Philox)
  • Overlay all random number generators
    • construct TracedRNG during tracing

@avik-pal
Copy link
Collaborator Author

avik-pal commented Dec 11, 2024

julia> using Reactant, Random

julia> fn() = randn(Random.default_rng(), 2, 3)
fn (generic function with 1 method)

julia> @code_hlo optimize = false fn()
module {
  func.func @main() -> tensor<3x2xf64> {
    %c = stablehlo.constant dense<[9454987348304227925, 11257230962712577529]> : tensor<2xui64>
    %output_state, %output = stablehlo.rng_bit_generator %c, algorithm =  DEFAULT : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x3xui64>)
    %0 = stablehlo.convert %output : (tensor<2x3xui64>) -> tensor<2x3xf64>
    %cst = stablehlo.constant dense<1.8446744073709552E+19> : tensor<2x3xf64>
    %1 = stablehlo.divide %0, %cst : tensor<2x3xf64>
    %cst_0 = stablehlo.constant dense<2.000000e+00> : tensor<2x3xf64>
    %2 = stablehlo.multiply %1, %cst_0 : tensor<2x3xf64>
    %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<2x3xf64>
    %3 = stablehlo.subtract %2, %cst_1 : tensor<2x3xf64>
    %4 = chlo.erf_inv %3 : tensor<2x3xf64> -> tensor<2x3xf64>
    %cst_2 = stablehlo.constant dense<1.4142135623730951> : tensor<2x3xf64>
    %5 = stablehlo.multiply %4, %cst_2 : tensor<2x3xf64>
    %6 = stablehlo.transpose %5, dims = [1, 0] : (tensor<2x3xf64>) -> tensor<3x2xf64>
    return %6 : tensor<3x2xf64>
  }
}

julia> @code_hlo fn()
module {
  func.func @main() -> tensor<3x2xf64> {
    %cst = stablehlo.constant dense<1.4142135623730951> : tensor<2x3xf64>
    %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<2x3xf64>
    %cst_1 = stablehlo.constant dense<2.000000e+00> : tensor<2x3xf64>
    %cst_2 = stablehlo.constant dense<1.8446744073709552E+19> : tensor<2x3xf64>
    %c = stablehlo.constant dense<[17523564455668573441, 5342821220909967229]> : tensor<2xui64>
    %output_state, %output = stablehlo.rng_bit_generator %c, algorithm =  DEFAULT : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x3xui64>)
    %0 = stablehlo.convert %output : (tensor<2x3xui64>) -> tensor<2x3xf64>
    %1 = stablehlo.divide %0, %cst_2 : tensor<2x3xf64>
    %2 = stablehlo.multiply %1, %cst_1 : tensor<2x3xf64>
    %3 = stablehlo.subtract %2, %cst_0 : tensor<2x3xf64>
    %4 = chlo.erf_inv %3 : tensor<2x3xf64> -> tensor<2x3xf64>
    %5 = stablehlo.multiply %4, %cst : tensor<2x3xf64>
    %6 = stablehlo.transpose %5, dims = [1, 0] : (tensor<2x3xf64>) -> tensor<3x2xf64>
    return %6 : tensor<3x2xf64>
  }
}

@avik-pal
Copy link
Collaborator Author

This is kind of working now. Can I get an initial review?

src/Interpreter.jl Outdated Show resolved Hide resolved
Copy link
Member

@wsmoses wsmoses left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me but should wait until the interpreter lands (which will simplify the override and can just use the macro)

test/ops.jl Show resolved Hide resolved
@avik-pal avik-pal force-pushed the ap/random_numbers branch 2 times, most recently from dd1f9e6 to e99089e Compare December 15, 2024 03:48
@avik-pal
Copy link
Collaborator Author

The overlay mechanism fixed the previous unreachable instruction issue 🎉

@avik-pal avik-pal linked an issue Dec 15, 2024 that may be closed by this pull request
src/Ops.jl Outdated Show resolved Hide resolved
@avik-pal avik-pal force-pushed the ap/random_numbers branch 3 times, most recently from 8af045f to 22f6810 Compare December 17, 2024 04:48
src/Overlay.jl Outdated
@warn "Directly writing to an array using Random.jl functions inside \
ReactantInterpreter will generate a constant array in the IR. Use with \
caution." maxlog = 1
return Random.$(randfun!)(rng, A)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wsmoses my understanding was that this should call the non-overlayed version. But here I get

┌ Warning: Directly writing to an array using Random.jl functions inside ReactantInterpreter will generate a constant array in the IR. Use with caution.
└ @ Reactant /mnt/software/lux/Reactant.jl/src/Overlay.jl:84
Unreachable reached at 0x7a64877b6b16

[1417783] signal 4 (2): Illegal instruction
in expression starting at REPL[7]:1
fn2 at ./REPL[6]:4 [inlined]
opaque closure at ./<missing>:0
unknown function (ip: 0x7a64877b6bff)
fn2 at ./REPL[6]:2 [inlined]
call_with_reactant at /mnt/software/lux/Reactant.jl/src/utils.jl:0
#8 at /mnt/software/lux/Reactant.jl/src/TracedUtils.jl:210
block! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
unknown function (ip: 0x7a64877b6316)
#make_mlir_fn#1 at /mnt/software/lux/Reactant.jl/src/TracedUtils.jl:197
make_mlir_fn at /mnt/software/lux/Reactant.jl/src/TracedUtils.jl:117 [inlined]
#10 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:295 [inlined]
block! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
#9 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:294 [inlined]
mmodule! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Module.jl:92
unknown function (ip: 0x7a64877b5976)
#compile_mlir!#8 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:291
compile_mlir! at /mnt/software/lux/Reactant.jl/src/Compiler.jl:290 [inlined]
#6 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:285 [inlined]
context! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:76
unknown function (ip: 0x7a64877b2a76)
#compile_mlir#5 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:283
compile_mlir at /mnt/software/lux/Reactant.jl/src/Compiler.jl:280
unknown function (ip: 0x7a64877b0a66)
jl_apply at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
do_call at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:126
eval_value at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:223
eval_stmt_value at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:174 [inlined]
eval_body at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:663
jl_interpret_toplevel_thunk at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:821
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:943
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:886
eval_body at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:625
eval_body at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:539
jl_interpret_toplevel_thunk at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:821
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:943
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:886
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:886
ijl_toplevel_eval_in at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:994
eval at ./boot.jl:430 [inlined]
eval_user_input at /home/avikpal/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:245
repl_backend_loop at /home/avikpal/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:342
#start_repl_backend#59 at /home/avikpal/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:327
start_repl_backend at /home/avikpal/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:324
#run_repl#72 at /home/avikpal/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:483
run_repl at /home/avikpal/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:469
jfptr_run_repl_10705 at /mnt/.julia/compiled/v1.11/REPL/u0gqU_FGbh7.so (unknown line)
#1150 at ./client.jl:446
jfptr_YY.1150_15174 at /mnt/.julia/compiled/v1.11/REPL/u0gqU_FGbh7.so (unknown line)
jl_apply at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
jl_f__call_latest at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/builtins.c:875
#invokelatest#2 at ./essentials.jl:1055 [inlined]
invokelatest at ./essentials.jl:1052 [inlined]
run_main_repl at ./client.jl:430
repl_main at ./client.jl:567 [inlined]
_start at ./client.jl:541
jfptr__start_73406.1 at /home/avikpal/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
jl_apply at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
true_main at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/jlapi.c:900
jl_repl_entrypoint at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/jlapi.c:1059
main at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/cli/loader_exe.c:58
unknown function (ip: 0x7a65341e2e07)
__libc_start_main at /usr/lib/libc.so.6 (unknown line)
unknown function (ip: 0x4010b8)
Allocations: 48616895 (Pool: 48615481; Big: 1414); GC: 47
[1]    1417783 illegal hardware instruction (core dumped)  julia --project=envs --threads=4 --check-bounds=yes

Copy link
Collaborator Author

@avik-pal avik-pal Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using this trick #369 (comment) seems to work correctly (though this has other weird edge-cases)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is definitely a general issue with any kind of recursion of an overlayed function. Do we have a way to force using the NativeInterpreter?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ccing @gbaraldi and @aviatesk

function $(overload_randfun)(rng::AbstractRNG, args...)
# XXX: Ideally the following should just work but currently it gives an illegal
# instruction error. Maybe an issue with Julia's AbsInt?
# seed_uint64 = rand(rng, UInt64, 2)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same problem here. calling this leads to the illegal instruction

@avik-pal avik-pal requested a review from wsmoses December 18, 2024 03:21
src/Overlay.jl Outdated
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T, dims)
end
error("Reactant doesn't support sampling of $(T) with the current interpreter.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
error("Reactant doesn't support sampling of $(T) with the current interpreter.")
return error(
"Reactant doesn't support sampling of $(T) with the current interpreter."
)

src/Overlay.jl Outdated
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...)
end
error("Reactant doesn't support sampling of $(T) with the current interpreter.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
error("Reactant doesn't support sampling of $(T) with the current interpreter.")
return error(
"Reactant doesn't support sampling of $(T) with the current interpreter."
)

src/Overlay.jl Outdated
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T)
end
error("Reactant doesn't support sampling of $(T) with the current interpreter.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
error("Reactant doesn't support sampling of $(T) with the current interpreter.")
return error(
"Reactant doesn't support sampling of $(T) with the current interpreter."
)

@avik-pal
Copy link
Collaborator Author

I am working around the AbsInt issues for now, but those would be nice to sort out, especially for handling the wrapped arrays.

Also, I am explicitly throwing errors in the cases where otherwise Julia would crash with an invalid instruction

@avik-pal avik-pal marked this pull request as ready for review December 18, 2024 03:24
@wsmoses wsmoses merged commit 94e9576 into main Dec 18, 2024
35 of 54 checks passed
@wsmoses wsmoses deleted the ap/random_numbers branch December 18, 2024 16:39
wsmoses added a commit that referenced this pull request Dec 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Handling random numbers correctly
3 participants