-
Notifications
You must be signed in to change notification settings - Fork 287
/
Copy pathfinetune_eval.py
139 lines (127 loc) · 4.2 KB
/
finetune_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from pathlib import Path
from typing import Dict
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import DeviceStatsMonitor, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from torch.nn import Module
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision import transforms as T
from lightly.data import LightlyDataset
from lightly.transforms.utils import IMAGENET_NORMALIZE
from lightly.utils.benchmarking import LinearClassifier, MetricCallback
from lightly.utils.dist import print_rank_zero
from lightly.utils.scheduler import CosineWarmupScheduler
class FinetuneEvalClassifier(LinearClassifier):
def configure_optimizers(self):
parameters = list(self.classification_head.parameters())
parameters += self.model.parameters()
optimizer = SGD(
parameters,
lr=0.05 * self.batch_size_per_device * self.trainer.world_size / 256,
momentum=0.9,
weight_decay=0.0,
)
scheduler = {
"scheduler": CosineWarmupScheduler(
optimizer=optimizer,
warmup_epochs=0,
max_epochs=self.trainer.estimated_stepping_batches,
),
"interval": "step",
}
return [optimizer], [scheduler]
def finetune_eval(
model: Module,
train_dir: Path,
val_dir: Path,
log_dir: Path,
batch_size_per_device: int,
num_workers: int,
accelerator: str,
devices: int,
precision: str,
num_classes: int,
) -> Dict[str, float]:
"""Runs fine-tune evaluation on the given model.
Parameters follow SimCLR [0] settings.
The most important settings are:
- Backbone: Frozen
- Epochs: 30
- Optimizer: SGD
- Base Learning Rate: 0.05
- Momentum: 0.9
- Weight Decay: 0.0
- LR Schedule: Cosine without warmup
References:
- [0]: SimCLR, 2020, https://arxiv.org/abs/2002.05709
"""
print_rank_zero("Running fine-tune evaluation...")
# Setup training data.
train_transform = T.Compose(
[
T.RandomResizedCrop(224),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize(mean=IMAGENET_NORMALIZE["mean"], std=IMAGENET_NORMALIZE["std"]),
]
)
train_dataset = LightlyDataset(input_dir=str(train_dir), transform=train_transform)
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size_per_device,
shuffle=True,
num_workers=num_workers,
drop_last=True,
persistent_workers=False,
)
# Setup validation data.
val_transform = T.Compose(
[
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=IMAGENET_NORMALIZE["mean"], std=IMAGENET_NORMALIZE["std"]),
]
)
val_dataset = LightlyDataset(input_dir=str(val_dir), transform=val_transform)
val_dataloader = DataLoader(
val_dataset,
batch_size=batch_size_per_device,
shuffle=False,
num_workers=num_workers,
persistent_workers=False,
)
# Train linear classifier.
metric_callback = MetricCallback()
trainer = Trainer(
max_epochs=30,
accelerator=accelerator,
devices=devices,
callbacks=[
LearningRateMonitor(),
DeviceStatsMonitor(),
metric_callback,
],
logger=TensorBoardLogger(save_dir=str(log_dir), name="finetune_eval"),
precision=precision,
strategy="ddp_find_unused_parameters_true",
num_sanity_val_steps=0,
)
classifier = FinetuneEvalClassifier(
model=model,
batch_size_per_device=batch_size_per_device,
feature_dim=2048,
num_classes=num_classes,
freeze_model=False,
)
trainer.fit(
model=classifier,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
metrics_dict: Dict[str, float] = dict()
for metric in ["val_top1", "val_top5"]:
print(f"max finetune {metric}: {max(metric_callback.val_metrics[metric])}")
metrics_dict[metric] = max(metric_callback.val_metrics[metric])
return metrics_dict