Skip to content

Commit

Permalink
fix: bad rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 18, 2024
1 parent 803a43e commit 33a2d75
Show file tree
Hide file tree
Showing 11 changed files with 8 additions and 302 deletions.
4 changes: 0 additions & 4 deletions ext/LuxMPIExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
module LuxMPIExt

using Lux: MPIBackend, NCCLBackend, DistributedUtils, __unwrap_val, MPI_CUDA_AWARE,
MPI_ROCM_AWARE
using MLDataDevices: AbstractDevice, CUDADevice, AMDGPUDevice, cpu_device, set_device!,
functional
using MPI: MPI

using Lux: Lux, MPIBackend, NCCLBackend, DistributedUtils, Utils
Expand Down
5 changes: 0 additions & 5 deletions ext/LuxReverseDiffExt/LuxReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@ using MLDataDevices: CPUDevice
using ReverseDiff: ReverseDiff, ForwardExecutor, ReverseExecutor, TrackedArray, TrackedReal,
@grad_from_chainrules

using Lux: Lux, Utils
using Lux.Training: TrainingBackendCache, TrainState
using LuxCore: LuxCore
using MLDataDevices: CPUDevice

include("utils.jl")
include("rules.jl")
include("training.jl")
Expand Down
18 changes: 0 additions & 18 deletions ext/LuxTrackerExt/LuxTrackerExt.jl

This file was deleted.

21 changes: 0 additions & 21 deletions ext/LuxTrackerExt/rules.jl

This file was deleted.

23 changes: 0 additions & 23 deletions ext/LuxTrackerExt/training.jl

This file was deleted.

18 changes: 0 additions & 18 deletions ext/LuxTrackerExt/utils.jl

This file was deleted.

2 changes: 0 additions & 2 deletions ext/LuxZygoteExt/LuxZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ using ArgCheck: @argcheck
using ADTypes: AutoZygote
using ChainRulesCore: ChainRulesCore
using ForwardDiff: ForwardDiff
using Lux: Lux
using MLDataDevices: get_device_type, CPUDevice
using Setfield: @set!
using Zygote: Zygote

Expand Down
2 changes: 1 addition & 1 deletion src/contrib/map.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ true
```
"""
macro layer_map(f, l, ps, st)
quote
return quote
layer_map($(esc(f)), $(esc(l)), $(esc(ps)), $(esc(st)), $(Meta.quot(l)))
end
end
Expand Down
203 changes: 0 additions & 203 deletions src/helpers/nested_ad.jl

This file was deleted.

12 changes: 6 additions & 6 deletions test/contrib/share_parameters_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
@test ps_1.d3.bias == ps_1.d2.l1.bias

ps_new_1 = (; weight=randn(rng, Float32, 4, 2), bias=randn(rng, Float32, 4)) |>
device
dev
ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2)) |>
device
dev

ps_2 = Lux.Experimental.share_parameters(ps, sharing, (ps_new_1, ps_new_2))

Expand All @@ -29,7 +29,7 @@
@test ps_2.d3.bias == ps_new_2.bias == ps_2.d2.l1.bias

# Mix in ComponentArray
ps_new_ca_1 = ComponentArray(ps_new_1 |> CPUDevice()) |> device
ps_new_ca_1 = ComponentArray(ps_new_1 |> CPUDevice()) |> dev

ps_3 = Lux.Experimental.share_parameters(ps, sharing, (ps_new_ca_1, ps_new_2))

Expand All @@ -47,14 +47,14 @@

# Parameter Structure Mismatch
ps_new_1 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 4)) |>
device
dev
ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2)) |>
device
dev

@test_throws ArgumentError Lux.Experimental.share_parameters(
ps, sharing, (ps_new_1, ps_new_2))

ps_new_ca_1 = ComponentArray(ps_new_1 |> CPUDevice()) |> device
ps_new_ca_1 = ComponentArray(ps_new_1 |> CPUDevice()) |> dev

@test_throws ArgumentError Lux.Experimental.share_parameters(
ps, sharing, (ps_new_ca_1, ps_new_2))
Expand Down
2 changes: 1 addition & 1 deletion test/utils_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ end
@testitem "FP Conversions" setup=[SharedTestSetup] tags=[:others] begin
rng = StableRNG(12345)

@testset "$mode" for (mode, aType, device, ongpu) in MODES
@testset "$mode" for (mode, aType, dev, ongpu) in MODES
model = Chain(
Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), BatchNorm(1))

Expand Down

0 comments on commit 33a2d75

Please sign in to comment.