-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
54 lines (43 loc) · 1.61 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import config
import os
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
class DRDataset(Dataset):
def __init__(self, images_folder, path_to_csv, train=True, transform=None):
super().__init__()
self.data = pd.read_csv(path_to_csv)
self.images_folder = images_folder
self.image_files = os.listdir(images_folder)
self.transform = transform
self.train = train
def __len__(self):
return self.data.shape[0] if self.train else len(self.image_files)
def __getitem__(self, index):
if self.train:
image_file, label = self.data.iloc[index]
else:
# if test simply return -1 for label, I do this in order to
# re-use same dataset class for test set submission later on
image_file, label = self.image_files[index], -1
image_file = image_file.replace(".jpeg", "")
image = np.array(Image.open(os.path.join(self.images_folder, image_file+".jpeg")))
if self.transform:
image = self.transform(image=image)["image"]
return image, label, image_file
if __name__ == "__main__":
dataset = DRDataset(
images_folder="../train/images_resized_650/",
path_to_csv="../train/trainLabels.csv",
transform=config.val_transforms,
)
loader = DataLoader(
dataset=dataset, batch_size=32, num_workers=2, shuffle=True, pin_memory=True
)
for x, label, file in tqdm(loader):
print(x.shape)
print(label.shape)
import sys
sys.exit()