-
Notifications
You must be signed in to change notification settings - Fork 39
/
hubconf.py
91 lines (52 loc) · 3.04 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
dependencies = ['torch', 'torchvision']
from torch.hub import load_state_dict_from_url
def _preact_resnet18(msda='fmix', pretrained=False, repeat=0, *args, **kwargs):
from models import ResNet18
model = ResNet18(*args, **kwargs)
if pretrained:
state = load_state_dict_from_url(
f'http://marc.ecs.soton.ac.uk/pytorch-models/cifar10/preact-resnet18/cifar10_preact_resnet18_{msda}_{repeat}.pt', progress=True)
model.load_state_dict(state)
return model
def _resnet101(msda='fmix', pretrained=False, *args, **kwargs):
from torchvision.models.resnet import resnet101
model = resnet101(*args, **kwargs)
if pretrained:
state = load_state_dict_from_url(
f'http://marc.ecs.soton.ac.uk/pytorch-models/imagenet/resnet101/imagenet_resnet101_{msda}.pt', progress=True)
model.load_state_dict(state)
return model
def _pyramidnet(msda='fmix', pretrained=False, *args, **kwargs):
from models import aa_PyramidNet
model = aa_PyramidNet(*args, **kwargs)
if pretrained:
state = load_state_dict_from_url(
f'http://marc.ecs.soton.ac.uk/pytorch-models/cifar10/pyramidnet/cifar10_pyramidnet_{msda}.pt', progress=True)
model.load_state_dict(state)
return model
def preact_resnet18_cifar10_baseline(pretrained=False, repeat=0, *args, **kwargs):
return _preact_resnet18('baseline', pretrained, repeat, *args, **kwargs)
def preact_resnet18_cifar10_fmix(pretrained=False, repeat=0, *args, **kwargs):
return _preact_resnet18('fmix', pretrained, repeat, *args, **kwargs)
def preact_resnet18_cifar10_mixup(pretrained=False, repeat=0, *args, **kwargs):
return _preact_resnet18('mixup', pretrained, repeat, *args, **kwargs)
def preact_resnet18_cifar10_cutmix(pretrained=False, repeat=0, *args, **kwargs):
return _preact_resnet18('cutmix', pretrained, repeat, *args, **kwargs)
def preact_resnet18_cifar10_fmixplusmixup(pretrained=False, repeat=0, *args, **kwargs):
return _preact_resnet18('fmixplusmixup', pretrained, repeat, *args, **kwargs)
def pyramidnet_cifar10_baseline(pretrained=False, *args, **kwargs):
return _pyramidnet('baseline', pretrained, *args, **kwargs)
def pyramidnet_cifar10_fmix(pretrained=False, *args, **kwargs):
return _pyramidnet('fmix', pretrained, *args, **kwargs)
def pyramidnet_cifar10_mixup(pretrained=False, *args, **kwargs):
return _pyramidnet('mixup', pretrained, *args, **kwargs)
def pyramidnet_cifar10_cutmix(pretrained=False, *args, **kwargs):
return _pyramidnet('cutmix', pretrained, *args, **kwargs)
def renset101_imagenet_baseline(pretrained=False, *args, **kwargs):
return _resnet101('baseline', pretrained, *args, **kwargs)
def renset101_imagenet_fmix(pretrained=False, *args, **kwargs):
return _resnet101('fmix', pretrained, *args, **kwargs)
def renset101_imagenet_mixup(pretrained=False, *args, **kwargs):
return _resnet101('mixup', pretrained, *args, **kwargs)
def renset101_imagenet_cutmix(pretrained=False, *args, **kwargs):
return _resnet101('cutmix', pretrained, *args, **kwargs)