Skip to content

Commit

Permalink
feat: allow conversion to numbers (#222)
Browse files Browse the repository at this point in the history
* feat: all conversion to numbers

* fix: preserve device
  • Loading branch information
avik-pal authored Nov 4, 2024
1 parent b57171f commit b0fd2a0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
30 changes: 23 additions & 7 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,36 +13,52 @@ mutable struct ConcreteRNumber{T} <: RNumber{T}
end

function ConcreteRNumber(
data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[]
data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing
) where {T<:Number}
crarray = ConcreteRArray(fill(data); client, idx)
crarray = ConcreteRArray(fill(data); client, idx, device)
return ConcreteRNumber{T}(crarray.data)
end

Base.size(::ConcreteRNumber) = ()

# Ensure the device and client are the same as the input
function Base.float(x::ConcreteRNumber{T}) where {T}
client = XLA.client(x.data)
device = XLA.device(x.data)
return ConcreteRNumber(float(T)(to_number(x)); client, device)
end

# written like this to avoid ambiguity errors
for T in Base.uniontypes(ReactantPrimitive)
@eval (::Type{$(T)})(x::ConcreteRNumber) = convert($T, x)
end

Base.convert(::Type{T}, x::ConcreteRNumber) where {T<:Number} = convert(T, to_number(x))

function ConcreteRArray(
data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[]
data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing
) where {T<:Number}
Base.depwarn(
"ConcreteRArray(data::Number) is deprecated, use ConcreteRNumber(data) instead",
:ConcreteRArray,
)
return ConcreteRArray(fill(data); client, idx)
return ConcreteRArray(fill(data); client, idx, device)
end

const ConcreteRScalar{T} = Union{ConcreteRArray{T,0},ConcreteRNumber{T}}

Adapt.adapt_storage(::Type{T}, x::AbstractArray) where {T<:ConcreteRArray} = T(x)

function ConcreteRArray(
data::Array{T,N}; client=XLA.default_backend[], idx=XLA.default_device_idx[]
data::Array{T,N};
client=XLA.default_backend[],
idx=XLA.default_device_idx[],
device=nothing,
) where {T,N}
device = XLA.ClientGetDevice(client, idx)
device = device === nothing ? XLA.ClientGetDevice(client, idx) : device
return ConcreteRArray{T,N}(
XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, data, device), nothing), size(data)
)
# ConcreteRArray{T, size(data), N}(XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, XLA.to_row_major(data), device), nothing))
end

Base.size(x::ConcreteRArray) = x.shape
Expand Down
13 changes: 13 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -498,3 +498,16 @@ relu(x) = relu.(x)

@test @jit(relu(x_ra)) relu(x)
end

@testset "concrete number to julia number" begin
x = ConcreteRNumber(3.14)
@test Float32(x) isa Float32
@test Float64(x) isa Float64
@test_throws InexactError Int(x)

x = ConcreteRNumber(3)
@test Float32(x) isa Float32
@test Float64(x) isa Float64
@test Int(x) isa Int
@test float(x) isa ConcreteRNumber{Float64}
end

1 comment on commit b0fd2a0

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reactant.jl Benchmarks

Benchmark suite Current: b0fd2a0 Previous: b57171f Ratio
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1256872102 ns 1206417685 ns 1.04
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1333443649 ns 1214871383 ns 1.10
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1185122837 ns 1200999245 ns 0.99
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 3031400324 ns 2749115179 ns 1.10
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Lux 220916486 ns 220152561 ns 1.00
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 7130674872 ns 5954497792 ns 1.20
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant 5034713515 ns 5547013317 ns 0.91
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 5072244215 ns 5953205389 ns 0.85
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 7332305315 ns 7769274161 ns 0.94
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux 33343934482 ns 29042431211 ns 1.15
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1218677413 ns 1466527087 ns 0.83
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1234638573.5 ns 1271259724.5 ns 0.97
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1311424337.5 ns 1220631229.5 ns 1.07
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 2867765274 ns 2955830416 ns 0.97
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Lux 8677702 ns 8809576 ns 0.99
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1555217423.5 ns 1632601737 ns 0.95
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant 1544020815 ns 1615677017 ns 0.96
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1545762311.5 ns 1612188044 ns 0.96
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3286621346 ns 3470065078 ns 0.95
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux 2857108371 ns 2295456922 ns 1.24
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1272248465.5 ns 1454883401 ns 0.87
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1253802424.5 ns 1282969154 ns 0.98
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1213693005 ns 1275643905.5 ns 0.95
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 3281264577 ns 2832382223 ns 1.16
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Lux 22707295.5 ns 22539732.5 ns 1.01
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 2133484502 ns 2223968702 ns 0.96
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant 2121542869 ns 2223076468 ns 0.95
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 2114356826 ns 2233158724 ns 0.95
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 3896755968 ns 4094621805 ns 0.95
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux 6263782616 ns 6063156187.5 ns 1.03
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1227311824.5 ns 1278819728.5 ns 0.96
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1291628216.5 ns 1248655937 ns 1.03
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1244002051 ns 1311559501.5 ns 0.95
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 3118021900 ns 2882409531 ns 1.08
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Lux 7152965 ns 6611350 ns 1.08
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1417433358.5 ns 1465556368 ns 0.97
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant 1405986386 ns 1461960298 ns 0.96
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1391770564 ns 1433405185 ns 0.97
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3124123780 ns 3212094865 ns 0.97
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux 1152850766 ns 1042273863 ns 1.11
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1298963457 ns 1299522709 ns 1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1321624559.5 ns 1266625759.5 ns 1.04
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1228743753.5 ns 1199316274.5 ns 1.02
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 3141106256 ns 2847580453 ns 1.10
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Lux 12274730.5 ns 12144311 ns 1.01
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 1701322034 ns 1732841938 ns 0.98
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant 1696176663 ns 1736374516 ns 0.98
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 1686548008 ns 1737563947 ns 0.97
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3445787389 ns 3537460358 ns 0.97
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux 3140343700.5 ns 3164282300 ns 0.99
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1308923747 ns 1195368902 ns 1.09
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1285000527 ns 1222022759.5 ns 1.05
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1255551465.5 ns 1209980754 ns 1.04
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 3175248554 ns 3057829023 ns 1.04
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Lux 27225045 ns 27300237 ns 1.00
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 2141752079 ns 2217674882 ns 0.97
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant 2120471203 ns 2213440570 ns 0.96
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 2125844215 ns 2212843779 ns 0.96
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3892667998 ns 4023171863 ns 0.97
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux 5059333048.5 ns 5808870489 ns 0.87
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1244546410 ns 1273797452 ns 0.98
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1244188168 ns 1258136707 ns 0.99
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1456657537 ns 1295078250 ns 1.12
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 3045543887 ns 3124884909 ns 0.97
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Lux 52743562.5 ns 52782039 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 2953298455 ns 3152861032 ns 0.94
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant 2971048669 ns 3116586732 ns 0.95
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 3053641042 ns 3125795636 ns 0.98
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 4884156519 ns 5054882478 ns 0.97
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux 9364284339 ns 9000472087 ns 1.04
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1409383874 ns 1296586618 ns 1.09
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1312408615 ns 1284476181 ns 1.02
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1300800117 ns 1212095690 ns 1.07
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 3068426125 ns 3178477034 ns 0.97
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Lux 70749333.5 ns 70924805.5 ns 1.00
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 3143918589 ns 3244919853 ns 0.97
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant 3077221269 ns 3677731020 ns 0.84
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 3183700916 ns 3206184434 ns 0.99
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 5000845941 ns 5266676547 ns 0.95
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux 13624219305 ns 13901530015 ns 0.98
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1281095534 ns 1293777394 ns 0.99
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1305561574 ns 1285741954.5 ns 1.02
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1346306927.5 ns 1200435469 ns 1.12
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 3137623892 ns 2909812931 ns 1.08
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Lux 20759677.5 ns 20593219 ns 1.01
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1859442749 ns 1898829192 ns 0.98
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant 1832501006 ns 1886577590 ns 0.97
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1848302233 ns 1888126659 ns 0.98
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3580451934 ns 3697218074 ns 0.97
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux 3224128269 ns 3235249605 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.