Skip to content

Commit

Permalink
Speed up anndata writing speed after define_clonotype_clusters (#556)
Browse files Browse the repository at this point in the history
* convert cell_indices str->array dict to a csr matrix before storing the anndata result

* save cell_indices as json format

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* removed unused conversion function

* Update CHANGELOG

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Gregor Sturm <[email protected]>
  • Loading branch information
3 people authored Nov 6, 2024
1 parent 339609d commit a9fe4d1
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 20 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,17 @@ and this project adheres to [Semantic Versioning][].

## [Unreleased]

### Backwards-incompatible changes

- The format of storing the results of `tl.define_clonotypes`/`tl.define_clonotype_clusters` in `adata.uns` has changed.
Older versions of Scirpy won't be able to run downstream functions (e.g. `tl.clonotype_network`) on AnnData objects
created with Scirpy v0.20 or later. This change was necessary to speed up writing results to `h5ad` when working
with large datasets ([#556](https://github.com/scverse/scirpy/pull/556)).

### Documentation

- Add a tutorial for BCR analysis with Scirpy ([#542](https://github.com/scverse/scirpy/pull/542)).
- Fix typo in `pp.index_chains` methods description ([#570](https://github.com/scverse/scirpy/pull/570))

## v0.19.0

Expand Down
2 changes: 1 addition & 1 deletion src/scirpy/ir_dist/_clonotype_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _make_clonotype_table(self, params: DataHandler) -> tuple[Mapping, pd.DataFr
ct_tuple[0] if len(ct_tuple) == 1 else ct_tuple,
[],
)
].values
].values.tolist()
for i, ct_tuple in enumerate(clonotypes.itertuples(index=False, name=None))
}

Expand Down
4 changes: 3 additions & 1 deletion src/scirpy/pl/_clonotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from scipy.sparse import issparse

from scirpy.tl._clonotypes import _doc_clonotype_network, _graph_from_coordinates
from scirpy.util import DataHandler
from scirpy.util import DataHandler, read_cell_indices
from scirpy.util.graph import _distance_to_connectivity

from .styling import _get_colors, _init_ax
Expand Down Expand Up @@ -413,6 +413,8 @@ def _plot_clonotype_network_panel(
scale_by_n_cells,
color_by_n_cells,
):
cell_indices = read_cell_indices(cell_indices)

colorbar_title = "mean per dot"
pie_colors = None
cat_colors = None
Expand Down
9 changes: 7 additions & 2 deletions src/scirpy/tests/test_ir_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ir_query_annotate,
ir_query_annotate_df,
)
from scirpy.util import read_cell_indices


@pytest.mark.parametrize("metric", ["identity", "levenshtein"])
Expand All @@ -32,9 +33,13 @@ def test_ir_query(adata_cdr3, adata_cdr3_2, metric, key1, key2):

tmp_key2 = f"ir_query_TESTDB_aa_{metric}" if key2 is None else key2
tmp_ad = adata_cdr3.mod["airr"] if isinstance(adata_cdr3, MuData) else adata_cdr3

cell_indices = read_cell_indices(tmp_ad.uns[tmp_key2]["cell_indices"])
cell_indices_reference = read_cell_indices(tmp_ad.uns[tmp_key2]["cell_indices_reference"])

assert tmp_ad.uns[tmp_key2]["distances"].shape == (4, 3)
assert len(tmp_ad.uns[tmp_key2]["cell_indices"]) == 4
assert len(tmp_ad.uns[tmp_key2]["cell_indices_reference"]) == 3
assert len(cell_indices) == 4
assert len(cell_indices_reference) == 3


@pytest.mark.parametrize(
Expand Down
16 changes: 8 additions & 8 deletions src/scirpy/tl/_clonotypes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import json
import random
from collections.abc import Sequence
from typing import Literal, cast
Expand All @@ -13,7 +14,7 @@
from scirpy.ir_dist import MetricType, _get_metric_key
from scirpy.ir_dist._clonotype_neighbors import ClonotypeNeighbors
from scirpy.pp import ir_dist
from scirpy.util import DataHandler
from scirpy.util import DataHandler, read_cell_indices
from scirpy.util.graph import igraph_from_sparse_matrix, layout_components

_common_doc = """\
Expand Down Expand Up @@ -89,7 +90,7 @@
A dictionary containing
* `distances`: A sparse, pairwise distance matrix between unique
receptor configurations
* `cell_indices`: A dict of arrays, containing the `adata.obs_names`
* `cell_indices`: A dict of lists, containing the `adata.obs_names`
(cell indices) for each row in the distance matrix.
If `inplace` is `True`, this is added to `adata.uns[key_added]`.
Expand Down Expand Up @@ -335,7 +336,7 @@ def define_clonotype_clusters(
# Return or store results
clonotype_distance_res = {
"distances": clonotype_dist,
"cell_indices": ctn.cell_indices,
"cell_indices": json.dumps(ctn.cell_indices),
}
if inplace:
params.set_obs(key_added, clonotype_cluster_series)
Expand Down Expand Up @@ -533,7 +534,7 @@ def clonotype_network(
# explicitly annotate node ids to keep them after subsetting
graph.vs["node_id"] = np.arange(0, len(graph.vs))

cell_indices = clonotype_res["cell_indices"]
cell_indices = read_cell_indices(clonotype_res["cell_indices"])

# store size in graph to be accessed by layout algorithms
clonotype_size = np.array([len(idx) for idx in cell_indices.values()])
Expand Down Expand Up @@ -603,7 +604,7 @@ def clonotype_network(
# Expand to cell coordinates to store in adata.obsm
idx, coords = zip(
*itertools.chain.from_iterable(
zip(clonotype_res["cell_indices"][str(node_id)], itertools.repeat(coord))
zip(cell_indices[str(node_id)], itertools.repeat(coord))
for node_id, coord in zip(graph.vs["node_id"], coords, strict=False) # type: ignore
),
strict=False,
Expand Down Expand Up @@ -631,10 +632,9 @@ def _graph_from_coordinates(adata: AnnData, clonotype_key: str, basis: str) -> t
"""
clonotype_res = adata.uns[clonotype_key]
# map the cell-id to the corresponding row/col in the clonotype distance matrix
cell_indices = read_cell_indices(clonotype_res["cell_indices"])
dist_idx, obs_names = zip(
*itertools.chain.from_iterable(
zip(itertools.repeat(i), obs_names) for i, obs_names in clonotype_res["cell_indices"].items()
),
*itertools.chain.from_iterable(zip(itertools.repeat(i), obs_names) for i, obs_names in cell_indices.items()),
strict=False,
)
dist_idx_lookup = pd.DataFrame(index=obs_names, data=dist_idx, columns=["dist_idx"])
Expand Down
17 changes: 10 additions & 7 deletions src/scirpy/tl/_ir_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from scirpy.ir_dist import MetricType, _get_metric_key
from scirpy.ir_dist._clonotype_neighbors import ClonotypeNeighbors
from scirpy.util import DataHandler, _is_na, tqdm
from scirpy.util import DataHandler, _is_na, read_cell_indices, tqdm

from ._clonotypes import _common_doc, _common_doc_parallelism, _doc_clonotype_definition, _validate_parameters

Expand Down Expand Up @@ -166,9 +166,9 @@ def ir_query(
A dictionary containing
* `distances`: A sparse distance matrix between unique receptor configurations
in `adata` aund unique receptor configurations in `reference`.
* `cell_indices`: A dict of arrays, containing the the `adata.obs_names`
* `cell_indices`: A dict of lists, containing the the `adata.obs_names`
(cell indices) for each row in the distance matrix.
* `cell_indices_reference`: A dict of arrays, containing the `reference.obs_names`
* `cell_indices_reference`: A dict of lists, containing the `reference.obs_names`
for each column in the distance matrix.
If `inplace` is `True`, this is added to `adata.uns[key_added]`.
Expand Down Expand Up @@ -206,8 +206,8 @@ def ir_query(
# Return or store results
clonotype_distance_res = {
"distances": clonotype_dist,
"cell_indices": ctn.cell_indices,
"cell_indices_reference": ctn.cell_indices2,
"cell_indices": json.dumps(ctn.cell_indices),
"cell_indices_reference": json.dumps(ctn.cell_indices2),
}
if inplace:
params.adata.uns[key_added] = clonotype_distance_res
Expand Down Expand Up @@ -284,10 +284,13 @@ def ir_query_annotate_df(
res = params.adata.uns[query_key]
dist = res["distances"]

cell_indices = read_cell_indices(res["cell_indices"])
cell_indices_reference = read_cell_indices(res["cell_indices_reference"])

def get_pairs():
for i, query_cells in res["cell_indices"].items():
for i, query_cells in cell_indices.items():
reference_cells = itertools.chain.from_iterable(
res["cell_indices_reference"][str(k)] for k in dist[int(i), :].indices
cell_indices_reference[str(k)] for k in dist[int(i), :].indices
)
yield from itertools.product(query_cells, reference_cells)

Expand Down
21 changes: 20 additions & 1 deletion src/scirpy/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import contextlib
import json
import os
import warnings
from collections.abc import Callable, Mapping, Sequence
from textwrap import dedent
from typing import Any, Optional, Union, cast, overload
from typing import Any, Literal, Optional, Union, cast, overload

import awkward as ak
import numpy as np
Expand Down Expand Up @@ -605,3 +606,21 @@ def _get_usable_cpus(n_jobs: int = 0, use_numba: bool = False):
usable_cpus = min(usable_cpus, config.NUMBA_NUM_THREADS)

return usable_cpus


def read_cell_indices(cell_indices: dict[str, np.ndarray[str]] | str) -> dict[str, list[str]]:
"""
The datatype of the cell_indices Mapping (clonotype_id -> cell_ids) that is stored to the anndata.uns
attribute after the ´define_clonotype_clusters´ function has changed from dict[str, np.ndarray[str] to
str (json) due to performance considerations regarding the writing speed of the anndata object. But we still
want that older anndata objects with the dict[str, np.ndarray[str] datatype can be used. So we use this function
to read the cell_indices from the anndata object to support both formats.
"""
if isinstance(cell_indices, str): # new format
return json.loads(cell_indices)
elif isinstance(cell_indices, dict): # old format
return {k: v.tolist() for k, v in cell_indices.items()}
else: # unsupported format
raise TypeError(
f"Unsupported type for cell_indices: {type(cell_indices)}. Expected str (json) or dict[str, np.ndarray[str]]."
)

0 comments on commit a9fe4d1

Please sign in to comment.