Skip to content

Commit

Permalink
Increase test coverage (#280)
Browse files Browse the repository at this point in the history
Add tests for get_sam_model and iterative prediction.
  • Loading branch information
constantinpape authored Nov 20, 2023
1 parent d8457fc commit fe13841
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 13 deletions.
1 change: 1 addition & 0 deletions micro_sam/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .evaluation import run_evaluation
from .inference import (
get_predictor,
run_inference_with_iterative_prompting,
run_inference_with_prompts,
precompute_all_embeddings,
precompute_all_prompts,
Expand Down
10 changes: 5 additions & 5 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@
"vit_h_em": "https://zenodo.org/record/8250291/files/vit_h_em.pth?download=1",
"vit_b_em": "https://zenodo.org/record/8250260/files/vit_b_em.pth?download=1",
}
_CACHE_DIR = os.environ.get('MICROSAM_CACHEDIR') or pooch.os_cache('micro_sam')
_CHECKPOINT_FOLDER = os.path.join(_CACHE_DIR, 'models')

_CHECKSUMS = {
# the default segment anything models
"vit_h": "a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e",
Expand Down Expand Up @@ -87,7 +86,7 @@ def get_cache_directory() -> None:
Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.
"""
default_cache_directory = os.path.expanduser(pooch.os_cache('micro-sam'))
default_cache_directory = os.path.expanduser(pooch.os_cache('micro_sam'))
cache_directory = Path(os.environ.get('MICROSAM_CACHEDIR', default_cache_directory))
return cache_directory

Expand Down Expand Up @@ -127,11 +126,12 @@ def _get_checkpoint(model_type, checkpoint_path=None):
if checkpoint_path is None:
checkpoint_url = _MODEL_URLS[model_type]
checkpoint_name = _DOWNLOAD_NAMES.get(model_type, checkpoint_url.split("/")[-1])
checkpoint_path = os.path.join(_CHECKPOINT_FOLDER, checkpoint_name)
checkpoint_folder = os.path.join(get_cache_directory(), "models")
checkpoint_path = os.path.join(checkpoint_folder, checkpoint_name)

# download the checkpoint if necessary
if not os.path.exists(checkpoint_path):
os.makedirs(_CHECKPOINT_FOLDER, exist_ok=True)
os.makedirs(checkpoint_folder, exist_ok=True)
_download(checkpoint_url, checkpoint_path, model_type)
elif not os.path.exists(checkpoint_path):
raise ValueError(f"The checkpoint path {checkpoint_path} that was passed does not exist.")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = ["setuptools>=42.0.0", "wheel"]
build-backend = "setuptools.build_meta"

[tool.pytest.ini_options]
addopts = "-v --durations=10 --cov=micro_sam --cov-report xml:coverage.xml"
addopts = "-v --durations=10 --cov=micro_sam --cov-report xml:coverage.xml --cov-report term-missing"
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"gui: marks GUI tests (deselect with '-m \"not gui\"')",
Expand Down
30 changes: 23 additions & 7 deletions test/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch_em

from micro_sam.sample_data import synthetic_data
from micro_sam.util import VIT_T_SUPPORT
from micro_sam.util import VIT_T_SUPPORT, get_custom_sam_model, SamPredictor


@unittest.skipUnless(VIT_T_SUPPORT, "Integration test is only run with vit_t support, otherwise it takes too long.")
Expand Down Expand Up @@ -133,6 +133,9 @@ def _run_inference_and_check_results(
inference_function(predictor, image_paths, label_paths, embedding_dir, prediction_dir)

pred_paths = sorted(glob(os.path.join(prediction_dir, "*.tif")))
if len(pred_paths) == 0: # we need to go to subfolder for iterative inference
pred_paths = sorted(glob(os.path.join(prediction_dir, "iteration02", "*.tif")))

self.assertEqual(len(pred_paths), len(label_paths))
eval_res = evaluation.run_evaluation(label_paths, pred_paths, verbose=False)
result = eval_res["sa50"].values.item()
Expand All @@ -150,39 +153,52 @@ def test_training(self):
checkpoint_path = os.path.join(self.tmp_folder, "checkpoints", "test", "best.pt")
self.assertTrue(os.path.exists(checkpoint_path))

# Check that the model can be loaded from a custom checkpoint.
predictor = get_custom_sam_model(checkpoint_path, model_type=model_type, device=device)
self.assertTrue(isinstance(predictor, SamPredictor))

# Export the model.
export_path = os.path.join(self.tmp_folder, "exported_model.pth")
self._export_model(checkpoint_path, export_path, model_type)
self.assertTrue(os.path.exists(export_path))

# Check the model with inference with a single point prompt.
prediction_dir = os.path.join(self.tmp_folder, "predictions-points")
normal_inference = partial(
point_inference = partial(
evaluation.run_inference_with_prompts,
use_points=True, use_boxes=False,
n_positives=1, n_negatives=0,
batch_size=64,
)
self._run_inference_and_check_results(
export_path, model_type, prediction_dir=prediction_dir,
inference_function=normal_inference, expected_sa=0.9
inference_function=point_inference, expected_sa=0.9
)

# Check the model with inference with a box point prompt.
prediction_dir = os.path.join(self.tmp_folder, "predictions-boxes")
normal_inference = partial(
box_inference = partial(
evaluation.run_inference_with_prompts,
use_points=False, use_boxes=True,
n_positives=1, n_negatives=0,
batch_size=64,
)
self._run_inference_and_check_results(
export_path, model_type, prediction_dir=prediction_dir,
inference_function=normal_inference, expected_sa=0.95,
inference_function=box_inference, expected_sa=0.95,
)

# Check the model with interactive inference
# TODO
# Check the model with interactive inference.
prediction_dir = os.path.join(self.tmp_folder, "predictions-iterative")
iterative_inference = partial(
evaluation.run_inference_with_iterative_prompting,
start_with_box_prompt=False,
n_iterations=3,
)
self._run_inference_and_check_results(
export_path, model_type, prediction_dir=prediction_dir,
inference_function=iterative_inference, expected_sa=0.95,
)


if __name__ == "__main__":
Expand Down
20 changes: 20 additions & 0 deletions test/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@

from skimage.data import binary_blobs
from skimage.measure import label
from micro_sam.util import VIT_T_SUPPORT, SamPredictor, get_cache_directory


class TestUtil(unittest.TestCase):
model_type = "vit_t" if VIT_T_SUPPORT else "vit_b"
tmp_folder = "tmp-files"

def setUp(self):
Expand All @@ -19,6 +21,24 @@ def setUp(self):
def tearDown(self):
rmtree(self.tmp_folder)

def test_get_sam_model(self):
from micro_sam.util import get_sam_model

def check_predictor(predictor):
self.assertTrue(isinstance(predictor, SamPredictor))
self.assertEqual(predictor.model_type, self.model_type)

# check predictor with download
predictor = get_sam_model(model_type=self.model_type)
check_predictor(predictor)

# check predictor with checkpoint path (using the cached model)
checkpoint_path = os.path.join(
get_cache_directory(), "models", "vit_t_mobile_sam.pth" if VIT_T_SUPPORT else "sam_vit_b_01ec64.pth"
)
predictor = get_sam_model(model_type=self.model_type, checkpoint_path=checkpoint_path)
check_predictor(predictor)

def test_compute_iou(self):
from micro_sam.util import compute_iou

Expand Down

0 comments on commit fe13841

Please sign in to comment.