-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Centroid Only Predictor #1535
Conversation
WalkthroughThis update introduces a more flexible Changes
TipsChat with CodeRabbit Bot (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Status
Actionable comments generated: 1
Files selected for processing (3)
- sleap/gui/app.py (5 hunks)
- sleap/nn/inference.py (8 hunks)
- tests/nn/test_inference.py (3 hunks)
Additional comments (Suppressed): 16
tests/nn/test_inference.py (3)
13-26: The import of
CentroidPredictor
is new and is used in the test functiontest_centroid_predictor
. Ensure that theCentroidPredictor
class is correctly implemented and available in thesleap.nn.inference
module.635-661: The new test function
test_centroid_predictor
is testing theCentroidPredictor
class. It checks the number of labels predicted, the number of instances in the first label, and the grayscale property of the predictor. It also checks the correctness of the predicted points by comparing them with the ground truth points. The test seems comprehensive and well-structured.1654-1655: The
if __name__ == "__main__":
block is used to run thetest_centroid_predictor
function when the script is run directly. This is a good practice for testing individual functions.sleap/gui/app.py (5)
101-107: The
MainWindow
constructor now accepts an optionallabels
parameter. Ensure that all calls to this constructor throughout the codebase have been updated to match the new signature.122-122: The
labels
attribute of theMainWindow
instance is now initialized with thelabels
parameter if it is provided, otherwise a newLabels
object is created. This change allows for greater flexibility in initializing theMainWindow
instance.1598-1601: The
create_parser
function has been added. This function creates an argument parser and is now used in themain
function. This change improves the modularity of the code.1641-1647: The
main
function now accepts an optionallabels
parameter and uses thecreate_parser
function to create an argument parser. Ensure that all calls to this function throughout the codebase have been updated to match the new signature.1660-1665: The
MainWindow
instance is now initialized with thelabels
parameter if it is provided. This change allows for greater flexibility in initializing theMainWindow
instance.sleap/nn/inference.py (8)
51-57: The import statement
from sleap.nn.data.dataset_ops import Batcher
is new. Ensure that theBatcher
class is used appropriately in the code.374-398: The
process_batch
method is a new addition. It runs inference on a batch of examples, updates the input data example with model outputs, converts certain fields to numpy arrays if they are not already, and adjusts for potential SizeMatcher scaling. Ensure that this method is called correctly in the code.418-423: The check for
self.verbosity == "rich"
is now outside theif self.inference_model is None:
block. Ensure that this change does not affect the expected behavior.434-440: The
process_batch
method is now being used in the loop over data batches. Ensure that this change is intended and that the method is used correctly.455-461: The
process_batch
method is now being used in the loop over data batches. Ensure that this change is intended and that the method is used correctly.491-496: The
process_batch
method is now being used in the loop over data batches. Ensure that this change is intended and that the method is used correctly.2199-2520: The
CentroidPredictor
class is a new addition. It extends thePredictor
class and is designed to handle initialization, preprocessing, and tracking using a trained top-down multi-instance SLEAP model. Ensure that this class is used correctly in the code.5801-5830: The code block under
if __name__ == "__main__":
is new. It seems to be a test block for theCentroidPredictor
class. Ensure that this block is intended to be included in the final code.
def test_topdown_predictor_centroid(min_labels, min_centroid_model_path): | ||
predictor = TopDownPredictor.from_trained_models( | ||
centroid_model_path=min_centroid_model_path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function test_topdown_predictor_centroid
is not new, but the indentation seems to be off. This might cause a syntax error. Please verify.
- def test_topdown_predictor_centroid(min_labels, min_centroid_model_path):
+ def test_topdown_predictor_centroid(min_labels, min_centroid_model_path):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix which node is assigned the centroid in _make_labeled_frames
.
Also, in the CentroidInferenceModel
in __init__
we should add a return_crops
attribute to be used in call
to keep the "crops"
key if return_crops
and otherwise remove the "crops"
key so that images are not transferred back to the CPU.
|
||
predicted_instances.append( | ||
PredictedInstance.from_numpy( | ||
points=[pts], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will set the first node in the skeleton to pts which is fine for single node skeletons, but is incorrect if a different anchor point is set.
inds1, inds2 = sleap.nn.utils.match_points(points_gt, points_pr) | ||
points_gt_closest = points_gt[inds1.numpy()][2:] | ||
points_pr = points_pr[inds2.numpy()][2:] | ||
assert_allclose(points_gt, points_gt_closest, atol=40) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test fails for two reasons...
- the coordinates returned in the centroid predictor are assigned to just the first node of the instance
- if an anchor point is not specified, then the midpoint of the bounding box is used as the centroid:
sleap/sleap/nn/data/instance_centroids.py
Lines 182 to 189 in ed77b49
def find_centroids(frame_data): """Local processing function for dataset mapping.""" # Find the bounding box midpoints. mid_pts = find_points_bbox_midpoint(frame_data["instances"]) # Update and return. frame_data["centroids"] = mid_pts return frame_data
""" | ||
|
||
centroid_config: TrainingJobConfig | ||
centroid_model: Optional[Model] = attr.ib(default=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to validate that the model either:
- is a single node skeleton
- specifies an anchor point
Description
Previously, we only had a top-down predictor (which utilizes a centroid inference layer), but it might be beneficial to add a predictor for just the centroid say for cropping about to create a traing dataset of single animal centered crops (for many animals in a frame), or to output a neat egocentric video as suggested in #1522.
Things to resolve
Types of changes
Does this address any currently open issues?
Outside contributors checklist
Thank you for contributing to SLEAP!
❤️
Summary by CodeRabbit
New Features:
CentroidPredictor
class to thesleap.nn.inference
module for handling initialization, preprocessing, and tracking using a trained top-down multi-instance SLEAP model.MainWindow
class insleap.gui.app
to accept an optionallabels
parameter, allowing users to initialize with pre-existing labels.Improvements:
process_batch
function insleap.nn.inference
for better modularity and maintainability. It is now a method of theCentroidPredictor
class.Tests:
test_centroid_predictor
intests.nn.test_inference
to ensure the correctness of theCentroidPredictor
class.Error Handling:
sleap.gui.app
by wrapping thesleap.use_cpu_only()
call in a try-except block to handle potentialRuntimeError
exceptions.