diff --git a/Project.toml b/Project.toml index 7aa08abd1..42c4c3a87 100644 --- a/Project.toml +++ b/Project.toml @@ -31,6 +31,18 @@ Reactant_jll = "0.0.12" julia = "1.9" [extras] -Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["BenchmarkTools", "Flux", "Hwloc", "InteractiveUtils", "Lux", "MLUtils", "OneHotArrays", "Optimisers", "Random", "ReTestItems", "Statistics", "Test"] diff --git a/test/Project.toml b/test/Project.toml deleted file mode 100644 index 10bc878e9..000000000 --- a/test/Project.toml +++ /dev/null @@ -1,13 +0,0 @@ -[deps] -BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" -OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" -Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/basic.jl b/test/basic_tests.jl similarity index 74% rename from test/basic.jl rename to test/basic_tests.jl index 07870de97..9df410027 100644 --- a/test/basic.jl +++ b/test/basic_tests.jl @@ -1,14 +1,36 @@ -using Reactant -using Test -using Enzyme +@testsetup module BasicTestSetup -# Reactant.set_default_backend("gpu") +using Enzyme fastmax(x::AbstractArray{T}) where {T} = reduce(max, x; dims=1, init=float(T)(-Inf)) +sinexp(x) = sin(exp(x)) +sinexpbc(x) = sinexp.(x) +sumexp(x) = sum(exp, x) +mysoftmax(x) = x .- fastmax(x) + +function sumcos(x) + return sum(cos.(x)) +end + +function grad_ip(x) + dx = Enzyme.make_zero(x) + Enzyme.autodiff(Reverse, sumcos, Active, Duplicated(x, dx)) + return dx +end + +function resgrad_ip(x) + dx = Enzyme.make_zero(x) + res = Enzyme.autodiff(ReverseWithPrimal, sumcos, Active, Duplicated(x, dx)) + return (res, dx) +end -using InteractiveUtils +mul(A, B) = A * B + +export fastmax, sinexp, sinexpbc, sumexp, mysoftmax, sumcos, grad_ip, resgrad_ip, mul + +end -@testset "2D sum" begin +@testitem "2D sum" setup=[BasicTestSetup] begin r_res = sum(ones(2, 10)) a = Reactant.ConcreteRArray(ones(2, 10)) @@ -23,7 +45,7 @@ using InteractiveUtils @test f_res ≈ r_res end -@testset "Basic reduce max" begin +@testitem "Basic reduce max" setup=[BasicTestSetup] begin r_res = fastmax(ones(2, 10)) a = Reactant.ConcreteRArray(ones(2, 10)) @@ -38,10 +60,8 @@ end @test f_res ≈ r_res end -sinexp(x) = sin(exp(x)) -sinexpbc(x) = sinexp.(x) -@testset "Broadcast combined" begin +@testitem "Broadcast combined" setup=[BasicTestSetup] begin r_res = sinexpbc(ones(2, 10)) a = Reactant.ConcreteRArray(ones(2, 10)) @@ -56,9 +76,7 @@ sinexpbc(x) = sinexp.(x) @test f_res ≈ r_res end -sumexp(x) = sum(exp, x) - -@testset "Basic mapreduce" begin +@testitem "Basic mapreduce" setup=[BasicTestSetup] begin x = ones(Float32, 10) a = Reactant.ConcreteRArray(x) r_res = sumexp(x) @@ -69,25 +87,20 @@ sumexp(x) = sum(exp, x) @test f_res ≈ r_res end -function mysoftmax!(x) - max_ = fastmax(x) - return x .- max_ -end - -@testset "Basic softmax" begin +@testitem "Basic softmax" setup=[BasicTestSetup] begin in = ones(2, 10) - r_res = mysoftmax!(in) + r_res = mysoftmax(in) in = Reactant.ConcreteRArray(ones(2, 10)) - f = Reactant.compile(mysoftmax!, (in,)) + f = Reactant.compile(mysoftmax, (in,)) f_res = f(in) @test f_res ≈ r_res end -@testset "Basic cos" begin +@testitem "Basic cos" setup=[BasicTestSetup] begin c = Reactant.ConcreteRArray(ones(3, 2)) f = Reactant.compile(cos, (c,)) @@ -95,23 +108,7 @@ end @test r ≈ cos.(ones(3, 2)) end -function sumcos(x) - return sum(cos.(x)) -end - -function grad_ip(x) - dx = Enzyme.make_zero(x) - Enzyme.autodiff(Reverse, sumcos, Active, Duplicated(x, dx)) - return dx -end - -function resgrad_ip(x) - dx = Enzyme.make_zero(x) - res = Enzyme.autodiff(ReverseWithPrimal, sumcos, Active, Duplicated(x, dx)) - return (res, dx) -end - -@testset "Basic grad cos" begin +@testitem "Basic grad cos" setup=[BasicTestSetup] begin c = Reactant.ConcreteRArray(ones(3, 2)) f = Reactant.compile(grad_ip, (c,)) @@ -126,10 +123,7 @@ end @test r ≈ -sin.(ones(3, 2)) end -function mul(A, B) - return A * B -end -@testset "Basic grad cos" begin +@testitem "Basic grad cos mul" setup=[BasicTestSetup] begin c = Reactant.ConcreteRArray(ones(50, 70)) d = Reactant.ConcreteRArray(ones(70, 30)) @@ -139,12 +133,12 @@ end @test r ≈ mul(ones(50, 70), ones(70, 30)) end -@testset "ConcreteRArray" begin +@testitem "ConcreteRArray" setup=[BasicTestSetup] begin c = Reactant.ConcreteRArray(ones(50, 70)) similar(c) end -@testset "Reactant.@code_hlo" begin +@testitem "Reactant.@code_hlo" setup=[BasicTestSetup] begin W = Reactant.ConcreteRArray(randn(Float32, 10, 20)) x = Reactant.ConcreteRArray(randn(Float32, 20, 5)) res = Reactant.@code_hlo W * x diff --git a/test/bcast.jl b/test/bcast.jl deleted file mode 100644 index 9d05200ad..000000000 --- a/test/bcast.jl +++ /dev/null @@ -1,58 +0,0 @@ - -using Reactant - -using Reactant.MLIR - -@noinline function no(@nospecialize(x)) - x = @ccall $(Base.@cfunction(identity, Any, (Any,)))(x::Any)::Any - return x[]::Any -end - -mutable struct Data - v::(Reactant.TracedRArray{Float64,S,1} where {S}) -end -@noinline function tmp(a, b, d) - @show d - @show typeof(d) - c = d.v - @show typeof(c) - - return reshape(a, (4,)) ./ sqrt.(b .+ a) -end - -function test() - ctx = MLIR.IR.Context() - Base.append!(Reactant.registry[]; context=ctx) - @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid - - MLIR.IR.context!(ctx) do - mod = MLIR.IR.Module(MLIR.IR.Location()) - modbody = MLIR.IR.body(mod) - - in_tys = [MLIR.IR.TensorType([4], MLIR.IR.Type(Float64))] - - func = MLIR.Dialects.func.func_(; - sym_name="main_tmp", - function_type=MLIR.IR.FunctionType(in_tys, []), - body=MLIR.IR.Region(), - ) - - fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for _ in in_tys]) - push!(MLIR.IR.region(func, 1), fnbody) - - GC.@preserve mod func fnbody begin - MLIR.IR.block!(fnbody) do - a = ones(4) - b = ones(4) - d = Data( - Reactant.TracedRArray{Float64,(4,),1}((), MLIR.IR.argument(fnbody, 1)) - ) - - return tmp(a, b, d) - end - end - - return println(string(mod)) - end -end -test() diff --git a/test/bcast_tests.jl b/test/bcast_tests.jl new file mode 100644 index 000000000..946492c45 --- /dev/null +++ b/test/bcast_tests.jl @@ -0,0 +1,60 @@ +@testitem "Broadcast" begin + using Reactant.MLIR + + @noinline function no(@nospecialize(x)) + x = @ccall $(Base.@cfunction(identity, Any, (Any,)))(x::Any)::Any + return x[]::Any + end + + mutable struct Data + v::(Reactant.TracedRArray{Float64,S,1} where {S}) + end + @noinline function tmp(a, b, d) + @show d + @show typeof(d) + c = d.v + @show typeof(c) + + return reshape(a, (4,)) ./ sqrt.(b .+ a) + end + + function test() + ctx = MLIR.IR.Context() + Base.append!(Reactant.registry[]; context=ctx) + @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid + + MLIR.IR.context!(ctx) do + mod = MLIR.IR.Module(MLIR.IR.Location()) + modbody = MLIR.IR.body(mod) + + in_tys = [MLIR.IR.TensorType([4], MLIR.IR.Type(Float64))] + + func = MLIR.Dialects.func.func_(; + sym_name="main_tmp", + function_type=MLIR.IR.FunctionType(in_tys, []), + body=MLIR.IR.Region(), + ) + + fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for _ in in_tys]) + push!(MLIR.IR.region(func, 1), fnbody) + + GC.@preserve mod func fnbody begin + MLIR.IR.block!(fnbody) do + a = ones(4) + b = ones(4) + d = Data( + Reactant.TracedRArray{Float64,(4,),1}( + (), MLIR.IR.argument(fnbody, 1) + ), + ) + + return tmp(a, b, d) + end + end + + return println(string(mod)) + end + end + + test() +end diff --git a/test/closure.jl b/test/closure_tests.jl similarity index 73% rename from test/closure.jl rename to test/closure_tests.jl index b38698882..cd39447fb 100644 --- a/test/closure.jl +++ b/test/closure_tests.jl @@ -1,8 +1,6 @@ -using Reactant +@testitem "closure" begin + muler(x) = y -> x * y -muler(x) = y -> x * y - -@testset "closure" begin x = Reactant.ConcreteRArray(ones(2, 2)) y = Reactant.ConcreteRArray(ones(2, 2)) diff --git a/test/compile.jl b/test/compile_tests.jl similarity index 70% rename from test/compile.jl rename to test/compile_tests.jl index 4578a45e3..3ff857f26 100644 --- a/test/compile.jl +++ b/test/compile_tests.jl @@ -1,9 +1,8 @@ -using Reactant -using Test - -Base.sum(x::NamedTuple{(:a,),Tuple{T}}) where {T<:Reactant.TracedRArray} = (; a=sum(x.a)) +@testitem "compile" begin + function Base.sum(x::NamedTuple{(:a,),Tuple{T}}) where {T<:Reactant.TracedRArray} + return (; a=sum(x.a)) + end -@testset "compile" begin @testset "create_result" begin @testset "NamedTuple" begin x = (; a=rand(4, 3)) diff --git a/test/flux_tests.jl b/test/flux_tests.jl new file mode 100644 index 000000000..76bd6324c --- /dev/null +++ b/test/flux_tests.jl @@ -0,0 +1,70 @@ +@testitem "Flux" begin + using Flux, BenchmarkTools, Statistics + + # Generate some data for the XOR problem: vectors of length 2, as columns of a matrix: + noisy = rand(Float32, 2, 1000) # 2×1000 Matrix{Float32} + truth = [xor(col[1] > 0.5, col[2] > 0.5) for col in eachcol(noisy)] # 1000-element Vector{Bool} + + # Define our model, a multi-layer perceptron with one hidden layer of size 3: + model = Chain( + Dense(2 => 3, tanh), # activation function inside layer + BatchNorm(3), + Dense(3 => 2), + softmax, + ) + + origout = model(noisy) + @show origout[3] + @btime model(noisy) + + cmodel = Reactant.make_tracer(IdDict(), model, (), Reactant.ArrayToConcrete) + cnoisy = Reactant.ConcreteRArray(noisy) + + # c_o = cmodel(noisy) + # @show c_o[3] + # @btime cmodel(noisy) + # + # o_c = model(cnoisy) + # @show o_c[3] + # @btime model(cnoisy) + # + # c_c = cmodel(cnoisy) + # @show c_c[3] + # @btime cmodel(cnoisy) + f = Reactant.compile((a, b) -> a(b), (cmodel, cnoisy)) + + # using InteractiveUtils + # @show @code_typed f(cmodel,cnoisy) + # @show @code_llvm f(cmodel,cnoisy) + comp = f(cmodel, cnoisy) + @show comp[3] + @btime f(cmodel, cnoisy) + + # The code below doesn't test anything Reactant-specific so commented out + + # # To train the model, we use batches of 64 samples, and one-hot encoding: + # target = Flux.onehotbatch(truth, [true, false]) # 2×1000 OneHotMatrix + # loader = Flux.DataLoader((noisy, target); batchsize=64, shuffle=true); + # # 16-element DataLoader with first element: (2×64 Matrix{Float32}, 2×64 OneHotMatrix) + + # optim = Flux.setup(Flux.Adam(0.01), model) # will store optimiser momentum, etc. + + # # Training loop, using the whole data set 1000 times: + # losses = [] + # for epoch in 1:1_000 + # for (x, y) in loader + # loss, grads = Flux.withgradient(model) do m + # # Evaluate model and loss inside gradient context: + # y_hat = m(x) + # return Flux.crossentropy(y_hat, y) + # end + # Flux.update!(optim, model, grads[1]) + # push!(losses, loss) # logging, outside gradient context + # end + # end + + # optim # parameters, momenta and output have all changed + # out2 = model(noisy) # first row is prob. of true, second row p(false) + + # mean((out2[1, :] .> 0.5) .== truth) # accuracy 94% so far! +end diff --git a/test/layout.jl b/test/layout_tests.jl similarity index 85% rename from test/layout.jl rename to test/layout_tests.jl index fbcaaad45..f5a441a4d 100644 --- a/test/layout.jl +++ b/test/layout_tests.jl @@ -1,7 +1,4 @@ -using Reactant -using Test - -@testset "Layout" begin +@testitem "Layout" begin x = reshape([1.0, 2.0, 3.0, 4.0], (2, 2)) y = Reactant.ConcreteRArray(x) diff --git a/test/lux_tests.jl b/test/lux_tests.jl new file mode 100644 index 000000000..66286b2b0 --- /dev/null +++ b/test/lux_tests.jl @@ -0,0 +1,79 @@ +@testitem "Lux" skip=:(VERSION < v"1.10") begin + using Lux, Random, Statistics, Enzyme, Test, BenchmarkTools + using MLUtils, OneHotArrays, Optimisers + + # Generate some data for the XOR problem: vectors of length 2, as columns of a matrix: + noisy = rand(Float32, 2, 1000) # 2×1000 Matrix{Float32} + truth = [xor(col[1] > 0.5, col[2] > 0.5) for col in eachcol(noisy)] # 1000-element Vector{Bool} + + # Define our model, a multi-layer perceptron with one hidden layer of size 3: + model = Lux.Chain( + Lux.Dense(2 => 3, tanh), # activation function inside layer + Lux.Dense(3 => 2), + softmax, + ) + ps, st = Lux.setup(Xoshiro(123), model) + + origout, _ = model(noisy, ps, st) + @show origout[3] + @btime model($noisy, $ps, $st) # 52.731 μs (10 allocations: 32.03 KiB) + + cmodel = Reactant.make_tracer(IdDict(), model, (), Reactant.ArrayToConcrete) + cps = Reactant.make_tracer(IdDict(), ps, (), Reactant.ArrayToConcrete) + cst = Reactant.make_tracer(IdDict(), st, (), Reactant.ArrayToConcrete) + cnoisy = Reactant.ConcreteRArray(noisy) + + f = Reactant.compile((a, b, c, d) -> first(a(b, c, d)), (cmodel, cnoisy, cps, cst)) + + # # using InteractiveUtils + # # @show @code_typed f(cmodel,cnoisy) + # # @show @code_llvm f(cmodel,cnoisy) + comp = f(cmodel, cnoisy, cps, cst) + @show comp[3] + @btime f($cmodel, $cnoisy, $cps, $cst) # 4.430 μs (5 allocations: 160 bytes) + + # To train the model, we use batches of 64 samples, and one-hot encoding: + + target = onehotbatch(truth, [true, false]) # 2×1000 OneHotMatrix + ctarget = Reactant.ConcreteRArray(Array{Float32}(target)) + loader = DataLoader((noisy, target); batchsize=64, shuffle=true); + # # 16-element DataLoader with first element: (2×64 Matrix{Float32}, 2×64 OneHotMatrix) + + opt = Optimisers.Adam(0.01f0) + losses = [] + + # Lux.Exprimental.TrainState is very specialized for Lux models, so we write out the + # training loop manually: + function crossentropy(ŷ, y) + logŷ = log.(ŷ) + result = y .* logŷ + # result = ifelse.(y .== 0.0f0, zero.(result), result) + return -sum(result) + end + + function loss_function(model, x, y, ps, st) + y_hat, _ = model(x, ps, st) + return crossentropy(y_hat, y) + end + + function gradient_loss_function(model, x, y, ps, st) + dps = Enzyme.make_zero(ps) + _, res = Enzyme.autodiff( + ReverseWithPrimal, + loss_function, + Active, + Const(model), + Const(x), + Const(y), + Duplicated(ps, dps), + Const(st), + ) + return res, dps + end + + gradient_loss_function(model, noisy, target, ps, st) + + compiled_gradient = Reactant.compile( + gradient_loss_function, (cmodel, cnoisy, ctarget, cps, cst) + ) +end diff --git a/test/nn.jl b/test/nn.jl deleted file mode 100644 index a65e7f9e5..000000000 --- a/test/nn.jl +++ /dev/null @@ -1,71 +0,0 @@ -# This will prompt if neccessary to install everything, including CUDA: - -using Reactant -using Flux -using Test -# Generate some data for the XOR problem: vectors of length 2, as columns of a matrix: -noisy = rand(Float32, 2, 1000) # 2×1000 Matrix{Float32} -truth = [xor(col[1] > 0.5, col[2] > 0.5) for col in eachcol(noisy)] # 1000-element Vector{Bool} - -# Define our model, a multi-layer perceptron with one hidden layer of size 3: -model = Chain( - Dense(2 => 3, tanh), # activation function inside layer - BatchNorm(3), - Dense(3 => 2), - softmax, -) - -using BenchmarkTools - -origout = model(noisy) -@show origout[3] -@btime model(noisy) - -cmodel = Reactant.make_tracer(IdDict(), model, (), Reactant.ArrayToConcrete) -cnoisy = Reactant.ConcreteRArray(noisy) - -# c_o = cmodel(noisy) -# @show c_o[3] -# @btime cmodel(noisy) -# -# o_c = model(cnoisy) -# @show o_c[3] -# @btime model(cnoisy) -# -# c_c = cmodel(cnoisy) -# @show c_c[3] -# @btime cmodel(cnoisy) -f = Reactant.compile((a, b) -> a(b), (cmodel, cnoisy)) - -# using InteractiveUtils -# @show @code_typed f(cmodel,cnoisy) -# @show @code_llvm f(cmodel,cnoisy) -comp = f(cmodel, cnoisy) -@show comp[3] -@btime f(cmodel, cnoisy) - -# To train the model, we use batches of 64 samples, and one-hot encoding: -target = Flux.onehotbatch(truth, [true, false]) # 2×1000 OneHotMatrix -loader = Flux.DataLoader((noisy, target); batchsize=64, shuffle=true); -# 16-element DataLoader with first element: (2×64 Matrix{Float32}, 2×64 OneHotMatrix) - -optim = Flux.setup(Flux.Adam(0.01), model) # will store optimiser momentum, etc. - -# Training loop, using the whole data set 1000 times: -losses = [] -for epoch in 1:1_000 - for (x, y) in loader - loss, grads = Flux.withgradient(model) do m - # Evaluate model and loss inside gradient context: - y_hat = m(x) - return Flux.crossentropy(y_hat, y) - end - Flux.update!(optim, model, grads[1]) - push!(losses, loss) # logging, outside gradient context - end -end - -optim # parameters, momenta and output have all changed -out2 = model(noisy) # first row is prob. of true, second row p(false) - -mean((out2[1, :] .> 0.5) .== truth) # accuracy 94% so far! diff --git a/test/nn_lux.jl b/test/nn_lux.jl deleted file mode 100644 index e9bf1b204..000000000 --- a/test/nn_lux.jl +++ /dev/null @@ -1,101 +0,0 @@ -using Reactant, Lux, Random, Statistics -using Enzyme -using Test - -# Generate some data for the XOR problem: vectors of length 2, as columns of a matrix: -noisy = rand(Float32, 2, 1000) # 2×1000 Matrix{Float32} -truth = [xor(col[1] > 0.5, col[2] > 0.5) for col in eachcol(noisy)] # 1000-element Vector{Bool} - -# Define our model, a multi-layer perceptron with one hidden layer of size 3: -model = Lux.Chain( - Lux.Dense(2 => 3, tanh), # activation function inside layer - Lux.Dense(3 => 2), - softmax, -) -ps, st = Lux.setup(Xoshiro(123), model) - -using BenchmarkTools - -origout, _ = model(noisy, ps, st) -@show origout[3] -@btime model($noisy, $ps, $st) # 52.731 μs (10 allocations: 32.03 KiB) - -cmodel = Reactant.make_tracer(IdDict(), model, (), Reactant.ArrayToConcrete) -cps = Reactant.make_tracer(IdDict(), ps, (), Reactant.ArrayToConcrete) -cst = Reactant.make_tracer(IdDict(), st, (), Reactant.ArrayToConcrete) -cnoisy = Reactant.ConcreteRArray(noisy) - -f = Reactant.compile((a, b, c, d) -> first(a(b, c, d)), (cmodel, cnoisy, cps, cst)) - -# # using InteractiveUtils -# # @show @code_typed f(cmodel,cnoisy) -# # @show @code_llvm f(cmodel,cnoisy) -comp = f(cmodel, cnoisy, cps, cst) -@show comp[3] -@btime f($cmodel, $cnoisy, $cps, $cst) # 4.430 μs (5 allocations: 160 bytes) - -# To train the model, we use batches of 64 samples, and one-hot encoding: - -using MLUtils, OneHotArrays, Optimisers - -target = onehotbatch(truth, [true, false]) # 2×1000 OneHotMatrix -ctarget = Reactant.ConcreteRArray(Array{Float32}(target)) -loader = DataLoader((noisy, target); batchsize=64, shuffle=true); -# # 16-element DataLoader with first element: (2×64 Matrix{Float32}, 2×64 OneHotMatrix) - -opt = Optimisers.Adam(0.01f0) -losses = [] - -# Lux.Exprimental.TrainState is very specialized for Lux models, so we write out the -# training loop manually: -function crossentropy(ŷ, y) - logŷ = log.(ŷ) - result = y .* logŷ - # result = ifelse.(y .== 0.0f0, zero.(result), result) - return -sum(result) -end - -function loss_function(model, x, y, ps, st) - y_hat, _ = model(x, ps, st) - return crossentropy(y_hat, y) -end - -function gradient_loss_function(model, x, y, ps, st) - dps = Enzyme.make_zero(ps) - _, res = Enzyme.autodiff( - ReverseWithPrimal, - loss_function, - Active, - Const(model), - Const(x), - Const(y), - Duplicated(ps, dps), - Const(st), - ) - return res, dps -end - -gradient_loss_function(model, noisy, target, ps, st) - -compiled_gradient = Reactant.compile( - gradient_loss_function, (cmodel, cnoisy, ctarget, cps, cst) -) - -# # Training loop, using the whole data set 1000 times: -# losses = [] -# for epoch in 1:1_000 -# for (x, y) in loader -# loss, grads = Flux.withgradient(model) do m -# # Evaluate model and loss inside gradient context: -# y_hat = m(x) -# return Flux.crossentropy(y_hat, y) -# end -# Flux.update!(optim, model, grads[1]) -# push!(losses, loss) # logging, outside gradient context -# end -# end - -# optim # parameters, momenta and output have all changed -# out2 = model(noisy) # first row is prob. of true, second row p(false) - -# mean((out2[1, :] .> 0.5) .== truth) # accuracy 94% so far! diff --git a/test/runtests.jl b/test/runtests.jl index 77ecc0637..d758862c3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,4 @@ -using Reactant +using Reactant, InteractiveUtils, Hwloc, ReTestItems, Test # parse some command-line arguments function extract_flag!(args, flag, default=nothing; typ=typeof(default)) @@ -35,15 +35,15 @@ end do_gpu_list, gpu_list = extract_flag!(ARGS, "--gpu") if do_gpu_list - Reactant.set_default_backend("gpu") - # TODO set which gpu + Reactant.set_default_backend("gpu") # TODO set which gpu end -include("layout.jl") -include("basic.jl") -include("bcast.jl") -include("nn.jl") -include("struct.jl") -include("closure.jl") -include("compile.jl") -include("nn_lux.jl") +@info sprint(io -> versioninfo(io; verbose=true)) + +const RETESTITEMS_NWORKERS = parse( + Int, get(ENV, "RETESTITEMS_NWORKERS", string(min(Hwloc.num_physical_cores(), 16))) +) + +@testset "Reactant" begin + ReTestItems.runtests(@__DIR__; nworkers=RETESTITEMS_NWORKERS, testitem_timeout=3600) +end diff --git a/test/struct.jl b/test/struct.jl deleted file mode 100644 index a7668bf82..000000000 --- a/test/struct.jl +++ /dev/null @@ -1,96 +0,0 @@ -using Reactant -using Test - -# from bsc-quantic/Tenet.jl -struct MockTensor{T,N,A<:AbstractArray{T,N}} - data::A - inds::Vector{Symbol} -end - -MockTensor(data::A, inds) where {T,N,A<:AbstractArray{T,N}} = MockTensor{T,N,A}(data, inds) -Base.parent(t::MockTensor) = t.data - -Base.cos(x::MockTensor) = MockTensor(cos(parent(x)), x.inds) - -mutable struct MutableMockTensor{T,N,A<:AbstractArray{T,N}} - data::A - inds::Vector{Symbol} -end - -function MutableMockTensor(data::A, inds) where {T,N,A<:AbstractArray{T,N}} - return MutableMockTensor{T,N,A}(data, inds) -end -Base.parent(t::MutableMockTensor) = t.data - -Base.cos(x::MutableMockTensor) = MutableMockTensor(cos(parent(x)), x.inds) - -# modified from JuliaCollections/DataStructures.jl -# NOTE original uses abstract type instead of union, which is not supported -mutable struct MockLinkedList{T} - head::T - tail::Union{MockLinkedList{T},Nothing} -end - -function list(x::T...) where {T} - l = nothing - for i in Iterators.reverse(eachindex(x)) - l = MockLinkedList{T}(x[i], l) - end - return l -end - -function Base.sum(x::MockLinkedList{T}) where {T} - if isnothing(x.tail) - return sum(x.head) - else - return sum(x.head) + sum(x.tail) - end -end - -@testset "Struct" begin - @testset "MockTensor" begin - @testset "immutable" begin - x = MockTensor(rand(4, 4), [:i, :j]) - x2 = MockTensor(Reactant.ConcreteRArray(parent(x)), x.inds) - - f = Reactant.compile(cos, (x2,)) - y = f(x2) - - @test y isa MockTensor{Float64,2,Reactant.ConcreteRArray{Float64,(4, 4),2}} - @test isapprox(parent(y), cos.(parent(x))) - @test x.inds == [:i, :j] - end - - @testset "mutable" begin - x = MutableMockTensor(rand(4, 4), [:i, :j]) - x2 = MutableMockTensor(Reactant.ConcreteRArray(parent(x)), x.inds) - - f = Reactant.compile(cos, (x2,)) - y = f(x2) - - @test y isa - MutableMockTensor{Float64,2,Reactant.ConcreteRArray{Float64,(4, 4),2}} - @test isapprox(parent(y), cos.(parent(x))) - @test x.inds == [:i, :j] - end - end - - @testset "MockLinkedList" begin - x = [rand(2, 2) for _ in 1:2] - x2 = list(x...) - x3 = Reactant.make_tracer(IdDict(), x2, (), Reactant.ArrayToConcrete) - - # TODO this should be able to run without problems, but crashes - @test_broken begin - f = Reactant.compile(identity, (x3,)) - isapprox(f(x3), x3) - end - - f = Reactant.compile(sum, (x3,)) - - y = sum(x2) - y3 = f(x3) - - @test isapprox(y, only(y3)) - end -end diff --git a/test/struct_tests.jl b/test/struct_tests.jl new file mode 100644 index 000000000..edcf57f8b --- /dev/null +++ b/test/struct_tests.jl @@ -0,0 +1,95 @@ +@testitem "struct" begin + # from bsc-quantic/Tenet.jl + struct MockTensor{T,N,A<:AbstractArray{T,N}} + data::A + inds::Vector{Symbol} + end + + MockTensor(data::A, inds) where {T,N,A<:AbstractArray{T,N}} = MockTensor{T,N,A}(data, inds) + Base.parent(t::MockTensor) = t.data + + Base.cos(x::MockTensor) = MockTensor(cos(parent(x)), x.inds) + + mutable struct MutableMockTensor{T,N,A<:AbstractArray{T,N}} + data::A + inds::Vector{Symbol} + end + + function MutableMockTensor(data::A, inds) where {T,N,A<:AbstractArray{T,N}} + return MutableMockTensor{T,N,A}(data, inds) + end + Base.parent(t::MutableMockTensor) = t.data + + Base.cos(x::MutableMockTensor) = MutableMockTensor(cos(parent(x)), x.inds) + + # modified from JuliaCollections/DataStructures.jl + # NOTE original uses abstract type instead of union, which is not supported + mutable struct MockLinkedList{T} + head::T + tail::Union{MockLinkedList{T},Nothing} + end + + function list(x::T...) where {T} + l = nothing + for i in Iterators.reverse(eachindex(x)) + l = MockLinkedList{T}(x[i], l) + end + return l + end + + function Base.sum(x::MockLinkedList{T}) where {T} + if isnothing(x.tail) + return sum(x.head) + else + return sum(x.head) + sum(x.tail) + end + end + + @testset "Struct" begin + @testset "MockTensor" begin + @testset "immutable" begin + x = MockTensor(rand(4, 4), [:i, :j]) + x2 = MockTensor(Reactant.ConcreteRArray(parent(x)), x.inds) + + f = Reactant.compile(cos, (x2,)) + y = f(x2) + + @test y isa MockTensor{Float64,2,Reactant.ConcreteRArray{Float64,(4, 4),2}} + @test isapprox(parent(y), cos.(parent(x))) + @test x.inds == [:i, :j] + end + + @testset "mutable" begin + x = MutableMockTensor(rand(4, 4), [:i, :j]) + x2 = MutableMockTensor(Reactant.ConcreteRArray(parent(x)), x.inds) + + f = Reactant.compile(cos, (x2,)) + y = f(x2) + + @test y isa + MutableMockTensor{Float64,2,Reactant.ConcreteRArray{Float64,(4, 4),2}} + @test isapprox(parent(y), cos.(parent(x))) + @test x.inds == [:i, :j] + end + end + + @testset "MockLinkedList" begin + x = [rand(2, 2) for _ in 1:2] + x2 = list(x...) + x3 = Reactant.make_tracer(IdDict(), x2, (), Reactant.ArrayToConcrete) + + # TODO this should be able to run without problems, but crashes + @test_broken begin + f = Reactant.compile(identity, (x3,)) + isapprox(f(x3), x3) + end + + f = Reactant.compile(sum, (x3,)) + + y = sum(x2) + y3 = f(x3) + + @test isapprox(y, only(y3)) + end + end +end \ No newline at end of file