From 6516bce84e5de863466736d4b09783bccb3934de Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 18:07:57 -0700 Subject: [PATCH] test: `mean` and `var` --- src/Reactant.jl | 4 ++-- test/basic.jl | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/src/Reactant.jl b/src/Reactant.jl index 73d4fa39e..c8ab789d7 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -57,13 +57,13 @@ function Base.isapprox(x::ConcreteRArray{ElType,(),0}, y; kwargs...) where {ElTy end function Base.isapprox(x, y::ConcreteRArray{ElType,(),0}; kwargs...) where {ElType} - return Base.isapprox(to_float(x), y; kwargs...) + return Base.isapprox(x, to_float(y); kwargs...) end function Base.isapprox( x::ConcreteRArray{ElType,(),0}, y::ConcreteRArray{ElType2,(),0}; kwargs... ) where {ElType,ElType2} - return Base.isapprox(to_float(x), y; kwargs...) + return Base.isapprox(to_float(x), to_float(y); kwargs...) end function Base.print_array(io::IO, X::ConcreteRArray) diff --git a/test/basic.jl b/test/basic.jl index 07870de97..199ed4bd1 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -152,3 +152,38 @@ end @test contains(res_repr, "stablehlo.dot_general") end + +@testset "Statistics: `mean` & `var`" begin + x = randn(2, 3, 4) + x_ca = Reactant.ConcreteRArray(x) + + mean_fn1(x) = mean(x) + mean_fn2(x) = mean(x; dims=1) + mean_fn3(x) = mean(x; dims=(1, 2)) + mean_fn4(x) = mean(x; dims=(1, 3)) + + mean_fn1_compiled = Reactant.compile(mean_fn1, (x_ca,)) + mean_fn2_compiled = Reactant.compile(mean_fn2, (x_ca,)) + mean_fn3_compiled = Reactant.compile(mean_fn3, (x_ca,)) + mean_fn4_compiled = Reactant.compile(mean_fn4, (x_ca,)) + + @test mean_fn1(x) ≈ mean_fn1_compiled(x_ca) + @test mean_fn2(x) ≈ mean_fn2_compiled(x_ca) + @test mean_fn3(x) ≈ mean_fn3_compiled(x_ca) + @test mean_fn4(x) ≈ mean_fn4_compiled(x_ca) + + var_fn1(x) = var(x) + var_fn2(x) = var(x; dims=1) + var_fn3(x) = var(x; dims=(1, 2), corrected=false) + var_fn4(x) = var(x; dims=(1, 3), corrected=false) + + var_fn1_compiled = Reactant.compile(var_fn1, (x_ca,)) + var_fn2_compiled = Reactant.compile(var_fn2, (x_ca,)) + var_fn3_compiled = Reactant.compile(var_fn3, (x_ca,)) + var_fn4_compiled = Reactant.compile(var_fn4, (x_ca,)) + + @test var_fn1(x) ≈ var_fn1_compiled(x_ca) + @test var_fn2(x) ≈ var_fn2_compiled(x_ca) + @test var_fn3(x) ≈ var_fn3_compiled(x_ca) + @test var_fn4(x) ≈ var_fn4_compiled(x_ca) +end