Skip to content

Commit

Permalink
fix: recursion in AbsInt working (#483)
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal authored Jan 7, 2025
1 parent ae305e3 commit 12531c9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 40 deletions.
42 changes: 10 additions & 32 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,9 @@ for randfun in (:rand, :randn, :randexp)
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T, dims)
end
return error(
"Reactant doesn't support sampling of $(T) with the current interpreter."
)
# XXX: The following will lead to illegal instruction
# @warn "Reactant doesn't support sampling of $(T) with the current \
# interpreter. Falling back to native interpreter." maxlog = 1
# return Random.$(randfun)(rng, T, dims)
@warn "Reactant doesn't support sampling of $(T) with the current \
interpreter. Falling back to native interpreter." maxlog = 1
return Base.inferencebarrier(Random.$(randfun))(rng, T, dims)
end

@reactant_overlay @noinline function Random.$(randfun)(
Expand All @@ -69,13 +65,9 @@ for randfun in (:rand, :randn, :randexp)
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...)
end
return error(
"Reactant doesn't support sampling of $(T) with the current interpreter."
)
# XXX: The following will lead to illegal instruction
# @warn "Reactant doesn't support sampling of $(T) with the current \
# interpreter. Falling back to native interpreter." maxlog = 1
# return Random.$(randfun)(rng, T, dim1, dims...)
@warn "Reactant doesn't support sampling of $(T) with the current \
interpreter. Falling back to native interpreter." maxlog = 1
return Base.inferencebarrier(Random.$(randfun))(rng, T, dim1, dims...)
end

# scalars
Expand All @@ -85,13 +77,9 @@ for randfun in (:rand, :randn, :randexp)
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T)
end
return error(
"Reactant doesn't support sampling of $(T) with the current interpreter."
)
# XXX: The following will lead to illegal instruction
# @warn "Reactant doesn't support sampling of $(T) with the current \
# interpreter. Falling back to native interpreter." maxlog = 1
# return Random.$(randfun)(rng, T)
@warn "Reactant doesn't support sampling of $(T) with the current \
interpreter. Falling back to native interpreter." maxlog = 1
return Base.inferencebarrier(Random.$(randfun))(rng, T)
end

# inplace
Expand All @@ -100,21 +88,11 @@ for randfun in (:rand, :randn, :randexp)
)
return TracedRandom.$(overload_randfun!)(rng, A)
end

# XXX: Uncomment once AbsInt issues with recursive calls are resolved
# @reactant_overlay @noinline function Random.$(randfun!)(
# rng::AbstractRNG, A::AbstractArray
# )
# @warn "Directly writing to an array using Random.jl functions inside \
# ReactantInterpreter will generate a constant array in the IR. Use with \
# caution." maxlog = 1
# return Random.$(randfun!)(rng, A)
# end
end
end

# LinearAlgebra.jl overloads
## `_mul!` goes through too many layers of abstractions and we aren't able to overload
## `mul!` goes through too many layers of abstractions and we aren't able to overload
## without specializing on every possible combination of types
for (cT, aT, bT) in (
(:AbstractVector, :AbstractMatrix, :AbstractVector),
Expand Down
10 changes: 2 additions & 8 deletions src/stdlibs/Random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,8 @@ using ..Reactant:
unwrapped_eltype
using Random: Random, AbstractRNG

@noinline function make_seed(rng::AbstractRNG=Random.RandomDevice())
# XXX: We should really be able to call this here. But with our AbsInt it leads to a
# segfault. So we'll just call it in the rand! method.
# return rand(rng, UInt64, 2)
seed = Array{UInt64}(undef, 2)
Random.rand!(rng, seed)
return seed
end
@noinline make_seed(rng::AbstractRNG=Random.RandomDevice()) =
Random.rand!(rng, Vector{UInt64}(undef, 2))

@noinline function Random.seed!(rng::TracedRNG, seed::Number)
if seed isa TracedRNumber
Expand Down

0 comments on commit 12531c9

Please sign in to comment.