-
Notifications
You must be signed in to change notification settings - Fork 134
/
dataloader.py
66 lines (45 loc) · 1.88 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import torch as torch
import numpy as np
from io import BytesIO
import scipy.misc
#import tensorflow as tf
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torch.autograd import Variable
from matplotlib import pyplot as plt
from PIL import Image
class dataloader:
def __init__(self, config):
self.root = config.train_data_root
self.batch_table = {4:32, 8:32, 16:32, 32:16, 64:16, 128:16, 256:12, 512:3, 1024:1} # change this according to available gpu memory.
self.batchsize = int(self.batch_table[pow(2,2)]) # we start from 2^2=4
self.imsize = int(pow(2,2))
self.num_workers = 4
def renew(self, resl):
print('[*] Renew dataloader configuration, load data from {}.'.format(self.root))
self.batchsize = int(self.batch_table[pow(2,resl)])
self.imsize = int(pow(2,resl))
self.dataset = ImageFolder(
root=self.root,
transform=transforms.Compose( [
transforms.Resize(size=(self.imsize,self.imsize), interpolation=Image.NEAREST),
transforms.ToTensor(),
]))
self.dataloader = DataLoader(
dataset=self.dataset,
batch_size=self.batchsize,
shuffle=True,
num_workers=self.num_workers
)
def __iter__(self):
return iter(self.dataloader)
def __next__(self):
return next(self.dataloader)
def __len__(self):
return len(self.dataloader.dataset)
def get_batch(self):
dataIter = iter(self.dataloader)
return next(dataIter)[0].mul(2).add(-1) # pixel range [-1, 1]