Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for float32 #27

Open
phisanti opened this issue May 2, 2024 · 1 comment
Open

Support for float32 #27

phisanti opened this issue May 2, 2024 · 1 comment

Comments

@phisanti
Copy link

phisanti commented May 2, 2024

I am working on a macbook M2. I have thew following error when working with the plugin:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/psygnal/_signal.py:1048, in SignalInstance._run_emit_loop(self=<SignalInstance 'changed' on PushButton(value=False, annotation=None, name='')>, args=(False,))
   1046     with Signal._emitting(self):
   1047         # allow receiver to query sender with Signal.current_emitter()
-> 1048         self._run_emit_loop_inner()
        self = <SignalInstance 'changed' on PushButton(value=False, annotation=None, name='')>
        self._run_emit_loop_inner = <bound method SignalInstance._run_emit_loop_immediate of <SignalInstance 'changed' on PushButton(value=False, annotation=None, name='')>>
   1049 except RecursionError as e:

File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/psygnal/_signal.py:1067, in SignalInstance._run_emit_loop_immediate(self=<SignalInstance 'changed' on PushButton(value=False, annotation=None, name='')>)
   1066 self._caller = caller
-> 1067 caller.cb(args)
        args = (False,)
        caller = <WeakMethod on napari_segment_anything._widget.SAMWidget._on_auto_run>

File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/psygnal/_weak_callback.py:453, in WeakMethod.cb(self=<WeakMethod on napari_segment_anything._widget.SAMWidget._on_auto_run>, args=())
    452     args = args[: self._max_args]
--> 453 func(obj, *self._args, *args, **self._kwargs)
        obj = <Container ()>
        func = <function SAMWidget._on_auto_run at 0x30c869b20>
        args = ()
        self = <WeakMethod on napari_segment_anything._widget.SAMWidget._on_auto_run>
        self._args = ()
        self._kwargs = {}

File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/napari_segment_anything/_widget.py:192, in SAMWidget._on_auto_run(self=<Container ()>)
    191 mask_gen = SamAutomaticMaskGenerator(self._sam)
--> 192 preds = mask_gen.generate(self._image)
        mask_gen = <segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object at 0x34318d410>
        self._image = <class 'numpy.ndarray'> (2400, 2400, 3) uint8
        self = <Container ()>
    194 labels = self._labels_layer.data

File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args=(<segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object>, <class 'numpy.ndarray'> (2400, 2400, 3) uint8), **kwargs={})
    114 with ctx_factory():
--> 115     return func(*args, **kwargs)
        func = <function SamAutomaticMaskGenerator.generate at 0x30c8691c0>
        args = (<segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object at 0x34318d410>, <class 'numpy.ndarray'> (2400, 2400, 3) uint8)
        kwargs = {}

File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/segment_anything/automatic_mask_generator.py:163, in SamAutomaticMaskGenerator.generate(self=<segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object>, image=<class 'numpy.ndarray'> (2400, 2400, 3) uint8)
    162 # Generate masks
--> 163 mask_data = self._generate_masks(image)
        image = <class 'numpy.ndarray'> (2400, 2400, 3) uint8
        self = <segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object at 0x34318d410>
    165 # Filter small disconnected regions and holes in masks

File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/segment_anything/automatic_mask_generator.py:206, in SamAutomaticMaskGenerator._generate_masks(self=<segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object>, image=<class 'numpy.ndarray'> (2400, 2400, 3) uint8)
    205 for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
--> 206     crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
        orig_size = (2400, 2400)
        crop_box = [0, 0, 2400, 2400]
        layer_idx = 0
        image = <class 'numpy.ndarray'> (2400, 2400, 3) uint8
        self = <segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object at 0x34318d410>
    207     data.cat(crop_data)

File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/segment_anything/automatic_mask_generator.py:245, in SamAutomaticMaskGenerator._process_crop(self=<segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object>, image=<class 'numpy.ndarray'> (2400, 2400, 3) uint8, crop_box=[0, 0, 2400, 2400], crop_layer_idx=0, orig_size=(2400, 2400))
    244 for (points,) in batch_iterator(self.points_per_batch, points_for_image):
--> 245     batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
        crop_box = [0, 0, 2400, 2400]
        cropped_im_size = (2400, 2400)
        points = <class 'numpy.ndarray'> (64, 2) float64
        self = <segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object at 0x34318d410>
        orig_size = (2400, 2400)
    246     data.cat(batch_data)

File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/segment_anything/automatic_mask_generator.py:277, in SamAutomaticMaskGenerator._process_batch(self=<segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object>, points=<class 'numpy.ndarray'> (64, 2) float64, im_size=(2400, 2400), crop_box=[0, 0, 2400, 2400], orig_size=(2400, 2400))
    276 transformed_points = self.predictor.transform.apply_coords(points, im_size)
--> 277 in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
        transformed_points = <class 'numpy.ndarray'> (64, 2) float64
        self.predictor = <segment_anything.predictor.SamPredictor object at 0x34318d890>
        self = <segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object at 0x34318d410>
    278 in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

The above exception was the direct cause of the following exception:

EmitLoopError                             Traceback (most recent call last)
File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/magicgui/widgets/bases/_value_widget.py:71, in ValueWidget._on_value_change(self=PushButton(value=False, annotation=None, name=''), value=False)
     69 if value is self.null_value and not self._nullable:
     70     return
---> 71 self.changed.emit(value)
        value = False
        self.changed = <SignalInstance 'changed' on PushButton(value=False, annotation=None, name='')>
        self = PushButton(value=False, annotation=None, name='')

File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/psygnal/_signal.py:1025, in SignalInstance.emit(self=<SignalInstance 'changed' on PushButton(value=False, annotation=None, name='')>, check_nargs=False, check_types=False, *args=(False,))
   1021     from ._group import EmissionInfo
   1023     SignalInstance._debug_hook(EmissionInfo(self, args))
-> 1025 self._run_emit_loop(args)
        self = <SignalInstance 'changed' on PushButton(value=False, annotation=None, name='')>
        args = (False,)

File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/psygnal/_signal.py:1055, in SignalInstance._run_emit_loop(self=<SignalInstance 'changed' on PushButton(value=False, annotation=None, name='')>, args=(False,))
   1050     raise RecursionError(
   1051         f"RecursionError when "
   1052         f"emitting signal {self.name!r} with args {args}"
   1053     ) from e
   1054 except Exception as e:
-> 1055     raise EmitLoopError(
        self = <SignalInstance 'changed' on PushButton(value=False, annotation=None, name='')>
        self._args = ()
        self._caller = None
   1056         cb=self._caller, args=self._args, exc=e, signal=self
   1057     ) from e
   1058 finally:
   1059     self._emit_queue.clear()

EmitLoopError: 
While emitting signal 'magicgui.widgets.PushButton.changed', an error occurred in callback 'napari_segment_anything._widget.SAMWidget._on_auto_run'.
The args passed to the callback were: (False,)
This is not a bug in psygnal.  See 'TypeError' above for details.

I think the error might be solve if there would be support for float32. Please, let me know if this would be possible.

@JoOkuma
Copy link
Member

JoOkuma commented May 2, 2024

Hi @phisanti,

This is a limitation of segmentation-anything from Meta; there's an open PR that enables Apple silicon support, facebookresearch/segment-anything#122 but it hasn't been merged.

You can install this version from git through

pip install git+https://github.com/DrSleep/segment-anything@mps-support

Let me know if it works

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants