Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Centroid Only Predictor #1535

Closed
wants to merge 4 commits into from
Closed

Conversation

roomrys
Copy link
Collaborator

@roomrys roomrys commented Oct 6, 2023

Description

Previously, we only had a top-down predictor (which utilizes a centroid inference layer), but it might be beneficial to add a predictor for just the centroid say for cropping about to create a traing dataset of single animal centered crops (for many animals in a frame), or to output a neat egocentric video as suggested in #1522.

Things to resolve

  • Why does the predictor perform terribly compared to ground truth in the tests?
  • How should the crop size be passed in: from the centroid model configuration?
  • Create a CLI for the centroid only predictor.

Types of changes

  • Bugfix
  • New feature
  • Refactor / Code style update (no logical changes)
  • Build / CI changes
  • Documentation Update
  • Other (explain)

Does this address any currently open issues?

Outside contributors checklist

  • Review the guidelines for contributing to this repository
  • Read and sign the CLA and add yourself to the authors list
  • Make sure you are making a pull request against the develop branch (not main). Also you should start your branch off develop
  • Add tests that prove your fix is effective or that your feature works
  • Add necessary documentation (if appropriate)

Thank you for contributing to SLEAP!

❤️

Summary by CodeRabbit

New Features:

  • Added a new CentroidPredictor class to the sleap.nn.inference module for handling initialization, preprocessing, and tracking using a trained top-down multi-instance SLEAP model.
  • Enhanced the MainWindow class in sleap.gui.app to accept an optional labels parameter, allowing users to initialize with pre-existing labels.

Improvements:

  • Refactored the process_batch function in sleap.nn.inference for better modularity and maintainability. It is now a method of the CentroidPredictor class.

Tests:

  • Added a new test function test_centroid_predictor in tests.nn.test_inference to ensure the correctness of the CentroidPredictor class.

Error Handling:

  • Improved error handling in sleap.gui.app by wrapping the sleap.use_cpu_only() call in a try-except block to handle potential RuntimeError exceptions.

@roomrys roomrys requested a review from talmo October 10, 2023 16:22
@roomrys roomrys marked this pull request as ready for review October 10, 2023 16:22
@coderabbitai
Copy link

coderabbitai bot commented Oct 10, 2023

Walkthrough

This update introduces a more flexible MainWindow class in the GUI, a new CentroidPredictor class for improved inference, and a corresponding test function. The changes enhance the application's robustness and maintainability, while also improving the user experience by handling potential runtime errors.

Changes

File Summary
sleap/gui/app.py The MainWindow class now accepts an optional labels parameter. A create_parser function is introduced for argument parsing. The sleap.use_cpu_only() call is now error-handled.
sleap/nn/inference.py Introduced CentroidPredictor class for handling initialization, preprocessing, and tracking. Moved process_batch function from _predict_generator method to CentroidPredictor class.
tests/nn/test_inference.py Added test_centroid_predictor function to test the new CentroidPredictor class. The function loads a trained model, makes predictions, and asserts the expected results.

🐇
"In the land of code, where the shadows lie,
A rabbit hops, under the cloudless sky.
With each key press, a new change is born,
A poem to celebrate, in the early morn.
The code is clean, the logic is tight,
The rabbit's work, is a beautiful sight. 🌟"


Tips

Chat with CodeRabbit Bot (@coderabbitai)

  • Mention @coderabbitai in any review comment for bot assistance.
  • Note: Review comments are made on code diffs or files, not on the PR overview.

Pause Incremental Reviews

  • Insert @coderabbitai: ignore in the PR description to halt the bot's ongoing reviews. Remove the line to resume.
  • Additionally, enter @coderabbitai pause to halt the bot's ongoing reviews as any issue or review comment. To
    resume reviews, enter @coderabbitai resume as any issue or review comment.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Status

Actionable comments generated: 1

Commits Files that changed from the base of the PR and between ed77b49 and a6f9c20.
Files selected for processing (3)
  • sleap/gui/app.py (5 hunks)
  • sleap/nn/inference.py (8 hunks)
  • tests/nn/test_inference.py (3 hunks)
Additional comments (Suppressed): 16
tests/nn/test_inference.py (3)
  • 13-26: The import of CentroidPredictor is new and is used in the test function test_centroid_predictor. Ensure that the CentroidPredictor class is correctly implemented and available in the sleap.nn.inference module.

  • 635-661: The new test function test_centroid_predictor is testing the CentroidPredictor class. It checks the number of labels predicted, the number of instances in the first label, and the grayscale property of the predictor. It also checks the correctness of the predicted points by comparing them with the ground truth points. The test seems comprehensive and well-structured.

  • 1654-1655: The if __name__ == "__main__": block is used to run the test_centroid_predictor function when the script is run directly. This is a good practice for testing individual functions.

sleap/gui/app.py (5)
  • 101-107: The MainWindow constructor now accepts an optional labels parameter. Ensure that all calls to this constructor throughout the codebase have been updated to match the new signature.

  • 122-122: The labels attribute of the MainWindow instance is now initialized with the labels parameter if it is provided, otherwise a new Labels object is created. This change allows for greater flexibility in initializing the MainWindow instance.

  • 1598-1601: The create_parser function has been added. This function creates an argument parser and is now used in the main function. This change improves the modularity of the code.

  • 1641-1647: The main function now accepts an optional labels parameter and uses the create_parser function to create an argument parser. Ensure that all calls to this function throughout the codebase have been updated to match the new signature.

  • 1660-1665: The MainWindow instance is now initialized with the labels parameter if it is provided. This change allows for greater flexibility in initializing the MainWindow instance.

sleap/nn/inference.py (8)
  • 51-57: The import statement from sleap.nn.data.dataset_ops import Batcher is new. Ensure that the Batcher class is used appropriately in the code.

  • 374-398: The process_batch method is a new addition. It runs inference on a batch of examples, updates the input data example with model outputs, converts certain fields to numpy arrays if they are not already, and adjusts for potential SizeMatcher scaling. Ensure that this method is called correctly in the code.

  • 418-423: The check for self.verbosity == "rich" is now outside the if self.inference_model is None: block. Ensure that this change does not affect the expected behavior.

  • 434-440: The process_batch method is now being used in the loop over data batches. Ensure that this change is intended and that the method is used correctly.

  • 455-461: The process_batch method is now being used in the loop over data batches. Ensure that this change is intended and that the method is used correctly.

  • 491-496: The process_batch method is now being used in the loop over data batches. Ensure that this change is intended and that the method is used correctly.

  • 2199-2520: The CentroidPredictor class is a new addition. It extends the Predictor class and is designed to handle initialization, preprocessing, and tracking using a trained top-down multi-instance SLEAP model. Ensure that this class is used correctly in the code.

  • 5801-5830: The code block under if __name__ == "__main__": is new. It seems to be a test block for the CentroidPredictor class. Ensure that this block is intended to be included in the final code.

Comment on lines 663 to 665
def test_topdown_predictor_centroid(min_labels, min_centroid_model_path):
predictor = TopDownPredictor.from_trained_models(
centroid_model_path=min_centroid_model_path
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function test_topdown_predictor_centroid is not new, but the indentation seems to be off. This might cause a syntax error. Please verify.

-  def test_topdown_predictor_centroid(min_labels, min_centroid_model_path):
+ def test_topdown_predictor_centroid(min_labels, min_centroid_model_path):

Copy link
Collaborator Author

@roomrys roomrys left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix which node is assigned the centroid in _make_labeled_frames.

Also, in the CentroidInferenceModel in __init__ we should add a return_crops attribute to be used in call to keep the "crops" key if return_crops and otherwise remove the "crops" key so that images are not transferred back to the CPU.


predicted_instances.append(
PredictedInstance.from_numpy(
points=[pts],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will set the first node in the skeleton to pts which is fine for single node skeletons, but is incorrect if a different anchor point is set.

inds1, inds2 = sleap.nn.utils.match_points(points_gt, points_pr)
points_gt_closest = points_gt[inds1.numpy()][2:]
points_pr = points_pr[inds2.numpy()][2:]
assert_allclose(points_gt, points_gt_closest, atol=40)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test fails for two reasons...

  1. the coordinates returned in the centroid predictor are assigned to just the first node of the instance
  2. if an anchor point is not specified, then the midpoint of the bounding box is used as the centroid:
    def find_centroids(frame_data):
    """Local processing function for dataset mapping."""
    # Find the bounding box midpoints.
    mid_pts = find_points_bbox_midpoint(frame_data["instances"])
    # Update and return.
    frame_data["centroids"] = mid_pts
    return frame_data

"""

centroid_config: TrainingJobConfig
centroid_model: Optional[Model] = attr.ib(default=None)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to validate that the model either:

  1. is a single node skeleton
  2. specifies an anchor point

@roomrys roomrys self-assigned this Jan 5, 2024
@roomrys roomrys closed this May 30, 2024
@roomrys roomrys deleted the liezl/centroid-only-predictor branch May 30, 2024 19:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant