diff --git a/biapy/data/generators/test_pair_data_generators.py b/biapy/data/generators/test_pair_data_generators.py index cb1cc18f..182c8811 100644 --- a/biapy/data/generators/test_pair_data_generators.py +++ b/biapy/data/generators/test_pair_data_generators.py @@ -142,7 +142,8 @@ def __init__(self, ndim, X=None, d_path=None, test_by_chunks=False, provide_Y=Fa self.data[f"sample_{c}"]["gt"] = os.path.join(dm_path,self.data_mask_path[i],gt_image_path) c += 1 - self.len = len(self.data) + self.sample_list = list(self.data.keys()) + self.len = len(self.sample_list) if self.len == 0: raise ValueError("No image found in {}".format(d_path)) else: @@ -159,7 +160,6 @@ def __init__(self, ndim, X=None, d_path=None, test_by_chunks=False, provide_Y=Fa raise ValueError("No test image found in {}".format(d_path)) else: self.len = len(X) - self.o_indexes = np.arange(self.len) # Check if a division is required self.X_norm = {} @@ -278,11 +278,12 @@ def load_sample(self, idx): # Choose the data source if self.X is None: if not self.all_files_in_same_folder: - img = imread(self.data[f"sample_{idx}"]["raw"]) + k = self.sample_list[idx] + img = imread(self.data[k]["raw"]) img = np.squeeze(img) - filename = self.data[f"sample_{idx}"]["raw"] + filename = self.data[k]["raw"] if self.provide_Y: - mask = imread(self.data[f"sample_{idx}"]["gt"]) + mask = imread(self.data[k]["gt"]) mask = np.squeeze(mask) else: filename = os.path.join(self.d_path, self.data_path[idx])