-
Notifications
You must be signed in to change notification settings - Fork 3
/
datasets.py
125 lines (100 loc) · 4.28 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from pathlib import Path
from PIL import Image, ImageFile
from torchvision import transforms
from torch.utils import data
import numpy as np
from skimage.color import rgb2lab
import torch
Image.MAX_IMAGE_PIXELS = None # Disable DecompressionBombError
# Disable OSError: image file is truncated
ImageFile.LOAD_TRUNCATED_IMAGES = True
def train_transform(img_size):
transform_list = [
transforms.Resize(size=(img_size, img_size)),
transforms.ToTensor()
]
return transforms.Compose(transform_list)
def test_transform(size, crop):
transform_list = []
if size != 0:
transform_list.append(transforms.Resize(size))
if crop:
transform_list.append(transforms.CenterCrop(size))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
class TrainDataset(data.Dataset):
def __init__(self, root, img_size, gray_only = False):
super(TrainDataset, self).__init__()
self.root = root
self.paths = list(Path(self.root).glob('*'))
self.img_size = img_size
self.gray_only = gray_only
def __getitem__(self, index):
path = self.paths[index]
rgb_image = Image.open(str(path)).convert('RGB')
w, h = rgb_image.size
if w != h:
min_val = min(w, h)
rgb_image = rgb_image.crop((w // 2 - min_val // 2, h // 2 - min_val // 2, w // 2 + min_val // 2, h // 2 + min_val // 2))
rgb_image = np.array(rgb_image.resize((self.img_size, self.img_size), Image.LANCZOS))
lab_image = rgb2lab(rgb_image)
## Normalize to [0, 1]
l_image = (np.clip(lab_image[:, :, 0:1], 0.0, 100.0) + 0.0) / (100.0 + 0.0)
a_image = (np.clip(lab_image[:, :, 1:2], -86.0, 98.0) + 86.0) / (98.0 + 86.0)
b_image = (np.clip(lab_image[:, :, 2:3], -107.0, 94.0) + 107.0) / (94.0 + 107.0)
ab_image = np.concatenate((a_image, b_image), axis=2)
l_image = torch.from_numpy(np.transpose(l_image, (2, 0, 1)).astype(np.float32)).repeat((3, 1, 1))
if self.gray_only:
return l_image
ab_image = torch.from_numpy(np.transpose(ab_image, (2, 0, 1)).astype(np.float32))
zero = torch.zeros((1, ab_image.shape[1], ab_image.shape[2]))
ab_image = torch.cat([zero, ab_image], dim=0)
return l_image, ab_image
def get_img_path(self, index):
path = self.paths[index]
return path
def __len__(self):
return len(self.paths)
def name(self):
return 'TrainDataset'
class TestDataset(data.Dataset):
def __init__(self, root, gray_only = False, T_only = False, C_only = False):
super(TestDataset, self).__init__()
self.root = root
self.paths = list(Path(self.root).glob('*'))
self.gray_only = gray_only
self.T_only = T_only
self.C_only = C_only
def __getitem__(self, index):
path = self.paths[index]
rgb_image = Image.open(str(path)).convert('RGB')
lab_image = rgb2lab(rgb_image)
## Normalize to [0, 1]
l_image = lab_image[:, :, 0:1]
a_image = lab_image[:, :, 1:2]
b_image = lab_image[:, :, 2:3]
if not self.C_only:
l_image = (np.clip(lab_image[:, :, 0:1], 0.0, 100.0) + 0.0) / (100.0 + 0.0)
if not self.T_only:
a_image = (np.clip(lab_image[:, :, 1:2], -86.0, 98.0) + 86.0) / (98.0 + 86.0)
b_image = (np.clip(lab_image[:, :, 2:3], -107.0, 94.0) + 107.0) / (94.0 + 107.0)
ab_image = np.concatenate((a_image, b_image), axis=2)
if not self.C_only:
l_image = torch.from_numpy(np.transpose(l_image, (2, 0, 1)).astype(np.float32)).repeat((3, 1, 1))
else:
l_image = torch.from_numpy(np.transpose(l_image, (2, 0, 1)).astype(np.float32))
if self.gray_only:
return l_image
ab_image = torch.from_numpy(np.transpose(ab_image, (2, 0, 1)).astype(np.float32))
if not self.T_only:
zero = torch.zeros((1, ab_image.shape[1], ab_image.shape[2]))
ab_image = torch.cat([zero, ab_image], dim=0)
return l_image, ab_image
def get_img_path(self, index):
path = self.paths[index]
return path
def __len__(self):
return len(self.paths)
def name(self):
return 'TestDataset'