diff --git a/test/test_sam_annotator/test_cli.py b/test/test_sam_annotator/test_cli.py index 1df399f9..8a413e89 100644 --- a/test/test_sam_annotator/test_cli.py +++ b/test/test_sam_annotator/test_cli.py @@ -6,6 +6,7 @@ import imageio.v3 as imageio import micro_sam.util as util +import pytest import zarr from skimage.data import binary_blobs @@ -35,19 +36,13 @@ def test_annotator_tracking(self): def test_image_series_annotator(self): self._test_command("micro_sam.image_series_annotator") + @pytest.mark.skipif(platform.system() == "Windows", reason="Gui test is not working on windows.") def test_precompute_embeddings(self): self._test_command("micro_sam.precompute_embeddings") - def test_automatic_segmentation(self): - self._test_command("micro_sam.automatic_segmentation") - - # The filepaths can't be found on windows, probably due different filepath conventions. - # The actual functionality likely works despite this issue. - if platform.system() == "Windows": - return - - # Create 3 images as testdata. - for i in range(3): + # Create 2 images as testdata. + n_images = 2 + for i in range(n_images): im_path = os.path.join(self.tmp_folder, f"image-{i}.tif") image_data = binary_blobs(512).astype("uint8") * 255 imageio.imwrite(im_path, image_data) @@ -73,7 +68,7 @@ def test_automatic_segmentation(self): self.assertTrue(os.path.exists(emb_path2)) with zarr.open(emb_path2, "r") as f: self.assertIn("features", f) - self.assertEqual(f["features"].shape[0], 3) + self.assertEqual(f["features"].shape[0], n_images) ais_path = os.path.join(emb_path2, "is_state.h5") self.assertTrue(os.path.exists(ais_path)) @@ -83,11 +78,14 @@ def test_automatic_segmentation(self): "micro_sam.precompute_embeddings", "-i", self.tmp_folder, "-e", emb_path3, "-m", self.model_type, "--pattern", "*.tif", "--precompute_amg_state" ]) - for i in range(3): + for i in range(n_images): self.assertTrue(os.path.exists(os.path.join(emb_path3, f"image-{i}.zarr"))) ais_path = os.path.join(emb_path3, f"image-{i}.zarr", "is_state.h5") self.assertTrue(os.path.exists(ais_path)) + def test_automatic_segmentation(self): + self._test_command("micro_sam.automatic_segmentation") + if __name__ == "__main__": unittest.main()