diff --git a/src/hest/segmentation/segmentation.py b/src/hest/segmentation/segmentation.py index c6b3e33..2253dfb 100644 --- a/src/hest/segmentation/segmentation.py +++ b/src/hest/segmentation/segmentation.py @@ -96,7 +96,8 @@ def segment_tissue_deep(img: Union[np.ndarray, openslide.OpenSlide, 'CuImage', W new_state_dict[new_key] = checkpoint['state_dict'][key] model.load_state_dict(new_state_dict) - model.cuda() + if torch.cuda.is_available(): + model.cuda() model.eval() @@ -108,7 +109,8 @@ def segment_tissue_deep(img: Union[np.ndarray, openslide.OpenSlide, 'CuImage', W # coords are top left coords of patch imgs, coords = batch - imgs = imgs.cuda() + if torch.cuda.is_available(): + imgs = imgs.cuda() masks = model(imgs)['out'] preds = masks.argmax(1)