Skip to content

Commit

Permalink
fix: handling floating point sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 11, 2024
1 parent 8673389 commit 7d17faf
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
19 changes: 18 additions & 1 deletion src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,7 @@ function rng_bit_generator(
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__),
) where {T<:ReactantPrimitive}
) where {T<:Integer}
@assert algorithm in ("DEFAULT", "PHILOX", "THREE_FRY")
if algorithm == "PHILOX"
@assert length(seed) (2, 3)
Expand All @@ -1007,6 +1007,23 @@ function rng_bit_generator(
)
end

function rng_bit_generator(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__),
) where {T<:Union{Float16,Float32,Float64}}
nbits = sizeof(T) * 8
uT = nbits == 16 ? UInt16 : (nbits == 32 ? UInt32 : UInt64)
(; output_state, output) = rng_bit_generator(uT, seed, shape; algorithm, location)
output = divide(
convert(TracedRArray{T,ndims(output)}, output),
constant(fill(T(typemax(uT)), Tuple(shape)); location),
)
return (; output_state, output)
end

# functional ops
function return_(
results::Union{TracedRArray,TracedRNumber}...;
Expand Down
6 changes: 4 additions & 2 deletions src/stdlibs/Random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ end

# XXX: Currently we get an illegal instruction if we don't call Random.default_rng()

# XXX: rng_bit_generator doesn't support floating point types

function Random.rand!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N}
length(A) == 0 && return A
res = Ops.rng_bit_generator(T, rng.seed, [size(A)...]; rng.algorithm)
Expand Down Expand Up @@ -62,6 +60,10 @@ for randfun in (:rand, :randn)
return Random.$(randfun!)(rng, TracedRArray{T,length(dims)}((), nothing, dims))
end

function Random.$(randfun)(rng::TracedRNG, dims::Dims)
return Random.$(randfun)(rng, Float64, dims)
end

function Random.$(randfun)(rng::TracedRNG, dim1::Integer, dims::Integer...)
return Random.$(randfun)(rng, Dims((dim1, dims...)))
end
Expand Down

0 comments on commit 7d17faf

Please sign in to comment.