Skip to content

Commit

Permalink
Fix scaling in centroidcrop
Browse files Browse the repository at this point in the history
  • Loading branch information
gitttt-1234 committed Dec 17, 2024
1 parent ad03905 commit d47c98a
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1647,6 +1647,9 @@ class CentroidCrop(InferenceLayer):
crop_size: Integer scalar specifying the height/width of the centered crops.
input_scale: Float indicating if the images should be resized before being
passed to the model.
precrop_resize: Float indicating the factor by which the original images
(not images resized for centroid model) should be resized before cropping.
Note: this resize only after getting the predictions for centroid model.
pad_to_stride: If not 1, input image will be paded to ensure that it is
divisible by this value (after scaling). This should be set to the max
stride of the model.
Expand Down Expand Up @@ -1687,6 +1690,7 @@ def __init__(
keras_model: tf.keras.Model,
crop_size: int,
input_scale: float = 1.0,
precrop_resize: Optional[float] = 1.0,
pad_to_stride: int = 1,
output_stride: Optional[int] = None,
peak_threshold: float = 0.2,
Expand All @@ -1707,6 +1711,7 @@ def __init__(
)

self.crop_size = crop_size
self.precrop_resize = precrop_resize

self.confmaps_ind = confmaps_ind
self.offsets_ind = offsets_ind
Expand Down Expand Up @@ -1901,6 +1906,13 @@ def call(self, inputs):

crop_offsets = centroid_points - (self.crop_size / 2)

# resize full images
if self.precrop_resize:
full_imgs = sleap.nn.data.resizing.resize_image(
full_imgs, self.precrop_resize
)
centroid_points *= self.precrop_resize

# Crop instances around centroids.
bboxes = sleap.nn.data.instance_cropping.make_centered_bboxes(
centroid_points, self.crop_size, self.crop_size
Expand Down Expand Up @@ -2372,6 +2384,7 @@ def _initialize_inference_model(self):
keras_model=self.centroid_model.keras_model,
crop_size=crop_size,
input_scale=self.centroid_config.data.preprocessing.input_scaling,
precrop_resize=None,
pad_to_stride=self.centroid_config.data.preprocessing.pad_to_stride,
output_stride=self.centroid_config.model.heads.centroid.output_stride,
peak_threshold=self.peak_threshold,
Expand All @@ -2397,6 +2410,10 @@ def _initialize_inference_model(self):
)
if use_gt_centroid:
centroid_crop_layer.input_scale = cfg.data.preprocessing.input_scaling
else:
centroid_crop_layer.precrop_resize = (
cfg.data.preprocessing.input_scaling
)

self.inference_model = TopDownInferenceModel(
centroid_crop=centroid_crop_layer, instance_peaks=instance_peaks_layer
Expand Down

0 comments on commit d47c98a

Please sign in to comment.