Skip to content

Commit

Permalink
Add docstrings, bugfix.
Browse files Browse the repository at this point in the history
  • Loading branch information
davidasamy committed Jul 31, 2023
1 parent 23c77d9 commit 611f6ef
Showing 1 changed file with 27 additions and 11 deletions.
38 changes: 27 additions & 11 deletions sleap_nn/data/confidence_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@
def make_confmaps(
points: torch.Tensor, xv: torch.Tensor, yv: torch.Tensor, sigma: float
):
points = torch.squeeze(points)

x = points[:, 0]
y = points[:, 1]

"""Create confidence maps from points and grid vectors."""
x = torch.reshape(points[:, 0], (1, 1, -1))
y = torch.reshape(points[:, 1], (1, 1, -1))
cm = torch.exp(
-(
(torch.reshape(xv, (1, -1, 1)) - x) ** 2
Expand All @@ -21,26 +19,44 @@ def make_confmaps(
/ (2 * sigma**2)
)

# Replace NaNs with 0.
cm = torch.where(torch.isnan(cm), 0.0, cm)
print(torch.max(cm))
return cm


def make_grid_vectors(image_height: int, image_width: int, output_stride: int = 1):
xv = torch.arange(0, image_width, output_stride, dtype=torch.float32)
yv = torch.arange(0, image_height, output_stride, dtype=torch.float32)
def make_grid_vectors(image_height: int, image_width: int, output_stride: int):
"""Create grid vectors of appropriate shape and stride."""
xv = torch.arange(0, image_width, step=output_stride).to(torch.float32)
yv = torch.arange(0, image_height, step=output_stride).to(torch.float32)
return xv, yv


class ConfidenceMapGenerator(IterDataPipe):
def __init__(self, source_dp, sigma=2.5, output_stride=2):
"""DataPipe for generating confidence maps.
This DataPipe will generate confidence maps for examples from the input pipeline.
Attributes:
source_dp: The input `IterDataPipe` with examples that contain an instance and
an image.
sigma: The standard deviation of the Gaussian distribution that is used to
generate confidence maps.
output_stride: The relative stride to use when generating confidence maps.
A larger stride will generate smaller confidence maps.
"""

def __init__(
self, source_dp: IterDataPipe, sigma: int = 1.5, output_stride: int = 1
):
"""Initialize ConfidenceMapGenerator with input `DataPipe`, sigma, and output stride."""
self.source_dp = source_dp
self.sigma = sigma
self.output_stride = output_stride

def __iter__(self):
"""Generate confidence maps for each example."""
for example in self.source_dp:
instance = example["instances"]
instance = example["instance"]
width = example["instance_image"].shape[-1]
height = example["instance_image"].shape[-2]
xv, yv = make_grid_vectors(height, width, self.output_stride)
Expand Down

0 comments on commit 611f6ef

Please sign in to comment.