diff --git a/pyproject.toml b/pyproject.toml index 7b9e55b..7512cfd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ [project.optional-dependencies] cellpose = [ - "cellpose>=3.0.0,<=3.0.10" + "cellpose>=3.0.0" ] testing = [ "tox", diff --git a/requirements.txt b/requirements.txt index 2ef4679..2c57471 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ tensorstore==0.1.59 ome-zarr==0.9.0 -cellpose-napari==0.1.5 magicgui==0.8.3 qtpy==2.4.1 scikit-image==0.24.0 diff --git a/src/napari_activelearning/_models.py b/src/napari_activelearning/_models.py index 8d06a6f..c6c9f73 100644 --- a/src/napari_activelearning/_models.py +++ b/src/napari_activelearning/_models.py @@ -91,8 +91,12 @@ def _run_pred(self, img, *args, **kwargs): x = self._transform(img) with torch.no_grad(): - y, _ = core.run_net(self._model_dropout.net, x) - logits = torch.from_numpy(y[:, :, 2]) + try: + y, _ = core.run_net(self._model_dropout.net, x) + logits = torch.from_numpy(y[:, :, 2]) + except ValueError: + y, _ = core.run_net(self._model_dropout.net, x[None, ...]) + logits = torch.from_numpy(y[0, :, :, 2]) probs = logits.sigmoid().numpy() return probs