-
Notifications
You must be signed in to change notification settings - Fork 2
/
multimnist3.py
87 lines (72 loc) · 3.1 KB
/
multimnist3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import logging
import hydra
import torch
import wandb
from omegaconf import DictConfig, OmegaConf
from src.datasets.multimnist3digits import MultiMnistThreeDataModule
from src.models.base_model import SharedBottom
from src.models.factory.cosmos.upsampler import Upsampler
from src.models.factory.lenet import MultiLeNetO, MultiLeNetR
from src.models.factory.phn.phn_wrappers import HyperModel
from src.models.factory.rotograd import RotogradWrapper
from src.utils import set_seed
from src.utils._selectors import get_callbacks, get_ensemble_model, get_optimizer, get_trainer
from src.utils.callbacks.auto_lambda_callback import AutoLambdaCallback
from src.utils.logging_utils import initialize_wandb, install_logging
from src.utils.losses import MultiTaskCrossEntropyLoss
@hydra.main(config_path="configs/experiment/multimnist3", config_name="multimnist3")
def my_app(config: DictConfig) -> None:
install_logging()
logging.info(OmegaConf.to_yaml(config))
set_seed(config.seed)
initialize_wandb(config)
dm = MultiMnistThreeDataModule(
batch_size=config.data.batch_size,
num_workers=config.data.num_workers,
)
logging.info(f"I am using the following benchmark {dm.name}")
if config.method.name == "phn":
model = HyperModel(dm.name)
elif config.method.name == "cosmos":
model = SharedBottom(encoder=MultiLeNetR(in_channels=4), decoder=MultiLeNetO(), num_tasks=3)
elif config.method.name == "rotograd":
backbone = MultiLeNetR(in_channels=1)
head1, head2, head3 = MultiLeNetO(), MultiLeNetO(), MultiLeNetO()
model = RotogradWrapper(backbone=backbone, heads=[head1, head2, head3], latent_size=50)
else:
model = SharedBottom(
encoder=MultiLeNetR(in_channels=1), decoder=[MultiLeNetO(), MultiLeNetO(), MultiLeNetO()], num_tasks=3
)
logging.info(model)
if config.method.name == "pamal":
model = get_ensemble_model(model, dm.num_tasks, config)
elif config.method.name == "cosmos":
model = Upsampler(dm.num_tasks, model, input_dim=dm.input_dims)
param_groups = model.parameters()
optimizer = get_optimizer(config, param_groups)
if config.method.name == "rotograd":
optimizer = torch.optim.Adam(
[{"params": m.parameters()} for m in [backbone, head1, head2]]
+ [{"params": model.parameters(), "lr": config.optimizer.lr * 0.1}],
lr=config.optimizer.lr,
)
callbacks = get_callbacks(config, dm.num_tasks)
if config.method.name == "autol":
callbacks.append(AutoLambdaCallback(config.method.meta_lr))
trainer_kwargs = dict(
model=model,
benchmark=dm,
optimizer=optimizer,
gpu=0,
callbacks=callbacks,
loss_fn=MultiTaskCrossEntropyLoss(),
)
trainer = get_trainer(config, trainer_kwargs, dm.num_tasks, model)
trainer.fit(epochs=config.training.epochs)
if config.method.name == "pamal":
trainer.predict_interpolations(dm.test_dataloader())
else:
trainer.predict(test_loader=dm.test_dataloader())
wandb.finish()
if __name__ == "__main__":
my_app()