diff --git a/Project.toml b/Project.toml index b0c81f2698..29c71f9318 100644 --- a/Project.toml +++ b/Project.toml @@ -106,7 +106,7 @@ MPI = "0.20.19" MacroTools = "0.5.13" Markdown = "1.10" NCCL = "0.1.1" -NNlib = "0.9.24" +NNlib = "0.9.26" Optimisers = "0.4.1" Preferences = "1.4.3" Random = "1.10" diff --git a/docs/Project.toml b/docs/Project.toml index 0ba60d55ec..0420577f80 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -51,7 +51,7 @@ LuxCore = "1.2" LuxLib = "1.3.4" LuxTestUtils = "1.5" MLDataDevices = "1.6" -NNlib = "0.9.24" +NNlib = "0.9.26" Optimisers = "0.4.1" Pkg = "1.10" Printf = "1.10" diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 27d1562ec8..acb9f2ec12 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -40,7 +40,7 @@ EnzymeCore = "0.8.6" Functors = "0.5" MLDataDevices = "1.6" Random = "1.10" -Reactant = "0.2.11" +Reactant = "0.2.6" ReverseDiff = "1.15" Setfield = "1" Tracker = "0.2.36" diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 9966f52fa5..5c22c0be3f 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -77,7 +77,7 @@ LuxCore = "1.2" MKL = "0.7" MLDataDevices = "1.6" Markdown = "1.10" -NNlib = "0.9.24" +NNlib = "0.9.26" Octavian = "0.3.28" Preferences = "1.4.3" Polyester = "0.7.15" diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index 403bc57fb5..df34c29520 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -49,7 +49,7 @@ LoopVectorization = "0.12.171" LuxTestUtils = "1.5" MKL = "0.7" MLDataDevices = "1.6" -NNlib = "0.9.21" +NNlib = "0.9.26" Octavian = "0.3.28" Pkg = "1.10" Random = "1.10" diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index eef790884e..2bc4613633 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -66,7 +66,7 @@ Metal = "1" OneHotArrays = "0.2.5" Preferences = "1.4" Random = "1.10" -Reactant = "0.2.11" +Reactant = "0.2.6" RecursiveArrayTools = "3.8" ReverseDiff = "1.15" SparseArrays = "1.10" diff --git a/src/helpers/training.jl b/src/helpers/training.jl index e5bbee3959..c11f74b93f 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -11,7 +11,7 @@ using Static: StaticBool, Static, False, True using ..Lux: Lux, Utils, ReactantCompatibleOptimisers using LuxCore: LuxCore, AbstractLuxLayer -using MLDataDevices: MLDataDevices, ReactantDevice, get_device_type, cpu_device +using MLDataDevices: MLDataDevices, ReactantDevice, get_device_type """ TrainState diff --git a/test/Project.toml b/test/Project.toml index 58dd94c2ee..7f9cb93e5c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -62,7 +62,7 @@ LuxLib = "1.3.4" LuxTestUtils = "1.5" MLDataDevices = "1.6" MLUtils = "0.4.3" -NNlib = "0.9.24" +NNlib = "0.9.26" Octavian = "0.3.28" OneHotArrays = "0.2.5" Optimisers = "0.4.1" diff --git a/test/helpers/loss_tests.jl b/test/helpers/loss_tests.jl index ba79343140..3c0113b5c0 100644 --- a/test/helpers/loss_tests.jl +++ b/test/helpers/loss_tests.jl @@ -46,12 +46,12 @@ @testset "$mode" for (mode, aType, dev, ongpu) in MODES x = rand(10) |> aType - __f = sum ∘ Broadcast.BroadcastFunction(LuxOps.xlogx) - @test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail=[AutoFiniteDiff()]) + @test_gradients(sum∘Broadcast.BroadcastFunction(LuxOps.xlogx), + x; atol=1.0f-3, rtol=1.0f-3, soft_fail=[AutoFiniteDiff()]) y = rand(10) |> aType - __f = sum ∘ Broadcast.BroadcastFunction(LuxOps.xlogy) - @test_gradients(__f, x, y; atol=1.0f-3, rtol=1.0f-3, soft_fail=[AutoFiniteDiff()]) + @test_gradients(sum∘Broadcast.BroadcastFunction(LuxOps.xlogy), + x, y; atol=1.0f-3, rtol=1.0f-3, soft_fail=[AutoFiniteDiff()]) end end @@ -79,8 +79,7 @@ end @jet loss_mean(ŷ, y) @jet loss_sum(ŷ, y) - __f = Base.Fix2(loss_mean, y) - @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(Base.Fix2(loss_mean, y), ŷ; atol=1.0f-3, rtol=1.0f-3) end @testset "MSLE" begin @@ -93,8 +92,7 @@ end @test @inferred(Zygote.gradient(MSLELoss(), ŷ, y)) isa Any broken=ongpu - __f = Base.Fix2(MSLELoss(), y) - @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(Base.Fix2(MSLELoss(), y), ŷ; atol=1.0f-3, rtol=1.0f-3) end end end @@ -203,9 +201,8 @@ end @test @inferred(Zygote.gradient(bceloss, σ.(logŷ), y)) isa Any - __f = Base.Fix2(bceloss, y) - σlogŷ = σ.(logŷ) - @test_gradients(__f, σlogŷ; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(Base.Fix2(bceloss, y), σ.(logŷ); atol=1.0f-3, rtol=1.0f-3, + enzyme_set_runtime_activity=true) end @testset "Logit BinaryCrossEntropyLoss" begin @@ -225,8 +222,8 @@ end @test @inferred(Zygote.gradient(logitbceloss, logŷ, y)) isa Any - __f = Base.Fix2(logitbceloss, y) - @test_gradients(__f, logŷ; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(Base.Fix2(logitbceloss, y), logŷ; atol=1.0f-3, rtol=1.0f-3, + enzyme_set_runtime_activity=true) end @testset "BinaryFocalLoss" begin @@ -248,8 +245,7 @@ end @test @inferred(Zygote.gradient(BinaryFocalLoss(), ŷ, y)) isa Any broken=ongpu - __f = Base.Fix2(BinaryFocalLoss(), y) - @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(Base.Fix2(BinaryFocalLoss(), y), ŷ; atol=1.0f-3, rtol=1.0f-3) end @testset "FocalLoss" begin diff --git a/test/layers/normalize_tests.jl b/test/layers/normalize_tests.jl index 6b545e15b6..ec3704d90e 100644 --- a/test/layers/normalize_tests.jl +++ b/test/layers/normalize_tests.jl @@ -56,7 +56,7 @@ @jet m(x, ps, Lux.testmode(st)) @test_gradients(sumabs2first, m, x, ps, st; atol=1.0f-3, - rtol=1.0f-3, skip_backends=[AutoFiniteDiff()], broken_backends) + rtol=1.0f-3, skip_backends=[AutoFiniteDiff()]) # with activation function m = BatchNorm(2, sigmoid; affine) diff --git a/test/runtests.jl b/test/runtests.jl index 0f96e8b49f..91db71bcb2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -127,9 +127,7 @@ const RETESTITEMS_NWORKER_THREADS = parse( string(max(Hwloc.num_virtual_cores() ÷ RETESTITEMS_NWORKERS, 1)))) @testset "Lux.jl Tests" begin - for (i, tag) in enumerate(LUX_TEST_GROUP) - @info "Running tests for group: [$(i)/$(length(LUX_TEST_GROUP))] $tag" - + @testset "[$(tag)] [$(i)/$(length(LUX_TEST_GROUP))]" for (i, tag) in enumerate(LUX_TEST_GROUP) nworkers = (tag == "reactant") || (BACKEND_GROUP == "amdgpu") ? 0 : RETESTITEMS_NWORKERS diff --git a/test/setup_modes.jl b/test/setup_modes.jl index 1617179a5b..b7c581ccca 100644 --- a/test/setup_modes.jl +++ b/test/setup_modes.jl @@ -1,4 +1,4 @@ -using Lux, MLDataDevices +using Lux, MLDataDevices, Pkg if !@isdefined(BACKEND_GROUP) const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "all"))