From 7dff6c6ae1f5b72b6984131d6fdb1ecbbc0f0a9f Mon Sep 17 00:00:00 2001 From: LeoXing Date: Fri, 31 Dec 2021 15:41:36 +0800 Subject: [PATCH 1/5] add file dataset --- mmgen/datasets/__init__.py | 3 +- mmgen/datasets/pipelines/formatting.py | 5 +- mmgen/datasets/unconditional_file_dataset.py | 102 ++++++++++++++++++ tests/data/file/test.npz | Bin 0 -> 12816 bytes .../test_datasets/test_file_image_dataset.py | 31 ++++++ 5 files changed, 138 insertions(+), 3 deletions(-) create mode 100644 mmgen/datasets/unconditional_file_dataset.py create mode 100644 tests/data/file/test.npz create mode 100644 tests/test_datasets/test_file_image_dataset.py diff --git a/mmgen/datasets/__init__.py b/mmgen/datasets/__init__.py index 3a35a4aba..571e3fb79 100644 --- a/mmgen/datasets/__init__.py +++ b/mmgen/datasets/__init__.py @@ -8,6 +8,7 @@ from .quick_test_dataset import QuickTestImageDataset from .samplers import DistributedSampler from .singan_dataset import SinGANDataset +from .unconditional_file_dataset import FileDataset from .unconditional_image_dataset import UnconditionalImageDataset from .unpaired_image_dataset import UnpairedImageDataset @@ -16,5 +17,5 @@ 'DistributedSampler', 'UnconditionalImageDataset', 'Compose', 'ToTensor', 'ImageToTensor', 'Collect', 'Flip', 'Resize', 'RepeatDataset', 'Normalize', 'GrowScaleImgDataset', 'SinGANDataset', 'PairedImageDataset', - 'UnpairedImageDataset', 'QuickTestImageDataset' + 'UnpairedImageDataset', 'QuickTestImageDataset', 'FileDataset' ] diff --git a/mmgen/datasets/pipelines/formatting.py b/mmgen/datasets/pipelines/formatting.py index 37a52d39d..93ce1125d 100644 --- a/mmgen/datasets/pipelines/formatting.py +++ b/mmgen/datasets/pipelines/formatting.py @@ -129,8 +129,9 @@ def __call__(self, results): """ data = {} img_meta = {} - for key in self.meta_keys: - img_meta[key] = results[key] + if self.meta_keys is not None: + for key in self.meta_keys: + img_meta[key] = results[key] data['meta'] = DC(img_meta, cpu_only=True) for key in self.keys: data[key] = results[key] diff --git a/mmgen/datasets/unconditional_file_dataset.py b/mmgen/datasets/unconditional_file_dataset.py new file mode 100644 index 000000000..18f226d3d --- /dev/null +++ b/mmgen/datasets/unconditional_file_dataset.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import numpy as np +from torch.utils.data import Dataset + +from .builder import DATASETS +from .pipelines import Compose + + +@DATASETS.register_module() +class FileDataset(Dataset): + """Uncoditional file Dataset. + + This dataset contains raw images for training unconditional GANs. Given + the path of a file, we will load all image in this file. The + transformation on data is defined by the pipeline. Please ensure that + ``LoadImageFromFile`` is not in your pipeline configs because we directly + get images in ``np.ndarray`` from the given file. + + Args: + file_path (str): Path of the file. + img_keys (str): Key of the images in npz file. + pipeline (list[dict | callable]): A sequence of data transforms. + test_mode (bool, optional): If True, the dataset will work in test + mode. Otherwise, in train mode. Default to False. + npz_keys (str | list[str], optional): Key of the images to load in the + npz file. Must with the input file is as npz file. + """ + + _VALID_FILE_SUFFIX = ('.npz') + + def __init__(self, file_path, pipeline, test_mode=False): + super().__init__() + assert any([ + file_path.endswith(suffix) for suffix in self._VALID_FILE_SUFFIX + ]), (f'We only support \'{self._VALID_FILE_SUFFIX}\' in this dataset, ' + f'but receive {file_path}.') + + self.file_path = file_path + self.pipeline = Compose(pipeline) + self.test_mode = test_mode + self.load_annotations() + + # print basic dataset information to check the validity + mmcv.print_log(repr(self), 'mmgen') + + def load_annotations(self): + """Load annotations.""" + if self.file_path.endswith('.npz'): + data_info_list = self._load_annotations_from_npz() + self.data_infos = data_info_list + + def _load_annotations_from_npz(self): + npz_file = np.load(self.file_path) + data_info_dict = dict() + data_info_list = [] + npz_keys = list(npz_file.keys()) + + num_samples = None + for k in npz_keys: + data_info_dict[k] = npz_file[k] + # check number of samples + if num_samples is None: + num_samples = npz_file[k].shape[0] + else: + assert num_samples == npz_file[k].shape[0] + + # save to list + for idx in range(num_samples): + data_info = dict() + for k in npz_keys: + var = data_info_dict[k][idx] + if var.shape == (): + var = np.array([var]) + data_info[k] = var + data_info_list.append(data_info) + + return data_info_list + + def prepare_data(self, idx): + """Prepare data. + + Args: + idx (int): Index of current batch. + + Returns: + dict: Prepared training data batch. + """ + return self.pipeline(self.data_infos[idx]) + + def __len__(self): + return len(self.data_infos) + + def __getitem__(self, idx): + return self.prepare_data(idx) + + def __repr__(self): + dataset_name = self.__class__ + file_path = self.file_path + num_imgs = len(self) + return (f'dataset_name: {dataset_name}, total {num_imgs} images in ' + f'file_path: {file_path}') diff --git a/tests/data/file/test.npz b/tests/data/file/test.npz new file mode 100644 index 0000000000000000000000000000000000000000..7210664f210ff1e880acd3bac1c67cef0199a3d1 GIT binary patch literal 12816 zcmbW82{@E*`0qz(RfwcgREi`^i?r~iMYK{Vl9at{p+%`|m5@|IB|9l4YuWcL`!d$C z%w~)+GgSC#b>83qf6jH5>zwO4=e(}D-fPCZGd%D8-1qnXe4oegh@g-p=|7LPq{H5C zn|;(sq<|_NMir-`oHh}_xJxE&yJfMGc*$*T_fF6u(9>9 zc30Rfr?A&fQ$bly!Oqp))BS?Wc~^HE+kaozyWr$uOT6ykaKX)%c&@lrSx#N!KOfsR zDa+lK`+q(bo-0(@_dte?^;JvN{d{R~$|i?D$Ynr^uI4;&x(%t-`?U;jWh44i+`QBK zyJ2HZ!o5Z|tk#}gAnwtEYx6xj^se?{`hc)1mr21-Rm-1!@|~Diyl2Dy$T!e0BHL~b zZv#E?&enyy#t=pJtrT59gx(K+LJ}PuXdeA3-g&AC(&1VDLKdyC^RvI=l+g~qTK=(= z_A!k0sL#1zKY|&h21`NDVVtx+?^eHZ1oOwOA4k-;q0k?07u`m1b00ajJd2IqnH6_c z&yGXz>)37WD&q5H5*Ap?=R#$gd!%7YJES8I-aY@j0oPO2FYn4`K%|Sis8XgK!pyue ztCI}GojfSDJ+KO-?7+|dyNY0Ld)ZGVhkG8}B#i#(cZahcL<5=$M1{@RWDJ3@MJYSmm3>fi{rb5tDP&azQB*KV`*0NysB__yS1!U!nlcu!<4j}Bo@SO9^wfYCZT=#E5};mx`#4))w)gN=*dq%*4@yL z^!Ufj7>5#jtKWD3H zAx^8N4$T$i!$VW*m+OO3oII)dyj6mS-~Br_y8Ri#+|Lz`9ecZwY!GmL&p{SWeSDIt zYsJHz(2OOk?+n0js1&|`87RoleHPi=i86_8L!HNn_>8?@xUIhsho3a8Cs|Oilpdma z_BJ0q7mxn&ecgeXt!w=SCX2u=-LUdfAq8Esg01_y*@(OK^3kjk88g{-()P#NV6&RS zTF_UCvOSWoMhzc;|6E(!PKJ!#L9WIBG>*Vzm*@@KUKXs61-xBn#)6HQ{L9#t!;r~J zclqWq2zheq@a^Ui)Lp$B{?rp6{l}+dXkHs8eM@-n?|utD(llt{s;~%v!yfsX~&M{c`dPXMiDj?TOvMP z3cB)1yTl(9{Jqn$kF3Rp&$Q*zzM)=tnB1y(<rzM+p`!Xu9hAKZC-N*u|X7J4ssw>!PiV0nH z#V8)3o6m1x;MP#1f1x&Ue>ZJB)YVlBj(AH>`QkCW-51?J*iu=1*J|?23n%0(OqwI?NS*0Kr=8P|X z_UC9T_O4m(KTc<1byVU0lrk>FUe+9WYZ8Z2mHb~G^f6f3n#ONe&%%3KJpltsH_{m2 z{?soS#g;eLC*4Uj{HE?-+xUZrJk!6;tz!9*FIeMq>01M4uD@OEv^XDPPOHnt^O&$n zD9-%US_`Vn`2;1(IKsa4a#*G1$k^{9w{OEcSX?%j4BSbF`_T(OX5_~4V!lRL=%F#p zo>{j=!;J~O3j;q2HjhEa{-l>s1PfKYJ3`p+EASzi<}t}2@?6d1#1p-JFbUjO+F8U$ zMbJM+-G^x?Q?UzuFcpfQ0-kZa6TKL|^*pom(Kz;AyT0D_LkFnA%*S62KS7Po^~G1E zIfxrD_}Hua0{Tr%IhTjU82>btwfqqg=LfUpOPea79~5Xj`L+R4U*=zO(x$@d!Q|x} z|9)I6xE}fB~)uq%Ah470?KcisduP&K4E z*Kx}zT=HT5WDf-v`og=ZJ&o|$ducRbYdw5In)RI`qoF7pe`KT0IJCd+y5Q(G1mSC3 z%Q6ir^6oN{>}9(!|7g9Bg;Nf0-q~R}u$70yy88o#q8ad;XY+;kn+~5VuYUd%%YuZ> zAN7^ZMBdCCPIRMWqD-^g_2)nFh&#uh&`y^9o^BTNK5_6%(c*@-5G*u>G7MX4%z<3lyZ-Yp{^z zApNAkyAx;co^z0kX5&+ksmtS3GAs?m*M9RL^1h?;oQUJ;SR{O6_-_{n*LdD;^LG>T zjM+XfKxG6`oHX7W*<#2Sr%Bix>VwcXm&se=^(eZT>Asi9~5YG_$J&dZf1i!7dqsaFR+tX1pf;a2sT-}S=ctpvLB1iCW zV`^wob7c#hM=li9fQpNb$s@+@?GW!V*eL8eM(7D$3obFw-z8hjy3US6F`fP}Us){g{XBvQ-amS=IXwvPQ&}z)QVYF$m5+L-$Ou}P#~j(+51XS>vDf^l zm`K;;tUJJj<-;FyZ~Lgs=>L8yy=g;ZP z0g!sAS9{`uVBw-AF1m;h{@RV*ruD6u^Yyj6Rw_53Du#eAAU|2syXJ zH2ms<=_3JwRsJK$dbzjcc}yRu>sOt*MomB+{hQP3*$Ql5`pp|DYTT`nmhGA_vExdj|4L-|u<_S24 z;CEiwo!EqG7@kU2Q9jUtK)nSTKZ@(2XQL7lsYJ|Av#{mIJsFr-+P$jxcqYh~Oj;U2 z!CljpUC-As@$-ZHwCUYG%y+xpGVgK;NPR!L8Ydla;?IT&10w#z?!LbwKBo`DA9ieB z(nyDSOWZ69X~_7X-?ru%3-WKOnvTnNV`JZ>N~s|gWo5nFMP%tnRglf9GUGwW@>Yn4 zLI?Z?tMNFJ1K&NzmhOw@qWB+g8sCD051*Ap_aCA`Jfducg*grJ({3&+{xDD`=kxZN zLoi4~i~5xAu^~AqwF7U;5Ve@Ei=6$XigtTo}+o_Y?IQg`rkMegI?5qWgRWl;gE4hNqr$1T|18{o>m@!T5G%hL~|9E7g#>~ z`-cY$<${qV>FscDczt!wcP=6>NNo*F=m0~m{%DUp8yC*)mI{v>!GuBlO54JIFh30( zTPr<|bT7q-zfHt=UBG)Ky8)=2?-n~JKLq{xyTYYz_u|i+8O>c!hOzF~$@5wkbPPsh z$9^5B;ij!f@y3ENkY&^rPdHSAI`d+c*IO2TrrIdDyD(u85%O-o>?nK=IkAdwd0=1g zyVr|^#}H<|^SEYV56(t}4s9KyVbhyeuCEAPoMU;*h*LtrHEM71pmHmWt1dn+IaG$_ zVJ7j6WI9%LEWTPA&cbUao4b;Pj^{0-SNx)KQS(o+>BYa@*q41-FuaM1Z}V??R)~yY zrfZ(?k$@T~${!sz3hjZ$*+ZcPUL9Do=8EUXb)7JE>NLC)-Hx2wH+0Te_Mu2pOwQ*- z6_T2RzI!j=Lq&3bXSB&68Z>#ab-_e`!Qw}hJ)uJ@+AySj4g+l1Fo4dpo&gVu{;K@vTV~s0 z4ie9L-PifRhh)#Povc;E2<*5z)#66Q9D2U5fnpmp(>zJjrA+k36o_A+VM4elaA|@L z4;l+f!fTDXL2gP}|CCh^p^wq-mZQUnFNOnF9zEbHZf66MKfjnh1(AI8Ep%MMwjRZw23Lu< zexq1;xKMVX1Oo}%pgH-d9rDp8p6e6A-h z^B;ibw|n#I)rtA>*JA2s?GUK9XAQm;N5Ifb@0|jPiKXks3r~cyFhA(Qij&n%usR+3 zaJ_sRChOmAbMNiMnta1vTdy&w$or~GZ5YN_ZBFZ>aUyQWu?_{*pJBLN|4GKz511&v z>0IF*k0i^)&YLzp@V@L+)|&DK8GSt!DvGT*bnJ#=>Xl(!H9Bop_k)hwwc+U}(pcEF zUO;5lh7OXQ>-)5^m+*3ewqz-?`hWu++WNsrcF;ZihN=s=<9tSrQ-;^1Scj10cYkzs?7%F`Deg@{<&{1cs+D_z|Wfi55(*=mU zeEu}w;SmQmZ;fNp9SCm0F-+;1LqVBc-C%L-3{jM{8Yo~advXn*F`o{#Nb8&<+vD+a)d(r}I-wsNzIV#4eF@FLHvgrx zY=jGlUo|Kgh0vc%&ozlUr6%O2VJCWYYNOsV(XhQ zgvqZl!GF`)2)tz^yCb(8aed!=Rwz6JyWT60?OhH-`VHkSp)PE@e5E%-h6etHrwd<- zQ&Esue@Z&74P4?-nuOVq#|8utGFqO=McIDDLs&o31G!v_7h-<0j0E~wF#9}VBrkBwFPy1a;f?yp9X zh{^R{oP4FRh;fz;)1pg^Iei^?E%--m!!I6svq`}s*=)2gmL=!a^U!mkcV_T#CrHOP zml)fH!%%yhmz`b}mhHamoHoC_0|5}`*V$sE5k>Vi^%CfwfhqeZy`E+;7PW~t~#}z1lYK!smHfxlUJ%ZnM z@|n6ZU$Ex`-R~};1A9MKTXt#=<9hDCxUyqoQ1fWEihbV#*-S3G<`f6iQDviKZX?RX zoi{HyN$8_Gc^${&YJ+GA!YhwOTrVN`!J{6TvEsMbkoE1gw;mma zg9Dom6ZMcKbUi?$mJXLej=F+76BB0}=6@9^hk3WPbWafrP0{O&oCJAzochmdlXncn zC$2gjGM5FrjX(W)*ZV+ImtVX*n+LJSI#*BYv2ea<|H7byqhO28xp?3V!5Pw=Hxx5U zaY9K_sKjJtL3vLZz38JE%2pS4eT~FwOq>RT-ETJl_b~`-JY#kNu8#lYR zvD=XT&*~k#$Ng~fUv>7x+y+=x&AN$DVldb5qw^+m8%pl>vj%h+IDG4Ri%57Y-JX^@!O?n0W9S^qCxb>J;j^pmt8H2;cZ_hkiI^dNg zNJ)9g#JT>HoKFP*Q;d9F(MagNW&6$NR8|mr+uzu3uL%$O4pNJCZ#1HC+WuR)8&S`u zo~DINjbo9LrGy)q(0%co*--T{q=(QKC1((Q|3}wd6(ZSqdPbLGw2BPz!#jeu1~Ngm z{_~|=y#jFy4>B4D5)m&_rCs)gg=_1x#B$R*5z;L=J4&2q&R*DRCtiZz$#Wiw9O{Sp zc4z(Qjy}xYV=JpZN`cKP!KM|ytw`*5Y_e$?he!Z!tNIc7LiJUj3wIEwTZVImQwV;e z`;+pqry5b)4DXl|T>AF;OM|lq`5@=-7jgYrfu3h$Qo-R>c$#=eBb(rqCT|MXX?~@` zX{pTp#*86+k9d8tn5ffBFIE<$GYOqG(5rQW(8~`(o@xa4QPKBRayY`h6@I;!-qlM| zz-%tGe$@U2t#)_V1|bX(gZZSrjEd*xY6e$_=y_8;Z z|HU4W4b0&oXWL4Rb1#Ol=v_d}LW27{)js5!r*jbDdun5-*&sYC+r6UiFyUmNa{0q< z9voix`Lf0fkyrEe(!IOgFtknGe`;MNau&x=N0w4C`Y`o|9nsJBpWOQWh#Cd!=J-6A zxKG3T^QrR=TegAGpl@5#&PKt4Wu?LjREW_8W|l9g!QIx)@wx}N@H{r1abIE#N>@jn zxQlpjKeRYqEU*hXu&9+H^1zwA?e04(`jFmZS6Nm~Mf*e6;3SES-S=`!4_+kl{;q8? zLg{2|(XbnG+r$F%>^P35SQT|OqxL^Ud#XMx0xmdjt~j5%zL1{q?W=e0Ys zF39ka(YQcRsl61T(fh4lx%PmicGo*ChRC<}Kdwcr$bs0GGb^^uv=H!KhNi4?)R%{DXv8ST>LDA9^Q`hv%UA4O*^6J_$o8gZ~)eRF;CVVVPIpm z<#{@V23pL8RQH2in5Q)CB~=U(`9+nJW3F^vCNYfh|2KiW&gpMeYKV0+OgBN{<@7GGkS&o+{#)&?8i%ay~EF=fa>)uB6t?}#6EsCfg zhkEnF?}hVd2(z?GJhZnHMZQ1sEM8P$;9bs2trNqTB-^SykmO=(Rm(lOGaMu&|9a+J z*9Dy~Z+3><$;BGl*&X}F316#|uQaf?1^mk^;zc4xalka*q05~OZlJ4HzX=;51Rp^pY6M#bwhI^3sMd*ACieag(^{TyU8sAt&*;$MtxOaY{^Nn zA0365X2)QH|1iM?Uj1F49E3Fg)Ofn*7{O5wDh&7Zf)ro=z2<&0)ZBilI?EN}Y0Jm7 zq4k5vx_8R#(gYhbp}#gSP$>pk_iN|Xpe{&uMxbX2Gjye$5tPIt(6uR^0iA7}wH6i8I1C5OL7sfG&#*QexS*h@Nt21@Go9TFAuS zl@XU}bo=0T&EUb;FY%DNI9{?%xE*ridrthRsz#W5V4!w6ab2qXtz$xiDDdn^318BW zpmzsO{h|lw${dkx7m39ecO%ca6bC0FUrg8Cj6nJ7nW=M7=h4b*@$v# z6hWUS>p4xI9E^ZxGDz zB0g7q_wE{TItu=nnA9*C2vbQ~T{ou_8%}Mx;Bt$Ba)A{3Ks-^`Wx@k9w~xZujN&CS zNCn$qbHWFwejI%ut}T$zj!z+VOB;+Sh_kd9-w^NwXU>i#r6y;>vQNQG_X5HD&L3*@ zb7Nzl^OaR+O37%Q?D^_AMesLgb62wjF4Q)z8Q53dhb>H>3%BH%XnQsu{Fb;6bjI(j zHD||g)+F^npDP#r8)W7$U&Vpwl7MSl-xD0v(fXa`e1apLS&+CUQWqv)_IXORjDXqZ zmK^xJ263(f!VCPeP?t-6>ubhG_=>r6w4Z#z3)yQ+=h8=rB==_dSI`nrV)h4DmvK4oNO)e1>o)K5|+9|zM^mK@ZjLs5|Y>f8^4 z-#d{!ucwYcZIg)7h-xF&S@=I4JweC(M9b%e%ZT_%X$Vi+Ovf6NyYChYP*C75D=xmO z2a%JivjPM^F|?gnu{*H{)Rh(OhAvSUntuN`w`>FzzfQ?Bo^o-lxbwXJ(_!3y`)q~f z#cn*i5pr_%>rp(R%@}MFV81TV|H1`VW)KLf@DAb=I`mw!^LMyJ=G{Ve5$G^CL1k#q{ zPtz4U!4lcDF3)@z0nOqod}E1o(HBjMKMk2v+sAT7ILKIfV?g+MJQUB~ljMIGK}nXG zQj8iMb$JoGXOhdIYod4S4$&_v)>lmLz01Q|wft7GJL8Zm_a5mMje$+}4_WGFCThdx z+btsURBC|IPQP9rVhF2#!f6m82e^m!d~d|)#=w!$0WK!rTr^wm^#vrM_KlGp)@VL= zqvz5ZGQqoUF(y;1@u)%m^64Z#9A?S>?(;vv*Iv6YJB5#3x|;g}-V-_@NJqWEgjkng zcJ|#4qE9q3zG|ucc^s2pA~Ujn6Z)#zbYRU~HuNF+dwDA1Yin;AJa5zr#-p^3&KWxL zbRA2B2_J}mp8Zwp5y2T8LvK?K`{JpU$H%FdZiu~I@^=^EXXgEVEA?r?Cp2bpI5$@F zKuP3=j4tFsZqDxSiThZPE8?)W79_$=^{4#y#XP)^VxHXmhJ(Ijrm_9|RT2uZXmGFbuUnk{6SHu`qpSmDX=rDpsGLyVJXf4!4s* z`g*Y=2wXXlmHmKBt4;6GP^6?;FF(HuOC-R!y>b+tPh)%vf;jjg zeuVO}l79S3*xN{bCh85e{=N6U#^!E^f?~CoP(y zp*B!HQED$6#ua9J(`%Vvc|Mle?%4^^css|>tT93#rOIc1Cpencl|`x61n2u+<#F4y z2NkS(y_fNwFnm**vq85R?NTzXznKUQY_csG+oHyVA`=cMXPcM4QzucNy{-v-tOObj_~iqRv4(- zon)bAZ2JWtLkikrj%mJK&IX(Hy*MPk3yL9DD_Lt=P#4U9m1oidmbT9MRDy?AUF=>~ zaJvcgQ+dhp7URf0zF@OuEeBF9#V2*{_P~n&TqC|J3uR2rm3hW-_uz2ZNZ#=N2?#9fbSp!ufruG3o4Y`^rLMDyKn&(9iWL%JTi=qX-lcbljhn0ikL6 zd0UCPGw3+vs%J;&XulHs>L(l&pWFSMk~IRoN{Qf?tC`sB=M!@5%Qz~Iyl7cY`xX%O_^AA)omnqO3644-c&)q>2cDb%6I4)- zvbByzen~|r=QL59M2F$v^U3OIQzs<cWnB~+vkKoI_WzG*N0!N2m5w1~xyNJA zjy9aST1T=^Cca4!o~htYM_`}%KME1s9^TDRY31j}AX znsvP(>YuXn#-~JIJ3lw}Zn-~mmpG&GjyBT?~Dm^M#jMSsnU0S(WpwxZK z9+G3iF<7oHZ0FvEU9_R2;@HdScY~EqVVQD zsn@o{AborG9XRB$Sq>M&H6=KkLyv;_bYAX*PC59 zYLIeZk3=1ABtN932^1o@Ok*;)hXMVvp4^Cv0w_P7yHV9@2rq_D-Aek%MCWDK&VgSv z=smFg1Q{ZJcO6cbyVijk`@+1vpSr=@k{`710#T=}8<>kt2Vnd#YJIUN8Kk0%GT&2k zQOujYkS4;x_9i7pk zi@MPHW47}x!BzFB?Vqi0?UJrX_z2NFHHDT0GJ1CYIyPg{3T4rTC)A8KgiiPe?FsKgsGrfr zD%CDnsYot5r}z?QdT*>w`bMmqKBPabdMO-dde$ek9U!CQT5zL(7Y$Y~@&y|uhaqxf z(Hr;gU%-6aX;p2{f{XwqyKjPu?bQ8jQ@L&&Nz*d()@7lt=gNzq`Ch-``@M*4OL{x7i2+W=JO8BaL9EI?x6mq# z2afRVtMgwlV3njaJ8Dnxjm=kPCJ26G^G$S%_|9&;?`%KN>>DRPet(^1j>K0Zr0aeQzHV`ulBF z-2+E(c(d;8xr<~_7c>Q?ia&$jlgvm5f`f#zCyOin$XKcrU~%|p4>C6Q3Y_dD=C#Ky z&l;i*WK=o{N?Md-9bGEfXO@_s?}F%0w&Wo1d){Qp_9p0ex74}>_TgCB7C%ch;=X-k z99e#^9sagt#-Am8T#&l?B$Vh2a^?T=5)veQk+@W$11kvs>(tQaCq#W8{5v2!$R0$v zpv~9&4peBbl|1rYvj>L%)9wY4-!h|N#IA&Y9}9@x3r-g<+B*HWtqcFY`0w%G&j0(Z z3y%M{TNnQSHZ1)29S#cx{`1cv&H1lq!y^K7BnAIZTQUCYn~0Ooe_Z%KZp`>!ul%2f g;eWicM2+;{2gUG+@Z5iYIfwYULYyC}|NHE}0P5i8WdHyG literal 0 HcmV?d00001 diff --git a/tests/test_datasets/test_file_image_dataset.py b/tests/test_datasets/test_file_image_dataset.py new file mode 100644 index 000000000..e87c467ef --- /dev/null +++ b/tests/test_datasets/test_file_image_dataset.py @@ -0,0 +1,31 @@ +import os.path as osp + +from mmgen.datasets import FileDataset + + +class TestFileDataset(object): + + @classmethod + def setup_class(cls): + cls.file_path = osp.join( + osp.dirname(__file__), '..', 'data/file/test.npz') + cls.default_pipeline = [ + dict(type='Resize', scale=(32, 32), keys=['fake_img']), + dict(type='ToTensor', keys=['label']), + dict(type='ImageToTensor', keys=['fake_img']), + dict(type='Collect', keys=['fake_img', 'label']) + ] + + def test_unconditional_imgs_dataset(self): + dataset = FileDataset(self.file_path, pipeline=self.default_pipeline) + + assert len(dataset) == 2 + data_dict = dataset[0] + img = data_dict['fake_img'] + lab = data_dict['label'] + assert img.shape == (3, 32, 32) + assert lab == 1 + print(repr(dataset)) + assert repr(dataset) == ( + f'dataset_name: {dataset.__class__}, ' + f'total {2} images in file_path: {self.file_path}') From c0c306f726db03f89bade8882b40b4f87273b3bb Mon Sep 17 00:00:00 2001 From: LeoXing Date: Sun, 2 Jan 2022 22:54:46 +0800 Subject: [PATCH 2/5] add unit test for meta_keys is None --- tests/test_datasets/test_pipelines/test_formatting.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_datasets/test_pipelines/test_formatting.py b/tests/test_datasets/test_pipelines/test_formatting.py index 4ccb91005..226bf3c26 100644 --- a/tests/test_datasets/test_pipelines/test_formatting.py +++ b/tests/test_datasets/test_pipelines/test_formatting.py @@ -90,7 +90,6 @@ def test_collect(): collect = Collect(keys, meta_keys=meta_keys) results = collect(inputs) assert set(list(results.keys())) == set(['img', 'label', 'meta']) - inputs.pop('img') assert set(results['meta'].data.keys()) == set(meta_keys) for key in results['meta'].data: assert results['meta'].data[key] == inputs[key] @@ -98,3 +97,9 @@ def test_collect(): assert repr(collect) == ( collect.__class__.__name__ + f'(keys={keys}, meta_keys={collect.meta_keys})') + + # test meta is None + collect = Collect(keys) + results = collect(inputs) + print(results['meta'].data) + assert results['meta'].data == {} From 1c58c029637b902d02f4a34c43d21bf3a8f3d9d2 Mon Sep 17 00:00:00 2001 From: LeoXing Date: Sun, 2 Jan 2022 22:56:36 +0800 Subject: [PATCH 3/5] support npz saving and loading in offline evaluation --- mmgen/core/evaluation/evaluation.py | 174 +++++++++++++++++- mmgen/datasets/__init__.py | 2 +- ...tional_file_dataset.py => file_dataset.py} | 56 ++++-- tools/evaluation.py | 10 +- 4 files changed, 211 insertions(+), 31 deletions(-) rename mmgen/datasets/{unconditional_file_dataset.py => file_dataset.py} (65%) diff --git a/mmgen/core/evaluation/evaluation.py b/mmgen/core/evaluation/evaluation.py index b5b32bce6..2de8d190f 100644 --- a/mmgen/core/evaluation/evaluation.py +++ b/mmgen/core/evaluation/evaluation.py @@ -5,6 +5,7 @@ from copy import deepcopy import mmcv +import numpy as np import torch import torch.distributed as dist from mmcv.runner import get_dist_info @@ -62,6 +63,90 @@ def make_vanilla_dataloader(img_path, batch_size, dist=False): return dataloader +def make_npz_dataloader(npz_path, batch_size, dist=False): + # TODO: align keys + pipeline = [ + # permute the color channel. Because in npz_dataloader's pipeline, we + # direct the image files from npz file which is RGB order and we must + # convert it to BGR. + dict( + type='Normalize', + keys=['real_img'], + mean=[127.5] * 3, + std=[127.5] * 3, + to_rgb=True), + dict(type='ImageToTensor', keys=['real_img']), + dict(type='Collect', keys=['real_img']) + ] + + dataset = build_dataset( + dict(type='FileDataset', file_path=npz_path, pipeline=pipeline)) + dataloader = build_dataloader( + dataset, + samples_per_gpu=batch_size, + workers_per_gpu=0, + dist=dist, + shuffle=True) + return dataloader + + +def parse_npz_file_name(npz_file): + """Parse the basic information from the give npz file. + Args: + npz_file (str): The file name of the npz file. + + Returns: + tuple(int): A tuple of (num_samples, H, W, num_channels). + """ + # remove 'samples_' (8) and '.npz' (4) + num_samples, H, W, num_channles = npz_file[8:-4].split('x') + assert num_samples.isdigit() + assert H.isdigit() + assert W.isdigit() + assert num_channles.isdigit() + return int(num_samples), int(H), int(W), int(num_channles) + + +def parse_npz_folder(npz_folder_path): + """Parse the npz files under the given folder. + Args: + npz_folder_path: The folder contains npz file. + + Returns: + tuple(list, int): A tuple contains a list valid npz files' names and + a int of existing image numbers. + """ + + npz_files = [ + f for f in list(mmcv.scandir(npz_folder_path, suffix=('.npz'))) + if 'samples' in f + ] + valid_npz_files = [] + img_shape = None + num_exist = 0 + for npz_file in npz_files: + try: + n_samples, H, W, n_channels = parse_npz_file_name(npz_file) + + # shape checking + if img_shape is None: + img_shape = (H, W, n_channels) + else: + if img_shape != (H, W, n_channels): + raise ValueError( + 'Image shape conflicting under sample path:' + f'\'{npz_folder_path}\'. Find {img_shape} vs. ' + f'{(H, W, n_channels)}.') + + valid_npz_files.append(npz_file) + num_exist += n_samples + except AssertionError: + mmcv.print_log( + f'Find npz file \'{npz_file}\' does not conform to the ' + 'standard naming convention.', 'mmgen') + return valid_npz_files, num_exist + + @torch.no_grad() def offline_evaluation(model, data_loader, @@ -70,6 +155,7 @@ def offline_evaluation(model, basic_table_info, batch_size, samples_path=None, + save_npz=False, **kwargs): """Evaluate model in offline mode. @@ -87,6 +173,10 @@ def offline_evaluation(model, samples_path (str): Used to save generated images. If it's none, we'll give it a default directory and delete it after finishing the evaluation. Default to None. + save_npz (bool, optional): Whether save the generated images to a npz + file named 'samples_{NUM_IMAGES}x{H}x{W}x{NUM_CHANNELS}.npz' If + true, dataset will be build upon npz file instead of image files. + Defaults to True. kwargs (dict): Other arguments. """ # eval special and recon metric online only @@ -111,11 +201,14 @@ def offline_evaluation(model, os.makedirs(samples_path) delete_samples_path = True - # sample images - num_exist = len( - list( - mmcv.scandir( - samples_path, suffix=('.jpg', '.png', '.jpeg', '.JPEG')))) + # check existing images + if save_npz: + npz_file_exist, num_exist = parse_npz_folder(samples_path) + else: + num_exist = len( + list( + mmcv.scandir( + samples_path, suffix=('.jpg', '.png', '.jpeg', '.JPEG')))) if basic_table_info['num_samples'] > 0: max_num_images = basic_table_info['num_samples'] else: @@ -128,6 +221,7 @@ def offline_evaluation(model, # define mmcv progress bar pbar = mmcv.ProgressBar(num_needed) + fake_img_list = [] # if no images, `num_needed` should be zero total_batch_size = batch_size * ws for begin in range(0, num_needed, total_batch_size): @@ -163,8 +257,65 @@ def offline_evaluation(model, images = fakes[i:i + 1] images = ((images + 1) / 2) images = images.clamp_(0, 1) - image_name = str(num_exist + begin + i) + '.png' - save_image(images, os.path.join(samples_path, image_name)) + if save_npz: + # permute to [H, W, chn] and rescale to [0, 255] + fake_img_list.append( + images.permute(0, 2, 3, 1).cpu().numpy() * 255) + else: + image_name = str(num_exist + begin + i) + '.png' + save_image(images, os.path.join(samples_path, image_name)) + + if save_npz: + # only one npz file and fake_img_list is empty --> do not need to save + if len(npz_file_exist) == 1 and len(fake_img_list) == 0: + npz_path = os.path.join(samples_path, npz_file_exist[0]) + if rank == 0: + mmcv.print_log( + f'Existing npz file \'{npz_path}\' has already met ' + 'requirements.', 'mmgen') + else: + # load from locl file and merge to one + if rank == 0: + fake_img_exist_list = [] + for exist_npz_file in npz_file_exist: + fake_imgs_ = np.load( + os.path.join(samples_path, exist_npz_file))['real_img'] + fake_img_exist_list.append(fake_imgs_) + + # merge fake_img_exist_list and fake_img_list + fake_imgs = np.concatenate( + fake_img_exist_list + fake_img_list, axis=0) + num_imgs, H, W, num_channels = fake_imgs.shape + + npz_path = os.path.join( + samples_path, + f'samples_{num_imgs}x{H}x{W}x{num_channels}.npz') + + # save new npz file + np.savez(npz_path, real_img=fake_imgs) + mmcv.print_log(f'Save new npz_file to \'{npz_path}\'.', + 'mmgen') + + # delete old npz files + for npz_file in npz_file_exist: + os.remove(os.path.join(samples_path, npz_file)) + mmcv.print_log( + 'Remove useless npz file ' + f'\'{os.path.join(samples_path, npz_file)}\'.', + 'mmgen') + + # waiting for rank-0 to save the new npz file + if ws > 1: + dist.barrier() + # get npz_path. + # We have delete useless npz files then there should only one + # file under the sample_path. Check and directly load it! + npz_files = [ + f for f in list(mmcv.scandir(samples_path, suffix=('.npz'))) + if 'samples' in f + ] + assert len(npz_files) == 1 + npz_path = os.path.join(samples_path, npz_files[0]) if num_needed > 0 and rank == 0: sys.stdout.write('\n') @@ -175,8 +326,13 @@ def offline_evaluation(model, # empty cache to release GPU memory torch.cuda.empty_cache() - fake_dataloader = make_vanilla_dataloader( - samples_path, batch_size, dist=ws > 1) + if save_npz: + fake_dataloader = make_npz_dataloader( + npz_path, batch_size, dist=ws > 1) + else: + fake_dataloader = make_vanilla_dataloader( + samples_path, batch_size, dist=ws > 1) + for metric in metrics: mmcv.print_log(f'Evaluate with {metric.name} metric.', 'mmgen') metric.prepare() diff --git a/mmgen/datasets/__init__.py b/mmgen/datasets/__init__.py index 571e3fb79..83d458769 100644 --- a/mmgen/datasets/__init__.py +++ b/mmgen/datasets/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .builder import build_dataloader, build_dataset from .dataset_wrappers import RepeatDataset +from .file_dataset import FileDataset from .grow_scale_image_dataset import GrowScaleImgDataset from .paired_image_dataset import PairedImageDataset from .pipelines import (Collect, Compose, Flip, ImageToTensor, @@ -8,7 +9,6 @@ from .quick_test_dataset import QuickTestImageDataset from .samplers import DistributedSampler from .singan_dataset import SinGANDataset -from .unconditional_file_dataset import FileDataset from .unconditional_image_dataset import UnconditionalImageDataset from .unpaired_image_dataset import UnpairedImageDataset diff --git a/mmgen/datasets/unconditional_file_dataset.py b/mmgen/datasets/file_dataset.py similarity index 65% rename from mmgen/datasets/unconditional_file_dataset.py rename to mmgen/datasets/file_dataset.py index 18f226d3d..96559c10b 100644 --- a/mmgen/datasets/unconditional_file_dataset.py +++ b/mmgen/datasets/file_dataset.py @@ -47,15 +47,25 @@ def __init__(self, file_path, pipeline, test_mode=False): def load_annotations(self): """Load annotations.""" if self.file_path.endswith('.npz'): - data_info_list = self._load_annotations_from_npz() - self.data_infos = data_info_list + data_info, data_length = self._load_annotations_from_npz() + data_fetch_fn = self._npz_data_fetch_fn + + self.data_infos = data_info + self.data_fetch_fn = data_fetch_fn + self.data_length = data_length def _load_annotations_from_npz(self): - npz_file = np.load(self.file_path) + """Load annotations from npz file and check number of samples are + consistent among all items. + + Returns: + tuple: dict and int + """ + npz_file = np.load(self.file_path, mmap_mode='r') data_info_dict = dict() - data_info_list = [] npz_keys = list(npz_file.keys()) + # checnk num samples num_samples = None for k in npz_keys: data_info_dict[k] = npz_file[k] @@ -64,35 +74,45 @@ def _load_annotations_from_npz(self): num_samples = npz_file[k].shape[0] else: assert num_samples == npz_file[k].shape[0] + return data_info_dict, num_samples - # save to list - for idx in range(num_samples): - data_info = dict() - for k in npz_keys: - var = data_info_dict[k][idx] - if var.shape == (): - var = np.array([var]) - data_info[k] = var - data_info_list.append(data_info) + @staticmethod + def _npz_data_fetch_fn(data_infos, idx): + """Fetch data from npz file by idx and package them to a dict. - return data_info_list + Args: + data_infos (array, tuple, dict): Data infos in the npz file. + idx (int): Index of current batch. + + Returns: + dict: Data infos of the given idx. + """ + data_dict = dict() + for k in data_infos.keys(): + data_dict[k] = data_infos[k][idx] + return data_dict - def prepare_data(self, idx): + def prepare_data(self, idx, data_fetch_fn=None): """Prepare data. Args: idx (int): Index of current batch. + data_fetch_fn (callable): Function to fetch data. Returns: dict: Prepared training data batch. """ - return self.pipeline(self.data_infos[idx]) + if data_fetch_fn is None: + data = self.data_infos[idx] + else: + data = data_fetch_fn(self.data_infos, idx) + return self.pipeline(data) def __len__(self): - return len(self.data_infos) + return self.data_length def __getitem__(self, idx): - return self.prepare_data(idx) + return self.prepare_data(idx, self.data_fetch_fn) def __repr__(self): dataset_name = self.__class__ diff --git a/tools/evaluation.py b/tools/evaluation.py index 731c623a8..4b5a98a03 100644 --- a/tools/evaluation.py +++ b/tools/evaluation.py @@ -50,6 +50,12 @@ def parse_args(): default=None, help='path to store images. If not given, remove it after evaluation\ finished') + parser.add_argument( + '--save-npz', + action='store_true', + help=('whether to save generated images to a npz file named ' + '\'NUM_IMAGES.npz\'. The npz file will be saved at ' + '`samples-path`. (only work in offline mode)')) parser.add_argument( '--sample-model', type=str, @@ -113,8 +119,6 @@ def main(): init_dist(args.launcher, **cfg.dist_params) rank, world_size = get_dist_info() cfg.gpu_ids = range(world_size) - assert args.online or world_size == 1, ( - 'We only support online mode for distrbuted evaluation.') dirname = os.path.dirname(args.checkpoint) ckpt = os.path.basename(args.checkpoint) @@ -218,7 +222,7 @@ def main(): else: offline_evaluation(model, data_loader, metrics, logger, basic_table_info, args.batch_size, - args.samples_path, **args.sample_cfg) + args.samples_path, args.save_npz, **args.sample_cfg) if __name__ == '__main__': From 3c6d9e05da92bb444a6c482da61debba155dd704 Mon Sep 17 00:00:00 2001 From: LeoXing Date: Sun, 2 Jan 2022 23:47:01 +0800 Subject: [PATCH 4/5] fix some comment --- mmgen/core/evaluation/evaluation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmgen/core/evaluation/evaluation.py b/mmgen/core/evaluation/evaluation.py index 2de8d190f..e38170c11 100644 --- a/mmgen/core/evaluation/evaluation.py +++ b/mmgen/core/evaluation/evaluation.py @@ -64,11 +64,10 @@ def make_vanilla_dataloader(img_path, batch_size, dist=False): def make_npz_dataloader(npz_path, batch_size, dist=False): - # TODO: align keys pipeline = [ # permute the color channel. Because in npz_dataloader's pipeline, we - # direct the image files from npz file which is RGB order and we must - # convert it to BGR. + # direct load image in RGB order from npz file and we must convert it + # to BGR by setting ``to_rgb=True``. dict( type='Normalize', keys=['real_img'], @@ -291,7 +290,8 @@ def offline_evaluation(model, samples_path, f'samples_{num_imgs}x{H}x{W}x{num_channels}.npz') - # save new npz file + # save new npz file --> + # set key as ``real_img`` to align with vanilla dataset np.savez(npz_path, real_img=fake_imgs) mmcv.print_log(f'Save new npz_file to \'{npz_path}\'.', 'mmgen') From 0a336130696a2f369f4fcb145260c5a669ab27e1 Mon Sep 17 00:00:00 2001 From: LeoXing Date: Mon, 3 Jan 2022 13:00:26 +0800 Subject: [PATCH 5/5] fix bug when loading variables shape like () --- mmgen/datasets/file_dataset.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mmgen/datasets/file_dataset.py b/mmgen/datasets/file_dataset.py index 96559c10b..761ec94ea 100644 --- a/mmgen/datasets/file_dataset.py +++ b/mmgen/datasets/file_dataset.py @@ -11,8 +11,8 @@ class FileDataset(Dataset): """Uncoditional file Dataset. - This dataset contains raw images for training unconditional GANs. Given - the path of a file, we will load all image in this file. The + This dataset load data information from files for training GANs. Given + the path of a file, we will load all information in the file. The transformation on data is defined by the pipeline. Please ensure that ``LoadImageFromFile`` is not in your pipeline configs because we directly get images in ``np.ndarray`` from the given file. @@ -89,7 +89,11 @@ def _npz_data_fetch_fn(data_infos, idx): """ data_dict = dict() for k in data_infos.keys(): - data_dict[k] = data_infos[k][idx] + if data_infos[k][idx].shape == (): + v = np.array([data_infos[k][idx]]) + else: + v = data_infos[k][idx] + data_dict[k] = v return data_dict def prepare_data(self, idx, data_fetch_fn=None):