Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HESTData: update to_spatial_data()`` #63

Merged
merged 9 commits into from
Oct 30, 2024
240 changes: 197 additions & 43 deletions src/hest/HESTData.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from hestcore.wsi import (WSI, CucimWarningSingleton, NumpyWSI,
contours_to_img, wsi_factory)
from loguru import logger
from spatialdata import SpatialData
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to move this import line inside the to_spatial_data() method, the SpatialData is huge

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved in 98d0acf - I had it outside to enable the type hint in the function, but appreciate that it comes with a lot of overhead




from hest.io.seg_readers import TissueContourReader
from hest.LazyShapes import LazyShapes, convert_old_to_gpd, old_geojson_to_new
Expand Down Expand Up @@ -460,60 +463,211 @@ def save_tissue_vis(self, save_dir: str, name: str) -> None:
vis.save(os.path.join(save_dir, f'{name}_vis.jpg'))


def to_spatial_data(self, lazy_img=True) -> SpatialData: # type: ignore
"""Convert a HESTData sample to a scverse SpatialData object

def to_spatial_data(self, fullres: bool = False) -> SpatialData:
"""
Convert a HESTData sample to a scverse SpatialData object. Note that a large part of this function is based on
spatialdata-io's [``from_legacy_anndata``](https://spatialdata.scverse.org/projects/io/en/latest/generated/spatialdata_io.experimental.from_legacy_anndata.html)
function with some adjustments for ``HESTData``.

Args:
lazy_img (bool, optional): whenever to lazily load the image if not already loaded (e.g. self.wsi is of type OpenSlide or CuImage). Defaults to True.
fullres (bool, optional): Includes pyramidal full resolution whole slide image as a ``DataTree`` object for those dimensions compatible with
Image2DModel's downsampling. Defaults to False.

Returns:
SpatialData: scverse SpatialData object
SpatialData: scverse SpatialData oobject containing the ``hires`` and ``lowres`` downsampled versions
of the image and their respective coordinate systems.

Example:
```python
from hest import load_hest
hest_data = load_hest('../hest_data', id_list=['TENX68'])
st = hest_data[0]
st.to_spatial_data(fullres=True)

>>>

```
SpatialData object
├── Images
│ ├── 'ST_downscaled_hires_image': SpatialImage[cyx] (3, 4779, 2586)
│ ├── 'ST_downscaled_lowres_image': SpatialImage[cyx] (3, 1000, 541)
│ └── 'ST_fullres_image': DataTree[cyx] (3, 38232, 20690), (3, 19116, 10345)
├── Shapes
│ └── 'locations': GeoDataFrame shape: (1657, 2) (2D shapes)
└── Tables
└── 'table': AnnData (1657, 18085)
with coordinate systems:
▸ 'ST_downscaled_hires', with elements:
ST_downscaled_hires_image (Images), locations (Shapes)
▸ 'ST_downscaled_lowres', with elements:
ST_downscaled_lowres_image (Images), locations (Shapes)
▸ 'ST_fullres', with elements:
ST_fullres_image (Images), locations (Shapes)
```

"""

# imports specific to spatial data conversion
import dask.array as da
from dask import delayed
from dask.array import from_delayed
from spatialdata import SpatialData
from spatial_image import SpatialImage
from spatialdata.models import Image2DModel, ShapesModel, TableModel

def read_hest_wsi(wsi: WSI):
return wsi.numpy()

if lazy_img:
width, height = self.wsi.get_dimensions()
arr = from_delayed(delayed(read_hest_wsi)(self.wsi), shape=(height, width, 3), dtype=np.int8)
from spatialdata.transformations import Identity, Scale

# AnnData keys
SPATIAL = "spatial"
SCALEFACTORS = "scalefactors"
TISSUE_HIRES_SCALEF = "tissue_hires_scalef"
# TISSUE_LOWRES_SCALEF = "tissue_lowres_scalef"
TISSUE_LOWRES_SCALEF = "tissue_downscaled_fullres_scalef"
SPOT_DIAMETER_FULLRES = "spot_diameter_fullres"

IMAGES = "images"
# HIRES = "hires"
# LOWRES = "lowres"
HIRES = "fullres"
LOWRES = "downscaled_fullres"

# SpatialData keys
REGION = "locations"
REGION_KEY = "region"
INSTANCE_KEY = "instance_id"
SPOT_DIAMETER_FULLRES_DEFAULT = 10

images = {}
shapes = {}
spot_diameter_fullres_list = []
shapes_transformations = {}
if SPATIAL in self.adata.uns:
dataset_ids = list(self.adata.uns[SPATIAL].keys())
for dataset_id in dataset_ids:
# read the image data and the scale factors for the shapes
keys = set(self.adata.uns[SPATIAL][dataset_id].keys())
tissue_hires_scalef = None
tissue_lowres_scalef = None
hires = None
lowres = None
if SCALEFACTORS in keys:
scalefactors = self.adata.uns[SPATIAL][dataset_id][SCALEFACTORS]
if TISSUE_HIRES_SCALEF in scalefactors:
tissue_hires_scalef = scalefactors[TISSUE_HIRES_SCALEF]
else:
pixel_size=self.meta['pixel_size_um_estimated']
ds_factor = 4/pixel_size # proxy for visium hires scale factor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this part going to work for non-visium ST?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - the hires image here is an approximation for ~4 MPPs which is standard with Visium. I have tested it with non-visium slides and the same logic works and accurately downsamples the image, although the hires terminology is not normally given for these slides to my knowledge.

ds_level = self.wsi.get_best_level_for_downsample(ds_factor)
tissue_hires_scalef = 1/self.wsi.level_downsamples()[ds_level]

if TISSUE_LOWRES_SCALEF in scalefactors:
tissue_lowres_scalef = scalefactors[TISSUE_LOWRES_SCALEF]
if SPOT_DIAMETER_FULLRES in scalefactors:
spot_diameter_fullres_list.append(scalefactors[SPOT_DIAMETER_FULLRES])
if IMAGES in keys:
image_data = self.adata.uns[SPATIAL][dataset_id][IMAGES]
if HIRES in image_data:
hires = image_data[HIRES]
else:
# load wsi
def read_hest_wsi(wsi: WSI, width, height):
return wsi.get_thumbnail(width, height)


if fullres:
full_width, full_height = self.wsi.get_dimensions()
fullres = from_delayed(delayed(read_hest_wsi)(self.wsi, full_width, full_height), shape=(full_height, full_width, 3), dtype=np.int8)
else:
fullres=None
hires_width, hires_height = self.wsi.level_dimensions()[ds_level]
hires = from_delayed(delayed(read_hest_wsi)(self.wsi, hires_width, hires_height), shape=(hires_height, hires_width, 3), dtype=np.int8)

if LOWRES in image_data:
lowres = image_data[LOWRES]

# construct the spatialdata elements
if hires is not None:
# prepare the hires image
assert (
tissue_hires_scalef is not None
), "tissue_hires_scalef is required when an the hires image is present"
hires = hires.transpose(2, 0, 1)
hires_image = Image2DModel.parse(
hires,
dims=("c", "y", "x"),
# scale_factors=[int(ds_factor)],
transformations={f"{dataset_id}_downscaled_hires": Identity()}
)
hires_image = SpatialImage(hires_image, dims=("c", "y", "x"), name=f"{dataset_id}_downscaled_lowres_image")
images[f"{dataset_id}_downscaled_hires_image"] = hires_image

scale_hires = Scale([tissue_hires_scalef, tissue_hires_scalef], axes=("x", "y"))
shapes_transformations[f"{dataset_id}_downscaled_hires"] = scale_hires

if fullres is not None:
fullres = fullres.transpose(2, 0, 1)
fullres_image = Image2DModel.parse(
fullres,
dims=("c", "y", "x"),
scale_factors=[int(l) for l in self.wsi.level_downsamples()[1:] if full_height % l == 0 and full_width % l == 0],
transformations={f"{dataset_id}_fullres": Identity()}
)
images[f"{dataset_id}_fullres_image"] = fullres_image
scale_fullres = Scale([1, 1], axes=("x", "y"))
shapes_transformations[f"{dataset_id}_fullres"] = scale_fullres


if lowres is not None:
assert (
tissue_lowres_scalef is not None
), "tissue_lowres_scalef is required when an the lowres image is present"
lowres = lowres.transpose(2, 0, 1)
lowres_image = Image2DModel.parse(
lowres, dims=("c", "y", "x"), transformations={f"{dataset_id}_downscaled_lowres": Identity()}
)
lowres_image = SpatialImage(lowres_image, dims=("c", "y", "x"), name=f"{dataset_id}_downscaled_lowres")
images[f"{dataset_id}_downscaled_lowres_image"] = lowres_image

scale_lowres = Scale([tissue_lowres_scalef, tissue_lowres_scalef], axes=("x", "y"))
shapes_transformations[f"{dataset_id}_downscaled_lowres"] = scale_lowres

# validate the spot_diameter_fullres value
if len(spot_diameter_fullres_list) > 0:
d = np.array(spot_diameter_fullres_list)
if not np.allclose(d, d[0]):
warnings.warn(
"spot_diameter_fullres is not constant across datasets. Using the average value.",
UserWarning,
stacklevel=2,
)
spot_diameter_fullres = d.mean()
else:
spot_diameter_fullres = d[0]
else:
img = self.wsi.numpy()
arr = da.from_array(img)

parsed_image = Image2DModel.parse(arr, dims=("y", "x", "c"))

shape_validated = []
shape_names = []
for it in self.shapes:
shapes = it.shapes
if len(shapes) > 0 and isinstance(shapes.iloc[0], Point):
if 'radius' not in shapes.columns:
shapes['radius'] = 1

shape_validated.append(ShapesModel.parse(shapes))
shape_names.append(it.name)
warnings.warn(
f"spot_diameter_fullres is not present. Using {SPOT_DIAMETER_FULLRES_DEFAULT} as default value.",
UserWarning,
stacklevel=2,
)
spot_diameter_fullres = SPOT_DIAMETER_FULLRES_DEFAULT

# parse and prepare the shapes
if SPATIAL in self.adata.obsm:
xy = self.adata.obsm[SPATIAL]
radius = spot_diameter_fullres / 2
# breakpoint()
shapes[REGION] = ShapesModel.parse(xy, geometry=0, radius=radius, transformations=shapes_transformations)

# link the shapes to the table
new_table = self.adata.copy()
if TableModel.ATTRS_KEY in new_table.uns:
del new_table.uns[TableModel.ATTRS_KEY]
new_table.obs[REGION_KEY] = REGION
new_table.obs[REGION_KEY] = new_table.obs[REGION_KEY].astype("category")
new_table.obs[INSTANCE_KEY] = shapes[REGION].index.values
new_table = TableModel.parse(new_table, region=REGION, region_key=REGION_KEY, instance_key=INSTANCE_KEY)
else:
new_table = self.adata.copy()

if self._tissue_contours is not None:
shape_validated.append(ShapesModel.parse(self.tissue_contours))
shape_names.append('tissue_contours')

my_images = {"he": parsed_image}

self.adata.obs['instance'] = self.adata.obs.index
self.adata.obs['region'] = 'he'
parsed_adata = TableModel.parse(self.adata, instance_key='instance', region_key='region', region='he')
my_tables = {"anndata": parsed_adata}

st = SpatialData(images=my_images, tables=my_tables, shapes=dict(zip(shape_names, shape_validated)))

return st

return SpatialData(tables=new_table, images=images, shapes=shapes)

class VisiumHESTData(HESTData):
def __init__(self,
Expand Down
Loading