From 7134550f54860f8e05751959965f7335a2ecee4d Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 15 Nov 2023 19:08:51 +0100 Subject: [PATCH] Refactor masking input logic --- micro_sam/training/sam_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/micro_sam/training/sam_trainer.py b/micro_sam/training/sam_trainer.py index 48fe4eab..1c0a290c 100644 --- a/micro_sam/training/sam_trainer.py +++ b/micro_sam/training/sam_trainer.py @@ -258,7 +258,7 @@ def _get_updated_points_per_mask_per_subiter(self, masks, sampled_binary_y, batc if self.mask_prob > 0: # using mask inputs for iterative prompting while training, with a probability - use_mask_inputs = True if random.random() < self.mask_prob else False + use_mask_inputs = (random.random() < self.mask_prob) if use_mask_inputs: _inp["mask_inputs"] = logits else: # remove previously existing mask inputs to avoid using them in next sub-iteration