diff --git a/mmselfsup/utils/collect.py b/mmselfsup/utils/collect.py index 26837e036..6277219b7 100644 --- a/mmselfsup/utils/collect.py +++ b/mmselfsup/utils/collect.py @@ -22,8 +22,8 @@ def nondist_forward_collect(func, data_loader, length): results_all (dict(np.ndarray)): The concatenated outputs. """ results = [] - prog_bar = mmcv.ProgressBar(len(data_loader)) - for i, data in enumerate(data_loader): + prog_bar = mmcv.ProgressBar(len(data_loader.dataset)) + for i, data in enumerate(data_loader.dataset): input_data = dict(img=data['img']) with torch.no_grad(): result = func(**input_data) # feat_dict @@ -58,8 +58,8 @@ def dist_forward_collect(func, data_loader, rank, length, ret_rank=-1): """ results = [] if rank == 0: - prog_bar = mmcv.ProgressBar(len(data_loader)) - for idx, data in enumerate(data_loader): + prog_bar = mmcv.ProgressBar(len(data_loader.dataset)) + for idx, data in enumerate(data_loader.dataset): with torch.no_grad(): result = func(**data) # dict{key: tensor} results.append(result)