diff --git a/Project.toml b/Project.toml index 67a795be0..f2ca30990 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" @@ -35,6 +36,7 @@ ArrayInterface = "7.10" CEnum = "0.4, 0.5" Downloads = "1.6" Enzyme = "0.13" +GPUArraysCore = "0.1, 0.2" LinearAlgebra = "1.10" NNlib = "0.9" OrderedCollections = "1" diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index 2f54d4602..07c8a8e83 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -1,6 +1,7 @@ module ReactantNNlibExt using NNlib +using GPUArraysCore: @allowscalar using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber using ReactantCore: @trace @@ -367,7 +368,7 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr colons = ntuple(Returns(Colon()), dims) start_sizes = ntuple(i -> size(src, i), dims) results = map(CartesianIndices(idxs)) do k - res = src[colons..., Tuple(idxs[k])...] + res = @allowscalar src[colons..., Tuple(idxs[k])...] res isa TracedRNumber && (res = Reactant.broadcast_to_size(res, (1,))) return reshape(res, start_sizes..., :) end diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index b676ce5f6..59e36e5f0 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -86,6 +86,7 @@ function Base.convert(::Type{T}, X::ConcreteRArray{ElType,N}) where {T<:Array,El return data # XLA.from_row_major(data) end +Base.Array(x::ConcreteRArray) = convert(Array, x) function synchronize(x::Union{ConcreteRArray,ConcreteRNumber}) XLA.synced_buffer(x.data) @@ -145,6 +146,20 @@ for T in (ConcreteRNumber, ConcreteRArray{<:Any,0}) end end +function Base.isapprox(x::ConcreteRArray, y::AbstractArray; kwargs...) + return Base.isapprox(convert(Array, x), convert(Array, y); kwargs...) +end +function Base.isapprox(x::AbstractArray, y::ConcreteRArray; kwargs...) + return Base.isapprox(convert(Array, x), convert(Array, y); kwargs...) +end +function Base.isapprox(x::ConcreteRArray, y::ConcreteRArray; kwargs...) + return Base.isapprox(convert(Array, x), convert(Array, y); kwargs...) +end + +Base.:(==)(x::ConcreteRArray, y::AbstractArray) = convert(Array, x) == convert(Array, y) +Base.:(==)(x::AbstractArray, y::ConcreteRArray) = convert(Array, x) == convert(Array, y) +Base.:(==)(x::ConcreteRArray, y::ConcreteRArray) = convert(Array, x) == convert(Array, y) + function Base.show(io::IO, X::ConcreteRScalar{T}) where {T} if X.data == XLA.AsyncEmptyBuffer println(io, "") @@ -171,12 +186,11 @@ function Base.show(io::IO, X::ConcreteRArray) return print(io, "$(typeof(X))($(str))") end -const getindex_warned = Ref(false) function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N} if a.data == XLA.AsyncEmptyBuffer throw("Cannot getindex from empty buffer") end - # error("""Scalar indexing is disallowed.""") + XLA.await(a.data) if XLA.BufferOnCPU(a.data.buffer) buf = a.data.buffer @@ -193,16 +207,8 @@ function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N} return unsafe_load(ptr, start) end end - if !getindex_warned[] - @warn( - """Performing scalar get-indexing on task $(current_task()). - Invocation resulted in scalar indexing of a ConcreteRArray. - This is typically caused by calling an iterating implementation of a method. - Such implementations *do not* execute on device, but very slowly on the CPU, - and require expensive copies and synchronization each time and therefore should be avoided.""" - ) - getindex_warned[] = true - end + + GPUArraysCore.assertscalar("getindex(::ConcreteRArray, ::Vararg{Int, N})") return convert(Array, a)[args...] end @@ -211,12 +217,11 @@ function mysetindex!(a, v, args::Vararg{Int,N}) where {N} return nothing end -const setindex_warned = Ref(false) - function Base.setindex!(a::ConcreteRArray{T}, v, args::Vararg{Int,N}) where {T,N} if a.data == XLA.AsyncEmptyBuffer throw("Cannot setindex! to empty buffer") end + XLA.await(a.data) if XLA.BufferOnCPU(a.data.buffer) buf = a.data.buffer @@ -234,19 +239,8 @@ function Base.setindex!(a::ConcreteRArray{T}, v, args::Vararg{Int,N}) where {T,N end return a end - if !setindex_warned[] - @warn( - """Performing scalar set-indexing on task $(current_task()). - Invocation resulted in scalar indexing of a ConcreteRArray. - This is typically caused by calling an iterating implementation of a method. - Such implementations *do not* execute on device, but very slowly on the CPU, - and require expensive copies and synchronization each time and therefore should be avoided. - - This error message will only be printed for the first invocation for brevity. -""" - ) - setindex_warned[] = true - end + + GPUArraysCore.assertscalar("setindex!(::ConcreteRArray, ::Any, ::Vararg{Int, N})") fn = Reactant.compile(mysetindex!, (a, v, args...)) fn(a, v, args...) return a diff --git a/src/Reactant.jl b/src/Reactant.jl index b6bc42f29..cf61ba284 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -4,6 +4,9 @@ using ReactantCore: ReactantCore, @trace, MissingTracedValue using LinearAlgebra: LinearAlgebra using Adapt: Adapt, WrappedArray +using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)` + +export @allowscalar # re-exported from GPUArraysCore # auxiliary types and functions include("OrderedIdDict.jl") @@ -114,8 +117,7 @@ function set_default_backend(backend::XLA.Client) end function set_default_backend(backend::String) - backend = XLA.backends[backend] - return XLA.default_backend[] = backend + return set_default_backend(XLA.backends[backend]) end end # module diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 72f67e106..8377f9900 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -59,13 +59,7 @@ end function Base.getindex( a::TracedRArray{T,N}, index::Vararg{Union{Int,TracedRNumber{Int}},N} ) where {T,N} - @warn( - """Performing scalar indexing on task $(current_task()). -Invocation resulted in scalar indexing of a TracedRArray. -This is typically caused by calling an iterating implementation of a method. -Such implementations *do not* execute on device, but very slowly on the CPU, -and require expensive copies and synchronization each time and therefore should be avoided.""" - ) + GPUArraysCore.assertscalar("getindex(::TracedRArray, ::Vararg{Int, N})") start_indices = [promote_to(TracedRNumber{Int}, i - 1).mlir_data for i in index] slice_sizes = [Int64(1) for _ in index] diff --git a/src/XLA.jl b/src/XLA.jl index 2c98b797f..a8a315a7b 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -130,6 +130,7 @@ function __init__() end end end + return nothing end diff --git a/test/basic.jl b/test/basic.jl index e386ae545..4331b1090 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -3,8 +3,6 @@ using Test using Enzyme using Statistics -# Reactant.set_default_backend("gpu") - fastmax(x::AbstractArray{T}) where {T} = reduce(max, x; dims=1, init=float(T)(-Inf)) using InteractiveUtils @@ -16,7 +14,7 @@ using InteractiveUtils a = Reactant.ConcreteRArray(x) - c_res = sum(a) + c_res = @allowscalar sum(a) @test c_res ≈ r_res @test @jit(sum(a)) ≈ r_res @@ -29,7 +27,7 @@ end a = Reactant.ConcreteRArray(x) - c_res = fastmax(a) + c_res = @allowscalar fastmax(a) @test c_res ≈ r_res @test @jit(fastmax(a)) ≈ r_res @@ -45,7 +43,7 @@ sinexpbc(x) = sinexp.(x) a = Reactant.ConcreteRArray(x) - c_res = sinexpbc(a) + c_res = @allowscalar sinexpbc(a) @test c_res ≈ r_res @test @jit(sinexpbc(a)) ≈ r_res @@ -427,10 +425,10 @@ end @test y ≈ y_ra x_ra_array = Array(x_ra) - @test all(iszero, x_ra_array[1, :]) - @test all(iszero, x_ra_array[2, :]) - @test all(isone, x_ra_array[3, :]) - @test all(isone, x_ra_array[4, :]) + @test @allowscalar all(iszero, x_ra_array[1, :]) + @test @allowscalar all(iszero, x_ra_array[2, :]) + @test @allowscalar all(isone, x_ra_array[3, :]) + @test @allowscalar all(isone, x_ra_array[4, :]) end tuple_byref(x) = (; a=(; b=x)) @@ -504,14 +502,14 @@ end f2 = @compile f1(x_ra) res2 = f2(Reactant.to_rarray((5, [3.14]); track_numbers=(Number,))) - @test only(res2) ≈ 5 * 3.14 + @test @allowscalar(only(res2)) ≈ 5 * 3.14 @test res2 isa ConcreteRArray x_ra = Reactant.to_rarray(x) f3 = @compile f1(x_ra) res3 = f3(Reactant.to_rarray((5, [3.14]))) - @test only(res3) ≈ only(f1(x)) + @test @allowscalar(only(res3)) ≈ only(f1(x)) @test res3 isa ConcreteRArray end end @@ -544,18 +542,22 @@ end x_ra = Reactant.to_rarray(x) y = @jit(clamp!(x_ra, 0.0, 0.25)) - @test maximum(y) ≤ 0.25 - @test minimum(y) ≥ 0.0 - @test maximum(x_ra) == maximum(y) - @test minimum(x_ra) == minimum(y) + @allowscalar begin + @test maximum(y) ≤ 0.25 + @test minimum(y) ≥ 0.0 + @test maximum(x_ra) == maximum(y) + @test minimum(x_ra) == minimum(y) + end x = randn(2, 3) x_ra = Reactant.to_rarray(x) y = @jit(clamp.(x_ra, 0.0, 0.25)) - @test maximum(y) ≤ 0.25 - @test minimum(y) ≥ 0.0 - @test x_ra ≈ x + @allowscalar begin + @test maximum(y) ≤ 0.25 + @test minimum(y) ≥ 0.0 + @test x_ra ≈ x + end end @testset "dynamic indexing" begin @@ -565,6 +567,8 @@ end idx = [1, 2, 3] idx_ra = Reactant.to_rarray(idx) - y = @jit(getindex(x_ra, idx_ra, :)) + fn(x, idx) = @allowscalar x[idx, :] + + y = @jit(fn(x_ra, idx_ra)) @test y ≈ x[idx, :] end diff --git a/test/closure.jl b/test/closure.jl index b5ad5626a..d6eb35008 100644 --- a/test/closure.jl +++ b/test/closure.jl @@ -4,9 +4,11 @@ using Reactant muler(x) = y -> x * y @testset "closure" begin - x = Reactant.ConcreteRArray(ones(2, 2)) - y = Reactant.ConcreteRArray(ones(2, 2)) + x = ones(2, 2) + y = ones(2, 2) + x_ra = Reactant.ConcreteRArray(x) + y_ra = Reactant.ConcreteRArray(y) - f = muler(x) - @test @jit(f(y)) ≈ x * y + f = muler(x_ra) + @test @jit(f(y_ra)) ≈ x * y end diff --git a/test/compile.jl b/test/compile.jl index fbf88ffad..155fbafb5 100644 --- a/test/compile.jl +++ b/test/compile.jl @@ -16,22 +16,25 @@ Base.sum(x::NamedTuple{(:a,),Tuple{T}}) where {T<:Reactant.TracedRArray} = (; a= end @testset "world-age" begin - a = Reactant.ConcreteRArray(ones(2, 10)) - b = Reactant.ConcreteRArray(ones(10, 2)) + a = ones(2, 10) + b = ones(10, 2) + a_ra = Reactant.ConcreteRArray(a) + b_ra = Reactant.ConcreteRArray(b) - fworld(x, y) = @jit(*(x, y)) + fworld(x, y) = @jit(x * y) - @test fworld(a, b) ≈ ones(2, 2) * 10 + @test fworld(a_ra, b_ra) ≈ ones(2, 2) * 10 end @testset "type casting & optimized out returns" begin - a = Reactant.ConcreteRArray(rand(2, 10)) + a = ones(2, 10) + a_ra = Reactant.ConcreteRArray(a) ftype1(x) = Float64.(x) ftype2(x) = Float32.(x) - y1 = @jit ftype1(a) - y2 = @jit ftype2(a) + y1 = @jit ftype1(a_ra) + y2 = @jit ftype2(a_ra) @test y1 isa Reactant.ConcreteRArray{Float64,2} @test y2 isa Reactant.ConcreteRArray{Float32,2} diff --git a/test/complex.jl b/test/complex.jl index d66652d4f..3bf19a051 100644 --- a/test/complex.jl +++ b/test/complex.jl @@ -99,6 +99,7 @@ end end @testset "complex reduction" begin - x_ra = Reactant.ConcreteRArray(randn(ComplexF32, 10, 10)) - @test @jit(sum(abs2, x_ra)) ≈ sum(abs2, x_ra) + x = randn(ComplexF32, 10, 10) + x_ra = Reactant.ConcreteRArray(x) + @test @jit(sum(abs2, x_ra)) ≈ sum(abs2, x) end diff --git a/test/control_flow.jl b/test/control_flow.jl index 1254e0c00..9b4ee9fcf 100644 --- a/test/control_flow.jl +++ b/test/control_flow.jl @@ -365,20 +365,20 @@ end x_ra = Reactant.to_rarray(x) res_ra = @jit(condition10_condition_with_setindex(x_ra)) - @test res_ra[1, 1] == -1.0 - @test res_ra[2, 1] == -1.0 - @test x_ra[1, 1] == -1.0 broken = true - @test x_ra[2, 1] == -1.0 broken = true + @test @allowscalar(res_ra[1, 1]) == -1.0 + @test @allowscalar(res_ra[2, 1]) == -1.0 + @test @allowscalar(x_ra[1, 1]) == -1.0 broken = true + @test @allowscalar(x_ra[2, 1]) == -1.0 broken = true x = -rand(2, 10) x[2, 1] = 0.0 x_ra = Reactant.to_rarray(x) res_ra = @jit(condition10_condition_with_setindex(x_ra)) - @test res_ra[1, 1] == 1.0 - @test res_ra[2, 1] == 0.0 - @test x_ra[1, 1] == 1.0 broken = true - @test x_ra[2, 1] == 0.0 + @test @allowscalar(res_ra[1, 1]) == 1.0 + @test @allowscalar(res_ra[2, 1]) == 0.0 + @test @allowscalar(x_ra[1, 1]) == 1.0 broken = true + @test @allowscalar(x_ra[2, 1]) == 0.0 end function condition11_nested_ifff(x, y, z) @@ -538,7 +538,7 @@ end function cumsum!(x) v = zero(eltype(x)) @trace for i in 1:length(x) - v += x[i] + v += @allowscalar x[i] x[i] = v end return x diff --git a/test/layout.jl b/test/layout.jl index fbcaaad45..91c5d5e70 100644 --- a/test/layout.jl +++ b/test/layout.jl @@ -12,8 +12,10 @@ using Test # @show [y[1,1], y[1,2], y[2, 1], y[2, 2]] - @test y[1, 1] == x[1, 1] - @test y[1, 2] == x[1, 2] - @test y[2, 1] == x[2, 1] - @test y[2, 2] == x[2, 2] + @allowscalar begin + @test y[1, 1] == x[1, 1] + @test y[1, 2] == x[1, 2] + @test y[2, 1] == x[2, 1] + @test y[2, 2] == x[2, 2] + end end diff --git a/test/struct.jl b/test/struct.jl index c3a93e3b4..82263a61f 100644 --- a/test/struct.jl +++ b/test/struct.jl @@ -88,6 +88,6 @@ end # TODO this should be able to run without problems, but crashes @test_broken isapprox(@jit(identity(x3)), x3) - @test isapprox(sum(x3), only(@jit(sum(x3)))) + @test isapprox(@allowscalar(sum(x3)), only(@jit(sum(x3)))) end end diff --git a/test/wrapped_arrays.jl b/test/wrapped_arrays.jl index 6e6c3630f..f2aafbe2c 100644 --- a/test/wrapped_arrays.jl +++ b/test/wrapped_arrays.jl @@ -20,7 +20,7 @@ end x = rand(4, 4, 3) x_ra = Reactant.to_rarray(x) - @test @jit(view_getindex_1(x_ra)) ≈ view_getindex_1(x) + @test @allowscalar(@jit(view_getindex_1(x_ra))) ≈ view_getindex_1(x) @test @jit(view_getindex_2(x_ra)) ≈ view_getindex_2(x) @test @jit(view_getindex_3(x_ra)) ≈ view_getindex_3(x) end @@ -97,5 +97,6 @@ end @testset "PermutedDimsArray" begin x = rand(4, 4, 3) x_ra = Reactant.to_rarray(x) - @test @jit(bypass_permutedims(x_ra)) ≈ bypass_permutedims(x) + y_ra = @jit(bypass_permutedims(x_ra)) + @test @allowscalar(Array(y_ra)) ≈ bypass_permutedims(x) end