Skip to content

Commit

Permalink
add dask suppport and valis alignment (#69)
Browse files Browse the repository at this point in the history
- add dask support for loading larger than RAM transcript dataframes
- refactor alignment methods (decouple alignment matrix loading from alignment)
- add Valis registration support to the processinig pipeline
- cleanup of unused functions/impots
- hestcore 1.0.3 -> 1.0.4 (better documentation)
  • Loading branch information
pauldoucet authored Nov 5, 2024
1 parent 239b506 commit b87046f
Show file tree
Hide file tree
Showing 19 changed files with 3,018 additions and 688 deletions.
81 changes: 23 additions & 58 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,58 +1,23 @@
__pycache__
data
align_coord_img.ipynb
src/hiSTloader/yolov8n.pt
.vscode/launch.json
dist
src/hest.egg-info
make.bat
Makefile
*.parquet
bench_data
cell_seg
filtered
src/hest/bench/timm_ctp
my_notebooks
.gitattributes
cells_xenium.geojson
nuclei_xenium.geojson
nuclei.tif
cells.tif
src_slides
data64.h5
tests/assets
config
tests/output_tests
tissue_seg

results
atlas
figures/test_paul
bench_data.zip
old_bench_data
tutorials/downloads
tutorials/processed
bench_config/my_bench_config.yaml
src/hest/bench/private
str.csv
bench_data_old
ST_data_emb/
ST_pred_results/
hest_data
fm_v1
cufile.log
int.csv
docs/build
docs/source/generated
local
hest_vis
hest_vis2
hest_vis
vis
vis2
models/deeplabv3*
htmlcov
models/CellViT-SAM-H-x40.pth
debug_seg
replace_seg
test_vis
data
.vscode/launch.json
dist
src/hest.egg-info
bench_data
.gitattributes
tests/assets
config
tests/output_tests
HEST/

results
atlas
ST_data_emb/
ST_pred_results/
hest_data
fm_v1
docs/build
docs/source/generated
local
models/deeplabv3*
htmlcov
models/CellViT-SAM-H-x40.pth
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [
"spatial_image >= 0.3.0",
"datasets",
"mygene",
"hestcore == 1.0.3"
"hestcore == 1.0.4"
]

requires-python = ">=3.9"
Expand Down
101 changes: 69 additions & 32 deletions src/hest/HESTData.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@
import os
import shutil
import warnings
from typing import Dict, Iterator, List, Union
from typing import Dict, List, Union

import cv2
import geopandas as gpd
import numpy as np
from loguru import logger
from hestcore.wsi import (WSI, CucimWarningSingleton, NumpyWSI,
contours_to_img, wsi_factory)
from loguru import logger



from hest.io.seg_readers import TissueContourReader
from hest.io.seg_readers import TissueContourReader, write_geojson
from hest.LazyShapes import LazyShapes, convert_old_to_gpd, old_geojson_to_new
from hest.segmentation.TissueMask import TissueMask, load_tissue_mask

Expand All @@ -31,7 +30,7 @@
from tqdm import tqdm

from .utils import (ALIGNED_HE_FILENAME, check_arg, deprecated,
find_first_file_endswith, get_path_from_meta_row,
find_first_file_endswith, get_k_genes_from_df, get_path_from_meta_row,
plot_verify_pixel_size, tiff_save, verify_paths)


Expand Down Expand Up @@ -100,7 +99,7 @@ class representing a single ST profile + its associated WSI image
else:
self._tissue_contours = tissue_contours

if 'total_counts' not in self.adata.var_names:
if 'total_counts' not in self.adata.var_names and len(self.adata) > 0:
sc.pp.calculate_qc_metrics(self.adata, inplace=True)


Expand Down Expand Up @@ -133,7 +132,7 @@ def load_wsi(self) -> None:
self.wsi = NumpyWSI(self.wsi.numpy())


def save(self, path: str, save_img=True, pyramidal=True, bigtiff=False, plot_pxl_size=False):
def save(self, path: str, save_img=True, pyramidal=True, bigtiff=False, plot_pxl_size=False, **kwargs):
"""Save a HESTData object to `path` as follows:
- aligned_adata.h5ad (contains expressions for each spots + their location on the fullres image + a downscaled version of the fullres image)
- metrics.json (contains useful metrics)
Expand All @@ -155,6 +154,8 @@ def save(self, path: str, save_img=True, pyramidal=True, bigtiff=False, plot_pxl
self.adata.write(os.path.join(path, 'aligned_adata.h5ad'))
except:
# workaround from https://github.com/theislab/scvelo/issues/255
import traceback
traceback.print_exc()
self.adata.__dict__['_raw'].__dict__['_var'] = self.adata.__dict__['_raw'].__dict__['_var'].rename(columns={'_index': 'features'})
self.adata.write(os.path.join(path, 'aligned_adata.h5ad'))

Expand All @@ -172,7 +173,8 @@ def save(self, path: str, save_img=True, pyramidal=True, bigtiff=False, plot_pxl
downscaled_img = self.adata.uns['spatial']['ST']['images']['downscaled_fullres']
down_fact = self.adata.uns['spatial']['ST']['scalefactors']['tissue_downscaled_fullres_scalef']
down_img = Image.fromarray(downscaled_img)
down_img.save(os.path.join(path, 'downscaled_fullres.jpeg'))
if len(downscaled_img) > 0:
down_img.save(os.path.join(path, 'downscaled_fullres.jpeg'))


if plot_pxl_size:
Expand Down Expand Up @@ -748,7 +750,9 @@ def __init__(
xenium_nuc_seg: pd.DataFrame=None,
xenium_cell_seg: pd.DataFrame=None,
cell_adata: sc.AnnData=None, # type: ignore
transcript_df: pd.DataFrame=None
transcript_df: pd.DataFrame=None,
dapi_path: str=None,
alignment_file_path: str=None
):
"""
class representing a single ST profile + its associated WSI image
Expand All @@ -765,16 +769,31 @@ class representing a single ST profile + its associated WSI image
xenium_cell_seg (pd.DataFrame): content of a xenium cell contour file as a dataframe (cell_boundaries.parquet)
cell_adata (sc.AnnData): ST cell data, each row in adata.obs is a cell, each row in obsm is the cell location on the H&E image in pixels
transcript_df (pd.DataFrame): dataframe of transcripts, each row is a transcript, he_x and he_y is the transcript location on the H&E image in pixels
dapi_path (str): path to a dapi focus image
alignment_file_path (np.ndarray): path to xenium alignment path
"""
super().__init__(adata=adata, img=img, pixel_size=pixel_size, meta=meta, tissue_seg=tissue_seg, tissue_contours=tissue_contours, shapes=shapes)

self.xenium_nuc_seg = xenium_nuc_seg
self.xenium_cell_seg = xenium_cell_seg
self.cell_adata = cell_adata
self.transcript_df = transcript_df


def save(self, path: str, save_img=True, pyramidal=True, bigtiff=False, plot_pxl_size=False):
self.dapi_path = dapi_path
self.alignment_file_path = alignment_file_path


def save(
self,
path: str,
save_img=True,
pyramidal=True,
bigtiff=False,
plot_pxl_size=False,
save_transcripts=False,
save_cell_seg=False,
save_nuclei_seg=False,
**kwargs
):
"""Save a HESTData object to `path` as follows:
- aligned_adata.h5ad (contains expressions for each spots + their location on the fullres image + a downscaled version of the fullres image)
- metrics.json (contains useful metrics)
Expand All @@ -795,21 +814,18 @@ def save(self, path: str, save_img=True, pyramidal=True, bigtiff=False, plot_pxl
if self.cell_adata is not None:
self.cell_adata.write_h5ad(os.path.join(path, 'aligned_cells.h5ad'))

if self.transcript_df is not None:
if save_transcripts and self.transcript_df is not None:
self.transcript_df.to_parquet(os.path.join(path, 'aligned_transcripts.parquet'))

if save_cell_seg:
he_cells = self.get_shapes('tenx_cell', 'he').shapes
he_cells.to_parquet(os.path.join(path, 'he_cell_seg.parquet'))
write_geojson(he_cells, os.path.join(path, f'he_cell_seg.geojson'), '', chunk=True)

if self.xenium_nuc_seg is not None:
print('Saving Xenium nucleus boundaries... (can be slow)')
with open(os.path.join(path, 'nuclei_xenium.geojson'), 'w') as f:
json.dump(self.xenium_nuc_seg, f, indent=4)

if self.xenium_cell_seg is not None:
print('Saving Xenium cells boundaries... (can be slow)')
with open(os.path.join(path, 'cells_xenium.geojson'), 'w') as f:
json.dump(self.xenium_cell_seg, f, indent=4)


# TODO save segmentation
if save_nuclei_seg:
he_nuclei = self.get_shapes('tenx_nucleus', 'he').shapes
he_nuclei.to_parquet(os.path.join(path, 'he_nucleus_seg.parquet'))
write_geojson(he_nuclei, os.path.join(path, f'he_nucleus_seg.geojson'), '', chunk=True)


def read_HESTData(
Expand Down Expand Up @@ -936,19 +952,33 @@ def mask_and_patchify_bench(meta_df: pd.DataFrame, save_dir: str, use_mask=True,
i += 1


def create_benchmark_data(meta_df, save_dir:str, K, adata_folder, use_mask, keep_largest=None):
def create_benchmark_data(meta_df, save_dir:str, K):
os.makedirs(save_dir, exist_ok=True)
if K is not None:
splits = meta_df.groupby('patient')['id'].agg(list).to_dict()
create_splits(os.path.join(save_dir, 'splits'), splits, K=K)

meta_df['patient'] = meta_df['patient'].fillna('Patient 1')

get_k_genes_from_df(meta_df, 50, 'var', os.path.join(save_dir, 'var_50genes.json'))

splits = meta_df.groupby(['dataset_title', 'patient'])['id'].agg(list).to_dict()
create_splits(os.path.join(save_dir, 'splits'), splits, K=K)

os.makedirs(os.path.join(save_dir, 'patches'), exist_ok=True)
mask_and_patchify_bench(meta_df, os.path.join(save_dir, 'patches'), use_mask=use_mask, keep_largest=keep_largest)
#mask_and_patchify_bench(meta_df, os.path.join(save_dir, 'patches'), use_mask=use_mask, keep_largest=keep_largest)

os.makedirs(os.path.join(save_dir, 'patches_vis'), exist_ok=True)
os.makedirs(os.path.join(save_dir, 'adata'), exist_ok=True)
for index, row in meta_df.iterrows():
for _, row in meta_df.iterrows():
id = row['id']
src_adata = os.path.join(adata_folder, id + '.h5ad')
path = os.path.join(get_path_from_meta_row(row), 'processed')
src_patch = os.path.join(path, 'patches.h5')
dst_patch = os.path.join(save_dir, 'patches', id + '.h5')
shutil.copy(src_patch, dst_patch)

src_vis = os.path.join(path, 'patches_patch_vis.png')
dst_vis = os.path.join(save_dir, 'patches_vis', id + '.png')
shutil.copy(src_vis, dst_vis)

src_adata = os.path.join(path, 'aligned_adata.h5ad')
dst_adata = os.path.join(save_dir, 'adata', id + '.h5ad')
shutil.copy(src_adata, dst_adata)

Expand Down Expand Up @@ -1200,6 +1230,13 @@ def unify_gene_names(adata: sc.AnnData, species="human", drop=False) -> sc.AnnDa
mask = ~adata.var_names.duplicated(keep='first')
adata = adata[:, mask]

duplicated_genes_after = adata.var_names[adata.var_names.duplicated()]
if len(duplicated_genes_after) > len(duplicated_genes_before):
logger.warning(f"duplicated genes increased from {len(duplicated_genes_before)} to {len(duplicated_genes_after)} after resolving aliases")
logger.info('deduplicating...')
mask = ~adata.var_names.duplicated(keep='first')
adata = adata[:, mask]

if drop:
adata = adata[:, ~remaining]

Expand Down
12 changes: 9 additions & 3 deletions src/hest/LazyShapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,30 @@
import pandas as pd
from shapely import Polygon

from hest.io.seg_readers import read_gdf
from hest.io.seg_readers import GDFReader, read_gdf
from hest.utils import verify_paths


class LazyShapes:

path: str = None

def __init__(self, path: str, name: str, coordinate_system: str):
def __init__(self, path: str, name: str, coordinate_system: str, reader: GDFReader=None, reader_kwargs = {}):
verify_paths([path])
self.path = path
self.name = name
self.coordinate_system = coordinate_system
self._shapes = None
self.reader_kwargs = reader_kwargs
self.reader = reader

def compute(self) -> None:
if self._shapes is None:
self._shapes = read_gdf(self.path)
if self.reader is None:
self._shapes = read_gdf(self.path, self.reader_kwargs)
else:
self._shapes = self.reader(**self.reader_kwargs).read_gdf(self.path)


@property
def shapes(self) -> gpd.GeoDataFrame:
Expand Down
46 changes: 46 additions & 0 deletions src/hest/SlideReaderAdapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Slide Adapter class for Valis compatibility
import os

import numpy as np
from valis import slide_tools
from valis.slide_io import PIXEL_UNIT, MetaData, SlideReader

from hestcore.wsi import wsi_factory


class SlideReaderAdapter(SlideReader):
def __init__(self, src_f, *args, **kwargs):
super().__init__(src_f, *args, **kwargs)
self.wsi = wsi_factory(src_f)
self.metadata = self.create_metadata()

def create_metadata(self):
meta_name = f"{os.path.split(self.src_f)[1]}_Series(0)".strip("_")
slide_meta = MetaData(meta_name, 'SlideReaderAdapter')

slide_meta.is_rgb = True
slide_meta.channel_names = self._get_channel_names('NO_NAME')
slide_meta.n_channels = 1
slide_meta.pixel_physical_size_xyu = [0.25, 0.25, PIXEL_UNIT]
level_dim = self.wsi.level_dimensions() #self._get_slide_dimensions()
slide_meta.slide_dimensions = np.array([list(item) for item in level_dim])

return slide_meta

def slide2vips(self, level, xywh=None, *args, **kwargs):
img = self.slide2image(level, xywh=xywh, *args, **kwargs)
vips_img = slide_tools.numpy2vips(img)

return vips_img

def slide2image(self, level, xywh=None, *args, **kwargs):
level_dim = self.wsi.level_dimensions()[level]
img = self.wsi.get_thumbnail(level_dim[0], level_dim[1])

if xywh is not None:
xywh = np.array(xywh)
start_c, start_r = xywh[0:2]
end_c, end_r = xywh[0:2] + xywh[2:]
img = img[start_r:end_r, start_c:end_c]

return img
2 changes: 1 addition & 1 deletion src/hest/bench/st_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ def load_adata(expr_path, genes = None, barcodes = None, normalize=False):
adata = adata[:, genes]
if normalize:
adata = normalize_adata(adata)
return adata.to_df()
return adata.to_df()
Loading

0 comments on commit b87046f

Please sign in to comment.