Skip to content

Commit

Permalink
style: Apply black and isort
Browse files Browse the repository at this point in the history
  • Loading branch information
loic-lb committed Aug 13, 2024
1 parent bc19b01 commit f2a045e
Show file tree
Hide file tree
Showing 24 changed files with 686 additions and 399 deletions.
3 changes: 2 additions & 1 deletion src/prismtoolbox/nucleiseg/models/sop/architectures.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import functools

import torch
import torch.nn as nn
from torch.nn import init
import functools


class Identity(nn.Module):
Expand Down
1 change: 1 addition & 0 deletions src/prismtoolbox/nucleiseg/models/sop/modules.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn

from .architectures import define_G


Expand Down
2 changes: 1 addition & 1 deletion src/prismtoolbox/nucleiseg/models/sop/postprocessing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from skimage import morphology as skmorphology
from scipy import ndimage as nd
from skimage import morphology as skmorphology
from skimage.feature import peak_local_max
from skimage.segmentation import watershed

Expand Down
24 changes: 13 additions & 11 deletions src/prismtoolbox/nucleiseg/seg_utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
from __future__ import annotations

import torch
import shapely
import numpy as np
from shapely import Polygon, MultiPolygon
from tqdm import tqdm
from functools import partial

import numpy as np
import shapely
import torch
from cellpose import models as cp_models
from .models import create_sop_segmenter, create_sop_postprocessing
from shapely import MultiPolygon, Polygon
from tqdm import tqdm

from .models import create_sop_postprocessing, create_sop_segmenter


def create_cellpose_tools(
device: str = "cuda",
model_type: str = "cyto3",
min_size=15,
flow_threshold=0.4,
channel_cellpose=0
device: str = "cuda",
model_type: str = "cyto3",
min_size=15,
flow_threshold=0.4,
channel_cellpose=0,
):
model = cp_models.Cellpose(model_type=model_type, device=torch.device(device))
model_infer = partial(
Expand Down
110 changes: 66 additions & 44 deletions src/prismtoolbox/nucleiseg/segmenter.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,46 @@
import logging
import multiprocessing as mp
import os
import time
import logging
import ipywidgets as widgets
from IPython.display import display

import cv2
import multiprocessing as mp
import ipywidgets as widgets
import numpy as np
from IPython.display import display
from PIL import Image
from tqdm import tqdm
from .seg_utils import create_segmentation_tools, solve_conflicts

from prismtoolbox import WSI
from prismtoolbox.wsicore.core_utils import contour_mask
from prismtoolbox.utils.qupath_utils import (
contoursToPolygons,
export_polygons_to_qupath,
)
from prismtoolbox.utils.torch_utils import BaseSlideHandler, ClipCustom
from prismtoolbox.utils.qupath_utils import contoursToPolygons, export_polygons_to_qupath
from prismtoolbox.wsicore.core_utils import contour_mask

from .seg_utils import create_segmentation_tools, solve_conflicts

log = logging.getLogger(__name__)


class NucleiSegmenter(BaseSlideHandler):
def __init__(
self,
slide_dir,
model_name,
pretrained_weights,
batch_size,
num_workers,
transforms_dict=None,
device="cuda",
engine="openslide",
coords_dir=None,
patch_size=None,
patch_level=None,
deconvolve_channel=None,
deconvolve_matrix="HE",
threshold_overlap=0.3,
**kwargs_seg_tool,
self,
slide_dir,
model_name,
pretrained_weights,
batch_size,
num_workers,
transforms_dict=None,
device="cuda",
engine="openslide",
coords_dir=None,
patch_size=None,
patch_level=None,
deconvolve_channel=None,
deconvolve_matrix="HE",
threshold_overlap=0.3,
**kwargs_seg_tool,
):
super().__init__(
slide_dir,
Expand All @@ -59,9 +65,14 @@ def __init__(
self.threshold_overlap = threshold_overlap
self.nuclei_seg = {}

def set_clip_custom_params_on_ex(self, slide_name, slide_ext, coords=None, deconvolve_channel=None):
def set_clip_custom_params_on_ex(
self, slide_name, slide_ext, coords=None, deconvolve_channel=None
):
dataset = self.create_dataset(
slide_name, slide_ext=slide_ext, coords=coords, deconvolve_channel=deconvolve_channel
slide_name,
slide_ext=slide_ext,
coords=coords,
deconvolve_channel=deconvolve_channel,
)
sample_patch = dataset.get_sample_patch()

Expand All @@ -71,31 +82,36 @@ def apply_clip_custom(min_value, max_value):
result_image = Image.fromarray((result_tensor.numpy() * 255).astype("uint8"))
display(result_image)

w = widgets.interact(apply_clip_custom,
min_value=widgets.FloatSlider(min=0, max=0.5, step=0.01, value=0.1),
max_value=widgets.FloatSlider(min=0.5, max=1.0, step=0.01, value=0.9))
w = widgets.interact(
apply_clip_custom,
min_value=widgets.FloatSlider(min=0, max=0.5, step=0.01, value=0.1),
max_value=widgets.FloatSlider(min=0.5, max=1.0, step=0.01, value=0.9),
)
display(w)

def segment_nuclei(
self,
slide_name,
slide_ext,
deconvolve_channel=None,
coords=None,
merge=False,
show_progress=True,
self,
slide_name,
slide_ext,
deconvolve_channel=None,
coords=None,
merge=False,
show_progress=True,
):
log.info(f"Extracting embeddings from the patches of {slide_name}.")
dataset = self.create_dataset(
slide_name, slide_ext=slide_ext, coords=coords, deconvolve_channel=deconvolve_channel
slide_name,
slide_ext=slide_ext,
coords=coords,
deconvolve_channel=deconvolve_channel,
)
dataloader = self.create_dataloader(dataset)
start_time = time.time()
masks = []
for patches, coord in tqdm(
dataloader,
desc=f"Extracting nuclei from the patches of {slide_name}",
disable=not show_progress,
dataloader,
desc=f"Extracting nuclei from the patches of {slide_name}",
disable=not show_progress,
):
if self.preprocessing_fct:
patches = self.preprocessing_fct(patches)
Expand All @@ -117,8 +133,16 @@ def _process_mask(coord, labeled_mask, border_width=3, min_point_cnt=5):

canvas = (labeled_mask == nucleus_idx).astype("uint8")

border_mask = np.pad(np.zeros((canvas.shape[0] - 2 * border_width, canvas.shape[1] - 2 * border_width)),
border_width, constant_values=1)
border_mask = np.pad(
np.zeros(
(
canvas.shape[0] - 2 * border_width,
canvas.shape[1] - 2 * border_width,
)
),
border_width,
constant_values=1,
)

if np.any((canvas * border_mask) != 0):
continue
Expand Down Expand Up @@ -149,9 +173,7 @@ def process_masks(self, masks, merge=False):

def save_nuclei(self, output_directory, slide_ext, flush_memory=True):
for slide_name, nuclei in self.nuclei_seg.items():
WSI_object = WSI(
os.path.join(self.slide_dir, f"{slide_name}.{slide_ext}")
)
WSI_object = WSI(os.path.join(self.slide_dir, f"{slide_name}.{slide_ext}"))
offset = WSI_object.offset
output_path = os.path.join(output_directory, f"{slide_name}.geojson")
export_polygons_to_qupath(
Expand Down
24 changes: 15 additions & 9 deletions src/prismtoolbox/utils/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from __future__ import annotations

import pickle
import json
import h5py
import pickle
from typing import Any, Tuple

import geopandas as gpd
import h5py
import numpy as np
from typing import Tuple, Any


def save_obj_with_pickle(obj: object, file_path: str) -> None:
"""Save an object to a file using pickle.
Args:
obj: A pickeable object.
file_path: The path to the file.
Expand All @@ -21,7 +22,7 @@ def save_obj_with_pickle(obj: object, file_path: str) -> None:

def save_obj_with_json(obj: object, file_path: str) -> None:
"""Save an object to a file using json.
Args:
obj: A json object.
file_path: The path to the file.
Expand Down Expand Up @@ -71,7 +72,10 @@ def read_h5_file(file_path: str, key: str) -> Tuple[np.ndarray, dict]:
attrs = {key: value for key, value in f[key].attrs.items()}
return object, attrs

def read_json_with_geopandas(file_path: str, offset: tuple[int, int] = (0, 0)) -> gpd.GeoDataFrame:

def read_json_with_geopandas(
file_path: str, offset: tuple[int, int] = (0, 0)
) -> gpd.GeoDataFrame:
"""Read a json file with geopandas.
Args:
Expand All @@ -84,7 +88,9 @@ def read_json_with_geopandas(file_path: str, offset: tuple[int, int] = (0, 0)) -
df = gpd.GeoDataFrame.from_features(data)
df.translate(xoff=offset[0], yoff=offset[1])
if not df.is_valid.any():
df.loc[~df.is_valid,:] = df.loc[~df.is_valid, :].buffer(0)
df.loc[~df.is_valid, :] = df.loc[~df.is_valid, :].buffer(0)
if "classification" in df.columns:
df["classification"] = df["classification"].apply(lambda x: x["name"] if type(x) == dict else x)
return df
df["classification"] = df["classification"].apply(
lambda x: x["name"] if type(x) == dict else x
)
return df
36 changes: 26 additions & 10 deletions src/prismtoolbox/utils/qupath_utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import logging
import os
import uuid
from typing import List, Optional, Tuple, Union

import numpy as np
from shapely import MultiPolygon, Polygon, box
from shapely.geometry import mapping, shape
from shapely.affinity import translate
from shapely.geometry import mapping, shape
from shapely.ops import unary_union
from typing import Optional, Tuple, List, Union
from .data_utils import load_obj_with_json, save_obj_with_json, read_json_with_geopandas

from .data_utils import load_obj_with_json, read_json_with_geopandas, save_obj_with_json

log = logging.getLogger(__name__)


def contoursToPolygons(
contours: List[np.ndarray], merge: Optional[bool] = False, make_valid: Optional[bool] = False,
contours: List[np.ndarray],
merge: Optional[bool] = False,
make_valid: Optional[bool] = False,
) -> Union[Polygon, MultiPolygon]:
"""Converts list of arrays to shapely polygons.
Expand Down Expand Up @@ -49,8 +54,13 @@ def PolygonsToContours(polygons: MultiPolygon):
for poly in polygons.geoms
]

def read_qupath_annotations(path: str, offset: Optional[Tuple[int, int]] = (0, 0), class_name: str = "annotation",
column_to_select: str = "objectType"):

def read_qupath_annotations(
path: str,
offset: Optional[Tuple[int, int]] = (0, 0),
class_name: str = "annotation",
column_to_select: str = "objectType",
):
"""Reads pathologist annotations from a .geojson file.
:param path: path to the .geojson file
Expand All @@ -73,6 +83,7 @@ def read_qupath_annotations(path: str, offset: Optional[Tuple[int, int]] = (0, 0
polygons = polygons.buffer(0)
return polygons


def convert_rgb_to_java_int_signed(rgb: Tuple[int, int, int]) -> int:
"""Converts RGB tuple to Java signed integer.
Expand Down Expand Up @@ -118,14 +129,19 @@ def export_polygons_to_qupath(
}
polygons = translate(polygons, xoff=offset[0], yoff=offset[1])
for poly in polygons.geoms:
features.append(
{
features.append(
{
"type": "Feature",
"id": str(uuid.uuid4()),
"geometry": mapping(poly),
"properties": properties,
})
features = {"type": "FeatureCollection", "features": features} if as_feature_collection else features
}
)
features = (
{"type": "FeatureCollection", "features": features}
if as_feature_collection
else features
)
if os.path.exists(path) and append_to_existing_file:
previous_features = load_obj_with_json(path)
if len(previous_features) == 0:
Expand Down
3 changes: 2 additions & 1 deletion src/prismtoolbox/utils/stain_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ def retrieve_conv_matrix(conv_matrix_name="HED"):
conv_matrix = skcolor.hed_from_rgb
elif conv_matrix_name == "HD":
conv_matrix = skcolor.hdx_from_rgb
elif conv_matrix_name == "HD_custom": # to fix
elif conv_matrix_name == "HD_custom": # to fix
conv_matrix = np.linalg.inv(IHC_custom)
else:
raise ValueError("conv_matrix_name must be 'HED', 'HD' or 'HD_custom'")
return conv_matrix


def deconvolve_img(img, conv_matrix_name="HED"):
conv_matrix = retrieve_conv_matrix(conv_matrix_name)
stains = skcolor.separate_stains(img, conv_matrix)
Expand Down
Loading

0 comments on commit f2a045e

Please sign in to comment.