Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support saving and loading npz file in offline evaluation mode. #201

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 165 additions & 9 deletions mmgen/core/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from copy import deepcopy

import mmcv
import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import get_dist_info
Expand Down Expand Up @@ -62,6 +63,89 @@ def make_vanilla_dataloader(img_path, batch_size, dist=False):
return dataloader


def make_npz_dataloader(npz_path, batch_size, dist=False):
pipeline = [
# permute the color channel. Because in npz_dataloader's pipeline, we
# direct load image in RGB order from npz file and we must convert it
# to BGR by setting ``to_rgb=True``.
dict(
type='Normalize',
keys=['real_img'],
mean=[127.5] * 3,
std=[127.5] * 3,
to_rgb=True),
dict(type='ImageToTensor', keys=['real_img']),
dict(type='Collect', keys=['real_img'])
]

dataset = build_dataset(
dict(type='FileDataset', file_path=npz_path, pipeline=pipeline))
dataloader = build_dataloader(
dataset,
samples_per_gpu=batch_size,
workers_per_gpu=0,
dist=dist,
shuffle=True)
return dataloader


def parse_npz_file_name(npz_file):
"""Parse the basic information from the give npz file.
Args:
npz_file (str): The file name of the npz file.

Returns:
tuple(int): A tuple of (num_samples, H, W, num_channels).
"""
# remove 'samples_' (8) and '.npz' (4)
num_samples, H, W, num_channles = npz_file[8:-4].split('x')
assert num_samples.isdigit()
assert H.isdigit()
assert W.isdigit()
assert num_channles.isdigit()
return int(num_samples), int(H), int(W), int(num_channles)


def parse_npz_folder(npz_folder_path):
"""Parse the npz files under the given folder.
Args:
npz_folder_path: The folder contains npz file.

Returns:
tuple(list, int): A tuple contains a list valid npz files' names and
a int of existing image numbers.
"""

npz_files = [
f for f in list(mmcv.scandir(npz_folder_path, suffix=('.npz')))
if 'samples' in f
]
valid_npz_files = []
img_shape = None
num_exist = 0
for npz_file in npz_files:
try:
n_samples, H, W, n_channels = parse_npz_file_name(npz_file)

# shape checking
if img_shape is None:
img_shape = (H, W, n_channels)
else:
if img_shape != (H, W, n_channels):
raise ValueError(
'Image shape conflicting under sample path:'
f'\'{npz_folder_path}\'. Find {img_shape} vs. '
f'{(H, W, n_channels)}.')

valid_npz_files.append(npz_file)
num_exist += n_samples
except AssertionError:
mmcv.print_log(
f'Find npz file \'{npz_file}\' does not conform to the '
'standard naming convention.', 'mmgen')
return valid_npz_files, num_exist


@torch.no_grad()
def offline_evaluation(model,
data_loader,
Expand All @@ -70,6 +154,7 @@ def offline_evaluation(model,
basic_table_info,
batch_size,
samples_path=None,
save_npz=False,
**kwargs):
"""Evaluate model in offline mode.

Expand All @@ -87,6 +172,10 @@ def offline_evaluation(model,
samples_path (str): Used to save generated images. If it's none, we'll
give it a default directory and delete it after finishing the
evaluation. Default to None.
save_npz (bool, optional): Whether save the generated images to a npz
file named 'samples_{NUM_IMAGES}x{H}x{W}x{NUM_CHANNELS}.npz' If
true, dataset will be build upon npz file instead of image files.
Defaults to True.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defaults to False

kwargs (dict): Other arguments.
"""
# eval special and recon metric online only
Expand All @@ -111,11 +200,14 @@ def offline_evaluation(model,
os.makedirs(samples_path)
delete_samples_path = True

# sample images
num_exist = len(
list(
mmcv.scandir(
samples_path, suffix=('.jpg', '.png', '.jpeg', '.JPEG'))))
# check existing images
if save_npz:
npz_file_exist, num_exist = parse_npz_folder(samples_path)
else:
num_exist = len(
list(
mmcv.scandir(
samples_path, suffix=('.jpg', '.png', '.jpeg', '.JPEG'))))
if basic_table_info['num_samples'] > 0:
max_num_images = basic_table_info['num_samples']
else:
Expand All @@ -128,6 +220,7 @@ def offline_evaluation(model,
# define mmcv progress bar
pbar = mmcv.ProgressBar(num_needed)

fake_img_list = []
# if no images, `num_needed` should be zero
total_batch_size = batch_size * ws
for begin in range(0, num_needed, total_batch_size):
Expand Down Expand Up @@ -163,8 +256,66 @@ def offline_evaluation(model,
images = fakes[i:i + 1]
images = ((images + 1) / 2)
images = images.clamp_(0, 1)
image_name = str(num_exist + begin + i) + '.png'
save_image(images, os.path.join(samples_path, image_name))
if save_npz:
# permute to [H, W, chn] and rescale to [0, 255]
fake_img_list.append(
images.permute(0, 2, 3, 1).cpu().numpy() * 255)
else:
image_name = str(num_exist + begin + i) + '.png'
save_image(images, os.path.join(samples_path, image_name))

if save_npz:
# only one npz file and fake_img_list is empty --> do not need to save
if len(npz_file_exist) == 1 and len(fake_img_list) == 0:
npz_path = os.path.join(samples_path, npz_file_exist[0])
if rank == 0:
mmcv.print_log(
f'Existing npz file \'{npz_path}\' has already met '
'requirements.', 'mmgen')
else:
# load from locl file and merge to one
if rank == 0:
fake_img_exist_list = []
for exist_npz_file in npz_file_exist:
fake_imgs_ = np.load(
os.path.join(samples_path, exist_npz_file))['real_img']
fake_img_exist_list.append(fake_imgs_)

# merge fake_img_exist_list and fake_img_list
fake_imgs = np.concatenate(
fake_img_exist_list + fake_img_list, axis=0)
num_imgs, H, W, num_channels = fake_imgs.shape

npz_path = os.path.join(
samples_path,
f'samples_{num_imgs}x{H}x{W}x{num_channels}.npz')

# save new npz file -->
# set key as ``real_img`` to align with vanilla dataset
np.savez(npz_path, real_img=fake_imgs)
mmcv.print_log(f'Save new npz_file to \'{npz_path}\'.',
'mmgen')

# delete old npz files
for npz_file in npz_file_exist:
os.remove(os.path.join(samples_path, npz_file))
mmcv.print_log(
'Remove useless npz file '
f'\'{os.path.join(samples_path, npz_file)}\'.',
'mmgen')

# waiting for rank-0 to save the new npz file
if ws > 1:
dist.barrier()
# get npz_path.
# We have delete useless npz files then there should only one
# file under the sample_path. Check and directly load it!
npz_files = [
f for f in list(mmcv.scandir(samples_path, suffix=('.npz')))
if 'samples' in f
]
assert len(npz_files) == 1
npz_path = os.path.join(samples_path, npz_files[0])

if num_needed > 0 and rank == 0:
sys.stdout.write('\n')
Expand All @@ -175,8 +326,13 @@ def offline_evaluation(model,

# empty cache to release GPU memory
torch.cuda.empty_cache()
fake_dataloader = make_vanilla_dataloader(
samples_path, batch_size, dist=ws > 1)
if save_npz:
fake_dataloader = make_npz_dataloader(
npz_path, batch_size, dist=ws > 1)
else:
fake_dataloader = make_vanilla_dataloader(
samples_path, batch_size, dist=ws > 1)

for metric in metrics:
mmcv.print_log(f'Evaluate with {metric.name} metric.', 'mmgen')
metric.prepare()
Expand Down
3 changes: 2 additions & 1 deletion mmgen/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import build_dataloader, build_dataset
from .dataset_wrappers import RepeatDataset
from .file_dataset import FileDataset
from .grow_scale_image_dataset import GrowScaleImgDataset
from .paired_image_dataset import PairedImageDataset
from .pipelines import (Collect, Compose, Flip, ImageToTensor,
Expand All @@ -16,5 +17,5 @@
'DistributedSampler', 'UnconditionalImageDataset', 'Compose', 'ToTensor',
'ImageToTensor', 'Collect', 'Flip', 'Resize', 'RepeatDataset', 'Normalize',
'GrowScaleImgDataset', 'SinGANDataset', 'PairedImageDataset',
'UnpairedImageDataset', 'QuickTestImageDataset'
'UnpairedImageDataset', 'QuickTestImageDataset', 'FileDataset'
]
126 changes: 126 additions & 0 deletions mmgen/datasets/file_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
from torch.utils.data import Dataset

from .builder import DATASETS
from .pipelines import Compose


@DATASETS.register_module()
class FileDataset(Dataset):
"""Uncoditional file Dataset.

This dataset load data information from files for training GANs. Given
the path of a file, we will load all information in the file. The
transformation on data is defined by the pipeline. Please ensure that
``LoadImageFromFile`` is not in your pipeline configs because we directly
get images in ``np.ndarray`` from the given file.

Args:
file_path (str): Path of the file.
img_keys (str): Key of the images in npz file.
pipeline (list[dict | callable]): A sequence of data transforms.
test_mode (bool, optional): If True, the dataset will work in test
mode. Otherwise, in train mode. Default to False.
npz_keys (str | list[str], optional): Key of the images to load in the
npz file. Must with the input file is as npz file.
"""

_VALID_FILE_SUFFIX = ('.npz')

def __init__(self, file_path, pipeline, test_mode=False):
super().__init__()
assert any([
file_path.endswith(suffix) for suffix in self._VALID_FILE_SUFFIX
]), (f'We only support \'{self._VALID_FILE_SUFFIX}\' in this dataset, '
f'but receive {file_path}.')

self.file_path = file_path
self.pipeline = Compose(pipeline)
self.test_mode = test_mode
self.load_annotations()

# print basic dataset information to check the validity
mmcv.print_log(repr(self), 'mmgen')

def load_annotations(self):
"""Load annotations."""
if self.file_path.endswith('.npz'):
data_info, data_length = self._load_annotations_from_npz()
data_fetch_fn = self._npz_data_fetch_fn

self.data_infos = data_info
self.data_fetch_fn = data_fetch_fn
self.data_length = data_length

def _load_annotations_from_npz(self):
"""Load annotations from npz file and check number of samples are
consistent among all items.

Returns:
tuple: dict and int
"""
npz_file = np.load(self.file_path, mmap_mode='r')
data_info_dict = dict()
npz_keys = list(npz_file.keys())

# checnk num samples
num_samples = None
for k in npz_keys:
data_info_dict[k] = npz_file[k]
# check number of samples
if num_samples is None:
num_samples = npz_file[k].shape[0]
else:
assert num_samples == npz_file[k].shape[0]
return data_info_dict, num_samples

@staticmethod
def _npz_data_fetch_fn(data_infos, idx):
"""Fetch data from npz file by idx and package them to a dict.

Args:
data_infos (array, tuple, dict): Data infos in the npz file.
idx (int): Index of current batch.

Returns:
dict: Data infos of the given idx.
"""
data_dict = dict()
for k in data_infos.keys():
if data_infos[k][idx].shape == ():
v = np.array([data_infos[k][idx]])
else:
v = data_infos[k][idx]
data_dict[k] = v
return data_dict

def prepare_data(self, idx, data_fetch_fn=None):
"""Prepare data.

Args:
idx (int): Index of current batch.
data_fetch_fn (callable): Function to fetch data.

Returns:
dict: Prepared training data batch.
"""
if data_fetch_fn is None:
data = self.data_infos[idx]
else:
data = data_fetch_fn(self.data_infos, idx)
return self.pipeline(data)

def __len__(self):
return self.data_length

def __getitem__(self, idx):
return self.prepare_data(idx, self.data_fetch_fn)

def __repr__(self):
dataset_name = self.__class__
file_path = self.file_path
num_imgs = len(self)
return (f'dataset_name: {dataset_name}, total {num_imgs} images in '
f'file_path: {file_path}')
5 changes: 3 additions & 2 deletions mmgen/datasets/pipelines/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ def __call__(self, results):
"""
data = {}
img_meta = {}
for key in self.meta_keys:
img_meta[key] = results[key]
if self.meta_keys is not None:
for key in self.meta_keys:
img_meta[key] = results[key]
data['meta'] = DC(img_meta, cpu_only=True)
for key in self.keys:
data[key] = results[key]
Expand Down
Binary file added tests/data/file/test.npz
Binary file not shown.
Loading