Skip to content

Commit

Permalink
test: use ReTestItems.jl for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 26, 2024
1 parent 2ac8391 commit 53ac644
Show file tree
Hide file tree
Showing 15 changed files with 376 additions and 411 deletions.
18 changes: 15 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ Reactant_jll = "0.0.12"
julia = "1.9"

[extras]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["BenchmarkTools", "Flux", "Hwloc", "InteractiveUtils", "Lux", "MLUtils", "OneHotArrays", "Optimisers", "Random", "ReTestItems", "Statistics", "Test"]
13 changes: 0 additions & 13 deletions test/Project.toml

This file was deleted.

84 changes: 39 additions & 45 deletions test/basic.jl → test/basic_tests.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,36 @@
using Reactant
using Test
using Enzyme
@testsetup module BasicTestSetup

# Reactant.set_default_backend("gpu")
using Enzyme

fastmax(x::AbstractArray{T}) where {T} = reduce(max, x; dims=1, init=float(T)(-Inf))
sinexp(x) = sin(exp(x))
sinexpbc(x) = sinexp.(x)
sumexp(x) = sum(exp, x)
mysoftmax(x) = x .- fastmax(x)

function sumcos(x)
return sum(cos.(x))
end

function grad_ip(x)
dx = Enzyme.make_zero(x)
Enzyme.autodiff(Reverse, sumcos, Active, Duplicated(x, dx))
return dx
end

function resgrad_ip(x)
dx = Enzyme.make_zero(x)
res = Enzyme.autodiff(ReverseWithPrimal, sumcos, Active, Duplicated(x, dx))
return (res, dx)
end

using InteractiveUtils
mul(A, B) = A * B

export fastmax, sinexp, sinexpbc, sumexp, mysoftmax, sumcos, grad_ip, resgrad_ip, mul

end

@testset "2D sum" begin
@testitem "2D sum" setup=[BasicTestSetup] begin
r_res = sum(ones(2, 10))

a = Reactant.ConcreteRArray(ones(2, 10))
Expand All @@ -23,7 +45,7 @@ using InteractiveUtils
@test f_res r_res
end

@testset "Basic reduce max" begin
@testitem "Basic reduce max" setup=[BasicTestSetup] begin
r_res = fastmax(ones(2, 10))

a = Reactant.ConcreteRArray(ones(2, 10))
Expand All @@ -38,10 +60,8 @@ end
@test f_res r_res
end

sinexp(x) = sin(exp(x))
sinexpbc(x) = sinexp.(x)

@testset "Broadcast combined" begin
@testitem "Broadcast combined" setup=[BasicTestSetup] begin
r_res = sinexpbc(ones(2, 10))

a = Reactant.ConcreteRArray(ones(2, 10))
Expand All @@ -56,9 +76,7 @@ sinexpbc(x) = sinexp.(x)
@test f_res r_res
end

sumexp(x) = sum(exp, x)

@testset "Basic mapreduce" begin
@testitem "Basic mapreduce" setup=[BasicTestSetup] begin
x = ones(Float32, 10)
a = Reactant.ConcreteRArray(x)
r_res = sumexp(x)
Expand All @@ -69,49 +87,28 @@ sumexp(x) = sum(exp, x)
@test f_res r_res
end

function mysoftmax!(x)
max_ = fastmax(x)
return x .- max_
end

@testset "Basic softmax" begin
@testitem "Basic softmax" setup=[BasicTestSetup] begin
in = ones(2, 10)
r_res = mysoftmax!(in)
r_res = mysoftmax(in)

in = Reactant.ConcreteRArray(ones(2, 10))

f = Reactant.compile(mysoftmax!, (in,))
f = Reactant.compile(mysoftmax, (in,))

f_res = f(in)

@test f_res r_res
end

@testset "Basic cos" begin
@testitem "Basic cos" setup=[BasicTestSetup] begin
c = Reactant.ConcreteRArray(ones(3, 2))

f = Reactant.compile(cos, (c,))
r = f(c)
@test r cos.(ones(3, 2))
end

function sumcos(x)
return sum(cos.(x))
end

function grad_ip(x)
dx = Enzyme.make_zero(x)
Enzyme.autodiff(Reverse, sumcos, Active, Duplicated(x, dx))
return dx
end

function resgrad_ip(x)
dx = Enzyme.make_zero(x)
res = Enzyme.autodiff(ReverseWithPrimal, sumcos, Active, Duplicated(x, dx))
return (res, dx)
end

@testset "Basic grad cos" begin
@testitem "Basic grad cos" setup=[BasicTestSetup] begin
c = Reactant.ConcreteRArray(ones(3, 2))

f = Reactant.compile(grad_ip, (c,))
Expand All @@ -126,10 +123,7 @@ end
@test r -sin.(ones(3, 2))
end

function mul(A, B)
return A * B
end
@testset "Basic grad cos" begin
@testitem "Basic grad cos mul" setup=[BasicTestSetup] begin
c = Reactant.ConcreteRArray(ones(50, 70))
d = Reactant.ConcreteRArray(ones(70, 30))

Expand All @@ -139,12 +133,12 @@ end
@test r mul(ones(50, 70), ones(70, 30))
end

@testset "ConcreteRArray" begin
@testitem "ConcreteRArray" setup=[BasicTestSetup] begin
c = Reactant.ConcreteRArray(ones(50, 70))
similar(c)
end

@testset "Reactant.@code_hlo" begin
@testitem "Reactant.@code_hlo" setup=[BasicTestSetup] begin
W = Reactant.ConcreteRArray(randn(Float32, 10, 20))
x = Reactant.ConcreteRArray(randn(Float32, 20, 5))
res = Reactant.@code_hlo W * x
Expand Down
58 changes: 0 additions & 58 deletions test/bcast.jl

This file was deleted.

60 changes: 60 additions & 0 deletions test/bcast_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
@testitem "Broadcast" begin
using Reactant.MLIR

@noinline function no(@nospecialize(x))
x = @ccall $(Base.@cfunction(identity, Any, (Any,)))(x::Any)::Any
return x[]::Any
end

mutable struct Data
v::(Reactant.TracedRArray{Float64,S,1} where {S})
end
@noinline function tmp(a, b, d)
@show d
@show typeof(d)
c = d.v
@show typeof(c)

return reshape(a, (4,)) ./ sqrt.(b .+ a)
end

function test()
ctx = MLIR.IR.Context()
Base.append!(Reactant.registry[]; context=ctx)
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid

MLIR.IR.context!(ctx) do
mod = MLIR.IR.Module(MLIR.IR.Location())
modbody = MLIR.IR.body(mod)

in_tys = [MLIR.IR.TensorType([4], MLIR.IR.Type(Float64))]

func = MLIR.Dialects.func.func_(;
sym_name="main_tmp",
function_type=MLIR.IR.FunctionType(in_tys, []),
body=MLIR.IR.Region(),
)

fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for _ in in_tys])
push!(MLIR.IR.region(func, 1), fnbody)

GC.@preserve mod func fnbody begin
MLIR.IR.block!(fnbody) do
a = ones(4)
b = ones(4)
d = Data(
Reactant.TracedRArray{Float64,(4,),1}(
(), MLIR.IR.argument(fnbody, 1)
),
)

return tmp(a, b, d)
end
end

return println(string(mod))
end
end

test()
end
6 changes: 2 additions & 4 deletions test/closure.jl → test/closure_tests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
using Reactant
@testitem "closure" begin
muler(x) = y -> x * y

muler(x) = y -> x * y

@testset "closure" begin
x = Reactant.ConcreteRArray(ones(2, 2))
y = Reactant.ConcreteRArray(ones(2, 2))

Expand Down
9 changes: 4 additions & 5 deletions test/compile.jl → test/compile_tests.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
using Reactant
using Test

Base.sum(x::NamedTuple{(:a,),Tuple{T}}) where {T<:Reactant.TracedRArray} = (; a=sum(x.a))
@testitem "compile" begin
function Base.sum(x::NamedTuple{(:a,),Tuple{T}}) where {T<:Reactant.TracedRArray}
return (; a=sum(x.a))
end

@testset "compile" begin
@testset "create_result" begin
@testset "NamedTuple" begin
x = (; a=rand(4, 3))
Expand Down
Loading

0 comments on commit 53ac644

Please sign in to comment.