Skip to content

Commit

Permalink
test: mean and var
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 27, 2024
1 parent 008df40 commit 6516bce
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ function Base.isapprox(x::ConcreteRArray{ElType,(),0}, y; kwargs...) where {ElTy
end

function Base.isapprox(x, y::ConcreteRArray{ElType,(),0}; kwargs...) where {ElType}
return Base.isapprox(to_float(x), y; kwargs...)
return Base.isapprox(x, to_float(y); kwargs...)
end

function Base.isapprox(
x::ConcreteRArray{ElType,(),0}, y::ConcreteRArray{ElType2,(),0}; kwargs...
) where {ElType,ElType2}
return Base.isapprox(to_float(x), y; kwargs...)
return Base.isapprox(to_float(x), to_float(y); kwargs...)
end

function Base.print_array(io::IO, X::ConcreteRArray)
Expand Down
35 changes: 35 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,38 @@ end

@test contains(res_repr, "stablehlo.dot_general")
end

@testset "Statistics: `mean` & `var`" begin
x = randn(2, 3, 4)
x_ca = Reactant.ConcreteRArray(x)

mean_fn1(x) = mean(x)
mean_fn2(x) = mean(x; dims=1)
mean_fn3(x) = mean(x; dims=(1, 2))
mean_fn4(x) = mean(x; dims=(1, 3))

mean_fn1_compiled = Reactant.compile(mean_fn1, (x_ca,))
mean_fn2_compiled = Reactant.compile(mean_fn2, (x_ca,))
mean_fn3_compiled = Reactant.compile(mean_fn3, (x_ca,))
mean_fn4_compiled = Reactant.compile(mean_fn4, (x_ca,))

@test mean_fn1(x) mean_fn1_compiled(x_ca)
@test mean_fn2(x) mean_fn2_compiled(x_ca)
@test mean_fn3(x) mean_fn3_compiled(x_ca)
@test mean_fn4(x) mean_fn4_compiled(x_ca)

var_fn1(x) = var(x)
var_fn2(x) = var(x; dims=1)
var_fn3(x) = var(x; dims=(1, 2), corrected=false)
var_fn4(x) = var(x; dims=(1, 3), corrected=false)

var_fn1_compiled = Reactant.compile(var_fn1, (x_ca,))
var_fn2_compiled = Reactant.compile(var_fn2, (x_ca,))
var_fn3_compiled = Reactant.compile(var_fn3, (x_ca,))
var_fn4_compiled = Reactant.compile(var_fn4, (x_ca,))

@test var_fn1(x) var_fn1_compiled(x_ca)
@test var_fn2(x) var_fn2_compiled(x_ca)
@test var_fn3(x) var_fn3_compiled(x_ca)
@test var_fn4(x) var_fn4_compiled(x_ca)
end

0 comments on commit 6516bce

Please sign in to comment.