-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
HESTData: update to_spatial_data()
``
#63
Changes from 3 commits
761deee
689a1d8
239901a
82f0779
98d0acf
d7b6075
53e8ed4
58298b3
6db5a9b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,9 @@ | |
from hestcore.wsi import (WSI, CucimWarningSingleton, NumpyWSI, | ||
contours_to_img, wsi_factory) | ||
from loguru import logger | ||
from spatialdata import SpatialData | ||
|
||
|
||
|
||
from hest.io.seg_readers import TissueContourReader | ||
from hest.LazyShapes import LazyShapes, convert_old_to_gpd, old_geojson_to_new | ||
|
@@ -460,60 +463,211 @@ def save_tissue_vis(self, save_dir: str, name: str) -> None: | |
vis.save(os.path.join(save_dir, f'{name}_vis.jpg')) | ||
|
||
|
||
def to_spatial_data(self, lazy_img=True) -> SpatialData: # type: ignore | ||
"""Convert a HESTData sample to a scverse SpatialData object | ||
|
||
def to_spatial_data(self, fullres: bool = False) -> SpatialData: | ||
""" | ||
Convert a HESTData sample to a scverse SpatialData object. Note that a large part of this function is based on | ||
spatialdata-io's [``from_legacy_anndata``](https://spatialdata.scverse.org/projects/io/en/latest/generated/spatialdata_io.experimental.from_legacy_anndata.html) | ||
function with some adjustments for ``HESTData``. | ||
|
||
Args: | ||
lazy_img (bool, optional): whenever to lazily load the image if not already loaded (e.g. self.wsi is of type OpenSlide or CuImage). Defaults to True. | ||
fullres (bool, optional): Includes pyramidal full resolution whole slide image as a ``DataTree`` object for those dimensions compatible with | ||
Image2DModel's downsampling. Defaults to False. | ||
|
||
Returns: | ||
SpatialData: scverse SpatialData object | ||
SpatialData: scverse SpatialData oobject containing the ``hires`` and ``lowres`` downsampled versions | ||
of the image and their respective coordinate systems. | ||
|
||
Example: | ||
```python | ||
from hest import load_hest | ||
hest_data = load_hest('../hest_data', id_list=['TENX68']) | ||
st = hest_data[0] | ||
st.to_spatial_data(fullres=True) | ||
|
||
>>> | ||
|
||
``` | ||
SpatialData object | ||
├── Images | ||
│ ├── 'ST_downscaled_hires_image': SpatialImage[cyx] (3, 4779, 2586) | ||
│ ├── 'ST_downscaled_lowres_image': SpatialImage[cyx] (3, 1000, 541) | ||
│ └── 'ST_fullres_image': DataTree[cyx] (3, 38232, 20690), (3, 19116, 10345) | ||
├── Shapes | ||
│ └── 'locations': GeoDataFrame shape: (1657, 2) (2D shapes) | ||
└── Tables | ||
└── 'table': AnnData (1657, 18085) | ||
with coordinate systems: | ||
▸ 'ST_downscaled_hires', with elements: | ||
ST_downscaled_hires_image (Images), locations (Shapes) | ||
▸ 'ST_downscaled_lowres', with elements: | ||
ST_downscaled_lowres_image (Images), locations (Shapes) | ||
▸ 'ST_fullres', with elements: | ||
ST_fullres_image (Images), locations (Shapes) | ||
``` | ||
|
||
""" | ||
|
||
# imports specific to spatial data conversion | ||
import dask.array as da | ||
from dask import delayed | ||
from dask.array import from_delayed | ||
from spatialdata import SpatialData | ||
from spatial_image import SpatialImage | ||
from spatialdata.models import Image2DModel, ShapesModel, TableModel | ||
|
||
def read_hest_wsi(wsi: WSI): | ||
return wsi.numpy() | ||
|
||
if lazy_img: | ||
width, height = self.wsi.get_dimensions() | ||
arr = from_delayed(delayed(read_hest_wsi)(self.wsi), shape=(height, width, 3), dtype=np.int8) | ||
from spatialdata.transformations import Identity, Scale | ||
|
||
# AnnData keys | ||
SPATIAL = "spatial" | ||
SCALEFACTORS = "scalefactors" | ||
TISSUE_HIRES_SCALEF = "tissue_hires_scalef" | ||
# TISSUE_LOWRES_SCALEF = "tissue_lowres_scalef" | ||
TISSUE_LOWRES_SCALEF = "tissue_downscaled_fullres_scalef" | ||
SPOT_DIAMETER_FULLRES = "spot_diameter_fullres" | ||
|
||
IMAGES = "images" | ||
# HIRES = "hires" | ||
# LOWRES = "lowres" | ||
HIRES = "fullres" | ||
LOWRES = "downscaled_fullres" | ||
|
||
# SpatialData keys | ||
REGION = "locations" | ||
REGION_KEY = "region" | ||
INSTANCE_KEY = "instance_id" | ||
SPOT_DIAMETER_FULLRES_DEFAULT = 10 | ||
|
||
images = {} | ||
shapes = {} | ||
spot_diameter_fullres_list = [] | ||
shapes_transformations = {} | ||
if SPATIAL in self.adata.uns: | ||
dataset_ids = list(self.adata.uns[SPATIAL].keys()) | ||
for dataset_id in dataset_ids: | ||
# read the image data and the scale factors for the shapes | ||
keys = set(self.adata.uns[SPATIAL][dataset_id].keys()) | ||
tissue_hires_scalef = None | ||
tissue_lowres_scalef = None | ||
hires = None | ||
lowres = None | ||
if SCALEFACTORS in keys: | ||
scalefactors = self.adata.uns[SPATIAL][dataset_id][SCALEFACTORS] | ||
if TISSUE_HIRES_SCALEF in scalefactors: | ||
tissue_hires_scalef = scalefactors[TISSUE_HIRES_SCALEF] | ||
else: | ||
pixel_size=self.meta['pixel_size_um_estimated'] | ||
ds_factor = 4/pixel_size # proxy for visium hires scale factor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this part going to work for non-visium ST? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes - the |
||
ds_level = self.wsi.get_best_level_for_downsample(ds_factor) | ||
tissue_hires_scalef = 1/self.wsi.level_downsamples()[ds_level] | ||
|
||
if TISSUE_LOWRES_SCALEF in scalefactors: | ||
tissue_lowres_scalef = scalefactors[TISSUE_LOWRES_SCALEF] | ||
if SPOT_DIAMETER_FULLRES in scalefactors: | ||
spot_diameter_fullres_list.append(scalefactors[SPOT_DIAMETER_FULLRES]) | ||
if IMAGES in keys: | ||
image_data = self.adata.uns[SPATIAL][dataset_id][IMAGES] | ||
if HIRES in image_data: | ||
hires = image_data[HIRES] | ||
else: | ||
# load wsi | ||
def read_hest_wsi(wsi: WSI, width, height): | ||
return wsi.get_thumbnail(width, height) | ||
|
||
|
||
if fullres: | ||
full_width, full_height = self.wsi.get_dimensions() | ||
fullres = from_delayed(delayed(read_hest_wsi)(self.wsi, full_width, full_height), shape=(full_height, full_width, 3), dtype=np.int8) | ||
else: | ||
fullres=None | ||
hires_width, hires_height = self.wsi.level_dimensions()[ds_level] | ||
hires = from_delayed(delayed(read_hest_wsi)(self.wsi, hires_width, hires_height), shape=(hires_height, hires_width, 3), dtype=np.int8) | ||
|
||
if LOWRES in image_data: | ||
lowres = image_data[LOWRES] | ||
|
||
# construct the spatialdata elements | ||
if hires is not None: | ||
# prepare the hires image | ||
assert ( | ||
tissue_hires_scalef is not None | ||
), "tissue_hires_scalef is required when an the hires image is present" | ||
hires = hires.transpose(2, 0, 1) | ||
hires_image = Image2DModel.parse( | ||
hires, | ||
dims=("c", "y", "x"), | ||
# scale_factors=[int(ds_factor)], | ||
transformations={f"{dataset_id}_downscaled_hires": Identity()} | ||
) | ||
hires_image = SpatialImage(hires_image, dims=("c", "y", "x"), name=f"{dataset_id}_downscaled_lowres_image") | ||
images[f"{dataset_id}_downscaled_hires_image"] = hires_image | ||
|
||
scale_hires = Scale([tissue_hires_scalef, tissue_hires_scalef], axes=("x", "y")) | ||
shapes_transformations[f"{dataset_id}_downscaled_hires"] = scale_hires | ||
|
||
if fullres is not None: | ||
fullres = fullres.transpose(2, 0, 1) | ||
fullres_image = Image2DModel.parse( | ||
fullres, | ||
dims=("c", "y", "x"), | ||
scale_factors=[int(l) for l in self.wsi.level_downsamples()[1:] if full_height % l == 0 and full_width % l == 0], | ||
transformations={f"{dataset_id}_fullres": Identity()} | ||
) | ||
images[f"{dataset_id}_fullres_image"] = fullres_image | ||
scale_fullres = Scale([1, 1], axes=("x", "y")) | ||
shapes_transformations[f"{dataset_id}_fullres"] = scale_fullres | ||
|
||
|
||
if lowres is not None: | ||
assert ( | ||
tissue_lowres_scalef is not None | ||
), "tissue_lowres_scalef is required when an the lowres image is present" | ||
lowres = lowres.transpose(2, 0, 1) | ||
lowres_image = Image2DModel.parse( | ||
lowres, dims=("c", "y", "x"), transformations={f"{dataset_id}_downscaled_lowres": Identity()} | ||
) | ||
lowres_image = SpatialImage(lowres_image, dims=("c", "y", "x"), name=f"{dataset_id}_downscaled_lowres") | ||
images[f"{dataset_id}_downscaled_lowres_image"] = lowres_image | ||
|
||
scale_lowres = Scale([tissue_lowres_scalef, tissue_lowres_scalef], axes=("x", "y")) | ||
shapes_transformations[f"{dataset_id}_downscaled_lowres"] = scale_lowres | ||
|
||
# validate the spot_diameter_fullres value | ||
if len(spot_diameter_fullres_list) > 0: | ||
d = np.array(spot_diameter_fullres_list) | ||
if not np.allclose(d, d[0]): | ||
warnings.warn( | ||
"spot_diameter_fullres is not constant across datasets. Using the average value.", | ||
UserWarning, | ||
stacklevel=2, | ||
) | ||
spot_diameter_fullres = d.mean() | ||
else: | ||
spot_diameter_fullres = d[0] | ||
else: | ||
img = self.wsi.numpy() | ||
arr = da.from_array(img) | ||
|
||
parsed_image = Image2DModel.parse(arr, dims=("y", "x", "c")) | ||
|
||
shape_validated = [] | ||
shape_names = [] | ||
for it in self.shapes: | ||
shapes = it.shapes | ||
if len(shapes) > 0 and isinstance(shapes.iloc[0], Point): | ||
if 'radius' not in shapes.columns: | ||
shapes['radius'] = 1 | ||
|
||
shape_validated.append(ShapesModel.parse(shapes)) | ||
shape_names.append(it.name) | ||
warnings.warn( | ||
f"spot_diameter_fullres is not present. Using {SPOT_DIAMETER_FULLRES_DEFAULT} as default value.", | ||
UserWarning, | ||
stacklevel=2, | ||
) | ||
spot_diameter_fullres = SPOT_DIAMETER_FULLRES_DEFAULT | ||
|
||
# parse and prepare the shapes | ||
if SPATIAL in self.adata.obsm: | ||
xy = self.adata.obsm[SPATIAL] | ||
radius = spot_diameter_fullres / 2 | ||
# breakpoint() | ||
shapes[REGION] = ShapesModel.parse(xy, geometry=0, radius=radius, transformations=shapes_transformations) | ||
|
||
# link the shapes to the table | ||
new_table = self.adata.copy() | ||
if TableModel.ATTRS_KEY in new_table.uns: | ||
del new_table.uns[TableModel.ATTRS_KEY] | ||
new_table.obs[REGION_KEY] = REGION | ||
new_table.obs[REGION_KEY] = new_table.obs[REGION_KEY].astype("category") | ||
new_table.obs[INSTANCE_KEY] = shapes[REGION].index.values | ||
new_table = TableModel.parse(new_table, region=REGION, region_key=REGION_KEY, instance_key=INSTANCE_KEY) | ||
else: | ||
new_table = self.adata.copy() | ||
|
||
if self._tissue_contours is not None: | ||
shape_validated.append(ShapesModel.parse(self.tissue_contours)) | ||
shape_names.append('tissue_contours') | ||
|
||
my_images = {"he": parsed_image} | ||
|
||
self.adata.obs['instance'] = self.adata.obs.index | ||
self.adata.obs['region'] = 'he' | ||
parsed_adata = TableModel.parse(self.adata, instance_key='instance', region_key='region', region='he') | ||
my_tables = {"anndata": parsed_adata} | ||
|
||
st = SpatialData(images=my_images, tables=my_tables, shapes=dict(zip(shape_names, shape_validated))) | ||
|
||
return st | ||
|
||
return SpatialData(tables=new_table, images=images, shapes=shapes) | ||
|
||
class VisiumHESTData(HESTData): | ||
def __init__(self, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to move this import line inside the to_spatial_data() method, the SpatialData is huge
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved in 98d0acf - I had it outside to enable the type hint in the function, but appreciate that it comes with a lot of overhead