From 87b2b7caaccd4b48dcc92a7e2827342be48eb1c5 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Wed, 31 Jul 2024 13:09:00 +0100 Subject: [PATCH] fix: correct marker for test-suite --- .github/workflows/ci.yml | 2 +- sam2/utils/misc.py | 8 ++++++++ tests/test_build_model.py | 5 ++++- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fc8228dd9..82981af58 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,4 +47,4 @@ jobs: - name: Test with PyTest run: | python -m venv .venv && export PATH=".venv/bin:$PATH" - pytest -v -m "all" + pytest -v -m "full" diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py index bf6a17999..d4c8c7dda 100644 --- a/sam2/utils/misc.py +++ b/sam2/utils/misc.py @@ -7,12 +7,20 @@ import os import warnings from threading import Thread +from typing import Dict import numpy as np import torch from PIL import Image from tqdm import tqdm +variant_to_config_mapping: Dict[str, str] = { + "tiny": "sam2_hiera_t.yaml", + "small": "sam2_hiera_s.yaml", + "base_plus": "sam2_hiera_b+.yaml", + "large": "sam2_hiera_l.yaml", +} + def get_sdpa_settings(): if torch.cuda.is_available(): diff --git a/tests/test_build_model.py b/tests/test_build_model.py index 4be992f11..912a1b77c 100644 --- a/tests/test_build_model.py +++ b/tests/test_build_model.py @@ -2,6 +2,7 @@ import torch from sam2.build_sam import build_sam2 +from sam2.utils.misc import variant_to_config_mapping @pytest.mark.full @@ -11,7 +12,9 @@ ) def test_build_sam(download_weights, variant: str): model = build_sam2( - "sam2_hiera_t.yaml", "./artifacts/sam2_hiera_tiny.pt", device="cpu" + variant_to_config_mapping[variant], + f"./artifacts/sam2_hiera_{variant}.pt", + device="cpu", ) assert isinstance(model, torch.nn.Module)