diff --git a/mmgen/core/evaluation/evaluation.py b/mmgen/core/evaluation/evaluation.py index b5b32bce6..e38170c11 100644 --- a/mmgen/core/evaluation/evaluation.py +++ b/mmgen/core/evaluation/evaluation.py @@ -5,6 +5,7 @@ from copy import deepcopy import mmcv +import numpy as np import torch import torch.distributed as dist from mmcv.runner import get_dist_info @@ -62,6 +63,89 @@ def make_vanilla_dataloader(img_path, batch_size, dist=False): return dataloader +def make_npz_dataloader(npz_path, batch_size, dist=False): + pipeline = [ + # permute the color channel. Because in npz_dataloader's pipeline, we + # direct load image in RGB order from npz file and we must convert it + # to BGR by setting ``to_rgb=True``. + dict( + type='Normalize', + keys=['real_img'], + mean=[127.5] * 3, + std=[127.5] * 3, + to_rgb=True), + dict(type='ImageToTensor', keys=['real_img']), + dict(type='Collect', keys=['real_img']) + ] + + dataset = build_dataset( + dict(type='FileDataset', file_path=npz_path, pipeline=pipeline)) + dataloader = build_dataloader( + dataset, + samples_per_gpu=batch_size, + workers_per_gpu=0, + dist=dist, + shuffle=True) + return dataloader + + +def parse_npz_file_name(npz_file): + """Parse the basic information from the give npz file. + Args: + npz_file (str): The file name of the npz file. + + Returns: + tuple(int): A tuple of (num_samples, H, W, num_channels). + """ + # remove 'samples_' (8) and '.npz' (4) + num_samples, H, W, num_channles = npz_file[8:-4].split('x') + assert num_samples.isdigit() + assert H.isdigit() + assert W.isdigit() + assert num_channles.isdigit() + return int(num_samples), int(H), int(W), int(num_channles) + + +def parse_npz_folder(npz_folder_path): + """Parse the npz files under the given folder. + Args: + npz_folder_path: The folder contains npz file. + + Returns: + tuple(list, int): A tuple contains a list valid npz files' names and + a int of existing image numbers. + """ + + npz_files = [ + f for f in list(mmcv.scandir(npz_folder_path, suffix=('.npz'))) + if 'samples' in f + ] + valid_npz_files = [] + img_shape = None + num_exist = 0 + for npz_file in npz_files: + try: + n_samples, H, W, n_channels = parse_npz_file_name(npz_file) + + # shape checking + if img_shape is None: + img_shape = (H, W, n_channels) + else: + if img_shape != (H, W, n_channels): + raise ValueError( + 'Image shape conflicting under sample path:' + f'\'{npz_folder_path}\'. Find {img_shape} vs. ' + f'{(H, W, n_channels)}.') + + valid_npz_files.append(npz_file) + num_exist += n_samples + except AssertionError: + mmcv.print_log( + f'Find npz file \'{npz_file}\' does not conform to the ' + 'standard naming convention.', 'mmgen') + return valid_npz_files, num_exist + + @torch.no_grad() def offline_evaluation(model, data_loader, @@ -70,6 +154,7 @@ def offline_evaluation(model, basic_table_info, batch_size, samples_path=None, + save_npz=False, **kwargs): """Evaluate model in offline mode. @@ -87,6 +172,10 @@ def offline_evaluation(model, samples_path (str): Used to save generated images. If it's none, we'll give it a default directory and delete it after finishing the evaluation. Default to None. + save_npz (bool, optional): Whether save the generated images to a npz + file named 'samples_{NUM_IMAGES}x{H}x{W}x{NUM_CHANNELS}.npz' If + true, dataset will be build upon npz file instead of image files. + Defaults to True. kwargs (dict): Other arguments. """ # eval special and recon metric online only @@ -111,11 +200,14 @@ def offline_evaluation(model, os.makedirs(samples_path) delete_samples_path = True - # sample images - num_exist = len( - list( - mmcv.scandir( - samples_path, suffix=('.jpg', '.png', '.jpeg', '.JPEG')))) + # check existing images + if save_npz: + npz_file_exist, num_exist = parse_npz_folder(samples_path) + else: + num_exist = len( + list( + mmcv.scandir( + samples_path, suffix=('.jpg', '.png', '.jpeg', '.JPEG')))) if basic_table_info['num_samples'] > 0: max_num_images = basic_table_info['num_samples'] else: @@ -128,6 +220,7 @@ def offline_evaluation(model, # define mmcv progress bar pbar = mmcv.ProgressBar(num_needed) + fake_img_list = [] # if no images, `num_needed` should be zero total_batch_size = batch_size * ws for begin in range(0, num_needed, total_batch_size): @@ -163,8 +256,66 @@ def offline_evaluation(model, images = fakes[i:i + 1] images = ((images + 1) / 2) images = images.clamp_(0, 1) - image_name = str(num_exist + begin + i) + '.png' - save_image(images, os.path.join(samples_path, image_name)) + if save_npz: + # permute to [H, W, chn] and rescale to [0, 255] + fake_img_list.append( + images.permute(0, 2, 3, 1).cpu().numpy() * 255) + else: + image_name = str(num_exist + begin + i) + '.png' + save_image(images, os.path.join(samples_path, image_name)) + + if save_npz: + # only one npz file and fake_img_list is empty --> do not need to save + if len(npz_file_exist) == 1 and len(fake_img_list) == 0: + npz_path = os.path.join(samples_path, npz_file_exist[0]) + if rank == 0: + mmcv.print_log( + f'Existing npz file \'{npz_path}\' has already met ' + 'requirements.', 'mmgen') + else: + # load from locl file and merge to one + if rank == 0: + fake_img_exist_list = [] + for exist_npz_file in npz_file_exist: + fake_imgs_ = np.load( + os.path.join(samples_path, exist_npz_file))['real_img'] + fake_img_exist_list.append(fake_imgs_) + + # merge fake_img_exist_list and fake_img_list + fake_imgs = np.concatenate( + fake_img_exist_list + fake_img_list, axis=0) + num_imgs, H, W, num_channels = fake_imgs.shape + + npz_path = os.path.join( + samples_path, + f'samples_{num_imgs}x{H}x{W}x{num_channels}.npz') + + # save new npz file --> + # set key as ``real_img`` to align with vanilla dataset + np.savez(npz_path, real_img=fake_imgs) + mmcv.print_log(f'Save new npz_file to \'{npz_path}\'.', + 'mmgen') + + # delete old npz files + for npz_file in npz_file_exist: + os.remove(os.path.join(samples_path, npz_file)) + mmcv.print_log( + 'Remove useless npz file ' + f'\'{os.path.join(samples_path, npz_file)}\'.', + 'mmgen') + + # waiting for rank-0 to save the new npz file + if ws > 1: + dist.barrier() + # get npz_path. + # We have delete useless npz files then there should only one + # file under the sample_path. Check and directly load it! + npz_files = [ + f for f in list(mmcv.scandir(samples_path, suffix=('.npz'))) + if 'samples' in f + ] + assert len(npz_files) == 1 + npz_path = os.path.join(samples_path, npz_files[0]) if num_needed > 0 and rank == 0: sys.stdout.write('\n') @@ -175,8 +326,13 @@ def offline_evaluation(model, # empty cache to release GPU memory torch.cuda.empty_cache() - fake_dataloader = make_vanilla_dataloader( - samples_path, batch_size, dist=ws > 1) + if save_npz: + fake_dataloader = make_npz_dataloader( + npz_path, batch_size, dist=ws > 1) + else: + fake_dataloader = make_vanilla_dataloader( + samples_path, batch_size, dist=ws > 1) + for metric in metrics: mmcv.print_log(f'Evaluate with {metric.name} metric.', 'mmgen') metric.prepare() diff --git a/mmgen/datasets/__init__.py b/mmgen/datasets/__init__.py index 3a35a4aba..83d458769 100644 --- a/mmgen/datasets/__init__.py +++ b/mmgen/datasets/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .builder import build_dataloader, build_dataset from .dataset_wrappers import RepeatDataset +from .file_dataset import FileDataset from .grow_scale_image_dataset import GrowScaleImgDataset from .paired_image_dataset import PairedImageDataset from .pipelines import (Collect, Compose, Flip, ImageToTensor, @@ -16,5 +17,5 @@ 'DistributedSampler', 'UnconditionalImageDataset', 'Compose', 'ToTensor', 'ImageToTensor', 'Collect', 'Flip', 'Resize', 'RepeatDataset', 'Normalize', 'GrowScaleImgDataset', 'SinGANDataset', 'PairedImageDataset', - 'UnpairedImageDataset', 'QuickTestImageDataset' + 'UnpairedImageDataset', 'QuickTestImageDataset', 'FileDataset' ] diff --git a/mmgen/datasets/file_dataset.py b/mmgen/datasets/file_dataset.py new file mode 100644 index 000000000..761ec94ea --- /dev/null +++ b/mmgen/datasets/file_dataset.py @@ -0,0 +1,126 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import numpy as np +from torch.utils.data import Dataset + +from .builder import DATASETS +from .pipelines import Compose + + +@DATASETS.register_module() +class FileDataset(Dataset): + """Uncoditional file Dataset. + + This dataset load data information from files for training GANs. Given + the path of a file, we will load all information in the file. The + transformation on data is defined by the pipeline. Please ensure that + ``LoadImageFromFile`` is not in your pipeline configs because we directly + get images in ``np.ndarray`` from the given file. + + Args: + file_path (str): Path of the file. + img_keys (str): Key of the images in npz file. + pipeline (list[dict | callable]): A sequence of data transforms. + test_mode (bool, optional): If True, the dataset will work in test + mode. Otherwise, in train mode. Default to False. + npz_keys (str | list[str], optional): Key of the images to load in the + npz file. Must with the input file is as npz file. + """ + + _VALID_FILE_SUFFIX = ('.npz') + + def __init__(self, file_path, pipeline, test_mode=False): + super().__init__() + assert any([ + file_path.endswith(suffix) for suffix in self._VALID_FILE_SUFFIX + ]), (f'We only support \'{self._VALID_FILE_SUFFIX}\' in this dataset, ' + f'but receive {file_path}.') + + self.file_path = file_path + self.pipeline = Compose(pipeline) + self.test_mode = test_mode + self.load_annotations() + + # print basic dataset information to check the validity + mmcv.print_log(repr(self), 'mmgen') + + def load_annotations(self): + """Load annotations.""" + if self.file_path.endswith('.npz'): + data_info, data_length = self._load_annotations_from_npz() + data_fetch_fn = self._npz_data_fetch_fn + + self.data_infos = data_info + self.data_fetch_fn = data_fetch_fn + self.data_length = data_length + + def _load_annotations_from_npz(self): + """Load annotations from npz file and check number of samples are + consistent among all items. + + Returns: + tuple: dict and int + """ + npz_file = np.load(self.file_path, mmap_mode='r') + data_info_dict = dict() + npz_keys = list(npz_file.keys()) + + # checnk num samples + num_samples = None + for k in npz_keys: + data_info_dict[k] = npz_file[k] + # check number of samples + if num_samples is None: + num_samples = npz_file[k].shape[0] + else: + assert num_samples == npz_file[k].shape[0] + return data_info_dict, num_samples + + @staticmethod + def _npz_data_fetch_fn(data_infos, idx): + """Fetch data from npz file by idx and package them to a dict. + + Args: + data_infos (array, tuple, dict): Data infos in the npz file. + idx (int): Index of current batch. + + Returns: + dict: Data infos of the given idx. + """ + data_dict = dict() + for k in data_infos.keys(): + if data_infos[k][idx].shape == (): + v = np.array([data_infos[k][idx]]) + else: + v = data_infos[k][idx] + data_dict[k] = v + return data_dict + + def prepare_data(self, idx, data_fetch_fn=None): + """Prepare data. + + Args: + idx (int): Index of current batch. + data_fetch_fn (callable): Function to fetch data. + + Returns: + dict: Prepared training data batch. + """ + if data_fetch_fn is None: + data = self.data_infos[idx] + else: + data = data_fetch_fn(self.data_infos, idx) + return self.pipeline(data) + + def __len__(self): + return self.data_length + + def __getitem__(self, idx): + return self.prepare_data(idx, self.data_fetch_fn) + + def __repr__(self): + dataset_name = self.__class__ + file_path = self.file_path + num_imgs = len(self) + return (f'dataset_name: {dataset_name}, total {num_imgs} images in ' + f'file_path: {file_path}') diff --git a/mmgen/datasets/pipelines/formatting.py b/mmgen/datasets/pipelines/formatting.py index 37a52d39d..93ce1125d 100644 --- a/mmgen/datasets/pipelines/formatting.py +++ b/mmgen/datasets/pipelines/formatting.py @@ -129,8 +129,9 @@ def __call__(self, results): """ data = {} img_meta = {} - for key in self.meta_keys: - img_meta[key] = results[key] + if self.meta_keys is not None: + for key in self.meta_keys: + img_meta[key] = results[key] data['meta'] = DC(img_meta, cpu_only=True) for key in self.keys: data[key] = results[key] diff --git a/tests/data/file/test.npz b/tests/data/file/test.npz new file mode 100644 index 000000000..7210664f2 Binary files /dev/null and b/tests/data/file/test.npz differ diff --git a/tests/test_datasets/test_file_image_dataset.py b/tests/test_datasets/test_file_image_dataset.py new file mode 100644 index 000000000..e87c467ef --- /dev/null +++ b/tests/test_datasets/test_file_image_dataset.py @@ -0,0 +1,31 @@ +import os.path as osp + +from mmgen.datasets import FileDataset + + +class TestFileDataset(object): + + @classmethod + def setup_class(cls): + cls.file_path = osp.join( + osp.dirname(__file__), '..', 'data/file/test.npz') + cls.default_pipeline = [ + dict(type='Resize', scale=(32, 32), keys=['fake_img']), + dict(type='ToTensor', keys=['label']), + dict(type='ImageToTensor', keys=['fake_img']), + dict(type='Collect', keys=['fake_img', 'label']) + ] + + def test_unconditional_imgs_dataset(self): + dataset = FileDataset(self.file_path, pipeline=self.default_pipeline) + + assert len(dataset) == 2 + data_dict = dataset[0] + img = data_dict['fake_img'] + lab = data_dict['label'] + assert img.shape == (3, 32, 32) + assert lab == 1 + print(repr(dataset)) + assert repr(dataset) == ( + f'dataset_name: {dataset.__class__}, ' + f'total {2} images in file_path: {self.file_path}') diff --git a/tests/test_datasets/test_pipelines/test_formatting.py b/tests/test_datasets/test_pipelines/test_formatting.py index 4ccb91005..226bf3c26 100644 --- a/tests/test_datasets/test_pipelines/test_formatting.py +++ b/tests/test_datasets/test_pipelines/test_formatting.py @@ -90,7 +90,6 @@ def test_collect(): collect = Collect(keys, meta_keys=meta_keys) results = collect(inputs) assert set(list(results.keys())) == set(['img', 'label', 'meta']) - inputs.pop('img') assert set(results['meta'].data.keys()) == set(meta_keys) for key in results['meta'].data: assert results['meta'].data[key] == inputs[key] @@ -98,3 +97,9 @@ def test_collect(): assert repr(collect) == ( collect.__class__.__name__ + f'(keys={keys}, meta_keys={collect.meta_keys})') + + # test meta is None + collect = Collect(keys) + results = collect(inputs) + print(results['meta'].data) + assert results['meta'].data == {} diff --git a/tools/evaluation.py b/tools/evaluation.py index 731c623a8..4b5a98a03 100644 --- a/tools/evaluation.py +++ b/tools/evaluation.py @@ -50,6 +50,12 @@ def parse_args(): default=None, help='path to store images. If not given, remove it after evaluation\ finished') + parser.add_argument( + '--save-npz', + action='store_true', + help=('whether to save generated images to a npz file named ' + '\'NUM_IMAGES.npz\'. The npz file will be saved at ' + '`samples-path`. (only work in offline mode)')) parser.add_argument( '--sample-model', type=str, @@ -113,8 +119,6 @@ def main(): init_dist(args.launcher, **cfg.dist_params) rank, world_size = get_dist_info() cfg.gpu_ids = range(world_size) - assert args.online or world_size == 1, ( - 'We only support online mode for distrbuted evaluation.') dirname = os.path.dirname(args.checkpoint) ckpt = os.path.basename(args.checkpoint) @@ -218,7 +222,7 @@ def main(): else: offline_evaluation(model, data_loader, metrics, logger, basic_table_info, args.batch_size, - args.samples_path, **args.sample_cfg) + args.samples_path, args.save_npz, **args.sample_cfg) if __name__ == '__main__':