From 9e5d5313fbfc5fc862940c03ec933305b3481769 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 24 Jan 2025 14:09:26 -0500 Subject: [PATCH] refactor: move tests around a bit --- test/basic.jl | 228 ----------------------------------------- test/indexing.jl | 223 ++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + test/wrapped_arrays.jl | 6 ++ 4 files changed, 230 insertions(+), 228 deletions(-) create mode 100644 test/indexing.jl diff --git a/test/basic.jl b/test/basic.jl index c3952549a..c0b836345 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -419,37 +419,6 @@ end end end -function update_on_copy(x) - y = x[1:2, 2:4, :] - y[1:1, 1:1, :] = ones(1, 1, 3) - return y -end - -@testset "view / setindex" begin - x = rand(2, 4, 3) - y = copy(x) - x_concrete = Reactant.to_rarray(x) - y_concrete = Reactant.to_rarray(y) - - y1 = update_on_copy(x) - y2 = @jit update_on_copy(x_concrete) - @test x == y - @test x_concrete == y_concrete - @test y1 == y2 - - # function update_inplace(x) - # y = view(x, 1:2, 1:2, :) - # y[1, 1, :] .= 1 - # return y - # end - - # get_indices(x) = x[1:2, 1:2, :] - # get_view(x) = view(x, 1:2, 1:2, :) - - # get_indices_compiled = @compile get_indices(x_concrete) - # get_view_compiled = @compile get_view(x_concrete) -end - function write_with_broadcast1!(x, y) x[1, :, :] .= reshape(y, 4, 3) return x @@ -483,63 +452,6 @@ end @test res[:, 1, :] ≈ view(y, :, 1:3) end -function masking(x) - y = similar(x) - y[1:2, :] .= 0 - y[3:4, :] .= 1 - return y -end - -function masking!(x) - x[1:2, :] .= 0 - x[3:4, :] .= 1 - return x -end - -@testset "setindex! with views" begin - x = rand(4, 4) .+ 2.0 - x_ra = Reactant.to_rarray(x) - - y = masking(x) - y_ra = @jit(masking(x_ra)) - @test y ≈ y_ra - - x_ra_array = Array(x_ra) - @test !(any(iszero, x_ra_array[1, :])) - @test !(any(iszero, x_ra_array[2, :])) - @test !(any(isone, x_ra_array[3, :])) - @test !(any(isone, x_ra_array[4, :])) - - y_ra = @jit(masking!(x_ra)) - @test y ≈ y_ra - - x_ra_array = Array(x_ra) - @test @allowscalar all(iszero, x_ra_array[1, :]) - @test @allowscalar all(iszero, x_ra_array[2, :]) - @test @allowscalar all(isone, x_ra_array[3, :]) - @test @allowscalar all(isone, x_ra_array[4, :]) -end - -function non_contiguous_setindex!(x) - x[[1, 3, 2], [1, 2, 3, 4]] .= 1.0 - return x -end - -@testset "non-contiguous setindex!" begin - x = rand(6, 6) - x_ra = Reactant.to_rarray(x) - - y = @jit(non_contiguous_setindex!(x_ra)) - y = Array(y) - x_ra = Array(x_ra) - @test all(isone, y[1:3, 1:4]) - @test all(isone, x_ra[1:3, 1:4]) - @test !all(isone, y[4:end, :]) - @test !all(isone, x_ra[4:end, :]) - @test !all(isone, y[:, 5:end]) - @test !all(isone, x_ra[:, 5:end]) -end - tuple_byref(x) = (; a=(; b=x)) tuple_byref2(x) = abs2.(x), tuple_byref2(x) @@ -681,19 +593,6 @@ end end end -@testset "dynamic indexing" begin - x = randn(5, 3) - x_ra = Reactant.to_rarray(x) - - idx = [1, 2, 3] - idx_ra = Reactant.to_rarray(idx) - - fn(x, idx) = @allowscalar x[idx, :] - - y = @jit(fn(x_ra, idx_ra)) - @test y ≈ x[idx, :] -end - @testset "aos_to_soa" begin using ArrayInterface @@ -822,102 +721,6 @@ end @test res[2] isa ConcreteRNumber{Float64} end -@testset "non-contiguous indexing" begin - x = rand(4, 4, 3) - x_ra = Reactant.to_rarray(x) - - non_contiguous_indexing1(x) = x[[1, 3, 2], :, :] - non_contiguous_indexing2(x) = x[:, [1, 2, 1, 3], [1, 3]] - - @test @jit(non_contiguous_indexing1(x_ra)) ≈ non_contiguous_indexing1(x) - @test @jit(non_contiguous_indexing2(x_ra)) ≈ non_contiguous_indexing2(x) - - x = rand(4, 2) - x_ra = Reactant.to_rarray(x) - - non_contiguous_indexing3(x) = x[[1, 3, 2], :] - non_contiguous_indexing4(x) = x[:, [1, 2, 2]] - - @test @jit(non_contiguous_indexing3(x_ra)) ≈ non_contiguous_indexing3(x) - @test @jit(non_contiguous_indexing4(x_ra)) ≈ non_contiguous_indexing4(x) - - x = rand(4, 4, 3) - x_ra = Reactant.to_rarray(x) - - non_contiguous_indexing1!(x) = x[[1, 3, 2], :, :] .= 2 - non_contiguous_indexing2!(x) = x[:, [1, 2, 1, 3], [1, 3]] .= 2 - - @jit(non_contiguous_indexing1!(x_ra)) - non_contiguous_indexing1!(x) - @test x_ra ≈ x - - x = rand(4, 4, 3) - x_ra = Reactant.to_rarray(x) - - @jit(non_contiguous_indexing2!(x_ra)) - non_contiguous_indexing2!(x) - @test x_ra ≈ x - - x = rand(4, 2) - x_ra = Reactant.to_rarray(x) - - non_contiguous_indexing3!(x) = x[[1, 3, 2], :] .= 2 - non_contiguous_indexing4!(x) = x[:, [1, 2, 2]] .= 2 - - @jit(non_contiguous_indexing3!(x_ra)) - non_contiguous_indexing3!(x) - @test x_ra ≈ x - - x = rand(4, 2) - x_ra = Reactant.to_rarray(x) - - @jit(non_contiguous_indexing4!(x_ra)) - non_contiguous_indexing4!(x) - @test x_ra ≈ x -end - -@testset "indexing with traced arrays" begin - x = rand(4, 4, 3) - idx1 = [1, 3, 2] - idx3 = [1, 2, 1, 3] - - x_ra = Reactant.to_rarray(x) - idx1_ra = Reactant.to_rarray(idx1) - idx3_ra = Reactant.to_rarray(idx3) - - getindex1(x, idx1) = x[idx1, :, :] - getindex2(x, idx1) = x[:, idx1, :] - getindex3(x, idx3) = x[:, :, idx3] - getindex4(x, idx1, idx3) = x[idx1, :, idx3] - - @test @jit(getindex1(x_ra, idx1_ra)) ≈ getindex1(x, idx1) - @test @jit(getindex2(x_ra, idx1_ra)) ≈ getindex2(x, idx1) - @test @jit(getindex3(x_ra, idx3_ra)) ≈ getindex3(x, idx3) - @test @jit(getindex4(x_ra, idx1_ra, idx3_ra)) ≈ getindex4(x, idx1, idx3) -end - -@testset "linear indexing" begin - x = rand(4, 4, 3) - x_ra = Reactant.to_rarray(x) - - getindex_linear_scalar(x, idx) = @allowscalar x[idx] - - @testset for i in 1:length(x) - @test @jit(getindex_linear_scalar(x_ra, i)) ≈ getindex_linear_scalar(x, i) - @test @jit( - getindex_linear_scalar(x_ra, Reactant.to_rarray(i; track_numbers=Number)) - ) ≈ getindex_linear_scalar(x, i) - end - - idx = rand(1:length(x), 8) - idx_ra = Reactant.to_rarray(idx) - - getindex_linear_vector(x, idx) = x[idx] - - @test @jit(getindex_linear_vector(x_ra, idx_ra)) ≈ getindex_linear_vector(x, idx) - @test @jit(getindex_linear_vector(x_ra, idx)) ≈ getindex_linear_vector(x, idx) -end - @testset "stack" begin x = rand(4, 4) y = rand(4, 4) @@ -985,18 +788,6 @@ end @test @jit(s4(x, y)) isa Any end -@testset "Boolean Indexing" begin - x_ra = Reactant.to_rarray(rand(Float32, 4, 16)) - idxs_ra = Reactant.to_rarray(rand(Bool, 16)) - - fn(x, idxs) = x[:, idxs] - - @test_throws ErrorException @jit(fn(x_ra, idxs_ra)) - - res = @jit fn(x_ra, Array(idxs_ra)) - @test res ≈ fn(Array(x_ra), Array(idxs_ra)) -end - @testset "duplicate args (#226)" begin first_arg(x, y) = x x_ra = Reactant.to_rarray(rand(2, 2)) @@ -1052,25 +843,6 @@ end @test !isfinite(Reactant.to_rarray(Inf; track_numbers=Number)) end -@testset "inconsistent indexing" begin - x_ra = Reactant.to_rarray(rand(3, 4, 3)) - idx_ra = Reactant.to_rarray(1; track_numbers=Number) - - fn1(x) = x[:, :, 1] - fn2(x, idx) = x[:, :, idx] - fn3(x, idx) = x[idx, :, 1] - - @test ndims(@jit(fn1(x_ra))) == 2 - @test ndims(@jit(fn2(x_ra, idx_ra))) == 2 - @test ndims(@jit(fn3(x_ra, idx_ra))) == 1 -end - -@testset "reshaped subarray indexing" begin - fn(x) = view(x, 1:2) .+ 1 - x_ra = Reactant.to_rarray(rand(3, 4, 3)) - @test @jit(fn(x_ra)) == fn(Array(x_ra)) -end - @testset "reduce integers" begin x = rand(Bool, 100) x_ra = Reactant.to_rarray(x) diff --git a/test/indexing.jl b/test/indexing.jl new file mode 100644 index 000000000..33ef3b862 --- /dev/null +++ b/test/indexing.jl @@ -0,0 +1,223 @@ +using LinearAlgebra, Reactant, Test + +function update_on_copy(x) + y = x[1:2, 2:4, :] + y[1:1, 1:1, :] = ones(1, 1, 3) + return y +end + +@testset "view / setindex" begin + x = rand(2, 4, 3) + y = copy(x) + x_concrete = Reactant.to_rarray(x) + y_concrete = Reactant.to_rarray(y) + + y1 = update_on_copy(x) + y2 = @jit update_on_copy(x_concrete) + @test x == y + @test x_concrete == y_concrete + @test y1 == y2 + + # function update_inplace(x) + # y = view(x, 1:2, 1:2, :) + # y[1, 1, :] .= 1 + # return y + # end + + # get_indices(x) = x[1:2, 1:2, :] + # get_view(x) = view(x, 1:2, 1:2, :) + + # get_indices_compiled = @compile get_indices(x_concrete) + # get_view_compiled = @compile get_view(x_concrete) +end + +function masking(x) + y = similar(x) + y[1:2, :] .= 0 + y[3:4, :] .= 1 + return y +end + +function masking!(x) + x[1:2, :] .= 0 + x[3:4, :] .= 1 + return x +end + +@testset "setindex! with views" begin + x = rand(4, 4) .+ 2.0 + x_ra = Reactant.to_rarray(x) + + y = masking(x) + y_ra = @jit(masking(x_ra)) + @test y ≈ y_ra + + x_ra_array = Array(x_ra) + @test !(any(iszero, x_ra_array[1, :])) + @test !(any(iszero, x_ra_array[2, :])) + @test !(any(isone, x_ra_array[3, :])) + @test !(any(isone, x_ra_array[4, :])) + + y_ra = @jit(masking!(x_ra)) + @test y ≈ y_ra + + x_ra_array = Array(x_ra) + @test @allowscalar all(iszero, x_ra_array[1, :]) + @test @allowscalar all(iszero, x_ra_array[2, :]) + @test @allowscalar all(isone, x_ra_array[3, :]) + @test @allowscalar all(isone, x_ra_array[4, :]) +end + +function non_contiguous_setindex!(x) + x[[1, 3, 2], [1, 2, 3, 4]] .= 1.0 + return x +end + +@testset "non-contiguous setindex!" begin + x = rand(6, 6) + x_ra = Reactant.to_rarray(x) + + y = @jit(non_contiguous_setindex!(x_ra)) + y = Array(y) + x_ra = Array(x_ra) + @test all(isone, y[1:3, 1:4]) + @test all(isone, x_ra[1:3, 1:4]) + @test !all(isone, y[4:end, :]) + @test !all(isone, x_ra[4:end, :]) + @test !all(isone, y[:, 5:end]) + @test !all(isone, x_ra[:, 5:end]) +end + +@testset "dynamic indexing" begin + x = randn(5, 3) + x_ra = Reactant.to_rarray(x) + + idx = [1, 2, 3] + idx_ra = Reactant.to_rarray(idx) + + fn(x, idx) = @allowscalar x[idx, :] + + y = @jit(fn(x_ra, idx_ra)) + @test y ≈ x[idx, :] +end + +@testset "non-contiguous indexing" begin + x = rand(4, 4, 3) + x_ra = Reactant.to_rarray(x) + + non_contiguous_indexing1(x) = x[[1, 3, 2], :, :] + non_contiguous_indexing2(x) = x[:, [1, 2, 1, 3], [1, 3]] + + @test @jit(non_contiguous_indexing1(x_ra)) ≈ non_contiguous_indexing1(x) + @test @jit(non_contiguous_indexing2(x_ra)) ≈ non_contiguous_indexing2(x) + + x = rand(4, 2) + x_ra = Reactant.to_rarray(x) + + non_contiguous_indexing3(x) = x[[1, 3, 2], :] + non_contiguous_indexing4(x) = x[:, [1, 2, 2]] + + @test @jit(non_contiguous_indexing3(x_ra)) ≈ non_contiguous_indexing3(x) + @test @jit(non_contiguous_indexing4(x_ra)) ≈ non_contiguous_indexing4(x) + + x = rand(4, 4, 3) + x_ra = Reactant.to_rarray(x) + + non_contiguous_indexing1!(x) = x[[1, 3, 2], :, :] .= 2 + non_contiguous_indexing2!(x) = x[:, [1, 2, 1, 3], [1, 3]] .= 2 + + @jit(non_contiguous_indexing1!(x_ra)) + non_contiguous_indexing1!(x) + @test x_ra ≈ x + + x = rand(4, 4, 3) + x_ra = Reactant.to_rarray(x) + + @jit(non_contiguous_indexing2!(x_ra)) + non_contiguous_indexing2!(x) + @test x_ra ≈ x + + x = rand(4, 2) + x_ra = Reactant.to_rarray(x) + + non_contiguous_indexing3!(x) = x[[1, 3, 2], :] .= 2 + non_contiguous_indexing4!(x) = x[:, [1, 2, 2]] .= 2 + + @jit(non_contiguous_indexing3!(x_ra)) + non_contiguous_indexing3!(x) + @test x_ra ≈ x + + x = rand(4, 2) + x_ra = Reactant.to_rarray(x) + + @jit(non_contiguous_indexing4!(x_ra)) + non_contiguous_indexing4!(x) + @test x_ra ≈ x +end + +@testset "indexing with traced arrays" begin + x = rand(4, 4, 3) + idx1 = [1, 3, 2] + idx3 = [1, 2, 1, 3] + + x_ra = Reactant.to_rarray(x) + idx1_ra = Reactant.to_rarray(idx1) + idx3_ra = Reactant.to_rarray(idx3) + + getindex1(x, idx1) = x[idx1, :, :] + getindex2(x, idx1) = x[:, idx1, :] + getindex3(x, idx3) = x[:, :, idx3] + getindex4(x, idx1, idx3) = x[idx1, :, idx3] + + @test @jit(getindex1(x_ra, idx1_ra)) ≈ getindex1(x, idx1) + @test @jit(getindex2(x_ra, idx1_ra)) ≈ getindex2(x, idx1) + @test @jit(getindex3(x_ra, idx3_ra)) ≈ getindex3(x, idx3) + @test @jit(getindex4(x_ra, idx1_ra, idx3_ra)) ≈ getindex4(x, idx1, idx3) +end + +@testset "linear indexing" begin + x = rand(4, 4, 3) + x_ra = Reactant.to_rarray(x) + + getindex_linear_scalar(x, idx) = @allowscalar x[idx] + + @testset for i in 1:length(x) + @test @jit(getindex_linear_scalar(x_ra, i)) ≈ getindex_linear_scalar(x, i) + @test @jit( + getindex_linear_scalar(x_ra, Reactant.to_rarray(i; track_numbers=Number)) + ) ≈ getindex_linear_scalar(x, i) + end + + idx = rand(1:length(x), 8) + idx_ra = Reactant.to_rarray(idx) + + getindex_linear_vector(x, idx) = x[idx] + + @test @jit(getindex_linear_vector(x_ra, idx_ra)) ≈ getindex_linear_vector(x, idx) + @test @jit(getindex_linear_vector(x_ra, idx)) ≈ getindex_linear_vector(x, idx) +end + +@testset "Boolean Indexing" begin + x_ra = Reactant.to_rarray(rand(Float32, 4, 16)) + idxs_ra = Reactant.to_rarray(rand(Bool, 16)) + + fn(x, idxs) = x[:, idxs] + + @test_throws ErrorException @jit(fn(x_ra, idxs_ra)) + + res = @jit fn(x_ra, Array(idxs_ra)) + @test res ≈ fn(Array(x_ra), Array(idxs_ra)) +end + +@testset "inconsistent indexing" begin + x_ra = Reactant.to_rarray(rand(3, 4, 3)) + idx_ra = Reactant.to_rarray(1; track_numbers=Number) + + fn1(x) = x[:, :, 1] + fn2(x, idx) = x[:, :, idx] + fn3(x, idx) = x[idx, :, 1] + + @test ndims(@jit(fn1(x_ra))) == 2 + @test ndims(@jit(fn2(x_ra, idx_ra))) == 2 + @test ndims(@jit(fn3(x_ra, idx_ra))) == 1 +end diff --git a/test/runtests.jl b/test/runtests.jl index be1775004..9c5e03690 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,6 +57,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Wrapped Arrays" include("wrapped_arrays.jl") @safetestset "Control Flow" include("control_flow.jl") @safetestset "Sorting" include("sorting.jl") + @safetestset "Indexing" include("indexing.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" diff --git a/test/wrapped_arrays.jl b/test/wrapped_arrays.jl index 5069e665d..0ae723f61 100644 --- a/test/wrapped_arrays.jl +++ b/test/wrapped_arrays.jl @@ -228,3 +228,9 @@ end @jit(broadcast_reshaped_array(x_ra, idx1_ra, idx3)) ≈ @jit(broadcast_reshaped_array(x_ra, Array(idx1_ra), Int64(idx3))) end + +@testset "reshaped subarray indexing" begin + fn(x) = view(x, 1:2) .+ 1 + x_ra = Reactant.to_rarray(rand(3, 4, 3)) + @test @jit(fn(x_ra)) == fn(Array(x_ra)) +end