Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More SAM2-fast server improvements #1285

Merged
merged 55 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
056c796
NT for MaskData
cpuhrsch Nov 11, 2024
3e914dc
More annotations
cpuhrsch Nov 11, 2024
c646d5f
use torch.ones
cpuhrsch Nov 11, 2024
bc84e06
remove rles_nt
cpuhrsch Nov 11, 2024
bdc3903
Remove filter calls from fullgraph region
cpuhrsch Nov 11, 2024
91942a6
Bring back rles_nt
cpuhrsch Nov 11, 2024
aea7dd9
One less kernel launch
cpuhrsch Nov 11, 2024
da45f58
Carry counts_init with rles_nt
cpuhrsch Nov 11, 2024
fd4e397
More benchmark shapes
cpuhrsch Nov 11, 2024
0997000
batch rles_nt
cpuhrsch Nov 12, 2024
a04310f
More separate compile regions
cpuhrsch Nov 12, 2024
07eb182
A/B compile regions
cpuhrsch Nov 12, 2024
c3e7a56
More CUDA graphs
cpuhrsch Nov 12, 2024
aac2ed2
More compile
cpuhrsch Nov 12, 2024
2dcc8a5
predict split
cpuhrsch Nov 12, 2024
0c336a3
More annotations
cpuhrsch Nov 12, 2024
dd2c970
Chunk at nonzero time
cpuhrsch Nov 12, 2024
4ffb068
async data transfers
cpuhrsch Nov 12, 2024
5aec26b
filter_by_index and more async transfers
cpuhrsch Nov 12, 2024
14abe87
Compile transforms and do them on GPU
cpuhrsch Nov 12, 2024
1aae906
Add a TODO
cpuhrsch Nov 12, 2024
ff72a69
MaskData None trick and skip boxes/points
cpuhrsch Nov 12, 2024
53f9615
Add a TODO
cpuhrsch Nov 12, 2024
e4a613c
Reorder filtering
cpuhrsch Nov 12, 2024
d7aac29
rles_nt_cpu
cpuhrsch Nov 12, 2024
3bc1906
filter fix
cpuhrsch Nov 13, 2024
609571f
revert torchao/_models/sam2/modeling/sam/transformer.py
cpuhrsch Nov 13, 2024
84d4ae0
ppb less than 1024
cpuhrsch Nov 14, 2024
9150d53
remove nt again
cpuhrsch Nov 14, 2024
0e8ec7f
Remove lines
cpuhrsch Nov 14, 2024
5dc6d59
remove lines
cpuhrsch Nov 14, 2024
7b159dd
remove lines
cpuhrsch Nov 14, 2024
48f0e75
cleanup
cpuhrsch Nov 14, 2024
3a3e3d2
cleanup
cpuhrsch Nov 14, 2024
46524df
cleanup
cpuhrsch Nov 14, 2024
7769ae0
Remove _process_batch_fullgraph_masks
cpuhrsch Nov 14, 2024
bf030e0
cleanup
cpuhrsch Nov 14, 2024
254e5b3
Remove uncrop_boxes_xyxy_torch
cpuhrsch Nov 14, 2024
0e677e8
_process_batch_fullgraph to only return data
cpuhrsch Nov 14, 2024
d565183
cleanup
cpuhrsch Nov 14, 2024
dfa71f8
cleanup
cpuhrsch Nov 14, 2024
8c767ce
TODO
cpuhrsch Nov 14, 2024
6f7e78d
move chunk next to diff
cpuhrsch Nov 14, 2024
9be7f1f
filter before _predict_masks_postprocess
cpuhrsch Nov 14, 2024
5c002f0
Comment out old filter
cpuhrsch Nov 14, 2024
42047ce
Do nonzero once for indexing
cpuhrsch Nov 14, 2024
e5f0b36
use jit script for transforms
cpuhrsch Nov 14, 2024
d2d4461
Use compile for calculate_stability_score
cpuhrsch Nov 14, 2024
0c690fe
RLEData
cpuhrsch Nov 14, 2024
c331fdf
keep_index, more compile, async transfer
cpuhrsch Nov 14, 2024
1f39cbf
Updated README
cpuhrsch Nov 14, 2024
f0cf865
NOTE
cpuhrsch Nov 14, 2024
d5a45b2
RLEData len
cpuhrsch Nov 14, 2024
cb3d586
Results
cpuhrsch Nov 14, 2024
a4f929e
More notes
cpuhrsch Nov 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions examples/sam2_amg_server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,19 @@ Experiments run on H100 and with batch size 1
| mode | mIoU | mask count mismatch | avg. ms per request | max. memory (MiB (%)) | batch size | points per batch |
| -------------- | ----------------- | ------------------- | ------------------- | --------------------- | ---------- | ---------------- |
| baseline | 1.0 | 0 | 863 | 4013MiB (4%) | 1 | 64 |
| ao | 1.0 | 0 | 840 | 4350MiB (4%) | 1 | 64 |
| fast | 0.9897813200950623 | 191 | 661 | 3916MiB (4%) | 1 | 64 |
| fast | 0.9897371530532837 | 192 | 388 | 50787MiB (52%) | 16 | 1024 |
| fast + furious | 0.974319338798523 | 209 | 461 | 3453MiB (3%) | 1 | 64 |
| fast + furious | 0.9702069759368896 | 196 | 195 | 48298MiB (49%) | 16 | 1024 |
| ao | 0.9999980926513672 | 6 | 586 | | 1 | 64 |
| fast | 0.9937329888343811 | 191 | 333 | | 1 | 1024 |
| fast | 0.9937219619750977 | 192 | 324 | | 16 | 1024 |
| fast + furious | 0.9804400205612183 | 292 | 131 | | 1 | 1024 |
| fast + furious | 0.9806423187255859 | 282 | 130 | | 16 | 1024 |

mask count mismatch counts the number of requests where the number of masks differ from the baseline.
For example, the baseline may have chosen to segment an image into 18 masks, but the fast variant produces 17 or 19.
We exclude these examples from the mIoU calculation.
Difference in mask count seem to stem from even only slight reorderings in compute. For example preprocessing on GPU instead of CPU.
A more relaxed way of measuring mIoU might be useful here to take into account slight differences in the number of masks, which may be caused by additional or missing sub-divisions.

The 'ao' mode is a copy of the baseline with modifications to make the code compile-able and improve the performance of fast.
The 'ao' mode is a copy of the baseline with modifications to make the code more compile-able and speed up run length encoding

### 0. Download checkpoints and install requirements

Expand Down
23 changes: 14 additions & 9 deletions examples/sam2_amg_server/compare_rle_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def iou(mask1, mask2):
union = torch.logical_or(mask1, mask2)
return (intersection.sum(dim=(-1, -2)) / union.sum(dim=(-1, -2)))


def compare_masks(masks, ref_masks, order_by_area=False, verbose=False):
from torchao._models.sam2.utils.amg import rle_to_mask
v0_areas = []
v1_areas = []
v0_masks = []
Expand All @@ -40,17 +40,20 @@ def compare_masks(masks, ref_masks, order_by_area=False, verbose=False):
v0_masks = sorted(v0_masks, key=(lambda x: x[1]), reverse=True)
v1_masks = sorted(v1_masks, key=(lambda x: x[1]), reverse=True)
miou_sum = 0.0
miou_count = 0
miou_count = 0.0
equal_count = 0
for ((v0_mask, _), (v1_mask, _)) in zip(v0_masks, v1_masks):
miou_sum += iou(v0_mask, v1_mask)
miou_count += 1
equal_count += torch.equal(v0_mask, v1_mask)
if verbose:
print(f"Masks don't match for key {k0}. IoU is {iou(v0_mask, v1_mask)}")

return miou_sum, miou_count
return miou_sum / miou_count, equal_count


def main(path0, path1):
def main(path0, path1, strict=False):
# path0 are candidates and path1 the ground truth
fail_count = 0
miou_sum = 0.0
miou_count = 0
Expand All @@ -59,11 +62,13 @@ def main(path0, path1):
masks0 = json.loads(line0)
masks1 = json.loads(line1)
if masks0.keys() != masks1.keys():
fail_count += 1
continue
s, c = compare_masks(masks0, masks1, order_by_area=True)
miou_sum += s
miou_count += c
if strict:
fail_count += 1
continue

m, e = compare_masks(masks0, masks1, order_by_area=True)
miou_sum += m
miou_count += 1

print(f"fail_count: {fail_count} mIoU: {miou_sum / miou_count}")

Expand Down
74 changes: 45 additions & 29 deletions examples/sam2_amg_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
inductorconfig.coordinate_descent_check_all_directions = True
inductorconfig.allow_buffer_reuse = False

# torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True


Expand Down Expand Up @@ -173,7 +174,7 @@ def masks_to_rle_dict(masks):

# Queue to hold incoming requests
request_queue = asyncio.Queue()
batch_interval = 0.1 # Time interval to wait before processing a batch
batch_interval = 0.01 # Time interval to wait before processing a batch


def process_batch(batch, mask_generator):
Expand All @@ -186,7 +187,7 @@ def process_batch(batch, mask_generator):
print(f"Processing batch of len {len(batch)} using generate_batch")
masks = mask_generator.generate_batch(image_tensors)
print(f"Took avg. {(time.time() - t) / len(batch)}s per batch entry")
max_memory_allocated()
# max_memory_allocated()
return masks


Expand Down Expand Up @@ -220,17 +221,17 @@ async def lifespan(app: FastAPI):
task.cancel()


def benchmark_fn(func, inp, mask_generator):
def benchmark_fn(func, inp, mask_generator, warmup=3, runs=10):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
logging.info("Running 3 warumup iterations.")
for _ in range(3):
logging.info("Running {warmup} warmup iterations.")
for _ in range(warmup):
func(inp, mask_generator)
logging.info("Running 10 benchmark iterations.")
logging.info("Running {runs} benchmark iterations.")
t = time.time()
for _ in range(10):
for _ in range(runs):
func(inp, mask_generator)
print(f"Benchmark took {(time.time() - t)/10.0}s per iteration.")
print(f"Benchmark took {(time.time() - t)/runs}s per iteration.")
max_memory_allocated()


Expand All @@ -244,11 +245,11 @@ def max_memory_allocated():

def unittest_fn(masks, ref_masks, order_by_area=False, verbose=False):
from compare_rle_lists import compare_masks
miou_sum, miou_count = compare_masks(masks, ref_masks, order_by_area=order_by_area, verbose=verbose)
if miou_count == 0:
miou, equal_count = compare_masks(masks, ref_masks, order_by_area=order_by_area, verbose=verbose)
if equal_count == len(masks):
print("Masks exactly match reference.")
else:
print(f"mIoU is {miou_sum / miou_count}")
print(f"mIoU is {miou} with equal count {equal_count} out of {len(masks)}")


def main(checkpoint_path,
Expand Down Expand Up @@ -290,7 +291,7 @@ def main(checkpoint_path,

logging.info(f"Loading model {sam2_checkpoint} with config {model_cfg}")
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)

logging.info(f"Using {points_per_batch} points_per_batch")
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle")

Expand All @@ -299,18 +300,31 @@ def main(checkpoint_path,
# TODO: Using CUDA graphs can cause numerical differences?
mask_generator.predictor.model.image_encoder = torch.compile(
mask_generator.predictor.model.image_encoder,
# mode="max-autotune-no-cudagraphs",
mode="max-autotune",
fullgraph=True,
dynamic=False,
)

mask_generator._process_batch_fullgraph = torch.compile(
mask_generator._process_batch_fullgraph,
mask_generator.predictor.model.sam_prompt_encoder.forward = torch.compile(
mask_generator.predictor.model.sam_prompt_encoder.forward,
mode="max-autotune",
fullgraph=True,
dynamic=False,
)

mask_generator.predictor._predict_masks = torch.compile(
mask_generator.predictor._predict_masks,
mode="max-autotune",
fullgraph=True,
dynamic=True,
dynamic=False,
)

# mask_generator.predictor._predict_masks_postprocess = torch.compile(
# mask_generator.predictor._predict_masks_postprocess,
# fullgraph=True,
# dynamic=True,
# )

if furious:
mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16)
# NOTE: Not baseline feature
Expand Down Expand Up @@ -340,27 +354,28 @@ def main(checkpoint_path,
unittest_fn(masks, ref_masks, order_by_area=True, verbose=verbose)

if benchmark:
print(f"batch size {batch_size} dog benchmark")
if batch_size == 1:
print("batch size 1 test")
benchmark_fn(image_tensor_to_masks, image_tensor, mask_generator)
benchmark_fn(image_tensor_to_masks, torch.tensor(image_tensor).transpose(0, 1).numpy(), mask_generator)
else:
print(f"batch size {batch_size} test")
benchmark_fn(image_tensors_to_masks, [image_tensor] * batch_size, mask_generator)

print(f"batch size {batch_size} example shapes test")
random_images = [np.random.randint(0, 256, size=size, dtype=np.uint8) for size in example_shapes()]
random_images = random_images[:batch_size]
benchmark_fn(image_tensors_to_masks, random_images, mask_generator)
for i, shapes in enumerate([example_shapes(), example_shapes_2()]):
print(f"batch size {batch_size} example shapes {i} benchmark")
random_images = [np.random.randint(0, 256, size=size, dtype=np.uint8) for size in shapes]

print(f"batch size {batch_size} example shapes 2 test")
random_images = [np.random.randint(0, 256, size=size, dtype=np.uint8) for size in example_shapes_2()]
random_images = random_images[:batch_size]
benchmark_fn(image_tensors_to_masks, random_images, mask_generator)
if batch_size == 1:
[benchmark_fn(image_tensor_to_masks, r, mask_generator) for r in random_images]
else:
random_images = random_images[:batch_size]
benchmark_fn(image_tensors_to_masks, random_images, mask_generator)

if profile is not None:
print(f"Saving profile under {profile}")
profiler_runner(profile, image_tensors_to_masks, [image_tensor] * batch_size, mask_generator)
if batch_size == 1:
profiler_runner(profile, image_tensor_to_masks, image_tensor, mask_generator)
else:
profiler_runner(profile, image_tensors_to_masks, [image_tensor] * batch_size, mask_generator)

if dry:
return
Expand Down Expand Up @@ -406,7 +421,8 @@ async def upload_image(image: UploadFile = File(...)):
return StreamingResponse(buf, media_type="image/png")


uvicorn.run(app, host=host, port=port, log_level="info")
# uvicorn.run(app, host=host, port=port, log_level="info")
uvicorn.run(app, host=host, port=port)

if __name__ == "__main__":
fire.Fire(main)
Loading
Loading