Skip to content

Commit

Permalink
feat: inherit scalar indexing functionality from GPUArraysCore (#268)
Browse files Browse the repository at this point in the history
* feat: inherit scalar indexing functionality from GPUArraysCore

* chore: run formatter

* fix: always warn inside tracing unless opt-out

* chore: reexport @allowscalar

* feat: add isapprox for array types

* fix: test fixes for scalar indexing

* fix: allow scalar indexing in gather
  • Loading branch information
avik-pal authored Nov 13, 2024
1 parent 05bd81f commit 5a60501
Show file tree
Hide file tree
Showing 14 changed files with 92 additions and 85 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Expand Down Expand Up @@ -35,6 +36,7 @@ ArrayInterface = "7.10"
CEnum = "0.4, 0.5"
Downloads = "1.6"
Enzyme = "0.13"
GPUArraysCore = "0.1, 0.2"
LinearAlgebra = "1.10"
NNlib = "0.9"
OrderedCollections = "1"
Expand Down
3 changes: 2 additions & 1 deletion ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module ReactantNNlibExt

using NNlib
using GPUArraysCore: @allowscalar
using Reactant:
Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber
using ReactantCore: @trace
Expand Down Expand Up @@ -367,7 +368,7 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr
colons = ntuple(Returns(Colon()), dims)
start_sizes = ntuple(i -> size(src, i), dims)
results = map(CartesianIndices(idxs)) do k
res = src[colons..., Tuple(idxs[k])...]
res = @allowscalar src[colons..., Tuple(idxs[k])...]
res isa TracedRNumber && (res = Reactant.broadcast_to_size(res, (1,)))
return reshape(res, start_sizes..., :)
end
Expand Down
48 changes: 21 additions & 27 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ function Base.convert(::Type{T}, X::ConcreteRArray{ElType,N}) where {T<:Array,El
return data
# XLA.from_row_major(data)
end
Base.Array(x::ConcreteRArray) = convert(Array, x)

function synchronize(x::Union{ConcreteRArray,ConcreteRNumber})
XLA.synced_buffer(x.data)
Expand Down Expand Up @@ -145,6 +146,20 @@ for T in (ConcreteRNumber, ConcreteRArray{<:Any,0})
end
end

function Base.isapprox(x::ConcreteRArray, y::AbstractArray; kwargs...)
return Base.isapprox(convert(Array, x), convert(Array, y); kwargs...)
end
function Base.isapprox(x::AbstractArray, y::ConcreteRArray; kwargs...)
return Base.isapprox(convert(Array, x), convert(Array, y); kwargs...)
end
function Base.isapprox(x::ConcreteRArray, y::ConcreteRArray; kwargs...)
return Base.isapprox(convert(Array, x), convert(Array, y); kwargs...)
end

Base.:(==)(x::ConcreteRArray, y::AbstractArray) = convert(Array, x) == convert(Array, y)
Base.:(==)(x::AbstractArray, y::ConcreteRArray) = convert(Array, x) == convert(Array, y)
Base.:(==)(x::ConcreteRArray, y::ConcreteRArray) = convert(Array, x) == convert(Array, y)

function Base.show(io::IO, X::ConcreteRScalar{T}) where {T}
if X.data == XLA.AsyncEmptyBuffer
println(io, "<Empty buffer>")
Expand All @@ -171,12 +186,11 @@ function Base.show(io::IO, X::ConcreteRArray)
return print(io, "$(typeof(X))($(str))")
end

const getindex_warned = Ref(false)
function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N}
if a.data == XLA.AsyncEmptyBuffer
throw("Cannot getindex from empty buffer")
end
# error("""Scalar indexing is disallowed.""")

XLA.await(a.data)
if XLA.BufferOnCPU(a.data.buffer)
buf = a.data.buffer
Expand All @@ -193,16 +207,8 @@ function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N}
return unsafe_load(ptr, start)
end
end
if !getindex_warned[]
@warn(
"""Performing scalar get-indexing on task $(current_task()).
Invocation resulted in scalar indexing of a ConcreteRArray.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on device, but very slowly on the CPU,
and require expensive copies and synchronization each time and therefore should be avoided."""
)
getindex_warned[] = true
end

GPUArraysCore.assertscalar("getindex(::ConcreteRArray, ::Vararg{Int, N})")
return convert(Array, a)[args...]
end

Expand All @@ -211,12 +217,11 @@ function mysetindex!(a, v, args::Vararg{Int,N}) where {N}
return nothing
end

const setindex_warned = Ref(false)

function Base.setindex!(a::ConcreteRArray{T}, v, args::Vararg{Int,N}) where {T,N}
if a.data == XLA.AsyncEmptyBuffer
throw("Cannot setindex! to empty buffer")
end

XLA.await(a.data)
if XLA.BufferOnCPU(a.data.buffer)
buf = a.data.buffer
Expand All @@ -234,19 +239,8 @@ function Base.setindex!(a::ConcreteRArray{T}, v, args::Vararg{Int,N}) where {T,N
end
return a
end
if !setindex_warned[]
@warn(
"""Performing scalar set-indexing on task $(current_task()).
Invocation resulted in scalar indexing of a ConcreteRArray.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on device, but very slowly on the CPU,
and require expensive copies and synchronization each time and therefore should be avoided.
This error message will only be printed for the first invocation for brevity.
"""
)
setindex_warned[] = true
end

GPUArraysCore.assertscalar("setindex!(::ConcreteRArray, ::Any, ::Vararg{Int, N})")
fn = Reactant.compile(mysetindex!, (a, v, args...))
fn(a, v, args...)
return a
Expand Down
6 changes: 4 additions & 2 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ using ReactantCore: ReactantCore, @trace, MissingTracedValue

using LinearAlgebra: LinearAlgebra
using Adapt: Adapt, WrappedArray
using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)`

export @allowscalar # re-exported from GPUArraysCore

# auxiliary types and functions
include("OrderedIdDict.jl")
Expand Down Expand Up @@ -114,8 +117,7 @@ function set_default_backend(backend::XLA.Client)
end

function set_default_backend(backend::String)
backend = XLA.backends[backend]
return XLA.default_backend[] = backend
return set_default_backend(XLA.backends[backend])
end

end # module
8 changes: 1 addition & 7 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,7 @@ end
function Base.getindex(
a::TracedRArray{T,N}, index::Vararg{Union{Int,TracedRNumber{Int}},N}
) where {T,N}
@warn(
"""Performing scalar indexing on task $(current_task()).
Invocation resulted in scalar indexing of a TracedRArray.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on device, but very slowly on the CPU,
and require expensive copies and synchronization each time and therefore should be avoided."""
)
GPUArraysCore.assertscalar("getindex(::TracedRArray, ::Vararg{Int, N})")

start_indices = [promote_to(TracedRNumber{Int}, i - 1).mlir_data for i in index]
slice_sizes = [Int64(1) for _ in index]
Expand Down
1 change: 1 addition & 0 deletions src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ function __init__()
end
end
end

return nothing
end

Expand Down
42 changes: 23 additions & 19 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ using Test
using Enzyme
using Statistics

# Reactant.set_default_backend("gpu")

fastmax(x::AbstractArray{T}) where {T} = reduce(max, x; dims=1, init=float(T)(-Inf))

using InteractiveUtils
Expand All @@ -16,7 +14,7 @@ using InteractiveUtils

a = Reactant.ConcreteRArray(x)

c_res = sum(a)
c_res = @allowscalar sum(a)
@test c_res r_res

@test @jit(sum(a)) r_res
Expand All @@ -29,7 +27,7 @@ end

a = Reactant.ConcreteRArray(x)

c_res = fastmax(a)
c_res = @allowscalar fastmax(a)
@test c_res r_res

@test @jit(fastmax(a)) r_res
Expand All @@ -45,7 +43,7 @@ sinexpbc(x) = sinexp.(x)

a = Reactant.ConcreteRArray(x)

c_res = sinexpbc(a)
c_res = @allowscalar sinexpbc(a)
@test c_res r_res

@test @jit(sinexpbc(a)) r_res
Expand Down Expand Up @@ -427,10 +425,10 @@ end
@test y y_ra

x_ra_array = Array(x_ra)
@test all(iszero, x_ra_array[1, :])
@test all(iszero, x_ra_array[2, :])
@test all(isone, x_ra_array[3, :])
@test all(isone, x_ra_array[4, :])
@test @allowscalar all(iszero, x_ra_array[1, :])
@test @allowscalar all(iszero, x_ra_array[2, :])
@test @allowscalar all(isone, x_ra_array[3, :])
@test @allowscalar all(isone, x_ra_array[4, :])
end

tuple_byref(x) = (; a=(; b=x))
Expand Down Expand Up @@ -504,14 +502,14 @@ end

f2 = @compile f1(x_ra)
res2 = f2(Reactant.to_rarray((5, [3.14]); track_numbers=(Number,)))
@test only(res2) 5 * 3.14
@test @allowscalar(only(res2)) 5 * 3.14
@test res2 isa ConcreteRArray

x_ra = Reactant.to_rarray(x)

f3 = @compile f1(x_ra)
res3 = f3(Reactant.to_rarray((5, [3.14])))
@test only(res3) only(f1(x))
@test @allowscalar(only(res3)) only(f1(x))
@test res3 isa ConcreteRArray
end
end
Expand Down Expand Up @@ -544,18 +542,22 @@ end
x_ra = Reactant.to_rarray(x)

y = @jit(clamp!(x_ra, 0.0, 0.25))
@test maximum(y) 0.25
@test minimum(y) 0.0
@test maximum(x_ra) == maximum(y)
@test minimum(x_ra) == minimum(y)
@allowscalar begin
@test maximum(y) 0.25
@test minimum(y) 0.0
@test maximum(x_ra) == maximum(y)
@test minimum(x_ra) == minimum(y)
end

x = randn(2, 3)
x_ra = Reactant.to_rarray(x)

y = @jit(clamp.(x_ra, 0.0, 0.25))
@test maximum(y) 0.25
@test minimum(y) 0.0
@test x_ra x
@allowscalar begin
@test maximum(y) 0.25
@test minimum(y) 0.0
@test x_ra x
end
end

@testset "dynamic indexing" begin
Expand All @@ -565,6 +567,8 @@ end
idx = [1, 2, 3]
idx_ra = Reactant.to_rarray(idx)

y = @jit(getindex(x_ra, idx_ra, :))
fn(x, idx) = @allowscalar x[idx, :]

y = @jit(fn(x_ra, idx_ra))
@test y x[idx, :]
end
10 changes: 6 additions & 4 deletions test/closure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ using Reactant
muler(x) = y -> x * y

@testset "closure" begin
x = Reactant.ConcreteRArray(ones(2, 2))
y = Reactant.ConcreteRArray(ones(2, 2))
x = ones(2, 2)
y = ones(2, 2)
x_ra = Reactant.ConcreteRArray(x)
y_ra = Reactant.ConcreteRArray(y)

f = muler(x)
@test @jit(f(y)) x * y
f = muler(x_ra)
@test @jit(f(y_ra)) x * y
end
17 changes: 10 additions & 7 deletions test/compile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,25 @@ Base.sum(x::NamedTuple{(:a,),Tuple{T}}) where {T<:Reactant.TracedRArray} = (; a=
end

@testset "world-age" begin
a = Reactant.ConcreteRArray(ones(2, 10))
b = Reactant.ConcreteRArray(ones(10, 2))
a = ones(2, 10)
b = ones(10, 2)
a_ra = Reactant.ConcreteRArray(a)
b_ra = Reactant.ConcreteRArray(b)

fworld(x, y) = @jit(*(x, y))
fworld(x, y) = @jit(x * y)

@test fworld(a, b) ones(2, 2) * 10
@test fworld(a_ra, b_ra) ones(2, 2) * 10
end

@testset "type casting & optimized out returns" begin
a = Reactant.ConcreteRArray(rand(2, 10))
a = ones(2, 10)
a_ra = Reactant.ConcreteRArray(a)

ftype1(x) = Float64.(x)
ftype2(x) = Float32.(x)

y1 = @jit ftype1(a)
y2 = @jit ftype2(a)
y1 = @jit ftype1(a_ra)
y2 = @jit ftype2(a_ra)

@test y1 isa Reactant.ConcreteRArray{Float64,2}
@test y2 isa Reactant.ConcreteRArray{Float32,2}
Expand Down
5 changes: 3 additions & 2 deletions test/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ end
end

@testset "complex reduction" begin
x_ra = Reactant.ConcreteRArray(randn(ComplexF32, 10, 10))
@test @jit(sum(abs2, x_ra)) sum(abs2, x_ra)
x = randn(ComplexF32, 10, 10)
x_ra = Reactant.ConcreteRArray(x)
@test @jit(sum(abs2, x_ra)) sum(abs2, x)
end
18 changes: 9 additions & 9 deletions test/control_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -365,20 +365,20 @@ end
x_ra = Reactant.to_rarray(x)

res_ra = @jit(condition10_condition_with_setindex(x_ra))
@test res_ra[1, 1] == -1.0
@test res_ra[2, 1] == -1.0
@test x_ra[1, 1] == -1.0 broken = true
@test x_ra[2, 1] == -1.0 broken = true
@test @allowscalar(res_ra[1, 1]) == -1.0
@test @allowscalar(res_ra[2, 1]) == -1.0
@test @allowscalar(x_ra[1, 1]) == -1.0 broken = true
@test @allowscalar(x_ra[2, 1]) == -1.0 broken = true

x = -rand(2, 10)
x[2, 1] = 0.0
x_ra = Reactant.to_rarray(x)

res_ra = @jit(condition10_condition_with_setindex(x_ra))
@test res_ra[1, 1] == 1.0
@test res_ra[2, 1] == 0.0
@test x_ra[1, 1] == 1.0 broken = true
@test x_ra[2, 1] == 0.0
@test @allowscalar(res_ra[1, 1]) == 1.0
@test @allowscalar(res_ra[2, 1]) == 0.0
@test @allowscalar(x_ra[1, 1]) == 1.0 broken = true
@test @allowscalar(x_ra[2, 1]) == 0.0
end

function condition11_nested_ifff(x, y, z)
Expand Down Expand Up @@ -538,7 +538,7 @@ end
function cumsum!(x)
v = zero(eltype(x))
@trace for i in 1:length(x)
v += x[i]
v += @allowscalar x[i]
x[i] = v
end
return x
Expand Down
10 changes: 6 additions & 4 deletions test/layout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ using Test

# @show [y[1,1], y[1,2], y[2, 1], y[2, 2]]

@test y[1, 1] == x[1, 1]
@test y[1, 2] == x[1, 2]
@test y[2, 1] == x[2, 1]
@test y[2, 2] == x[2, 2]
@allowscalar begin
@test y[1, 1] == x[1, 1]
@test y[1, 2] == x[1, 2]
@test y[2, 1] == x[2, 1]
@test y[2, 2] == x[2, 2]
end
end
Loading

0 comments on commit 5a60501

Please sign in to comment.