Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Black boy zeus batchwisevectorization amg utils #237

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,116 changes: 680 additions & 436 deletions notebooks/automatic_mask_generator_example.ipynb

Large diffs are not rendered by default.

3,262 changes: 1,917 additions & 1,345 deletions notebooks/video_predictor_example.ipynb

Large diffs are not rendered by default.

85 changes: 85 additions & 0 deletions sam2/utils/amg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import numpy as np
import torch
import torch.nn.functional as F

# Very lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py

Expand Down Expand Up @@ -346,3 +347,87 @@ def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
out = out[0]

return out

def mask_to_rle_pytorch_optimized(tensor: torch.Tensor) -> List[Dict[str, Any]]:
"""
Optimized version of mask_to_rle_pytorch().
"""
b, h, w = tensor.shape
flattened_masks = tensor.permute(0, 2, 1).flatten(1) # B x H*W

# Vectorized Change Index Calculation using torch.diff()
change_indices = (torch.diff(flattened_masks, dim=1) != 0).nonzero()

# Create a tensor with start, end, and end indices for each run length
change_indices = torch.cat([
torch.zeros((b, 1), dtype=change_indices.dtype, device=change_indices.device),
change_indices + 1,
torch.full((b, 1), h * w, dtype=change_indices.dtype, device=change_indices.device)
], dim=1)

# Batch-wise run length calculation
run_lengths = torch.diff(change_indices, dim=1)

rle = []
for i in range(b):
counts = [] if flattened_masks[i, 0] == 0 else [0]
counts.extend(run_lengths[i].detach().cpu().tolist())
rle.append({"size": [h, w], "counts": counts})
return rle

def fill_holes_in_mask_scores(mask, max_area):
"""
A post processor to fill small holes in mask scores with area under `max_area`.
Uses PyTorch operations for hole filling.
"""
assert max_area > 0, "max_area must be positive"

# 1. Identify Holes using PyTorch:
# - We use thresholding for faster hole detection.
# - If you need a more accurate approach, you can replace this with
# a connected component algorithm implemented in PyTorch.
holes = (mask <= 0).float()

# 2. Calculate Hole Areas using PyTorch:
hole_areas = holes.sum(dim=(-1, -2))

# 3. Filter Small Holes and Fill with a Small Positive Value:
mask = torch.where(
(holes > 0) & (hole_areas <= max_area), # Condition for small holes
0.1, # Fill holes with a small positive value
mask
)

return mask

def mask_intersection(mask1: torch.Tensor, mask2: torch.Tensor) -> torch.Tensor:
return mask1 * mask2


def mask_union(mask1: torch.Tensor, mask2: torch.Tensor) -> torch.Tensor:
return (mask1 + mask2) > 0


def mask_subtraction(mask1: torch.Tensor, mask2: torch.Tensor) -> torch.Tensor:
return mask1 * (1 - mask2)


def boxes_to_masks(boxes: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
"""Generates masks from bounding boxes in XYXY format."""
h, w = image_size
masks = torch.zeros((*boxes.shape[:-1], h, w), dtype=torch.bool, device=boxes.device)
for i, box in enumerate(boxes):
x1, y1, x2, y2 = box.long()
masks[i, y1:y2, x1:x2] = True
return masks



'''
Key Changes from the Previous Script:

mask_to_rle_pytorch_optimized Added: This function replaces the original mask_to_rle_pytorch function and provides significant performance improvements using vectorization and batch-wise processing.
fill_holes_in_mask_scores Updated: This function now uses PyTorch operations (thresholding and torch.where) for hole filling instead of cv2.connectedComponentsWithStats. This makes it faster and eliminates the need for data conversions.
No Changes to mask_intersection, mask_union, mask_subtraction, and boxes_to_masks: These functions are already tensor-based and remain unchanged.

'''
99 changes: 99 additions & 0 deletions sam2_configs/sam2_hiera_b+.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,102 @@ model:
use_mlp_for_obj_ptr_proj: true
# Compilation flag
compile_image_encoder: False


#Let's break down this configuration file, which appears to define the architecture of a deep learning model, likely for object detection or segmentation.
#Key Components
#The configuration uses a Python-like syntax to define the model's structure. Here's a breakdown of the key sections and their functions:
#@package _global_
#This might be a convention used by the framework to signal that the following configuration is intended to be applied globally within the project.
#model:
#Defines the main model structure.
#_target_: sam2.modeling.sam2_base.SAM2Base
#This indicates that the model is based on the SAM2Base class, likely from a library named sam2.
#image_encoder:
#Describes the image encoder part of the model, which takes an image as input and extracts features.
#_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
#Points to a specific image encoder implementation.
#scalp: 1
#Likely a parameter related to the scaling or resolution of the image encoder.
#trunk:
#Represents the backbone of the image encoder, often a convolutional neural network.
#_target_: sam2.modeling.backbones.hieradet.Hiera
#Specifies a specific type of backbone called Hiera (likely referring to a hierarchical architecture).
#embed_dim: 112
#The dimensionality of the feature vectors produced by the backbone.
#num_heads: 2
#The number of attention heads in a transformer-based layer within the backbone (likely in the Hiera architecture).
#neck:
#A feature pyramid network (FPN) responsible for combining features from different levels of the backbone.
#_target_: sam2.modeling.backbones.image_encoder.FpnNeck
#Points to a specific FPN implementation.
#position_encoding:
#A mechanism to encode spatial information into the features, which is important for tasks like object detection and segmentation.
#_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
#A common method for positional encoding using sinusoidal functions.
#d_model: 256
#The dimensionality of the features processed by the FPN.
#backbone_channel_list: [896, 448, 224, 112]
#Specifies the number of channels in different levels of the backbone.
#fpn_top_down_levels: [2, 3]
#Defines the levels of the FPN that directly use features from the backbone.
#fpn_interp_model: nearest
#The interpolation method used in the FPN for upsampling features.
#memory_attention:
#Defines a memory attention mechanism, often used to maintain a history of past observations or interactions.
#_target_: sam2.modeling.memory_attention.MemoryAttention
#Points to the implementation of the memory attention module.
#d_model: 256
#The dimensionality of the features used in the memory attention.
#pos_enc_at_input: true
#Indicates that positional encoding is applied at the input of the memory attention.
#layer:
#The specific layer within the memory attention mechanism.
#_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
#The class of the layer within the memory attention.
#self_attention:
#A self-attention mechanism, which attends to different parts of the input sequence or feature map.
#_target_: sam2.modeling.sam.transformer.RoPEAttention
#A specific type of self-attention using a mechanism called RoPE (Rotary Position Embedding) to encode positional information.
#cross_attention:
#A cross-attention mechanism, which allows the model to attend to information from different parts of the input or from other sources.
#_target_: sam2.modeling.sam.transformer.RoPEAttention
#The type of cross-attention, likely also using RoPE.
#memory_encoder:
#Defines the memory encoder, which processes and encodes the information from the memory attention module.
#_target_: sam2.modeling.memory_encoder.MemoryEncoder
#The class of the memory encoder.
#out_dim: 64
#The dimensionality of the output from the memory encoder.
#position_encoding:
#Positional encoding used in the memory encoder.
#mask_downsampler:
#A module for downsampling the mask representations.
#fuser:
#A module for fusing features from different levels of the memory encoder.
#num_maskmem: 7
#Likely the number of memory slots or cells used for storing mask representations.
#image_size: 1024
#The expected input image size for the model.
#sigmoid_scale_for_mem_enc: 20.0
#A scaling factor applied to the sigmoid function for the memory encoder's output.
#sigmoid_bias_for_mem_enc: -10.0
#A bias term added to the sigmoid function for the memory encoder's output.
#use_mask_input_as_output_without_sam: true
#An option indicating that the input mask is used as the output mask without further processing by a module called SAM.
#directly_add_no_mem_embed: true
#A flag for how the memory embedding is integrated into the model.
#use_high_res_features_in_sam: true
#Indicates that high-resolution features are used in the SAM (Segment Anything Model) module.
#multimask_output_in_sam: true
#Suggests that the SAM module can produce multiple masks as output.
#iou_prediction_use_sigmoid: True
#A flag related to the use of a sigmoid function in predicting Intersection over Union (IoU) scores.
#use_obj_ptrs_in_encoder: true
#Indicates that object pointers (information about object locations) are used in the encoder.
#pred_obj_scores: true
#A flag indicating that the model predicts object scores (confidence scores for object presence).
#multimask_output_for_tracking: true
#Suggests that the model supports multi-mask tracking, likely for tracking multiple objects over time.
#compile_image_encoder: False
#A flag for whether the image encoder should be compiled for optimization.
Loading