Skip to content

Commit

Permalink
Embeding extraction testing (#9)
Browse files Browse the repository at this point in the history
* fix warnings while testing


Former-commit-id: 727f40c

* delete old tests


Former-commit-id: 7049b91

* add test for embedding extraction

* update embedding extraction test to include new models

* fix requirements typo

* extract embedding extraction function from widget

* fix stop sequence adding yield to patches

* add GPU resource cleanup

* fix test cases

* hide device into model adapters

* fix ci error on windows

* reduce the number of python versions in CI tests

* rewrite tests to run inference through mock model

* pin numpy version

* register pytest mark

* fix typo
  • Loading branch information
veegalinova authored Jul 17, 2024
1 parent 7a8554d commit a590a0c
Show file tree
Hide file tree
Showing 22 changed files with 301 additions and 160 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,21 @@ on:
push:
branches:
- main
- npe2
tags:
- "v*" # Push events to matching v*, i.e. v1.0, v20.15.10
pull_request:
branches:
- main
- npe2
workflow_dispatch:

jobs:
test:
name: ${{ matrix.platform }} py${{ matrix.python-version }}
runs-on: ${{ matrix.platform }}

strategy:
matrix:
platform: [ubuntu-latest, windows-latest, macos-latest]
python-version: ['3.8', '3.9', '3.10']
python-version: ['3.10']

steps:
- uses: actions/checkout@v3
Expand Down Expand Up @@ -63,7 +61,7 @@ jobs:
- name: Coverage
uses: codecov/codecov-action@v3

deploy:
publish:
# this will run when you have tagged a commit, starting with "v*"
# and requires that you have put your twine API key in your
# github secrets (see readme for details)
Expand Down
2 changes: 1 addition & 1 deletion env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies:
- pooch
- pip
- pip:
- numpy
- numpy==1.23.5 # timm drops deprecation warnings with newer versions
- matplotlib
- opencv-python
- timm
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ requires = ["setuptools>=42.0.0", "wheel"]
build-backend = "setuptools.build_meta"


[tool.pytest.ini_options]
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
]


[tool.black]
line-length = 88
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ package_dir =

python_requires = >=3.8
install_requires =
numpy
numpy == 1.23.5
opencv-python
scikit-learn
scikit-image
Expand Down
67 changes: 19 additions & 48 deletions src/featureforest/_feature_extractor_widget.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,28 @@
import napari
import napari.utils.notifications as notif
from napari.utils.events import Event
import torch
from napari.qt.threading import create_worker
from napari.utils import progress as np_progress
from napari.utils.events import Event

from qtpy.QtCore import Qt
from qtpy.QtWidgets import (
QHBoxLayout, QVBoxLayout, QWidget,
QGroupBox,
QPushButton, QLabel, QComboBox, QLineEdit,
QFileDialog, QProgressBar,
)
from qtpy.QtCore import Qt

import h5py

from .widgets import (
ScrollWidgetWrapper,
get_layer,
)
from .models import get_available_models, get_model
from .utils import (
config
)
from .utils.data import (
get_stack_dims,
)
from .utils.extract import (
get_slice_features
from .utils.extract import extract_embeddings_to_file
from .widgets import (
ScrollWidgetWrapper,
get_layer,
)


Expand All @@ -35,10 +31,7 @@ def __init__(self, napari_viewer: napari.Viewer):
super().__init__()
self.viewer = napari_viewer
self.extract_worker = None
self.storage = None
self.model_adapter = None
self.device = None

self.prepare_widget()

def prepare_widget(self):
Expand Down Expand Up @@ -163,53 +156,25 @@ def extract_embeddings(self):
# initialize the selected model
_, img_height, img_width = get_stack_dims(image_layer.data)
model_name = self.model_combo.currentText()
self.model_adapter, self.device = get_model(model_name, img_height, img_width)
self.model_adapter = get_model(model_name, img_height, img_width)

self.extract_button.setEnabled(False)
self.stop_button.setEnabled(True)
self.extract_worker = create_worker(
self.get_stack_sam_embeddings,
image_layer, storage_path
extract_embeddings_to_file,
image=image_layer.data,
storage_file_path=storage_path,
model_adapter=self.model_adapter
)
self.extract_worker.yielded.connect(self.update_extract_progress)
self.extract_worker.finished.connect(self.extract_is_done)
self.extract_worker.errored.connect(self.stop_extracting)
self.extract_worker.run()

def get_stack_sam_embeddings(
self, image_layer, storage_path
):
# prepare the storage hdf5 file
self.storage = h5py.File(storage_path, "w")
# get sam embeddings slice by slice and save them into storage file
num_slices, img_height, img_width = get_stack_dims(image_layer.data)
self.storage.attrs["num_slices"] = num_slices
self.storage.attrs["img_height"] = img_height
self.storage.attrs["img_width"] = img_width
self.storage.attrs["model"] = self.model_combo.currentText()
self.storage.attrs["patch_size"] = self.model_adapter.patch_size
self.storage.attrs["overlap"] = self.model_adapter.overlap

for slice_index in np_progress(
range(num_slices), desc="extract features for slices"
):
image = image_layer.data[slice_index] if num_slices > 1 else image_layer.data
slice_grp = self.storage.create_group(str(slice_index))
get_slice_features(
image, self.model_adapter.patch_size, self.model_adapter.overlap,
self.model_adapter, self.device, slice_grp
)

yield (slice_index, num_slices)

self.storage.close()

def stop_extracting(self):
if self.extract_worker is not None:
self.extract_worker.quit()
self.extract_worker = None
if isinstance(self.storage, h5py.File):
self.storage.close()
self.storage = None
self.stop_button.setEnabled(False)

def update_extract_progress(self, values):
Expand All @@ -224,3 +189,9 @@ def extract_is_done(self):
self.stop_button.setEnabled(False)
print("Extracting is done!")
notif.show_info("Extracting is done!")
self.free_resource()

def free_resource(self):
if self.model_adapter is not None:
del self.model_adapter
torch.cuda.empty_cache()
3 changes: 1 addition & 2 deletions src/featureforest/_segmentation_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def __init__(self, napari_viewer: napari.Viewer):
self.storage = None
self.rf_model = None
self.model_adapter = None
self.device = None
self.patch_size = 512 # default values
self.overlap = 384
self.stride = self.patch_size - self.overlap
Expand Down Expand Up @@ -368,7 +367,7 @@ def select_storage(self):
img_width = self.storage.attrs["img_width"]
# TODO: raise an error if current image dims are in conflicting with storage
model_name = self.storage.attrs["model"]
self.model_adapter, self.device = get_model(model_name, img_height, img_width)
self.model_adapter = get_model(model_name, img_height, img_width)
print(model_name, self.patch_size, self.overlap)

def add_labels_layer(self):
Expand Down
Empty file.
66 changes: 0 additions & 66 deletions src/featureforest/_tests/test_widget.py

This file was deleted.

6 changes: 4 additions & 2 deletions src/featureforest/models/DinoV2/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@ def __init__(
self,
model: nn.Module,
img_height: float,
img_width: float
img_width: float,
device: torch.device
) -> None:
super().__init__(model, img_height, img_width)
super().__init__(model, img_height, img_width, device)
self.name = "DinoV2"
self.model = self.model
self.dino_patch_size = 14
self.dino_out_channels = 384
self._set_patch_size()
self.device = device

# input transform for dinov2
self.input_transforms = tv_transforms2.Compose([
Expand Down
6 changes: 3 additions & 3 deletions src/featureforest/models/DinoV2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

def get_model(
img_height: float, img_width: float, *args, **kwargs
) -> Tuple[DinoV2Adapter, torch.device]:
) -> DinoV2Adapter:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"running on {device}")
# get the pretrained model
Expand All @@ -19,7 +19,7 @@ def get_model(

# create the model adapter
dino_model_adapter = DinoV2Adapter(
model, img_height, img_width
model, img_height, img_width, device
)

return dino_model_adapter, device
return dino_model_adapter
5 changes: 4 additions & 1 deletion src/featureforest/models/MobileSAM/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@ def __init__(
model: nn.Module,
img_height: float,
img_width: float,
device: torch.device
) -> None:
super().__init__(model, img_height, img_width)
super().__init__(model, img_height, img_width, device)
self.name = "MobileSAM"
# we need sam image encoder part
self.encoder = self.model.image_encoder
self.encoder_num_channels = 256
self.embed_layer_num_channels = 64
self._set_patch_size()
self.device = device

# input transform for sam
self.sam_input_dim = 1024
self.input_transforms = tv_transforms2.Compose([
Expand Down
6 changes: 3 additions & 3 deletions src/featureforest/models/MobileSAM/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def get_model(
img_height: float, img_width: float, *args, **kwargs
) -> Tuple[MobileSAMAdapter, torch.device]:
) -> MobileSAMAdapter:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"running on {device}")
# get the model
Expand All @@ -33,10 +33,10 @@ def get_model(

# create the model adapter
sam_model_adapter = MobileSAMAdapter(
model, img_height, img_width
model, img_height, img_width, device
)

return sam_model_adapter, device
return sam_model_adapter


def setup_model() -> Sam:
Expand Down
4 changes: 2 additions & 2 deletions src/featureforest/models/MobileSAM/tiny_vit_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath as TimmDropPath, \
from timm.layers import DropPath as TimmDropPath, \
to_2tuple, trunc_normal_
from timm.models.registry import register_model
from timm.models import register_model
from typing import Tuple


Expand Down
5 changes: 4 additions & 1 deletion src/featureforest/models/SAM/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@ def __init__(
image_encoder: nn.Module,
img_height: float,
img_width: float,
device: torch.device
) -> None:
super().__init__(image_encoder, img_height, img_width)
super().__init__(image_encoder, img_height, img_width, device)
self.name = "SAM"
# we need sam image encoder part
self.encoder = image_encoder
self.encoder_num_channels = 256
self.embed_layer_num_channels = 1280
self._set_patch_size()
self.device = device

# input transform for sam
self.sam_input_dim = 1024
self.input_transforms = tv_transforms2.Compose([
Expand Down
6 changes: 3 additions & 3 deletions src/featureforest/models/SAM/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def get_model(
img_height: float, img_width: float, *args, **kwargs
) -> Tuple[SAMAdapter, torch.device]:
) -> SAMAdapter:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"running on {device}")
# download model's weights
Expand All @@ -32,7 +32,7 @@ def get_model(

# create the model adapter
sam_model_adapter = SAMAdapter(
sam_image_encoder, img_height, img_width
sam_image_encoder, img_height, img_width, device
)

return sam_model_adapter, device
return sam_model_adapter
Loading

0 comments on commit a590a0c

Please sign in to comment.