Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HESTData: update to_spatial_data()`` #63

Merged
merged 9 commits into from
Oct 30, 2024
250 changes: 207 additions & 43 deletions src/hest/HESTData.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
contours_to_img, wsi_factory)
from loguru import logger



from hest.io.seg_readers import TissueContourReader
from hest.LazyShapes import LazyShapes, convert_old_to_gpd, old_geojson_to_new
from hest.segmentation.TissueMask import TissueMask, load_tissue_mask
Expand Down Expand Up @@ -459,61 +461,223 @@ def save_tissue_vis(self, save_dir: str, name: str) -> None:
vis = self.get_tissue_vis()
vis.save(os.path.join(save_dir, f'{name}_vis.jpg'))

def to_spatial_data(self, fullres: bool = False) -> SpatialData:
"""
Convert a HESTData sample to a scverse SpatialData object. Note that a large part of this function is based on
spatialdata-io's [``from_legacy_anndata``](https://spatialdata.scverse.org/projects/io/en/latest/generated/spatialdata_io.experimental.from_legacy_anndata.html)
function with some adjustments for ``HESTData``.

def to_spatial_data(self, lazy_img=True) -> SpatialData: # type: ignore
"""Convert a HESTData sample to a scverse SpatialData object

Args:
lazy_img (bool, optional): whenever to lazily load the image if not already loaded (e.g. self.wsi is of type OpenSlide or CuImage). Defaults to True.
fullres (bool, optional): Includes pyramidal full resolution whole slide image as a ``DataTree`` object for those dimensions compatible with
Image2DModel's downsampling. Defaults to False.

Returns:
SpatialData: scverse SpatialData object
SpatialData: scverse SpatialData oobject containing the ``hires`` and ``lowres`` downsampled versions
of the image and their respective coordinate systems.

Example:
```python
from hest import load_hest
hest_data = load_hest('../hest_data', id_list=['TENX68'])
st = hest_data[0]
st.to_spatial_data(fullres=True)

>>>

```
SpatialData object
├── Images
│ ├── 'ST_downscaled_hires_image': SpatialImage[cyx] (3, 4779, 2586)
│ ├── 'ST_downscaled_lowres_image': SpatialImage[cyx] (3, 1000, 541)
│ └── 'ST_fullres_image': DataTree[cyx] (3, 38232, 20690), (3, 19116, 10345)
├── Shapes
│ └── 'locations': GeoDataFrame shape: (1657, 2) (2D shapes)
└── Tables
└── 'table': AnnData (1657, 18085)
with coordinate systems:
▸ 'ST_downscaled_hires', with elements:
ST_downscaled_hires_image (Images), locations (Shapes)
▸ 'ST_downscaled_lowres', with elements:
ST_downscaled_lowres_image (Images), locations (Shapes)
▸ 'ST_fullres', with elements:
ST_fullres_image (Images), locations (Shapes)
```

"""

import dask.array as da
# imports specific to spatial data conversion
from dask import delayed
from dask.array import from_delayed
from spatial_image import SpatialImage
from spatialdata import SpatialData
from spatialdata.models import Image2DModel, ShapesModel, TableModel

def read_hest_wsi(wsi: WSI):
return wsi.numpy()

if lazy_img:
width, height = self.wsi.get_dimensions()
arr = from_delayed(delayed(read_hest_wsi)(self.wsi), shape=(height, width, 3), dtype=np.int8)
from spatialdata.transformations import Identity, Scale

# AnnData keys
SPATIAL = "spatial"
SCALEFACTORS = "scalefactors"
TISSUE_HIRES_SCALEF = "tissue_hires_scalef"
TISSUE_LOWRES_SCALEF = "tissue_downscaled_fullres_scalef"
SPOT_DIAMETER_FULLRES = "spot_diameter_fullres"

IMAGES = "images"
HIRES = "fullres"
LOWRES = "downscaled_fullres"

# SpatialData keys
REGION = "locations"
REGION_KEY = "region"
INSTANCE_KEY = "instance_id"
SPOT_DIAMETER_FULLRES_DEFAULT = 10

images = {}
shapes = {}
spot_diameter_fullres_list = []
shapes_transformations = {}
if SPATIAL in self.adata.uns:
dataset_ids = list(self.adata.uns[SPATIAL].keys())
for dataset_id in dataset_ids:
# read the image data and the scale factors for the shapes
keys = set(self.adata.uns[SPATIAL][dataset_id].keys())
tissue_hires_scalef = None
tissue_lowres_scalef = None
hires = None
lowres = None
if SCALEFACTORS in keys:
scalefactors = self.adata.uns[SPATIAL][dataset_id][SCALEFACTORS]
if TISSUE_HIRES_SCALEF in scalefactors:
tissue_hires_scalef = scalefactors[TISSUE_HIRES_SCALEF]
else:
pixel_size=self.meta['pixel_size_um_estimated']
ds_factor = 4/pixel_size # proxy for visium hires scale factor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this part going to work for non-visium ST?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - the hires image here is an approximation for ~4 MPPs which is standard with Visium. I have tested it with non-visium slides and the same logic works and accurately downsamples the image, although the hires terminology is not normally given for these slides to my knowledge.

ds_level = self.wsi.get_best_level_for_downsample(ds_factor)
tissue_hires_scalef = 1/self.wsi.level_downsamples()[ds_level]

if TISSUE_LOWRES_SCALEF in scalefactors:
tissue_lowres_scalef = scalefactors[TISSUE_LOWRES_SCALEF]
if SPOT_DIAMETER_FULLRES in scalefactors:
spot_diameter_fullres_list.append(scalefactors[SPOT_DIAMETER_FULLRES])
if IMAGES in keys:
image_data = self.adata.uns[SPATIAL][dataset_id][IMAGES]
if HIRES in image_data:
hires = image_data[HIRES]
else:

# load wsi
def read_hest_wsi(wsi: WSI, width, height):
return wsi.get_thumbnail(width, height)

if fullres:
full_width, full_height = self.wsi.get_dimensions()
fullres = from_delayed(delayed(read_hest_wsi)(self.wsi, full_width, full_height), shape=(full_height, full_width, 3), dtype=np.int8)
else:
fullres=None
hires_width, hires_height = self.wsi.level_dimensions()[ds_level]
hires = from_delayed(delayed(read_hest_wsi)(self.wsi, hires_width, hires_height), shape=(hires_height, hires_width, 3), dtype=np.int8)

if LOWRES in image_data:
lowres = image_data[LOWRES]

# construct the spatialdata elements
if hires is not None:
# prepare the hires image
assert (
tissue_hires_scalef is not None
), "tissue_hires_scalef is required when an the hires image is present"
hires = hires.transpose(2, 0, 1)
hires_image = Image2DModel.parse(
hires,
dims=("c", "y", "x"),
transformations={f"{dataset_id}_downscaled_hires": Identity()}
)
hires_image = SpatialImage(hires_image, dims=("c", "y", "x"), name=f"{dataset_id}_downscaled_lowres_image")
images[f"{dataset_id}_downscaled_hires_image"] = hires_image

scale_hires = Scale([tissue_hires_scalef, tissue_hires_scalef], axes=("x", "y"))
shapes_transformations[f"{dataset_id}_downscaled_hires"] = scale_hires

if fullres is not None:
fullres = fullres.transpose(2, 0, 1)

# compute scale factors: each scale level is relative to the previous level
scale_factors = np.array([int(l) for l in self.wsi.level_downsamples()[1:] if full_height % l == 0 and full_width % l == 0])
scale_factors[1:] = scale_factors[1:] / scale_factors[:-1]
scale_factors = scale_factors.tolist()

fullres_image = Image2DModel.parse(
fullres,
dims=("c", "y", "x"),
scale_factors=scale_factors,
transformations={f"{dataset_id}_fullres": Identity()}
)
images[f"{dataset_id}_fullres_image"] = fullres_image
scale_fullres = Scale([1, 1], axes=("x", "y"))
shapes_transformations[f"{dataset_id}_fullres"] = scale_fullres


if lowres is not None:
assert (
tissue_lowres_scalef is not None
), "tissue_lowres_scalef is required when an the lowres image is present"
lowres = lowres.transpose(2, 0, 1)
lowres_image = Image2DModel.parse(
lowres, dims=("c", "y", "x"), transformations={f"{dataset_id}_downscaled_lowres": Identity()}
)
lowres_image = SpatialImage(lowres_image, dims=("c", "y", "x"), name=f"{dataset_id}_downscaled_lowres")
images[f"{dataset_id}_downscaled_lowres_image"] = lowres_image

scale_lowres = Scale([tissue_lowres_scalef, tissue_lowres_scalef], axes=("x", "y"))
shapes_transformations[f"{dataset_id}_downscaled_lowres"] = scale_lowres

# add cellvit and tissue contours
for it in self.shapes:
shape = it.shapes
key = it.name
if len(shape) > 0 and isinstance(shape.iloc[0], Point):
shape['radius'] = 1
val = ShapesModel.parse(shape, transformations=shapes_transformations)
shapes[key] = val
if self._tissue_contours is not None:
shapes['tissue_contours'] = ShapesModel.parse(self._tissue_contours, transformations=shapes_transformations)

# validate the spot_diameter_fullres value
if len(spot_diameter_fullres_list) > 0:
d = np.array(spot_diameter_fullres_list)
if not np.allclose(d, d[0]):
warnings.warn(
"spot_diameter_fullres is not constant across datasets. Using the average value.",
UserWarning,
stacklevel=2,
)
spot_diameter_fullres = d.mean()
else:
spot_diameter_fullres = d[0]
else:
img = self.wsi.numpy()
arr = da.from_array(img)

parsed_image = Image2DModel.parse(arr, dims=("y", "x", "c"))

shape_validated = []
shape_names = []
for it in self.shapes:
shapes = it.shapes
if len(shapes) > 0 and isinstance(shapes.iloc[0], Point):
if 'radius' not in shapes.columns:
shapes['radius'] = 1

shape_validated.append(ShapesModel.parse(shapes))
shape_names.append(it.name)

if self._tissue_contours is not None:
shape_validated.append(ShapesModel.parse(self.tissue_contours))
shape_names.append('tissue_contours')

my_images = {"he": parsed_image}

self.adata.obs['instance'] = self.adata.obs.index
self.adata.obs['region'] = 'he'
parsed_adata = TableModel.parse(self.adata, instance_key='instance', region_key='region', region='he')
my_tables = {"anndata": parsed_adata}

st = SpatialData(images=my_images, tables=my_tables, shapes=dict(zip(shape_names, shape_validated)))

return st
warnings.warn(
f"spot_diameter_fullres is not present. Using {SPOT_DIAMETER_FULLRES_DEFAULT} as default value.",
UserWarning,
stacklevel=2,
)
spot_diameter_fullres = SPOT_DIAMETER_FULLRES_DEFAULT

# parse and prepare the shapes
if SPATIAL in self.adata.obsm:
xy = self.adata.obsm[SPATIAL]
radius = spot_diameter_fullres / 2
shapes[REGION] = ShapesModel.parse(xy, geometry=0, radius=radius, transformations=shapes_transformations)

# link the shapes to the table
new_table = self.adata.copy()
if TableModel.ATTRS_KEY in new_table.uns:
del new_table.uns[TableModel.ATTRS_KEY]
new_table.obs[REGION_KEY] = REGION
new_table.obs[REGION_KEY] = new_table.obs[REGION_KEY].astype("category")
new_table.obs[INSTANCE_KEY] = shapes[REGION].index.values
new_table = TableModel.parse(new_table, region=REGION, region_key=REGION_KEY, instance_key=INSTANCE_KEY)
else:
new_table = self.adata.copy()

return SpatialData(tables=new_table, images=images, shapes=shapes)

class VisiumHESTData(HESTData):
def __init__(self,
Expand Down
10 changes: 9 additions & 1 deletion tests/hest_tests.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import os
import unittest
import warnings
from datetime import datetime
from os.path import join as _j

from hestcore.segmentation import get_path_relative
from hestcore.wsi import CucimWarningSingleton

MAX_HEST_IMPORT_S = 2
start_time = datetime.now()
import hest
end_time = datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
if elapsed_time > MAX_HEST_IMPORT_S:
raise ImportError(f"Importing 'hest' took too long ({elapsed_time:.2f} seconds). Maximum allowed time is {MAX_HEST_IMPORT_S} seconds. Please, keep large large imports conditional")

from hest.autoalign import autoalign_visium
from hest.readers import VisiumReader
from hest.utils import load_image
Expand Down Expand Up @@ -158,8 +166,8 @@ def test_spatialdata(self):
def test_patching(self):
""" Save patches as .h5 then load with H5HESTDataset """
from hestcore.datasets import H5HESTDataset
from torch.utils.data import DataLoader
from PIL import Image, ImageDraw
from torch.utils.data import DataLoader
output_dir = os.path.join(self.output_dir, 'test_patching')

for idx, st in enumerate(self.sts):
Expand Down
Loading