From f24138aab3d37449be3d71382d73a4b7191195ff Mon Sep 17 00:00:00 2001 From: Konstantin Hemker <33329141+konst-int-i@users.noreply.github.com> Date: Wed, 30 Oct 2024 16:17:40 -0400 Subject: [PATCH] HESTData: update `to_spatial_data()``` (#63) * HESTData: update `to_spatial_data` to include pyramidal fullres and hires * HESTData: move to_spatial_data related imports to function * HESTData: move to_spatial_data related imports to function * HESTData: add cellvit and tissue_contours to shapes and coordinate systems * HESTData: move SpatialData import * modify imports and add tests * add return type * bug: each scale is relative to prev. one --------- Co-authored-by: Paul Doucet --- src/hest/HESTData.py | 250 +++++++++++++++++++++++++++++++++++-------- tests/hest_tests.py | 10 +- 2 files changed, 216 insertions(+), 44 deletions(-) diff --git a/src/hest/HESTData.py b/src/hest/HESTData.py index 140714c..825bc88 100644 --- a/src/hest/HESTData.py +++ b/src/hest/HESTData.py @@ -13,6 +13,8 @@ contours_to_img, wsi_factory) from loguru import logger + + from hest.io.seg_readers import TissueContourReader from hest.LazyShapes import LazyShapes, convert_old_to_gpd, old_geojson_to_new from hest.segmentation.TissueMask import TissueMask, load_tissue_mask @@ -459,61 +461,223 @@ def save_tissue_vis(self, save_dir: str, name: str) -> None: vis = self.get_tissue_vis() vis.save(os.path.join(save_dir, f'{name}_vis.jpg')) + def to_spatial_data(self, fullres: bool = False) -> SpatialData: + """ + Convert a HESTData sample to a scverse SpatialData object. Note that a large part of this function is based on + spatialdata-io's [``from_legacy_anndata``](https://spatialdata.scverse.org/projects/io/en/latest/generated/spatialdata_io.experimental.from_legacy_anndata.html) + function with some adjustments for ``HESTData``. - def to_spatial_data(self, lazy_img=True) -> SpatialData: # type: ignore - """Convert a HESTData sample to a scverse SpatialData object - Args: - lazy_img (bool, optional): whenever to lazily load the image if not already loaded (e.g. self.wsi is of type OpenSlide or CuImage). Defaults to True. + fullres (bool, optional): Includes pyramidal full resolution whole slide image as a ``DataTree`` object for those dimensions compatible with + Image2DModel's downsampling. Defaults to False. Returns: - SpatialData: scverse SpatialData object + SpatialData: scverse SpatialData oobject containing the ``hires`` and ``lowres`` downsampled versions + of the image and their respective coordinate systems. + + Example: + ```python + from hest import load_hest + hest_data = load_hest('../hest_data', id_list=['TENX68']) + st = hest_data[0] + st.to_spatial_data(fullres=True) + + >>> + + ``` + SpatialData object + ├── Images + │ ├── 'ST_downscaled_hires_image': SpatialImage[cyx] (3, 4779, 2586) + │ ├── 'ST_downscaled_lowres_image': SpatialImage[cyx] (3, 1000, 541) + │ └── 'ST_fullres_image': DataTree[cyx] (3, 38232, 20690), (3, 19116, 10345) + ├── Shapes + │ └── 'locations': GeoDataFrame shape: (1657, 2) (2D shapes) + └── Tables + └── 'table': AnnData (1657, 18085) + with coordinate systems: + ▸ 'ST_downscaled_hires', with elements: + ST_downscaled_hires_image (Images), locations (Shapes) + ▸ 'ST_downscaled_lowres', with elements: + ST_downscaled_lowres_image (Images), locations (Shapes) + ▸ 'ST_fullres', with elements: + ST_fullres_image (Images), locations (Shapes) + ``` + """ - import dask.array as da + # imports specific to spatial data conversion from dask import delayed from dask.array import from_delayed + from spatial_image import SpatialImage from spatialdata import SpatialData from spatialdata.models import Image2DModel, ShapesModel, TableModel - - def read_hest_wsi(wsi: WSI): - return wsi.numpy() - - if lazy_img: - width, height = self.wsi.get_dimensions() - arr = from_delayed(delayed(read_hest_wsi)(self.wsi), shape=(height, width, 3), dtype=np.int8) + from spatialdata.transformations import Identity, Scale + + # AnnData keys + SPATIAL = "spatial" + SCALEFACTORS = "scalefactors" + TISSUE_HIRES_SCALEF = "tissue_hires_scalef" + TISSUE_LOWRES_SCALEF = "tissue_downscaled_fullres_scalef" + SPOT_DIAMETER_FULLRES = "spot_diameter_fullres" + + IMAGES = "images" + HIRES = "fullres" + LOWRES = "downscaled_fullres" + + # SpatialData keys + REGION = "locations" + REGION_KEY = "region" + INSTANCE_KEY = "instance_id" + SPOT_DIAMETER_FULLRES_DEFAULT = 10 + + images = {} + shapes = {} + spot_diameter_fullres_list = [] + shapes_transformations = {} + if SPATIAL in self.adata.uns: + dataset_ids = list(self.adata.uns[SPATIAL].keys()) + for dataset_id in dataset_ids: + # read the image data and the scale factors for the shapes + keys = set(self.adata.uns[SPATIAL][dataset_id].keys()) + tissue_hires_scalef = None + tissue_lowres_scalef = None + hires = None + lowres = None + if SCALEFACTORS in keys: + scalefactors = self.adata.uns[SPATIAL][dataset_id][SCALEFACTORS] + if TISSUE_HIRES_SCALEF in scalefactors: + tissue_hires_scalef = scalefactors[TISSUE_HIRES_SCALEF] + else: + pixel_size=self.meta['pixel_size_um_estimated'] + ds_factor = 4/pixel_size # proxy for visium hires scale factor + ds_level = self.wsi.get_best_level_for_downsample(ds_factor) + tissue_hires_scalef = 1/self.wsi.level_downsamples()[ds_level] + + if TISSUE_LOWRES_SCALEF in scalefactors: + tissue_lowres_scalef = scalefactors[TISSUE_LOWRES_SCALEF] + if SPOT_DIAMETER_FULLRES in scalefactors: + spot_diameter_fullres_list.append(scalefactors[SPOT_DIAMETER_FULLRES]) + if IMAGES in keys: + image_data = self.adata.uns[SPATIAL][dataset_id][IMAGES] + if HIRES in image_data: + hires = image_data[HIRES] + else: + + # load wsi + def read_hest_wsi(wsi: WSI, width, height): + return wsi.get_thumbnail(width, height) + + if fullres: + full_width, full_height = self.wsi.get_dimensions() + fullres = from_delayed(delayed(read_hest_wsi)(self.wsi, full_width, full_height), shape=(full_height, full_width, 3), dtype=np.int8) + else: + fullres=None + hires_width, hires_height = self.wsi.level_dimensions()[ds_level] + hires = from_delayed(delayed(read_hest_wsi)(self.wsi, hires_width, hires_height), shape=(hires_height, hires_width, 3), dtype=np.int8) + + if LOWRES in image_data: + lowres = image_data[LOWRES] + + # construct the spatialdata elements + if hires is not None: + # prepare the hires image + assert ( + tissue_hires_scalef is not None + ), "tissue_hires_scalef is required when an the hires image is present" + hires = hires.transpose(2, 0, 1) + hires_image = Image2DModel.parse( + hires, + dims=("c", "y", "x"), + transformations={f"{dataset_id}_downscaled_hires": Identity()} + ) + hires_image = SpatialImage(hires_image, dims=("c", "y", "x"), name=f"{dataset_id}_downscaled_lowres_image") + images[f"{dataset_id}_downscaled_hires_image"] = hires_image + + scale_hires = Scale([tissue_hires_scalef, tissue_hires_scalef], axes=("x", "y")) + shapes_transformations[f"{dataset_id}_downscaled_hires"] = scale_hires + + if fullres is not None: + fullres = fullres.transpose(2, 0, 1) + + # compute scale factors: each scale level is relative to the previous level + scale_factors = np.array([int(l) for l in self.wsi.level_downsamples()[1:] if full_height % l == 0 and full_width % l == 0]) + scale_factors[1:] = scale_factors[1:] / scale_factors[:-1] + scale_factors = scale_factors.tolist() + + fullres_image = Image2DModel.parse( + fullres, + dims=("c", "y", "x"), + scale_factors=scale_factors, + transformations={f"{dataset_id}_fullres": Identity()} + ) + images[f"{dataset_id}_fullres_image"] = fullres_image + scale_fullres = Scale([1, 1], axes=("x", "y")) + shapes_transformations[f"{dataset_id}_fullres"] = scale_fullres + + + if lowres is not None: + assert ( + tissue_lowres_scalef is not None + ), "tissue_lowres_scalef is required when an the lowres image is present" + lowres = lowres.transpose(2, 0, 1) + lowres_image = Image2DModel.parse( + lowres, dims=("c", "y", "x"), transformations={f"{dataset_id}_downscaled_lowres": Identity()} + ) + lowres_image = SpatialImage(lowres_image, dims=("c", "y", "x"), name=f"{dataset_id}_downscaled_lowres") + images[f"{dataset_id}_downscaled_lowres_image"] = lowres_image + + scale_lowres = Scale([tissue_lowres_scalef, tissue_lowres_scalef], axes=("x", "y")) + shapes_transformations[f"{dataset_id}_downscaled_lowres"] = scale_lowres + + # add cellvit and tissue contours + for it in self.shapes: + shape = it.shapes + key = it.name + if len(shape) > 0 and isinstance(shape.iloc[0], Point): + shape['radius'] = 1 + val = ShapesModel.parse(shape, transformations=shapes_transformations) + shapes[key] = val + if self._tissue_contours is not None: + shapes['tissue_contours'] = ShapesModel.parse(self._tissue_contours, transformations=shapes_transformations) + + # validate the spot_diameter_fullres value + if len(spot_diameter_fullres_list) > 0: + d = np.array(spot_diameter_fullres_list) + if not np.allclose(d, d[0]): + warnings.warn( + "spot_diameter_fullres is not constant across datasets. Using the average value.", + UserWarning, + stacklevel=2, + ) + spot_diameter_fullres = d.mean() + else: + spot_diameter_fullres = d[0] else: - img = self.wsi.numpy() - arr = da.from_array(img) - - parsed_image = Image2DModel.parse(arr, dims=("y", "x", "c")) - - shape_validated = [] - shape_names = [] - for it in self.shapes: - shapes = it.shapes - if len(shapes) > 0 and isinstance(shapes.iloc[0], Point): - if 'radius' not in shapes.columns: - shapes['radius'] = 1 - - shape_validated.append(ShapesModel.parse(shapes)) - shape_names.append(it.name) - - if self._tissue_contours is not None: - shape_validated.append(ShapesModel.parse(self.tissue_contours)) - shape_names.append('tissue_contours') - - my_images = {"he": parsed_image} - - self.adata.obs['instance'] = self.adata.obs.index - self.adata.obs['region'] = 'he' - parsed_adata = TableModel.parse(self.adata, instance_key='instance', region_key='region', region='he') - my_tables = {"anndata": parsed_adata} - - st = SpatialData(images=my_images, tables=my_tables, shapes=dict(zip(shape_names, shape_validated))) - - return st + warnings.warn( + f"spot_diameter_fullres is not present. Using {SPOT_DIAMETER_FULLRES_DEFAULT} as default value.", + UserWarning, + stacklevel=2, + ) + spot_diameter_fullres = SPOT_DIAMETER_FULLRES_DEFAULT + + # parse and prepare the shapes + if SPATIAL in self.adata.obsm: + xy = self.adata.obsm[SPATIAL] + radius = spot_diameter_fullres / 2 + shapes[REGION] = ShapesModel.parse(xy, geometry=0, radius=radius, transformations=shapes_transformations) + + # link the shapes to the table + new_table = self.adata.copy() + if TableModel.ATTRS_KEY in new_table.uns: + del new_table.uns[TableModel.ATTRS_KEY] + new_table.obs[REGION_KEY] = REGION + new_table.obs[REGION_KEY] = new_table.obs[REGION_KEY].astype("category") + new_table.obs[INSTANCE_KEY] = shapes[REGION].index.values + new_table = TableModel.parse(new_table, region=REGION, region_key=REGION_KEY, instance_key=INSTANCE_KEY) + else: + new_table = self.adata.copy() + return SpatialData(tables=new_table, images=images, shapes=shapes) class VisiumHESTData(HESTData): def __init__(self, diff --git a/tests/hest_tests.py b/tests/hest_tests.py index bbad0d5..94924d4 100644 --- a/tests/hest_tests.py +++ b/tests/hest_tests.py @@ -1,12 +1,20 @@ import os import unittest import warnings +from datetime import datetime from os.path import join as _j from hestcore.segmentation import get_path_relative from hestcore.wsi import CucimWarningSingleton +MAX_HEST_IMPORT_S = 2 +start_time = datetime.now() import hest +end_time = datetime.now() +elapsed_time = (end_time - start_time).total_seconds() +if elapsed_time > MAX_HEST_IMPORT_S: + raise ImportError(f"Importing 'hest' took too long ({elapsed_time:.2f} seconds). Maximum allowed time is {MAX_HEST_IMPORT_S} seconds. Please, keep large large imports conditional") + from hest.autoalign import autoalign_visium from hest.readers import VisiumReader from hest.utils import load_image @@ -158,8 +166,8 @@ def test_spatialdata(self): def test_patching(self): """ Save patches as .h5 then load with H5HESTDataset """ from hestcore.datasets import H5HESTDataset - from torch.utils.data import DataLoader from PIL import Image, ImageDraw + from torch.utils.data import DataLoader output_dir = os.path.join(self.output_dir, 'test_patching') for idx, st in enumerate(self.sts):