From 9d4ca5ecd589e846045d3047d58611f44f0e55c2 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Thu, 7 Nov 2024 15:16:53 +0100 Subject: [PATCH] fixed toml sam-2 dependency; clean-up after post-processing --- pyproject.toml | 2 +- src/featureforest/postprocess/postprocess_with_sam.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5a61892..9d8a6ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ dependencies = [ "iopath>=0.1.10", "timm==1.0.9", "segment-anything-py", - "sam2 @ git+https://github.com/facebookresearch/sam2.git" + "sam-2 @ git+https://github.com/facebookresearch/sam2.git" ] [project.optional-dependencies] # development dependencies and tooling diff --git a/src/featureforest/postprocess/postprocess_with_sam.py b/src/featureforest/postprocess/postprocess_with_sam.py index c348ccc..f290195 100644 --- a/src/featureforest/postprocess/postprocess_with_sam.py +++ b/src/featureforest/postprocess/postprocess_with_sam.py @@ -193,6 +193,10 @@ def postprocess_with_sam( # put the final label mask into final result mask final_mask[sam_label_mask] = label + # clean-up + del predictor + torch.cuda.empty_cache() + return final_mask @@ -232,6 +236,10 @@ def get_sam_auto_masks(input_image: np.ndarray) -> Tuple[np.ndarray, np.ndarray] sam_masks = np.array([mask["segmentation"] for mask in sam_generated_masks]) sam_areas = np.array([mask["area"] for mask in sam_generated_masks]) + # clean-up + del mask_generator + torch.cuda.empty_cache() + return sam_masks, sam_areas