-
Notifications
You must be signed in to change notification settings - Fork 3
/
generate_validation.py
80 lines (68 loc) · 3.5 KB
/
generate_validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
PATH = 'data/cifar_data' #YOUR DATAPATH HERE
import torch
import argparse
from torchvision import transforms, datasets
from torch.utils.data.sampler import SubsetRandomSampler
import os
import numpy as np
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='cifar10')
args = parser.parse_args()
dataset = args.dataset
if __name__ == '__main__':
if not os.path.exists('data'):
os.mkdir('data')
if dataset == 'cifar10':
# CIFAR10
train_transform_ = transforms.Compose([
transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()])
transform_ = transforms.Compose([transforms.ToTensor()])
ori_set = datasets.CIFAR10(PATH, train=True, download=True, transform=train_transform_)
valid_set = datasets.CIFAR10(PATH, train=True, download=True, transform=transform_)
ori_label = torch.tensor([y for (_, y) in ori_set])
n = 100 # for each classes (2% of 5000)
valid_index, train_index = [], []
for i in range(10):
valid_index_i = (ori_label==i).nonzero()[:n]
train_index_i = (ori_label==i).nonzero()[n:]
valid_index.append(valid_index_i)
train_index.append(train_index_i)
valid_index = torch.cat(valid_index, dim=0).flatten()
train_index = torch.cat(train_index, dim=0).flatten()
N = len(train_index)
order = np.random.permutation(N)
train_index = train_index[order]
train_sampler = SubsetRandomSampler(train_index)
valid_sampler = SubsetRandomSampler(valid_index)
train_loader = torch.utils.data.DataLoader(ori_set, batch_size=128, shuffle=False, sampler=train_sampler)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=100, shuffle=False, sampler=valid_sampler)
torch.save([train_loader, valid_loader], 'data/split_dataset.pth')
else:
# CIFAR 100
train_transform_ = transforms.Compose([
transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()])
transform_ = transforms.Compose([transforms.ToTensor()])
ori_set = datasets.CIFAR100(PATH, train=True, download=True, transform=train_transform_)
valid_set = datasets.CIFAR100(PATH, train=True, download=True, transform=transform_)
ori_label = torch.tensor([y for (_, y) in ori_set])
n = 25 # for each classes
valid_index, train_index = [], []
for i in range(100):
valid_index_i = (ori_label==i).nonzero()[:n]
train_index_i = (ori_label==i).nonzero()[n:]
valid_index.append(valid_index_i)
train_index.append(train_index_i)
valid_index = torch.cat(valid_index, dim=0).flatten()
train_index = torch.cat(train_index, dim=0).flatten()
N = len(train_index)
order = np.random.permutation(N)
train_index = train_index[order]
train_sampler = SubsetRandomSampler(train_index)
valid_sampler = SubsetRandomSampler(valid_index)
train_loader = torch.utils.data.DataLoader(ori_set, batch_size=128, shuffle=False, sampler=train_sampler)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=100, shuffle=False, sampler=valid_sampler)
torch.save([train_loader, valid_loader], 'data/split_dataset_100.pth')