From 77871053b627c1e14bd44836e65bb9a19d586f7f Mon Sep 17 00:00:00 2001 From: Chaymaa_bs Date: Wed, 2 Nov 2022 17:00:43 +0100 Subject: [PATCH 1/7] add class mapping for multi class segmentation --- train.py | 4 +++- utils/data_loading.py | 20 +++++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 6067c72bb3..d2f7b0434a 100644 --- a/train.py +++ b/train.py @@ -32,7 +32,9 @@ def train_net(net, amp: bool = False): # 1. Create dataset try: - dataset = CarvanaDataset(dir_img, dir_mask, img_scale) + # if multi_class semantic segmentation add class mapping + # example for 3 class segmentation : mapping = {(0, 0, 0): 0, (255, 0, 255): 1, (0, 255, 255): 2} + dataset = CarvanaDataset(dir_img, dir_mask, img_scale, mapping = {}) except (AssertionError, RuntimeError): dataset = BasicDataset(dir_img, dir_mask, img_scale) diff --git a/utils/data_loading.py b/utils/data_loading.py index 8bb4f9252c..9440b17e67 100644 --- a/utils/data_loading.py +++ b/utils/data_loading.py @@ -10,12 +10,13 @@ class BasicDataset(Dataset): - def __init__(self, images_dir: str, masks_dir: str, scale: float = 1.0, mask_suffix: str = ''): + def __init__(self, images_dir: str, masks_dir: str, scale: float = 1.0, mask_suffix: str = '', mapping={}): self.images_dir = Path(images_dir) self.masks_dir = Path(masks_dir) assert 0 < scale <= 1, 'Scale must be between 0 and 1' self.scale = scale self.mask_suffix = mask_suffix + self.mapping = mapping self.ids = [splitext(file)[0] for file in listdir(images_dir) if not file.startswith('.')] if not self.ids: @@ -53,6 +54,20 @@ def load(filename): else: return Image.open(filename) + @classmethod + def mask_to_class(cls, mask: np.ndarray, mapping): + mask_ = np.empty((mask.shape[1], mask.shape[2])) + for k in mapping: + k_array = np.array(k) + # to have the same dim as the mask + k_array = np.expand_dims(k_array, axis=(1, 2)) + # Extract each class indexes + idx = (mask == k_array) + # check there is 3 channels + validx = (idx.sum(0) == 3) + mask_[validx] = mapping[k] + return mask_ + def __getitem__(self, idx): name = self.ids[idx] mask_file = list(self.masks_dir.glob(name + self.mask_suffix + '.*')) @@ -69,6 +84,9 @@ def __getitem__(self, idx): img = self.preprocess(img, self.scale, is_mask=False) mask = self.preprocess(mask, self.scale, is_mask=True) + # mapping the class colors + mask = self.mask_to_class(mask, self.mapping) + return { 'image': torch.as_tensor(img.copy()).float().contiguous(), 'mask': torch.as_tensor(mask.copy()).long().contiguous() From e97bb820270bc5cb7b5e55774037007d99e8fd50 Mon Sep 17 00:00:00 2001 From: Chaymaa_bs Date: Wed, 2 Nov 2022 17:04:12 +0100 Subject: [PATCH 2/7] add class mapping for multi class segmentation --- train.py | 2 +- utils/data_loading.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index d2f7b0434a..be302c4c74 100644 --- a/train.py +++ b/train.py @@ -36,7 +36,7 @@ def train_net(net, # example for 3 class segmentation : mapping = {(0, 0, 0): 0, (255, 0, 255): 1, (0, 255, 255): 2} dataset = CarvanaDataset(dir_img, dir_mask, img_scale, mapping = {}) except (AssertionError, RuntimeError): - dataset = BasicDataset(dir_img, dir_mask, img_scale) + dataset = BasicDataset(dir_img, dir_mask, img_scale, mapping = {}) # 2. Split into train / validation partitions n_val = int(len(dataset) * val_percent) diff --git a/utils/data_loading.py b/utils/data_loading.py index 9440b17e67..08c8b92af6 100644 --- a/utils/data_loading.py +++ b/utils/data_loading.py @@ -94,5 +94,5 @@ def __getitem__(self, idx): class CarvanaDataset(BasicDataset): - def __init__(self, images_dir, masks_dir, scale=1): + def __init__(self, images_dir, masks_dir, scale=1, mapping = {}): super().__init__(images_dir, masks_dir, scale, mask_suffix='_mask') From bd65bdf465b6c991fd72d046a4d6ca03309910c4 Mon Sep 17 00:00:00 2001 From: Chaymaa_bs Date: Mon, 7 Nov 2022 12:14:46 +0100 Subject: [PATCH 3/7] remove /255 from preprocessing --- utils/data_loading.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/utils/data_loading.py b/utils/data_loading.py index 08c8b92af6..1698517b36 100644 --- a/utils/data_loading.py +++ b/utils/data_loading.py @@ -40,8 +40,6 @@ def preprocess(pil_img, scale, is_mask): else: img_ndarray = img_ndarray.transpose((2, 0, 1)) - img_ndarray = img_ndarray / 255 - return img_ndarray @staticmethod @@ -56,14 +54,13 @@ def load(filename): @classmethod def mask_to_class(cls, mask: np.ndarray, mapping): - mask_ = np.empty((mask.shape[1], mask.shape[2])) + mask_ = np.zeros((mask.shape[1], mask.shape[2])) for k in mapping: k_array = np.array(k) # to have the same dim as the mask k_array = np.expand_dims(k_array, axis=(1, 2)) # Extract each class indexes idx = (mask == k_array) - # check there is 3 channels validx = (idx.sum(0) == 3) mask_[validx] = mapping[k] return mask_ From 73ce81f7918771b26bed15b08254d5d8b00d65ae Mon Sep 17 00:00:00 2001 From: Chaymaa_bs Date: Mon, 7 Nov 2022 17:31:07 +0100 Subject: [PATCH 4/7] add reverse mapping in predict.py for multiclass semantic segmentation --- predict.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/predict.py b/predict.py index 956b6f8894..6b672eeb04 100755 --- a/predict.py +++ b/predict.py @@ -68,12 +68,21 @@ def _generate_name(fn): return args.output or list(map(_generate_name, args.input)) - -def mask_to_image(mask: np.ndarray): +# if multiclass semantic segmentation, consider setting the mapping dict used during training, example: +# mapping = {(0, 0, 0): 0, (255, 0, 255): 1, (0, 255, 255): 2} +def mask_to_image(mask: np.ndarray, mapping = {}): if mask.ndim == 2: return Image.fromarray((mask * 255).astype(np.uint8)) elif mask.ndim == 3: - return Image.fromarray((np.argmax(mask, axis=0) * 255 / mask.shape[0]).astype(np.uint8)) + # reverse the mapping values we have used during training + rev_mapping = {mapping[k]: k for k in mapping} + # create an empty image with 3 channels of shape : (3, h, w) + pred_image = torch.zeros(3, mask_pred.size(0), mask_pred.size(1), dtype=torch.uint8) + # replace predicted mask values with mapped values + for k in rev_mapping: + pred_image[:, mask_pred == k] = torch.tensor(rev_mapping[k]).byte().view(3, 1) + final_mask_pred = pred_image.permute(1, 2, 0).numpy() + return PIL.Image.fromarray(final_mask_pred) if __name__ == '__main__': @@ -104,6 +113,8 @@ def mask_to_image(mask: np.ndarray): if not args.no_save: out_filename = out_files[i] + # if multiclass semantic segmentation, consider setting the mapping dict used during training, example: + # mapping = {(0, 0, 0): 0, (255, 0, 255): 1, (0, 255, 255): 2} result = mask_to_image(mask) result.save(out_filename) logging.info(f'Mask saved to {out_filename}') From 022e7b823fbd4718adeb6e3ac27492658087aced Mon Sep 17 00:00:00 2001 From: Chaymaa_bs Date: Mon, 7 Nov 2022 17:32:35 +0100 Subject: [PATCH 5/7] add reverse mapping in predict.py for multiclass semantic segmentation --- predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/predict.py b/predict.py index 6b672eeb04..33ef911354 100755 --- a/predict.py +++ b/predict.py @@ -115,7 +115,7 @@ def mask_to_image(mask: np.ndarray, mapping = {}): out_filename = out_files[i] # if multiclass semantic segmentation, consider setting the mapping dict used during training, example: # mapping = {(0, 0, 0): 0, (255, 0, 255): 1, (0, 255, 255): 2} - result = mask_to_image(mask) + result = mask_to_image(mask, mapping={}) result.save(out_filename) logging.info(f'Mask saved to {out_filename}') From 4b00ce9a7fef03d8aa55d2528ee7e0cc69ce1cc6 Mon Sep 17 00:00:00 2001 From: Chaymaa_bs Date: Wed, 9 Nov 2022 16:55:13 +0100 Subject: [PATCH 6/7] remove conditions --- predict.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/predict.py b/predict.py index 33ef911354..72786cf471 100755 --- a/predict.py +++ b/predict.py @@ -71,18 +71,16 @@ def _generate_name(fn): # if multiclass semantic segmentation, consider setting the mapping dict used during training, example: # mapping = {(0, 0, 0): 0, (255, 0, 255): 1, (0, 255, 255): 2} def mask_to_image(mask: np.ndarray, mapping = {}): - if mask.ndim == 2: - return Image.fromarray((mask * 255).astype(np.uint8)) - elif mask.ndim == 3: - # reverse the mapping values we have used during training - rev_mapping = {mapping[k]: k for k in mapping} - # create an empty image with 3 channels of shape : (3, h, w) - pred_image = torch.zeros(3, mask_pred.size(0), mask_pred.size(1), dtype=torch.uint8) - # replace predicted mask values with mapped values - for k in rev_mapping: - pred_image[:, mask_pred == k] = torch.tensor(rev_mapping[k]).byte().view(3, 1) - final_mask_pred = pred_image.permute(1, 2, 0).numpy() - return PIL.Image.fromarray(final_mask_pred) + + # reverse the mapping values we have used during training + rev_mapping = {mapping[k]: k for k in mapping} + # create an empty image with 3 channels of shape : (3, h, w) + pred_image = torch.zeros(3, mask_pred.size(0), mask_pred.size(1), dtype=torch.uint8) + # replace predicted mask values with mapped values + for k in rev_mapping: + pred_image[:, mask_pred == k] = torch.tensor(rev_mapping[k]).byte().view(3, 1) + final_mask_pred = pred_image.permute(1, 2, 0).numpy() + return PIL.Image.fromarray(final_mask_pred) if __name__ == '__main__': From 7cb1698eae6896356df9aba67d2dbc0d6726f4c1 Mon Sep 17 00:00:00 2001 From: Chaymaa_bs Date: Thu, 17 Nov 2022 16:06:09 +0100 Subject: [PATCH 7/7] rectify code --- predict.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/predict.py b/predict.py index 72786cf471..01ccb3a0a3 100755 --- a/predict.py +++ b/predict.py @@ -68,19 +68,25 @@ def _generate_name(fn): return args.output or list(map(_generate_name, args.input)) + # if multiclass semantic segmentation, consider setting the mapping dict used during training, example: # mapping = {(0, 0, 0): 0, (255, 0, 255): 1, (0, 255, 255): 2} def mask_to_image(mask: np.ndarray, mapping = {}): - - # reverse the mapping values we have used during training - rev_mapping = {mapping[k]: k for k in mapping} - # create an empty image with 3 channels of shape : (3, h, w) - pred_image = torch.zeros(3, mask_pred.size(0), mask_pred.size(1), dtype=torch.uint8) - # replace predicted mask values with mapped values - for k in rev_mapping: - pred_image[:, mask_pred == k] = torch.tensor(rev_mapping[k]).byte().view(3, 1) - final_mask_pred = pred_image.permute(1, 2, 0).numpy() - return PIL.Image.fromarray(final_mask_pred) + if mask.ndim == 2: + return Image.fromarray((mask * 255).astype(np.uint8)) + else: + # probabilities to indexes --> index of each class that has the highest probability + mask = torch.argmax(mask, axis=0) + # mask shape: (h, w) + # reverse the mapping values we have used during training + rev_mapping = {mapping[k]: k for k in mapping} + # create an empty image with 3 channels of shape : (3, h, w) + pred_image = torch.zeros(3, mask.size(0), mask.size(1), dtype=torch.uint8) + # replace predicted mask values with mapped values + for k in rev_mapping: + pred_image[:, mask == k] = torch.tensor(rev_mapping[k]).byte().view(3, 1) + final_mask_pred = pred_image.permute(1, 2, 0).numpy() + return Image.fromarray(final_mask_pred) if __name__ == '__main__':