diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 5b65c825d..4f00419a5 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -1647,6 +1647,9 @@ class CentroidCrop(InferenceLayer): crop_size: Integer scalar specifying the height/width of the centered crops. input_scale: Float indicating if the images should be resized before being passed to the model. + precrop_resize: Float indicating the factor by which the original images + (not images resized for centroid model) should be resized before cropping. + Note: this resize only after getting the predictions for centroid model. pad_to_stride: If not 1, input image will be paded to ensure that it is divisible by this value (after scaling). This should be set to the max stride of the model. @@ -1687,6 +1690,7 @@ def __init__( keras_model: tf.keras.Model, crop_size: int, input_scale: float = 1.0, + precrop_resize: Optional[float] = 1.0, pad_to_stride: int = 1, output_stride: Optional[int] = None, peak_threshold: float = 0.2, @@ -1707,6 +1711,7 @@ def __init__( ) self.crop_size = crop_size + self.precrop_resize = precrop_resize self.confmaps_ind = confmaps_ind self.offsets_ind = offsets_ind @@ -1901,6 +1906,13 @@ def call(self, inputs): crop_offsets = centroid_points - (self.crop_size / 2) + # resize full images + if self.precrop_resize: + full_imgs = sleap.nn.data.resizing.resize_image( + full_imgs, self.precrop_resize + ) + centroid_points *= self.precrop_resize + # Crop instances around centroids. bboxes = sleap.nn.data.instance_cropping.make_centered_bboxes( centroid_points, self.crop_size, self.crop_size @@ -2372,6 +2384,7 @@ def _initialize_inference_model(self): keras_model=self.centroid_model.keras_model, crop_size=crop_size, input_scale=self.centroid_config.data.preprocessing.input_scaling, + precrop_resize=None, pad_to_stride=self.centroid_config.data.preprocessing.pad_to_stride, output_stride=self.centroid_config.model.heads.centroid.output_stride, peak_threshold=self.peak_threshold, @@ -2397,6 +2410,10 @@ def _initialize_inference_model(self): ) if use_gt_centroid: centroid_crop_layer.input_scale = cfg.data.preprocessing.input_scaling + else: + centroid_crop_layer.precrop_resize = ( + cfg.data.preprocessing.input_scaling + ) self.inference_model = TopDownInferenceModel( centroid_crop=centroid_crop_layer, instance_peaks=instance_peaks_layer