-
Notifications
You must be signed in to change notification settings - Fork 0
/
ood_train.py
69 lines (55 loc) · 2.07 KB
/
ood_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import json
from dataclasses import dataclass
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar
from utils.helper_functions import load_dataset
from utils.arguments import get_parser
from nets.wrapperood import WrapperOod
@dataclass
class Arguments:
dataset_dir: str = "./datasets"
dataset: str = ""
horizon: float = 1.0
alpha: float = 1.0
blur: int = 0
histogram: bool = False
lr: float = .00001
beta: float = .01
beta2: float = .001
checkpoint: str = "./"
epochs: int = 5
def load_pretrained(dataset_name, lr=.00001, beta=.01, beta2=.001):
# pylint: disable=no-value-for-parameter
if dataset_name == "SHIFT":
model = WrapperOod(backbone="resnet50", num_classes=21, lr=lr, beta=beta, beta2=beta2)
model.load_state_dict(state_dict=torch.load("pretrained/shift_weights.ckpt"))
return model
elif dataset_name == "StreetHazards":
model = WrapperOod(backbone="resnet50", num_classes=14, lr=lr, beta=beta, beta2=beta2)
model.load_state_dict(state_dict=torch.load("pretrained/sh_weights.ckpt"))
return model
return None
def main(args: Arguments):
dm = load_dataset(
args.dataset, args.dataset_dir, args.horizon, args.alpha, args.histogram, args.blur)
model = load_pretrained(args.dataset, args.lr, args.beta, args.beta2)
tr = Trainer(
default_root_dir=args.checkpoint, accelerator="cuda",
callbacks=[
ModelCheckpoint(save_on_train_epoch_end=True),
TQDMProgressBar(refresh_rate=2)
], max_epochs=args.epochs, check_val_every_n_epoch=100)
tr.fit(model=model, datamodule=dm)
model.ood_scores = []
model.ood_masks = []
out = tr.test(model=model, datamodule=dm)
print(out)
with open(os.path.join(args.checkpoint, "result.json"), "w") as f:
json.dump({"results": out, "args": args.__dict__}, f)
if __name__ == "__main__":
p = get_parser(Arguments)
arg = p.parse_args()
print(arg)
main(arg)