-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
124 lines (99 loc) · 4.11 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import fnmatch
import os
import PIL
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn import BatchNorm2d
global args
def lr_poly(base_lr, iter, max_iter, power):
return base_lr * ((1 - float(iter) / max_iter) ** (power))
def adjust_learning_rate(optimizer, i_iter, config, args):
lr = lr_poly(config.lr, i_iter, config.num_steps, args.power)
optimizer.param_groups[0]['lr'] = lr
if len(optimizer.param_groups) > 1:
optimizer.param_groups[1]['lr'] = lr * 10
def freeze_model(model, exclude_layers=('inconv',)):
for name, param in model.named_parameters():
requires_grad = False
for l in exclude_layers:
if l in name:
requires_grad = True
param.requires_grad = requires_grad
def freeze_norm_stats(model, exclude_layers=('inconv',)):
for name, m in model.named_modules():
if isinstance(m, BatchNorm2d):
m.eval()
m.track_running_stats = False
for l in exclude_layers:
if l in name:
m.train()
class CrossEntropy2d(nn.Module):
def __init__(self, size_average=True, ignore_label=255):
super(CrossEntropy2d, self).__init__()
self.size_average = size_average
assert self.size_average
self.ignore_label = ignore_label
def forward(self, predict, target, weight=None):
"""
Args:
predict:(n, c, h, w)
target:(n, h, w)
weight (Tensor, optional): a manual rescaling weight given to each class.
If given, has to be a Tensor of size "nclasses"
"""
assert not target.requires_grad
assert predict.dim() == 4
if target.dim() != 3:
assert target.dim() == 4
target = target.squeeze(1)
assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0))
assert predict.size(2) == target.size(1), "{0} vs {1} ".format(predict.size(2), target.size(1))
assert predict.size(3) == target.size(2), "{0} vs {1} ".format(predict.size(3), target.size(2))
n, c, h, w = predict.size()
target_mask = (target >= 0) * (target != self.ignore_label)
target = target[target_mask]
if not target.data.dim():
return Variable(torch.zeros(1))
predict = predict.transpose(1, 2).transpose(2, 3).contiguous()
predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c)
loss = F.cross_entropy(predict, target, weight=weight, reduction='mean')
return loss
def loss_calc(pred, label, gpu):
"""
This function returns cross entropy loss for semantic segmentation
"""
# out shape batch_size x channels x h x w -> batch_size x channels x h x w
# label shape h x w x 1 x batch_size -> batch_size x 1 x h x w
label = Variable(label.long()).to(gpu)
criterion = CrossEntropy2d().to(gpu)
return criterion(pred, label)
def include_patterns(*patterns):
"""Factory function that can be used with copytree() ignore parameter.
Arguments define a sequence of glob-style patterns
that are used to specify what files to NOT ignore.
Creates and returns a function that determines this for each directory
in the file hierarchy rooted at the source directory when used with
shutil.copytree().
"""
def _ignore_patterns(path, names):
keep = set(name for pattern in patterns
for name in fnmatch.filter(names, pattern))
ignore = set(name for name in names
if name not in keep and not os.path.isdir(os.path.join(path, name)))
return ignore
return _ignore_patterns
def tensor_to_image(tensor):
tensor = tensor * 255
tensor = tensor.detach().cpu()
tensor = np.array(tensor, dtype=np.uint8)
if np.ndim(tensor) > 3:
assert tensor.shape[0] == 1
tensor = tensor[0]
if tensor.shape[0] == 1:
tensor = tensor.squeeze()
else:
assert False
return PIL.Image.fromarray(tensor)