Skip to content

Commit

Permalink
HESTData: update to_spatial_data()`` (#63)
Browse files Browse the repository at this point in the history
* HESTData: update `to_spatial_data` to include pyramidal fullres and hires

* HESTData: move to_spatial_data related imports to function

* HESTData: move to_spatial_data related imports to function

* HESTData: add cellvit and tissue_contours to shapes and coordinate systems

* HESTData: move SpatialData import

* modify imports and add tests

* add return type

* bug: each scale is relative to prev. one

---------

Co-authored-by: Paul Doucet <[email protected]>
  • Loading branch information
konst-int-i and pauldoucet authored Oct 30, 2024
1 parent 8aa1279 commit f24138a
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 44 deletions.
250 changes: 207 additions & 43 deletions src/hest/HESTData.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
contours_to_img, wsi_factory)
from loguru import logger



from hest.io.seg_readers import TissueContourReader
from hest.LazyShapes import LazyShapes, convert_old_to_gpd, old_geojson_to_new
from hest.segmentation.TissueMask import TissueMask, load_tissue_mask
Expand Down Expand Up @@ -459,61 +461,223 @@ def save_tissue_vis(self, save_dir: str, name: str) -> None:
vis = self.get_tissue_vis()
vis.save(os.path.join(save_dir, f'{name}_vis.jpg'))

def to_spatial_data(self, fullres: bool = False) -> SpatialData:
"""
Convert a HESTData sample to a scverse SpatialData object. Note that a large part of this function is based on
spatialdata-io's [``from_legacy_anndata``](https://spatialdata.scverse.org/projects/io/en/latest/generated/spatialdata_io.experimental.from_legacy_anndata.html)
function with some adjustments for ``HESTData``.
def to_spatial_data(self, lazy_img=True) -> SpatialData: # type: ignore
"""Convert a HESTData sample to a scverse SpatialData object
Args:
lazy_img (bool, optional): whenever to lazily load the image if not already loaded (e.g. self.wsi is of type OpenSlide or CuImage). Defaults to True.
fullres (bool, optional): Includes pyramidal full resolution whole slide image as a ``DataTree`` object for those dimensions compatible with
Image2DModel's downsampling. Defaults to False.
Returns:
SpatialData: scverse SpatialData object
SpatialData: scverse SpatialData oobject containing the ``hires`` and ``lowres`` downsampled versions
of the image and their respective coordinate systems.
Example:
```python
from hest import load_hest
hest_data = load_hest('../hest_data', id_list=['TENX68'])
st = hest_data[0]
st.to_spatial_data(fullres=True)
>>>
```
SpatialData object
├── Images
│ ├── 'ST_downscaled_hires_image': SpatialImage[cyx] (3, 4779, 2586)
│ ├── 'ST_downscaled_lowres_image': SpatialImage[cyx] (3, 1000, 541)
│ └── 'ST_fullres_image': DataTree[cyx] (3, 38232, 20690), (3, 19116, 10345)
├── Shapes
│ └── 'locations': GeoDataFrame shape: (1657, 2) (2D shapes)
└── Tables
└── 'table': AnnData (1657, 18085)
with coordinate systems:
▸ 'ST_downscaled_hires', with elements:
ST_downscaled_hires_image (Images), locations (Shapes)
▸ 'ST_downscaled_lowres', with elements:
ST_downscaled_lowres_image (Images), locations (Shapes)
▸ 'ST_fullres', with elements:
ST_fullres_image (Images), locations (Shapes)
```
"""

import dask.array as da
# imports specific to spatial data conversion
from dask import delayed
from dask.array import from_delayed
from spatial_image import SpatialImage
from spatialdata import SpatialData
from spatialdata.models import Image2DModel, ShapesModel, TableModel

def read_hest_wsi(wsi: WSI):
return wsi.numpy()

if lazy_img:
width, height = self.wsi.get_dimensions()
arr = from_delayed(delayed(read_hest_wsi)(self.wsi), shape=(height, width, 3), dtype=np.int8)
from spatialdata.transformations import Identity, Scale

# AnnData keys
SPATIAL = "spatial"
SCALEFACTORS = "scalefactors"
TISSUE_HIRES_SCALEF = "tissue_hires_scalef"
TISSUE_LOWRES_SCALEF = "tissue_downscaled_fullres_scalef"
SPOT_DIAMETER_FULLRES = "spot_diameter_fullres"

IMAGES = "images"
HIRES = "fullres"
LOWRES = "downscaled_fullres"

# SpatialData keys
REGION = "locations"
REGION_KEY = "region"
INSTANCE_KEY = "instance_id"
SPOT_DIAMETER_FULLRES_DEFAULT = 10

images = {}
shapes = {}
spot_diameter_fullres_list = []
shapes_transformations = {}
if SPATIAL in self.adata.uns:
dataset_ids = list(self.adata.uns[SPATIAL].keys())
for dataset_id in dataset_ids:
# read the image data and the scale factors for the shapes
keys = set(self.adata.uns[SPATIAL][dataset_id].keys())
tissue_hires_scalef = None
tissue_lowres_scalef = None
hires = None
lowres = None
if SCALEFACTORS in keys:
scalefactors = self.adata.uns[SPATIAL][dataset_id][SCALEFACTORS]
if TISSUE_HIRES_SCALEF in scalefactors:
tissue_hires_scalef = scalefactors[TISSUE_HIRES_SCALEF]
else:
pixel_size=self.meta['pixel_size_um_estimated']
ds_factor = 4/pixel_size # proxy for visium hires scale factor
ds_level = self.wsi.get_best_level_for_downsample(ds_factor)
tissue_hires_scalef = 1/self.wsi.level_downsamples()[ds_level]

if TISSUE_LOWRES_SCALEF in scalefactors:
tissue_lowres_scalef = scalefactors[TISSUE_LOWRES_SCALEF]
if SPOT_DIAMETER_FULLRES in scalefactors:
spot_diameter_fullres_list.append(scalefactors[SPOT_DIAMETER_FULLRES])
if IMAGES in keys:
image_data = self.adata.uns[SPATIAL][dataset_id][IMAGES]
if HIRES in image_data:
hires = image_data[HIRES]
else:

# load wsi
def read_hest_wsi(wsi: WSI, width, height):
return wsi.get_thumbnail(width, height)

if fullres:
full_width, full_height = self.wsi.get_dimensions()
fullres = from_delayed(delayed(read_hest_wsi)(self.wsi, full_width, full_height), shape=(full_height, full_width, 3), dtype=np.int8)
else:
fullres=None
hires_width, hires_height = self.wsi.level_dimensions()[ds_level]
hires = from_delayed(delayed(read_hest_wsi)(self.wsi, hires_width, hires_height), shape=(hires_height, hires_width, 3), dtype=np.int8)

if LOWRES in image_data:
lowres = image_data[LOWRES]

# construct the spatialdata elements
if hires is not None:
# prepare the hires image
assert (
tissue_hires_scalef is not None
), "tissue_hires_scalef is required when an the hires image is present"
hires = hires.transpose(2, 0, 1)
hires_image = Image2DModel.parse(
hires,
dims=("c", "y", "x"),
transformations={f"{dataset_id}_downscaled_hires": Identity()}
)
hires_image = SpatialImage(hires_image, dims=("c", "y", "x"), name=f"{dataset_id}_downscaled_lowres_image")
images[f"{dataset_id}_downscaled_hires_image"] = hires_image

scale_hires = Scale([tissue_hires_scalef, tissue_hires_scalef], axes=("x", "y"))
shapes_transformations[f"{dataset_id}_downscaled_hires"] = scale_hires

if fullres is not None:
fullres = fullres.transpose(2, 0, 1)

# compute scale factors: each scale level is relative to the previous level
scale_factors = np.array([int(l) for l in self.wsi.level_downsamples()[1:] if full_height % l == 0 and full_width % l == 0])
scale_factors[1:] = scale_factors[1:] / scale_factors[:-1]
scale_factors = scale_factors.tolist()

fullres_image = Image2DModel.parse(
fullres,
dims=("c", "y", "x"),
scale_factors=scale_factors,
transformations={f"{dataset_id}_fullres": Identity()}
)
images[f"{dataset_id}_fullres_image"] = fullres_image
scale_fullres = Scale([1, 1], axes=("x", "y"))
shapes_transformations[f"{dataset_id}_fullres"] = scale_fullres


if lowres is not None:
assert (
tissue_lowres_scalef is not None
), "tissue_lowres_scalef is required when an the lowres image is present"
lowres = lowres.transpose(2, 0, 1)
lowres_image = Image2DModel.parse(
lowres, dims=("c", "y", "x"), transformations={f"{dataset_id}_downscaled_lowres": Identity()}
)
lowres_image = SpatialImage(lowres_image, dims=("c", "y", "x"), name=f"{dataset_id}_downscaled_lowres")
images[f"{dataset_id}_downscaled_lowres_image"] = lowres_image

scale_lowres = Scale([tissue_lowres_scalef, tissue_lowres_scalef], axes=("x", "y"))
shapes_transformations[f"{dataset_id}_downscaled_lowres"] = scale_lowres

# add cellvit and tissue contours
for it in self.shapes:
shape = it.shapes
key = it.name
if len(shape) > 0 and isinstance(shape.iloc[0], Point):
shape['radius'] = 1
val = ShapesModel.parse(shape, transformations=shapes_transformations)
shapes[key] = val
if self._tissue_contours is not None:
shapes['tissue_contours'] = ShapesModel.parse(self._tissue_contours, transformations=shapes_transformations)

# validate the spot_diameter_fullres value
if len(spot_diameter_fullres_list) > 0:
d = np.array(spot_diameter_fullres_list)
if not np.allclose(d, d[0]):
warnings.warn(
"spot_diameter_fullres is not constant across datasets. Using the average value.",
UserWarning,
stacklevel=2,
)
spot_diameter_fullres = d.mean()
else:
spot_diameter_fullres = d[0]
else:
img = self.wsi.numpy()
arr = da.from_array(img)

parsed_image = Image2DModel.parse(arr, dims=("y", "x", "c"))

shape_validated = []
shape_names = []
for it in self.shapes:
shapes = it.shapes
if len(shapes) > 0 and isinstance(shapes.iloc[0], Point):
if 'radius' not in shapes.columns:
shapes['radius'] = 1

shape_validated.append(ShapesModel.parse(shapes))
shape_names.append(it.name)

if self._tissue_contours is not None:
shape_validated.append(ShapesModel.parse(self.tissue_contours))
shape_names.append('tissue_contours')

my_images = {"he": parsed_image}

self.adata.obs['instance'] = self.adata.obs.index
self.adata.obs['region'] = 'he'
parsed_adata = TableModel.parse(self.adata, instance_key='instance', region_key='region', region='he')
my_tables = {"anndata": parsed_adata}

st = SpatialData(images=my_images, tables=my_tables, shapes=dict(zip(shape_names, shape_validated)))

return st
warnings.warn(
f"spot_diameter_fullres is not present. Using {SPOT_DIAMETER_FULLRES_DEFAULT} as default value.",
UserWarning,
stacklevel=2,
)
spot_diameter_fullres = SPOT_DIAMETER_FULLRES_DEFAULT

# parse and prepare the shapes
if SPATIAL in self.adata.obsm:
xy = self.adata.obsm[SPATIAL]
radius = spot_diameter_fullres / 2
shapes[REGION] = ShapesModel.parse(xy, geometry=0, radius=radius, transformations=shapes_transformations)

# link the shapes to the table
new_table = self.adata.copy()
if TableModel.ATTRS_KEY in new_table.uns:
del new_table.uns[TableModel.ATTRS_KEY]
new_table.obs[REGION_KEY] = REGION
new_table.obs[REGION_KEY] = new_table.obs[REGION_KEY].astype("category")
new_table.obs[INSTANCE_KEY] = shapes[REGION].index.values
new_table = TableModel.parse(new_table, region=REGION, region_key=REGION_KEY, instance_key=INSTANCE_KEY)
else:
new_table = self.adata.copy()

return SpatialData(tables=new_table, images=images, shapes=shapes)

class VisiumHESTData(HESTData):
def __init__(self,
Expand Down
10 changes: 9 additions & 1 deletion tests/hest_tests.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import os
import unittest
import warnings
from datetime import datetime
from os.path import join as _j

from hestcore.segmentation import get_path_relative
from hestcore.wsi import CucimWarningSingleton

MAX_HEST_IMPORT_S = 2
start_time = datetime.now()
import hest
end_time = datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
if elapsed_time > MAX_HEST_IMPORT_S:
raise ImportError(f"Importing 'hest' took too long ({elapsed_time:.2f} seconds). Maximum allowed time is {MAX_HEST_IMPORT_S} seconds. Please, keep large large imports conditional")

from hest.autoalign import autoalign_visium
from hest.readers import VisiumReader
from hest.utils import load_image
Expand Down Expand Up @@ -158,8 +166,8 @@ def test_spatialdata(self):
def test_patching(self):
""" Save patches as .h5 then load with H5HESTDataset """
from hestcore.datasets import H5HESTDataset
from torch.utils.data import DataLoader
from PIL import Image, ImageDraw
from torch.utils.data import DataLoader
output_dir = os.path.join(self.output_dir, 'test_patching')

for idx, st in enumerate(self.sts):
Expand Down

0 comments on commit f24138a

Please sign in to comment.