Skip to content

Commit

Permalink
feat: efficient sampling for non-native RNGs
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 11, 2024
1 parent b9411d7 commit 8673389
Showing 1 changed file with 25 additions and 14 deletions.
39 changes: 25 additions & 14 deletions src/stdlibs/Random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,18 @@ mutable struct TracedRNG <: Random.AbstractRNG
const algorithm::String
end

# TODO: Base.seed!
function Random.seed!(rng::TracedRNG, seed::Number)
seed = reinterpret(UInt64, Random.hash_seed(seed))
# TODO: Using `seed!` inside tracing should generate a TracedRArray
return Random.seed!(rng, ConcreteRArray(seed[1:length(rng.seed)]))
end

function Random.seed!(
rng::TracedRNG, seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}}
)
rng.seed = seed
return rng
end

make_seed() = rand(Random.RandomDevice(), UInt64, 2)

Expand All @@ -21,6 +32,8 @@ end

# XXX: Currently we get an illegal instruction if we don't call Random.default_rng()

# XXX: rng_bit_generator doesn't support floating point types

function Random.rand!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N}
length(A) == 0 && return A
res = Ops.rng_bit_generator(T, rng.seed, [size(A)...]; rng.algorithm)
Expand Down Expand Up @@ -62,11 +75,21 @@ for randfun in (:rand, :randn)
Random.$(randfun!)(A::AnyTracedRArray) = Random.$(randfun!)(default_rng(), A)

# scalars
function Random.$(randfun)(rng::TracedRNG, ::Type{T} = Float64) where {T}
function Random.$(randfun)(rng::TracedRNG, ::Type{T}=Float64) where {T}
A = promote_to(TracedRArray{T,0}, fill(T(0)))
Random.$(randfun!)(rng, A)
return A[]
end

# Non-Traced RNGs if used will lead to disastrous performance. We attempt to fix
# that but with a warning
function Random.$(randfun!)(rng::Random.AbstractRNG, A::AnyTracedRArray)
@warn "`rng` is not a `TracedRNG`. We will use this to seed the `TracedRNG` \
instead of generating samples from this RNG type." maxlog = 1
seed = promote_to(TracedRArray{UInt64,1}, rand(rng, UInt64, 2))
trng = TracedRNG(seed, "DEFAULT")
return Random.$(randfun!)(trng, A)
end
end
end

Expand All @@ -77,18 +100,6 @@ function Random.randn(rng::TracedRNG, T::Random.BitFloatType)
return A[]
end

# # CPU arrays
# function Random.rand!(rng::RNG, A::AbstractArray{T}) where {T}
# B = CuArray{T}(undef, size(A))
# rand!(rng, B)
# copyto!(A, B)
# end
# function Random.randn!(rng::RNG, A::AbstractArray{T}) where {T}
# B = CuArray{T}(undef, size(A))
# randn!(rng, B)
# copyto!(A, B)
# end

# TODO: At some later point we might want to implement the sampler API as well since it
# makes all RNG implementation work by default. From the post-optimize IR we need to
# confirm that the dynamic_update_slice calls are optimized away into a single
Expand Down

0 comments on commit 8673389

Please sign in to comment.