diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index f7bf1653f..b676ce5f6 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -13,22 +13,36 @@ mutable struct ConcreteRNumber{T} <: RNumber{T} end function ConcreteRNumber( - data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[] + data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing ) where {T<:Number} - crarray = ConcreteRArray(fill(data); client, idx) + crarray = ConcreteRArray(fill(data); client, idx, device) return ConcreteRNumber{T}(crarray.data) end Base.size(::ConcreteRNumber) = () +# Ensure the device and client are the same as the input +function Base.float(x::ConcreteRNumber{T}) where {T} + client = XLA.client(x.data) + device = XLA.device(x.data) + return ConcreteRNumber(float(T)(to_number(x)); client, device) +end + +# written like this to avoid ambiguity errors +for T in Base.uniontypes(ReactantPrimitive) + @eval (::Type{$(T)})(x::ConcreteRNumber) = convert($T, x) +end + +Base.convert(::Type{T}, x::ConcreteRNumber) where {T<:Number} = convert(T, to_number(x)) + function ConcreteRArray( - data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[] + data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing ) where {T<:Number} Base.depwarn( "ConcreteRArray(data::Number) is deprecated, use ConcreteRNumber(data) instead", :ConcreteRArray, ) - return ConcreteRArray(fill(data); client, idx) + return ConcreteRArray(fill(data); client, idx, device) end const ConcreteRScalar{T} = Union{ConcreteRArray{T,0},ConcreteRNumber{T}} @@ -36,13 +50,15 @@ const ConcreteRScalar{T} = Union{ConcreteRArray{T,0},ConcreteRNumber{T}} Adapt.adapt_storage(::Type{T}, x::AbstractArray) where {T<:ConcreteRArray} = T(x) function ConcreteRArray( - data::Array{T,N}; client=XLA.default_backend[], idx=XLA.default_device_idx[] + data::Array{T,N}; + client=XLA.default_backend[], + idx=XLA.default_device_idx[], + device=nothing, ) where {T,N} - device = XLA.ClientGetDevice(client, idx) + device = device === nothing ? XLA.ClientGetDevice(client, idx) : device return ConcreteRArray{T,N}( XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, data, device), nothing), size(data) ) - # ConcreteRArray{T, size(data), N}(XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, XLA.to_row_major(data), device), nothing)) end Base.size(x::ConcreteRArray) = x.shape diff --git a/test/basic.jl b/test/basic.jl index 007b3a26b..478bd7802 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -498,3 +498,16 @@ relu(x) = relu.(x) @test @jit(relu(x_ra)) ≈ relu(x) end + +@testset "concrete number to julia number" begin + x = ConcreteRNumber(3.14) + @test Float32(x) isa Float32 + @test Float64(x) isa Float64 + @test_throws InexactError Int(x) + + x = ConcreteRNumber(3) + @test Float32(x) isa Float32 + @test Float64(x) isa Float64 + @test Int(x) isa Int + @test float(x) isa ConcreteRNumber{Float64} +end