Skip to content

Commit

Permalink
feat: update LuxCore to latest Functors
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 13, 2024
1 parent 44750b6 commit 854917b
Show file tree
Hide file tree
Showing 18 changed files with 41 additions and 35 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.2.3"
version = "1.3.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -73,7 +73,7 @@ ArrayInterface = "7.10"
CUDA = "5.3.2"
ChainRulesCore = "1.24"
Compat = "4.15"
ComponentArrays = "0.15.16"
ComponentArrays = "0.15.18"
ConcreteStructs = "0.2.3"
DispatchDoctor = "0.4.12"
Enzyme = "0.13.13"
Expand All @@ -82,7 +82,7 @@ FastClosures = "0.3.2"
Flux = "0.14.25"
ForwardDiff = "0.10.36"
FunctionWrappers = "1.1.3"
Functors = "0.4.12"
Functors = "0.5"
GPUArraysCore = "0.1.6, 0.2"
LinearAlgebra = "1.10"
LossFunctions = "0.11.1"
Expand Down
4 changes: 2 additions & 2 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ADTypes = "1.3"
Adapt = "4"
ChainRulesCore = "1.24"
ComponentArrays = "0.15"
ComponentArrays = "0.15.18"
Documenter = "1.4"
DocumenterVitepress = "0.1.3"
Enzyme = "0.13.13"
FiniteDiff = "2.23.1"
ForwardDiff = "0.10.36"
Functors = "0.4.12"
Functors = "0.5"
GPUArraysCore = "0.1, 0.2"
KernelAbstractions = "0.9"
LinearAlgebra = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion examples/Basics/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ComponentArrays = "0.15"
ComponentArrays = "0.15.18"
ForwardDiff = "0.10"
Lux = "1"
LuxCUDA = "0.3"
Expand Down
2 changes: 1 addition & 1 deletion examples/BayesianNN/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
CairoMakie = "0.12"
Functors = "0.4"
Functors = "0.5"
LinearAlgebra = "1"
Lux = "1"
Random = "1"
Expand Down
2 changes: 1 addition & 1 deletion examples/GravitationalWaveForm/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"

[compat]
CairoMakie = "0.12"
ComponentArrays = "0.15"
ComponentArrays = "0.15.18"
LineSearches = "7"
Lux = "1"
Optimization = "4"
Expand Down
2 changes: 1 addition & 1 deletion examples/HyperNet/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "1"
ComponentArrays = "0.15"
ComponentArrays = "0.15.18"
Lux = "1"
LuxCUDA = "0.3"
MLDatasets = "0.7"
Expand Down
2 changes: 1 addition & 1 deletion examples/NeuralODE/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ComponentArrays = "0.15"
ComponentArrays = "0.15.18"
Lux = "1"
LuxCUDA = "0.3"
MLDatasets = "0.7"
Expand Down
2 changes: 1 addition & 1 deletion examples/OptimizationIntegration/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"

[compat]
CairoMakie = "0.12.10"
ComponentArrays = "0.15.17"
ComponentArrays = "0.15.18"
Lux = "1"
LuxCUDA = "0.3.3"
MLUtils = "0.4.4"
Expand Down
6 changes: 3 additions & 3 deletions lib/LuxCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxCore"
uuid = "bb33d45b-7691-41d6-9220-0943567d0623"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.1.1"
version = "1.2.0"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down Expand Up @@ -35,8 +35,8 @@ ChainRulesCore = "1.24"
Compat = "4.15.0"
DispatchDoctor = "0.4.10"
EnzymeCore = "0.8.5"
Functors = "0.4.12"
MLDataDevices = "1"
Functors = "0.5"
MLDataDevices = "1.5"
Random = "1.10"
Reactant = "0.2.4"
ReverseDiff = "1.15"
Expand Down
11 changes: 7 additions & 4 deletions lib/LuxCore/ext/LuxCoreFunctorsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@ function LuxCore.Internal.fmap_with_path_impl(args...; kwargs...)
end
LuxCore.Internal.fleaves_impl(args...; kwargs...) = Functors.fleaves(args...; kwargs...)

function Functors.functor(::Type{<:LuxCore.AbstractLuxContainerLayer{layers}},
x) where {layers}
function Functors.functor(::Type{<:LuxCore.AbstractLuxLayer}, x)
return Functors.NoChildren(), Returns(x)
end

function Functors.functor(
::Type{<:LuxCore.AbstractLuxContainerLayer{layers}}, x) where {layers}
children = NamedTuple{layers}(getproperty.((x,), layers))
layer_reconstructor = let x = x, layers = layers
z -> reduce(LuxCore.Internal.setfield, zip(layers, z); init=x)
end
return children, layer_reconstructor
end

function Functors.functor(::Type{<:LuxCore.AbstractLuxWrapperLayer{layer}},
x) where {layer}
function Functors.functor(::Type{<:LuxCore.AbstractLuxWrapperLayer{layer}}, x) where {layer}
children = NamedTuple{(layer,)}((getproperty(x, layer),))
layer_reconstructor = let x = x, layer = layer
z -> LuxCore.Internal.setfield(x, layer, getproperty(z, layer))
Expand Down
4 changes: 2 additions & 2 deletions lib/LuxCore/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Aqua = "0.8.7"
EnzymeCore = "0.8.5"
ExplicitImports = "1.9.0"
Functors = "0.4.12"
MLDataDevices = "1.0.0"
Functors = "0.5"
MLDataDevices = "1.5"
Optimisers = "0.3.3, 0.4"
Random = "1.10"
Test = "1.10"
7 changes: 5 additions & 2 deletions lib/LuxCore/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,11 @@ end
end

@testset "Convenience Checks" begin
models1 = [Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))),
Chain2(Dense(5, 10), Dense(10, 5)), [Dense(5, 10), Dense(10, 5)]]
models1 = [
Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))),
Chain2(Dense(5, 10), Dense(10, 5)),
[Dense(5, 10), Dense(10, 5)]
]

@test LuxCore.contains_lux_layer(models1)

Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Aqua = "0.8.7"
BLISBLAS = "0.1"
BenchmarkTools = "1.5"
ChainRulesCore = "1.24"
ComponentArrays = "0.15.16"
ComponentArrays = "0.15.18"
Enzyme = "0.13.13"
EnzymeCore = "0.8.5"
ExplicitImports = "1.9.0"
Expand Down
8 changes: 4 additions & 4 deletions lib/LuxTestUtils/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxTestUtils"
uuid = "ac9de150-d08f-4546-94fb-7472b5760531"
authors = ["Avik Pal <[email protected]>"]
version = "1.5.0"
version = "1.6.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -24,14 +24,14 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ADTypes = "1.8.1"
ArrayInterface = "7.9"
ChainRulesCore = "1.24.0"
ComponentArrays = "0.15.14"
ComponentArrays = "0.15.18"
DispatchDoctor = "0.4.12"
Enzyme = "0.13.13"
FiniteDiff = "2.23.1"
ForwardDiff = "0.10.36"
Functors = "0.4.11"
Functors = "0.5"
JET = "0.9.6"
MLDataDevices = "1.0.0"
MLDataDevices = "1.5"
ReverseDiff = "1.15.3"
Test = "1.10"
Tracker = "0.2.34"
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxTestUtils/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
CUDA = "5"
ComponentArrays = "0.15"
ComponentArrays = "0.15.18"
Hwloc = "3"
InteractiveUtils = "<0.0.1, 1"
MetaTesting = "0.1"
Expand Down
4 changes: 2 additions & 2 deletions lib/MLDataDevices/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.5.3"
version = "1.6.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -56,7 +56,7 @@ ChainRules = "1.51"
ChainRulesCore = "1.23"
Compat = "4.15"
FillArrays = "1"
Functors = "0.4.8"
Functors = "0.5"
GPUArrays = "10, 11"
MLUtils = "0.4.4"
Metal = "1"
Expand Down
6 changes: 3 additions & 3 deletions lib/MLDataDevices/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ Adapt = "4"
Aqua = "0.8.4"
ArrayInterface = "7.11"
ChainRulesTestUtils = "1.13.0"
ComponentArrays = "0.15.8"
ComponentArrays = "0.15.18"
ExplicitImports = "1.9.0"
FillArrays = "1"
ForwardDiff = "0.10.36"
Functors = "0.4.8"
MLUtils = "0.4"
Functors = "0.5"
MLUtils = "0.4.4"
OneHotArrays = "0.2.5"
Pkg = "1.10"
Random = "1.10"
Expand Down
4 changes: 2 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ ADTypes = "1.8.1"
Adapt = "4"
Aqua = "0.8.4"
ChainRulesCore = "1.24"
ComponentArrays = "0.15.16"
ComponentArrays = "0.15.18"
DispatchDoctor = "0.4.12"
Documenter = "1.4"
Enzyme = "0.13.13"
ExplicitImports = "1.9.0"
ForwardDiff = "0.10.36"
Functors = "0.4.12"
Functors = "0.5"
Hwloc = "3.2.0"
InteractiveUtils = "<0.0.1, 1"
LinearAlgebra = "1.10"
Expand Down

0 comments on commit 854917b

Please sign in to comment.