Skip to content

Commit

Permalink
Merge pull request #60 from SciML/ap/paper
Browse files Browse the repository at this point in the history
Experiments for the Updated Manuscript
  • Loading branch information
avik-pal authored Jan 10, 2023
2 parents 1ab3691 + d18c610 commit 6f5a6b7
Show file tree
Hide file tree
Showing 24 changed files with 51,314 additions and 7 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ build
statprof
profs
logs
benchmarking
benchmarking
*/tensorflow_datasets/
checkpoints
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DeepEquilibriumNetworks"
uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
authors = ["Avik Pal <[email protected]>"]
version = "0.2.3"
version = "0.2.4"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
4 changes: 2 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Documenter, DocumenterCitations, DeepEquilibriumNetworks

cp("./docs/Manifest.toml", "./docs/src/assets/Manifest.toml", force = true)
cp("./docs/Project.toml", "./docs/src/assets/Project.toml", force = true)
cp("./docs/Manifest.toml", "./docs/src/assets/Manifest.toml"; force=true)
cp("./docs/Project.toml", "./docs/src/assets/Project.toml"; force=true)

bib = CitationBibliography(joinpath(@__DIR__, "ref.bib"); sorting=:nyt)

Expand Down
40 changes: 40 additions & 0 deletions experiments/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name = "DEQExperiments"
uuid = "6748aba7-04a0-415e-a410-ae3cc0ecb334"
authors = ["Avik Pal <[email protected]>"]
version = "0.1.0"

[deps]
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
Augmentor = "02898b10-1f73-11ea-317c-6393d7073e15"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DeepEquilibriumNetworks = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
FluxMPI = "acf642fa-ee0e-513a-b9d5-bcd150f7ec3b"
Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
JLSO = "9da8a3cd-07a3-59c0-a743-3fdc52c30d11"
JpegTurbo = "b835a17e-a41a-41e7-81f0-2f016b05efe0"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SimpleConfig = "f2d95530-262a-480f-aff0-1c0431e662a7"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Wandb = "ad70616a-06c9-5745-b1f1-6a5f42545108"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
MLUtils = "0.2.10"
74 changes: 74 additions & 0 deletions experiments/cifar10/large.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
seed: 0 # Always control using command-line arguments

model:
num_classes: 10
dropout_rate: 0.3
group_count: 8
weight_norm: true
downsample_times: 0
expansion_factor: 4
image_size:
- 32
- 32
num_branches: 4
big_kernels:
- 0
- 0
- 0
- 0
head_channels:
- 14
- 28
- 56
- 112
num_channels:
- 32
- 64
- 128
- 256
fuse_method: "sum"
final_channelsize: 1680
model_type: "vanilla"
maxiters: 18
in_channels: 3
sensealg:
jfb: false
abstol: 5.0e-2
reltol: 5.0e-2
maxiters: 20
solver:
continuous: true
abstol: 5.0e-2
reltol: 5.0e-2
ode_solver: "vcab3"
stop_mode: "rel_deq_best"
abstol_termination: 5.0e-2
reltol_termination: 5.0e-2

optimizer:
lr_scheduler: "cosine"
optimizer: "adam"
learning_rate: 0.001
nesterov: false
momentum: 0.0
weight_decay: 0.0000
cycle_length: 90000

dataset:
augment: true
data_root: "data/cifar10"
eval_batchsize: 128
train_batchsize: 128

train:
total_steps: 90000
pretrain_steps: 0
evaluate_every: 500
resume: ""
evaluate: false
checkpoint_dir: "checkpoints/"
log_dir: "logs/"
expt_subdir: "cifar10/large/"
expt_id: ""
print_frequency: 100
w_skip: 0.01
245 changes: 245 additions & 0 deletions experiments/cifar10/main.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
import CUDA, DEQExperiments, FluxMPI, Logging, Lux, OneHotArrays, Optimisers, PyCall,
Random, Setfield, SimpleConfig, Statistics, Wandb
import Lux.Training

# Dataloaders
function get_dataloaders(; augment, data_root, eval_batchsize, train_batchsize)
tf = PyCall.pyimport("tensorflow")
tfds = PyCall.pyimport("tensorflow_datasets")

tf.config.set_visible_devices([], "GPU")

ds_train, ds_test = tfds.load("cifar10"; split=["train", "test"], as_supervised=true,
data_dir=data_root)

image_mean = tf.constant([[[0.4914f0, 0.4822f0, 0.4465f0]]])
image_std = tf.constant([[[0.2023f0, 0.1994f0, 0.2010f0]]])

function normalize(img, label)
img = tf.cast(img, tf.float32) / 255.0f0
img = (img - image_mean) / image_std
return img, label
end

ds_train = ds_train.cache()
ds_test = ds_test.cache().map(normalize)
if augment
tf_rng = tf.random.Generator.from_seed(12345; alg="philox")
function augmentation(img, label)
seed = tf_rng.make_seeds(2)[1]

img, label = normalize(img, label)
img = tf.image.stateless_random_flip_left_right(img, seed)
img = tf.image.resize(img, (42, 42))
img = tf.image.stateless_random_crop(img, (32, 32, 3), seed)

return img, label
end
ds_train = ds_train.map(augmentation; num_parallel_calls=tf.data.AUTOTUNE)
else
ds_train = ds_train.map(normalize; num_parallel_calls=tf.data.AUTOTUNE)
end

if DEQExperiments.is_distributed()
ds_train = ds_train.shard(FluxMPI.total_worders(), FluxMPI.local_rank())
ds_test = ds_test.shard(FluxMPI.total_worders(), FluxMPI.local_rank())
end

ds_train = ds_train.prefetch(tf.data.AUTOTUNE).shuffle(1024).repeat(-1)
ds_test = ds_test.prefetch(tf.data.AUTOTUNE).repeat(1)

return (tfds.as_numpy(ds_train.batch(train_batchsize)),
tfds.as_numpy(ds_test.batch(eval_batchsize)))
end

function _data_postprocess(image, label)
return (Lux.gpu(permutedims(image, (3, 2, 4, 1))),
Lux.gpu(OneHotArrays.onehotbatch(label, 0:9)))
end

function main(filename, args)
cfg = SimpleConfig.define_configuration(args, DEQExperiments.ExperimentConfig, filename)

return main(splitext(basename(filename))[1], cfg)
end

function main(config_name::String, cfg::DEQExperiments.ExperimentConfig)
rng = Random.Xoshiro()
Random.seed!(rng, cfg.seed)

model = DEQExperiments.construct(cfg.model)

loss_function = DEQExperiments.get_loss_function(cfg)

opt, sched = DEQExperiments.construct(cfg.optimizer)

tstate = if cfg.model.model_type != "neural_ode"
Training.TrainState(rng, model, opt; transform_variables=Lux.gpu)
else
ps, st = Lux.setup(rng, model)
ps = ps |> Lux.ComponentArray |> Lux.gpu
st = st |> Lux.gpu
opt_state = Optimisers.setup(opt, ps)
Training.TrainState(model, ps, st, opt_state, 0)
end
vjp_rule = Training.ZygoteVJP()

DEQExperiments.warmup_model(loss_function, model, tstate.parameters, tstate.states, cfg;
transform_input=Lux.gpu)

ds_train, ds_test = get_dataloaders(; cfg.dataset.augment, cfg.dataset.data_root,
cfg.dataset.eval_batchsize,
cfg.dataset.train_batchsize)
_, ds_train_iter = iterate(ds_train)

# Setup
expt_name = ("config-$(config_name)_continuous-$(cfg.model.solver.continuous)" *
"_type-$(cfg.model.model_type)_seed-$(cfg.seed)" *
"_jfb-$(cfg.model.sensealg.jfb)_id-$(cfg.train.expt_id)")

ckpt_dir = joinpath(cfg.train.expt_subdir, cfg.train.checkpoint_dir, expt_name)
log_dir = joinpath(cfg.train.expt_subdir, cfg.train.log_dir, expt_name)
if cfg.train.resume == ""
rpath = joinpath(ckpt_dir, "model_current.jlso")
else
rpath = cfg.train.resume
end

ckpt = DEQExperiments.load_checkpoint(rpath)
if !isnothing(ckpt)
tstate = ckpt.tstate
initial_step = ckpt.step
DEQExperiments.should_log() && @info "Training Started from Step: $initial_step"
else
initial_step = 1
end

if cfg.train.pretrain_steps != 0
if DEQExperiments.should_log()
@info "Will pretrain for $(cfg.train.pretrain_steps) steps"
end
Setfield.@set! tstate.states = Lux.update_state(tstate.states, :fixed_depth, Val(5))
end

# Setup Logging
loggers = DEQExperiments.create_logger(log_dir, cfg.train.total_steps - initial_step,
cfg.train.total_steps - initial_step, expt_name,
SimpleConfig.flatten_configuration(cfg))

best_test_accuracy = 0

for step in initial_step:(cfg.train.total_steps)
# Train Step
t = time()
(x, y), ds_train_iter = iterate(ds_train, ds_train_iter)
x, y = _data_postprocess(x, y)
data_time = time() - t

bsize = size(x, ndims(x))

ret_val = DEQExperiments.run_training_step(vjp_rule, loss_function, (x, y), tstate)
loss, _, stats, tstate, gs, step_stats = ret_val
Setfield.@set! tstate.states = Lux.update_state(tstate.states, :update_mask, Val(true))

# LR Update
lr_new = sched(step + 1)
Setfield.@set! tstate.optimizer_state = Optimisers.adjust(tstate.optimizer_state,
lr_new)

accuracy = DEQExperiments.accuracy(Lux.cpu(stats.y_pred), Lux.cpu(y))
residual = abs(Statistics.mean(stats.residual))

# Logging
loggers.progress_loggers.train.avg_meters.batch_time(data_time +
step_stats.fwd_time +
step_stats.bwd_time +
step_stats.opt_time, bsize)
loggers.progress_loggers.train.avg_meters.data_time(data_time, bsize)
loggers.progress_loggers.train.avg_meters.fwd_time(step_stats.fwd_time, bsize)
loggers.progress_loggers.train.avg_meters.bwd_time(step_stats.bwd_time, bsize)
loggers.progress_loggers.train.avg_meters.opt_time(step_stats.opt_time, bsize)
loggers.progress_loggers.train.avg_meters.loss(loss, bsize)
loggers.progress_loggers.train.avg_meters.ce_loss(stats.ce_loss, bsize)
loggers.progress_loggers.train.avg_meters.skip_loss(stats.skip_loss, bsize)
loggers.progress_loggers.train.avg_meters.residual(residual, bsize)
loggers.progress_loggers.train.avg_meters.top1(accuracy, bsize)
loggers.progress_loggers.train.avg_meters.top5(-1, bsize)
loggers.progress_loggers.train.avg_meters.nfe(stats.nfe, bsize)

if step % cfg.train.print_frequency == 1 && DEQExperiments.should_log()
DEQExperiments.print_meter(loggers.progress_loggers.train.progress, step)
log_vals = DEQExperiments.get_loggable_values(loggers.progress_loggers.train.progress)
loggers.csv_loggers.train(step, log_vals...)
Wandb.log(loggers.wandb_logger, loggers.log_functions.train(step, log_vals...))
DEQExperiments.reset_meter!(loggers.progress_loggers.train.progress)
end

if step == cfg.train.pretrain_steps
DEQExperiments.should_log() && @info "Pretraining Completed!!!"
Setfield.@set! tstate.states = Lux.update_state(tstate.states, :fixed_depth, Val(0))
end

# Free memory eagarly
CUDA.unsafe_free!(x)
CUDA.unsafe_free!(y)

if step % cfg.train.evaluate_every == 1 || step == cfg.train.total_steps
is_best = true

st_eval = Lux.testmode(tstate.states)
for (x, y) in ds_test
t = time()
x, y = _data_postprocess(x, y)
dtime = time() - t

t = time()
loss, st_, stats = loss_function(model, tstate.parameters, st_eval, (x, y))
fwd_time = time() - t

bsize = size(x, ndims(x))

acc = DEQExperiments.accuracy(Lux.cpu(stats.y_pred), Lux.cpu(y))

loggers.progress_loggers.eval.avg_meters.batch_time(dtime + fwd_time, bsize)
loggers.progress_loggers.eval.avg_meters.data_time(dtime, bsize)
loggers.progress_loggers.eval.avg_meters.fwd_time(fwd_time, bsize)
loggers.progress_loggers.eval.avg_meters.loss(loss, bsize)
loggers.progress_loggers.eval.avg_meters.ce_loss(stats.ce_loss, bsize)
loggers.progress_loggers.eval.avg_meters.skip_loss(stats.skip_loss, bsize)
loggers.progress_loggers.eval.avg_meters.residual(abs(Statistics.mean(stats.residual)),
bsize)
loggers.progress_loggers.eval.avg_meters.top1(acc, bsize)
loggers.progress_loggers.eval.avg_meters.top5(-1, bsize)
loggers.progress_loggers.eval.avg_meters.nfe(stats.nfe, bsize)

# Free memory eagarly
CUDA.unsafe_free!(x)
CUDA.unsafe_free!(y)
end

if DEQExperiments.should_log()
DEQExperiments.print_meter(loggers.progress_loggers.eval.progress, step)
log_vals = DEQExperiments.get_loggable_values(loggers.progress_loggers.eval.progress)
loggers.csv_loggers.eval(step, log_vals...)
Wandb.log(loggers.wandb_logger, loggers.log_functions.eval(step, log_vals...))
DEQExperiments.reset_meter!(loggers.progress_loggers.eval.progress)
end

accuracy = loggers.progress_loggers.eval.avg_meters.top1.average
is_best = accuracy >= best_test_accuracy
if is_best
best_test_accuracy = accuracy
end

ckpt = (tstate=tstate, step=initial_step)
DEQExperiments.save_checkpoint(ckpt; is_best,
filename=joinpath(ckpt_dir, "model_$(step).jlso"))
end
end

return nothing
end

if abspath(PROGRAM_FILE) == @__FILE__
main(ARGS[1], ARGS[2:end])
end
Loading

2 comments on commit 6f5a6b7

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/75437

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.4 -m "<description of version>" 6f5a6b7596c13befbad2c0071d6050ffe6370a66
git push origin v0.2.4

Please sign in to comment.