-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtv_utils.py
144 lines (119 loc) · 4.5 KB
/
tv_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from PIL import Image
import PIL
from PIL import ImageFilter
import numbers
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset
import numpy as np
import os
import torch
import torch.nn as nn
import torchvision.datasets
_pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST',
Image.BILINEAR: 'PIL.Image.BILINEAR',
Image.BICUBIC: 'PIL.Image.BICUBIC',
Image.LANCZOS: 'PIL.Image.LANCZOS',
}
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
class Permute(nn.Module):
def __init__(self, permutation = [2,1,0]):
super().__init__()
self.permutation = permutation
def forward(self, input):
return input[:, self.permutation]
class ImageNet(Dataset):
def __init__(self, root_dir, csv_name='labels', transform=None):
self.transform = transform
self.datas = []
with open(os.path.join(root_dir, csv_name)) as f:
for line in f.readlines():
img_path, gt_label = line.strip().split(' ')
self.datas.append((os.path.join(root_dir, img_path), int(gt_label)))
def __len__(self):
l = len(self.datas)
return l
def __getitem__(self, idx):
filename, label_source = self.datas[idx]
# filename = os.path.join(self.image_dir, self.labels.at[idx, 'ImageId'])
in_img_t = Image.open(filename)
if self.transform is not None:
in_img_t = self.transform(in_img_t)
return in_img_t, label_source
class GaussianSmoothing(object):
def __init__(self, radius):
if isinstance(radius, numbers.Number):
self.min_radius = radius
self.max_radius = radius
elif isinstance(radius, list):
if len(radius) != 2:
raise Exception(
"`radius` should be a number or a list of two numbers")
if radius[1] < radius[0]:
raise Exception(
"radius[0] should be <= radius[1]")
self.min_radius = radius[0]
self.max_radius = radius[1]
else:
raise Exception(
"`radius` should be a number or a list of two numbers")
def __call__(self, image):
radius = np.random.uniform(self.min_radius, self.max_radius)
return image.filter(ImageFilter.GaussianBlur(radius))
class SpatialAffine(object):
def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0):
self.degrees = degrees
if translate is not None:
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
"translate should be a list or tuple and it must be of length 2."
self.translate = translate
self.scale = scale
self.shear = shear
self.resample = resample
self.fillcolor = fillcolor
@staticmethod
def get_params(degrees, translate, scale_ranges, shears, img_size):
"""Get parameters for affine transformation
Returns:
sequence: params to be passed to the affine transformation
"""
angle = degrees
if translate is not None:
max_dx = translate[0]
max_dy = translate[1]
translations = (max_dx, max_dy)
else:
translations = (0, 0)
if scale_ranges is not None:
scale = scale_ranges
else:
scale = 1.0
if shears is not None:
shear = shears
else:
shear = 0.0
return angle, translations, scale, shear
def __call__(self, img):
"""
img (PIL Image): Image to be transformed.
Returns:
PIL Image: Affine transformed image.
"""
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
return TF.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor)
def __repr__(self):
s = '{name}(degrees={degrees}'
if self.translate is not None:
s += ', translate={translate}'
if self.scale is not None:
s += ', scale={scale}'
if self.shear is not None:
s += ', shear={shear}'
if self.resample > 0:
s += ', resample={resample}'
if self.fillcolor != 0:
s += ', fillcolor={fillcolor}'
s += ')'
d = dict(self.__dict__)
d['resample'] = _pil_interpolation_to_str[d['resample']]
return s.format(name=self.__class__.__name__, **d)