-
Notifications
You must be signed in to change notification settings - Fork 0
/
prepare_data.py
102 lines (91 loc) · 4.36 KB
/
prepare_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#!/usr/bin/env python3
from glob import glob
import os
from io import BytesIO
import argparse
import multiprocessing
from ipdb import set_trace
# import cv2
import lmdb
from tqdm import tqdm
from PIL import Image
from torchvision.transforms import functional as trans_fn
def format_for_lmdb(*args):
key_parts = []
for arg in args:
if isinstance(arg, int):
arg = str(arg).zfill(7)
key_parts.append(arg)
return '-'.join(key_parts).encode('utf-8')
class Resizer:
def __init__(self, data_type, *, size, quality):
assert data_type in ('images', 'videos'), data_type
self.data_type = data_type
self.size = size
self.quality = quality
def get_resized_bytes(self, img):
img = trans_fn.resize(img, self.size)
buf = BytesIO()
img.save(buf, format='jpeg', quality=self.quality)
img_bytes = buf.getvalue()
return img_bytes
def prepare(self, filename):
if self.data_type == 'images':
img = Image.open(filename)
img = img.convert('RGB')
return self.get_resized_bytes(img)
def __call__(self, index, filename):
result = self.prepare(filename)
return index, result
def prepare_data(data_type, path, out, n_worker, sizes, quality, chunksize):
print('Starting...')
filenames = list()
IMAGE_EXTENSIONS_LOWERCASE = {'jpg', 'png', 'jpeg', 'webp'}
IMAGE_EXTENSIONS = IMAGE_EXTENSIONS_LOWERCASE.union({f.upper() for f in IMAGE_EXTENSIONS_LOWERCASE})
VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
extensions = IMAGE_EXTENSIONS if data_type == 'images' else VIDEO_EXTENSIONS
for ext in extensions:
filenames += glob(f'{path}/**/*.{ext}', recursive=True)
filenames = sorted(filenames)
total = len(filenames)
os.makedirs(out, exist_ok=True)
for size in sizes:
lmdb_path = os.path.join(out, str(size))
with lmdb.open(lmdb_path, map_size=1024 ** 4, readahead=False) as env:
with env.begin(write=True) as txn:
txn.put(format_for_lmdb('length'), format_for_lmdb(total))
resizer = Resizer(data_type, size=size, quality=quality)
# with multiprocessing.Pool(n_worker) as pool:
# bar = tqdm(
# pool.imap_unordered(resizer, enumerate(filenames), chunksize=chunksize),
# total=total)
# for idx, result in bar:
# bar.set_description(filenames[idx])
# if data_type == 'images':
# txn.put(format_for_lmdb(idx), result)
# else:
# txn.put(format_for_lmdb(idx, 'length'), format_for_lmdb(len(result)))
# for frame_idx, frame in enumerate(result):
# txn.put(format_for_lmdb(idx, frame_idx), frame)
bar = tqdm(enumerate(filenames), total=total)
for index, filename in bar:
bar.set_description(filename)
idx, result = resizer(index, filename)
if data_type == 'images':
txn.put(format_for_lmdb(idx), result)
else:
txn.put(format_for_lmdb(idx, 'length'), format_for_lmdb(len(result)))
for frame_idx, frame in enumerate(result):
txn.put(format_for_lmdb(idx, frame_idx), frame)
if __name__ == '__main__':
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('data_type', type=str, help='data type', choices=['images', 'videos'])
parser.add_argument('path', type=str, help='a path to input directiory')
parser.add_argument('--out', type=str, help='a path to output directory')
parser.add_argument('--sizes', type=int, nargs='+', default=(8, 16, 32, 64, 128, 256, 512, 1024))
parser.add_argument('--quality', type=int, help='output jpeg quality', default=85)
parser.add_argument('--n_worker', type=int, help='number of worker processes', default=1)
parser.add_argument('--chunksize', type=int, help='approximate chunksize for each worker', default=10)
args = parser.parse_args()
prepare_data(**vars(args))