From 38e965fc432f9a41cb720c75bd933007210f6fac Mon Sep 17 00:00:00 2001 From: jingt2ch Date: Tue, 27 Sep 2022 04:26:38 +0900 Subject: [PATCH] [Fix] Fixed typo in loop condition for testing model. --- mmselfsup/utils/collect.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmselfsup/utils/collect.py b/mmselfsup/utils/collect.py index 26837e036..6277219b7 100644 --- a/mmselfsup/utils/collect.py +++ b/mmselfsup/utils/collect.py @@ -22,8 +22,8 @@ def nondist_forward_collect(func, data_loader, length): results_all (dict(np.ndarray)): The concatenated outputs. """ results = [] - prog_bar = mmcv.ProgressBar(len(data_loader)) - for i, data in enumerate(data_loader): + prog_bar = mmcv.ProgressBar(len(data_loader.dataset)) + for i, data in enumerate(data_loader.dataset): input_data = dict(img=data['img']) with torch.no_grad(): result = func(**input_data) # feat_dict @@ -58,8 +58,8 @@ def dist_forward_collect(func, data_loader, rank, length, ret_rank=-1): """ results = [] if rank == 0: - prog_bar = mmcv.ProgressBar(len(data_loader)) - for idx, data in enumerate(data_loader): + prog_bar = mmcv.ProgressBar(len(data_loader.dataset)) + for idx, data in enumerate(data_loader.dataset): with torch.no_grad(): result = func(**data) # dict{key: tensor} results.append(result)