-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
50 lines (42 loc) · 2.03 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from PIL import Image
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
import random
import torch
import numpy as np
class CIFAR10Pair(CIFAR10):
def __getitem__(self, index):
img, target = self.data[index], self.targets[index]
img = Image.fromarray(img)
if self.transform is not None:
x_i = self.transform(img)
x_j = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return x_i, x_j, target
def transform(train=True,seed=2147483647):
random.seed(seed)
torch.random.manual_seed(seed)
color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
if train:
return transforms.Compose([
transforms.RandomResizedCrop(32),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])
def get_data(batch_size, num_workers=16):
train_data = CIFAR10Pair(root='data', train=True, transform=transform(), download=True)
val_data = CIFAR10Pair(root='data', train=True, transform=transform(False), download=True)
test_data = CIFAR10Pair(root='data', train=False, transform=transform(False), download=True)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers = num_workers, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers = num_workers, pin_memory=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers = num_workers, pin_memory=True)
return train_loader, val_loader, test_loader