diff --git a/src/Overlay.jl b/src/Overlay.jl index eac99c25f..7c41346e9 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -48,13 +48,9 @@ for randfun in (:rand, :randn, :randexp) if T <: ReactantPrimitive return TracedRandom.$(overload_randfun)(rng, T, dims) end - return error( - "Reactant doesn't support sampling of $(T) with the current interpreter." - ) - # XXX: The following will lead to illegal instruction - # @warn "Reactant doesn't support sampling of $(T) with the current \ - # interpreter. Falling back to native interpreter." maxlog = 1 - # return Random.$(randfun)(rng, T, dims) + @warn "Reactant doesn't support sampling of $(T) with the current \ + interpreter. Falling back to native interpreter." maxlog = 1 + return Base.inferencebarrier(Random.$(randfun))(rng, T, dims) end @reactant_overlay @noinline function Random.$(randfun)( @@ -69,13 +65,9 @@ for randfun in (:rand, :randn, :randexp) if T <: ReactantPrimitive return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...) end - return error( - "Reactant doesn't support sampling of $(T) with the current interpreter." - ) - # XXX: The following will lead to illegal instruction - # @warn "Reactant doesn't support sampling of $(T) with the current \ - # interpreter. Falling back to native interpreter." maxlog = 1 - # return Random.$(randfun)(rng, T, dim1, dims...) + @warn "Reactant doesn't support sampling of $(T) with the current \ + interpreter. Falling back to native interpreter." maxlog = 1 + return Base.inferencebarrier(Random.$(randfun))(rng, T, dim1, dims...) end # scalars @@ -85,13 +77,9 @@ for randfun in (:rand, :randn, :randexp) if T <: ReactantPrimitive return TracedRandom.$(overload_randfun)(rng, T) end - return error( - "Reactant doesn't support sampling of $(T) with the current interpreter." - ) - # XXX: The following will lead to illegal instruction - # @warn "Reactant doesn't support sampling of $(T) with the current \ - # interpreter. Falling back to native interpreter." maxlog = 1 - # return Random.$(randfun)(rng, T) + @warn "Reactant doesn't support sampling of $(T) with the current \ + interpreter. Falling back to native interpreter." maxlog = 1 + return Base.inferencebarrier(Random.$(randfun))(rng, T) end # inplace @@ -100,21 +88,11 @@ for randfun in (:rand, :randn, :randexp) ) return TracedRandom.$(overload_randfun!)(rng, A) end - - # XXX: Uncomment once AbsInt issues with recursive calls are resolved - # @reactant_overlay @noinline function Random.$(randfun!)( - # rng::AbstractRNG, A::AbstractArray - # ) - # @warn "Directly writing to an array using Random.jl functions inside \ - # ReactantInterpreter will generate a constant array in the IR. Use with \ - # caution." maxlog = 1 - # return Random.$(randfun!)(rng, A) - # end end end # LinearAlgebra.jl overloads -## `_mul!` goes through too many layers of abstractions and we aren't able to overload +## `mul!` goes through too many layers of abstractions and we aren't able to overload ## without specializing on every possible combination of types for (cT, aT, bT) in ( (:AbstractVector, :AbstractMatrix, :AbstractVector), diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index 2d783c09c..1b6686d68 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -19,14 +19,8 @@ using ..Reactant: unwrapped_eltype using Random: Random, AbstractRNG -@noinline function make_seed(rng::AbstractRNG=Random.RandomDevice()) - # XXX: We should really be able to call this here. But with our AbsInt it leads to a - # segfault. So we'll just call it in the rand! method. - # return rand(rng, UInt64, 2) - seed = Array{UInt64}(undef, 2) - Random.rand!(rng, seed) - return seed -end +@noinline make_seed(rng::AbstractRNG=Random.RandomDevice()) = + Random.rand!(rng, Vector{UInt64}(undef, 2)) @noinline function Random.seed!(rng::TracedRNG, seed::Number) if seed isa TracedRNumber