From a9fe4d1d4be13cf287b58539268d8b412ccf1305 Mon Sep 17 00:00:00 2001 From: felixpetschko <48593591+felixpetschko@users.noreply.github.com> Date: Wed, 6 Nov 2024 20:34:55 +0100 Subject: [PATCH] Speed up anndata writing speed after define_clonotype_clusters (#556) * convert cell_indices str->array dict to a csr matrix before storing the anndata result * save cell_indices as json format * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * removed unused conversion function * Update CHANGELOG --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Gregor Sturm --- CHANGELOG.md | 8 ++++++++ src/scirpy/ir_dist/_clonotype_neighbors.py | 2 +- src/scirpy/pl/_clonotypes.py | 4 +++- src/scirpy/tests/test_ir_query.py | 9 +++++++-- src/scirpy/tl/_clonotypes.py | 16 ++++++++-------- src/scirpy/tl/_ir_query.py | 17 ++++++++++------- src/scirpy/util/__init__.py | 21 ++++++++++++++++++++- 7 files changed, 57 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 436c7dc0c..0436fcf8d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,9 +10,17 @@ and this project adheres to [Semantic Versioning][]. ## [Unreleased] +### Backwards-incompatible changes + +- The format of storing the results of `tl.define_clonotypes`/`tl.define_clonotype_clusters` in `adata.uns` has changed. + Older versions of Scirpy won't be able to run downstream functions (e.g. `tl.clonotype_network`) on AnnData objects + created with Scirpy v0.20 or later. This change was necessary to speed up writing results to `h5ad` when working + with large datasets ([#556](https://github.com/scverse/scirpy/pull/556)). + ### Documentation - Add a tutorial for BCR analysis with Scirpy ([#542](https://github.com/scverse/scirpy/pull/542)). +- Fix typo in `pp.index_chains` methods description ([#570](https://github.com/scverse/scirpy/pull/570)) ## v0.19.0 diff --git a/src/scirpy/ir_dist/_clonotype_neighbors.py b/src/scirpy/ir_dist/_clonotype_neighbors.py index c948c617e..e615734da 100644 --- a/src/scirpy/ir_dist/_clonotype_neighbors.py +++ b/src/scirpy/ir_dist/_clonotype_neighbors.py @@ -118,7 +118,7 @@ def _make_clonotype_table(self, params: DataHandler) -> tuple[Mapping, pd.DataFr ct_tuple[0] if len(ct_tuple) == 1 else ct_tuple, [], ) - ].values + ].values.tolist() for i, ct_tuple in enumerate(clonotypes.itertuples(index=False, name=None)) } diff --git a/src/scirpy/pl/_clonotypes.py b/src/scirpy/pl/_clonotypes.py index da493921c..5dc2a9b48 100644 --- a/src/scirpy/pl/_clonotypes.py +++ b/src/scirpy/pl/_clonotypes.py @@ -22,7 +22,7 @@ from scipy.sparse import issparse from scirpy.tl._clonotypes import _doc_clonotype_network, _graph_from_coordinates -from scirpy.util import DataHandler +from scirpy.util import DataHandler, read_cell_indices from scirpy.util.graph import _distance_to_connectivity from .styling import _get_colors, _init_ax @@ -413,6 +413,8 @@ def _plot_clonotype_network_panel( scale_by_n_cells, color_by_n_cells, ): + cell_indices = read_cell_indices(cell_indices) + colorbar_title = "mean per dot" pie_colors = None cat_colors = None diff --git a/src/scirpy/tests/test_ir_query.py b/src/scirpy/tests/test_ir_query.py index 2713f2718..7c33e69cf 100644 --- a/src/scirpy/tests/test_ir_query.py +++ b/src/scirpy/tests/test_ir_query.py @@ -14,6 +14,7 @@ ir_query_annotate, ir_query_annotate_df, ) +from scirpy.util import read_cell_indices @pytest.mark.parametrize("metric", ["identity", "levenshtein"]) @@ -32,9 +33,13 @@ def test_ir_query(adata_cdr3, adata_cdr3_2, metric, key1, key2): tmp_key2 = f"ir_query_TESTDB_aa_{metric}" if key2 is None else key2 tmp_ad = adata_cdr3.mod["airr"] if isinstance(adata_cdr3, MuData) else adata_cdr3 + + cell_indices = read_cell_indices(tmp_ad.uns[tmp_key2]["cell_indices"]) + cell_indices_reference = read_cell_indices(tmp_ad.uns[tmp_key2]["cell_indices_reference"]) + assert tmp_ad.uns[tmp_key2]["distances"].shape == (4, 3) - assert len(tmp_ad.uns[tmp_key2]["cell_indices"]) == 4 - assert len(tmp_ad.uns[tmp_key2]["cell_indices_reference"]) == 3 + assert len(cell_indices) == 4 + assert len(cell_indices_reference) == 3 @pytest.mark.parametrize( diff --git a/src/scirpy/tl/_clonotypes.py b/src/scirpy/tl/_clonotypes.py index f76a0c934..750734f19 100644 --- a/src/scirpy/tl/_clonotypes.py +++ b/src/scirpy/tl/_clonotypes.py @@ -1,4 +1,5 @@ import itertools +import json import random from collections.abc import Sequence from typing import Literal, cast @@ -13,7 +14,7 @@ from scirpy.ir_dist import MetricType, _get_metric_key from scirpy.ir_dist._clonotype_neighbors import ClonotypeNeighbors from scirpy.pp import ir_dist -from scirpy.util import DataHandler +from scirpy.util import DataHandler, read_cell_indices from scirpy.util.graph import igraph_from_sparse_matrix, layout_components _common_doc = """\ @@ -89,7 +90,7 @@ A dictionary containing * `distances`: A sparse, pairwise distance matrix between unique receptor configurations - * `cell_indices`: A dict of arrays, containing the `adata.obs_names` + * `cell_indices`: A dict of lists, containing the `adata.obs_names` (cell indices) for each row in the distance matrix. If `inplace` is `True`, this is added to `adata.uns[key_added]`. @@ -335,7 +336,7 @@ def define_clonotype_clusters( # Return or store results clonotype_distance_res = { "distances": clonotype_dist, - "cell_indices": ctn.cell_indices, + "cell_indices": json.dumps(ctn.cell_indices), } if inplace: params.set_obs(key_added, clonotype_cluster_series) @@ -533,7 +534,7 @@ def clonotype_network( # explicitly annotate node ids to keep them after subsetting graph.vs["node_id"] = np.arange(0, len(graph.vs)) - cell_indices = clonotype_res["cell_indices"] + cell_indices = read_cell_indices(clonotype_res["cell_indices"]) # store size in graph to be accessed by layout algorithms clonotype_size = np.array([len(idx) for idx in cell_indices.values()]) @@ -603,7 +604,7 @@ def clonotype_network( # Expand to cell coordinates to store in adata.obsm idx, coords = zip( *itertools.chain.from_iterable( - zip(clonotype_res["cell_indices"][str(node_id)], itertools.repeat(coord)) + zip(cell_indices[str(node_id)], itertools.repeat(coord)) for node_id, coord in zip(graph.vs["node_id"], coords, strict=False) # type: ignore ), strict=False, @@ -631,10 +632,9 @@ def _graph_from_coordinates(adata: AnnData, clonotype_key: str, basis: str) -> t """ clonotype_res = adata.uns[clonotype_key] # map the cell-id to the corresponding row/col in the clonotype distance matrix + cell_indices = read_cell_indices(clonotype_res["cell_indices"]) dist_idx, obs_names = zip( - *itertools.chain.from_iterable( - zip(itertools.repeat(i), obs_names) for i, obs_names in clonotype_res["cell_indices"].items() - ), + *itertools.chain.from_iterable(zip(itertools.repeat(i), obs_names) for i, obs_names in cell_indices.items()), strict=False, ) dist_idx_lookup = pd.DataFrame(index=obs_names, data=dist_idx, columns=["dist_idx"]) diff --git a/src/scirpy/tl/_ir_query.py b/src/scirpy/tl/_ir_query.py index 72a45679f..89bf5c975 100644 --- a/src/scirpy/tl/_ir_query.py +++ b/src/scirpy/tl/_ir_query.py @@ -10,7 +10,7 @@ from scirpy.ir_dist import MetricType, _get_metric_key from scirpy.ir_dist._clonotype_neighbors import ClonotypeNeighbors -from scirpy.util import DataHandler, _is_na, tqdm +from scirpy.util import DataHandler, _is_na, read_cell_indices, tqdm from ._clonotypes import _common_doc, _common_doc_parallelism, _doc_clonotype_definition, _validate_parameters @@ -166,9 +166,9 @@ def ir_query( A dictionary containing * `distances`: A sparse distance matrix between unique receptor configurations in `adata` aund unique receptor configurations in `reference`. - * `cell_indices`: A dict of arrays, containing the the `adata.obs_names` + * `cell_indices`: A dict of lists, containing the the `adata.obs_names` (cell indices) for each row in the distance matrix. - * `cell_indices_reference`: A dict of arrays, containing the `reference.obs_names` + * `cell_indices_reference`: A dict of lists, containing the `reference.obs_names` for each column in the distance matrix. If `inplace` is `True`, this is added to `adata.uns[key_added]`. @@ -206,8 +206,8 @@ def ir_query( # Return or store results clonotype_distance_res = { "distances": clonotype_dist, - "cell_indices": ctn.cell_indices, - "cell_indices_reference": ctn.cell_indices2, + "cell_indices": json.dumps(ctn.cell_indices), + "cell_indices_reference": json.dumps(ctn.cell_indices2), } if inplace: params.adata.uns[key_added] = clonotype_distance_res @@ -284,10 +284,13 @@ def ir_query_annotate_df( res = params.adata.uns[query_key] dist = res["distances"] + cell_indices = read_cell_indices(res["cell_indices"]) + cell_indices_reference = read_cell_indices(res["cell_indices_reference"]) + def get_pairs(): - for i, query_cells in res["cell_indices"].items(): + for i, query_cells in cell_indices.items(): reference_cells = itertools.chain.from_iterable( - res["cell_indices_reference"][str(k)] for k in dist[int(i), :].indices + cell_indices_reference[str(k)] for k in dist[int(i), :].indices ) yield from itertools.product(query_cells, reference_cells) diff --git a/src/scirpy/util/__init__.py b/src/scirpy/util/__init__.py index b46e0d336..d9924c60f 100644 --- a/src/scirpy/util/__init__.py +++ b/src/scirpy/util/__init__.py @@ -1,9 +1,10 @@ import contextlib +import json import os import warnings from collections.abc import Callable, Mapping, Sequence from textwrap import dedent -from typing import Any, Optional, Union, cast, overload +from typing import Any, Literal, Optional, Union, cast, overload import awkward as ak import numpy as np @@ -605,3 +606,21 @@ def _get_usable_cpus(n_jobs: int = 0, use_numba: bool = False): usable_cpus = min(usable_cpus, config.NUMBA_NUM_THREADS) return usable_cpus + + +def read_cell_indices(cell_indices: dict[str, np.ndarray[str]] | str) -> dict[str, list[str]]: + """ + The datatype of the cell_indices Mapping (clonotype_id -> cell_ids) that is stored to the anndata.uns + attribute after the ´define_clonotype_clusters´ function has changed from dict[str, np.ndarray[str] to + str (json) due to performance considerations regarding the writing speed of the anndata object. But we still + want that older anndata objects with the dict[str, np.ndarray[str] datatype can be used. So we use this function + to read the cell_indices from the anndata object to support both formats. + """ + if isinstance(cell_indices, str): # new format + return json.loads(cell_indices) + elif isinstance(cell_indices, dict): # old format + return {k: v.tolist() for k, v in cell_indices.items()} + else: # unsupported format + raise TypeError( + f"Unsupported type for cell_indices: {type(cell_indices)}. Expected str (json) or dict[str, np.ndarray[str]]." + )