-
Notifications
You must be signed in to change notification settings - Fork 15
/
inat.py
86 lines (68 loc) · 2.87 KB
/
inat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch.utils.data as data
from PIL import Image
import os
import json
from torchvision import transforms
import random
import numpy as np
def default_loader(path):
return Image.open(path).convert('RGB')
class INAT(data.Dataset):
def __init__(self, root, ann_file, is_train=True):
# load annotations
print('Loading annotations from: ' + os.path.basename(ann_file))
with open(ann_file) as data_file:
ann_data = json.load(data_file)
# set up the filenames and annotations
self.imgs = [aa['file_name'] for aa in ann_data['images']]
# if we dont have class labels set them to '0'
if 'annotations' in ann_data.keys():
self.classes = [aa['category_id'] for aa in ann_data['annotations']]
else:
self.classes = [0]*len(self.imgs)
# print out some stats
print('\t' + str(len(self.imgs)) + ' images')
print('\t' + str(len(set(self.classes))) + ' classes')
self.root = root
self.is_train = is_train
self.loader = default_loader
# augmentation params
self.im_size = [224, 224] # can change this to train on higher res
self.mu_data = [0.485, 0.456, 0.406]
self.std_data = [0.229, 0.224, 0.225]
self.brightness = 0.4
self.contrast = 0.4
self.saturation = 0.4
self.hue = 0.25
# augmentations
self.resize_aug = transforms.Resize(256)
self.center_crop = transforms.CenterCrop((self.im_size[0], self.im_size[1]))
self.scale_aug = transforms.RandomResizedCrop(size=self.im_size[0])
self.flip_aug = transforms.RandomHorizontalFlip()
self.color_aug = transforms.ColorJitter(self.brightness, self.contrast, self.saturation, self.hue)
self.tensor_aug = transforms.ToTensor()
self.norm_aug = transforms.Normalize(mean=self.mu_data, std=self.std_data)
def __getitem__(self, index):
path = self.root + self.imgs[index]
img = self.loader(path)
species_id = self.classes[index]
if self.is_train:
img = self.scale_aug(img)
img = self.flip_aug(img)
img = self.color_aug(img)
else:
img = self.resize_aug(img)
img = self.center_crop(img)
img = self.tensor_aug(img)
img = self.norm_aug(img)
return img, species_id
def __len__(self):
return len(self.imgs)
if __name__ == '__main__':
data_dir = "/home/ubuntu/inat/"
train_dataset = INAT(root=data_dir, ann_file=data_dir+"new_train.json", is_train=True)
val_dataset = INAT(root=data_dir, ann_file=data_dir+"new_val.json", is_train=True)
test_dataset = INAT(root=data_dir, ann_file=data_dir+"new_test.json", is_train=False)
print(len(train_dataset), train_dataset[0])
print(len(val_dataset), val_dataset[0])
print(len(test_dataset), test_dataset[0])