From 7d17faf180352ec629546cc80e7bb9086d601d88 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 11 Dec 2024 19:24:40 +0530 Subject: [PATCH] fix: handling floating point sampling --- src/Ops.jl | 19 ++++++++++++++++++- src/stdlibs/Random.jl | 6 ++++-- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 189c4ffa5..4ccfd403a 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -987,7 +987,7 @@ function rng_bit_generator( shape; algorithm::String="DEFAULT", location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__), -) where {T<:ReactantPrimitive} +) where {T<:Integer} @assert algorithm in ("DEFAULT", "PHILOX", "THREE_FRY") if algorithm == "PHILOX" @assert length(seed) ∈ (2, 3) @@ -1007,6 +1007,23 @@ function rng_bit_generator( ) end +function rng_bit_generator( + ::Type{T}, + seed::TracedRArray{UInt64,1}, + shape; + algorithm::String="DEFAULT", + location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__), +) where {T<:Union{Float16,Float32,Float64}} + nbits = sizeof(T) * 8 + uT = nbits == 16 ? UInt16 : (nbits == 32 ? UInt32 : UInt64) + (; output_state, output) = rng_bit_generator(uT, seed, shape; algorithm, location) + output = divide( + convert(TracedRArray{T,ndims(output)}, output), + constant(fill(T(typemax(uT)), Tuple(shape)); location), + ) + return (; output_state, output) +end + # functional ops function return_( results::Union{TracedRArray,TracedRNumber}...; diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index fccc583bf..090b72bb3 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -32,8 +32,6 @@ end # XXX: Currently we get an illegal instruction if we don't call Random.default_rng() -# XXX: rng_bit_generator doesn't support floating point types - function Random.rand!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} length(A) == 0 && return A res = Ops.rng_bit_generator(T, rng.seed, [size(A)...]; rng.algorithm) @@ -62,6 +60,10 @@ for randfun in (:rand, :randn) return Random.$(randfun!)(rng, TracedRArray{T,length(dims)}((), nothing, dims)) end + function Random.$(randfun)(rng::TracedRNG, dims::Dims) + return Random.$(randfun)(rng, Float64, dims) + end + function Random.$(randfun)(rng::TracedRNG, dim1::Integer, dims::Integer...) return Random.$(randfun)(rng, Dims((dim1, dims...))) end