From c11e5e09e0e19f85a5ad17d4c26351bae201cd25 Mon Sep 17 00:00:00 2001 From: Jayaram Kancherla Date: Tue, 18 Jun 2024 09:25:47 -0700 Subject: [PATCH] more modifications --- src/cellarr/CellArrDataset.py | 199 ++++++++++++++---- src/cellarr/__init__.py | 2 +- ...ataset.py => buildutils_cellarrdataset.py} | 19 +- ...db_array.py => buildutils_tiledb_array.py} | 0 ...db_frame.py => buildutils_tiledb_frame.py} | 0 src/cellarr/queryutils_tiledb_frame.py | 78 +++++++ src/cellarr/utils_anndata.py | 2 +- tests/test_query.py | 10 +- 8 files changed, 247 insertions(+), 63 deletions(-) rename src/cellarr/{build_cellarrdataset.py => buildutils_cellarrdataset.py} (95%) rename src/cellarr/{utils_tiledb_array.py => buildutils_tiledb_array.py} (100%) rename src/cellarr/{utils_tiledb_frame.py => buildutils_tiledb_frame.py} (100%) create mode 100644 src/cellarr/queryutils_tiledb_frame.py diff --git a/src/cellarr/CellArrDataset.py b/src/cellarr/CellArrDataset.py index 91d1f28..9da138f 100644 --- a/src/cellarr/CellArrDataset.py +++ b/src/cellarr/CellArrDataset.py @@ -1,20 +1,25 @@ import os -from typing import Union +from typing import List, Union, Sequence +import pandas as pd import tiledb +from . import queryutils_tiledb_frame as qtd + __author__ = "Jayaram Kancherla" __copyright__ = "Jayaram Kancherla" __license__ = "MIT" class CellArrDataset: - """A class that represent a collection of cells in TileDB.""" + """A class that represent a collection of cells and their associated metadata + in a TileDB backed store. + """ def __init__( self, dataset_path: str, - counts_tdb_uri: str = "counts", + matrix_tdb_uri: str = "counts", gene_metadata_uri: str = "gene_metadata", cell_metadata_uri: str = "cell_metadata", ): @@ -22,18 +27,18 @@ def __init__( Args: dataset_path: - Path to the directory containing the tiledb files. - Usually the output_path used in the + Path to the directory containing the tiledb stores. + Usually the ``output_path`` from the :py:func:`~cellarr.build_cellarrdataset.build_cellarrdataset`. counts_tdb_uri: - Path to counts TileDB. + Relative path to matrix store. gene_metadata_uri: - Path to gene metadata TileDB. + Relative path to gene metadata store. cell_metadata_uri: - Path to cell metadata TileDB. + Relative path to cell metadata store. """ if not os.path.isdir(dataset_path): @@ -41,7 +46,7 @@ def __init__( self._dataset_path = dataset_path # TODO: Maybe switch to on-demand loading of these objects - self._counts_tdb_tdb = tiledb.open(f"{dataset_path}/{counts_tdb_uri}", "r") + self._matrix_tdb_tdb = tiledb.open(f"{dataset_path}/{matrix_tdb_uri}", "r") self._gene_metadata_tdb = tiledb.open( f"{dataset_path}/{gene_metadata_uri}", "r" ) @@ -50,62 +55,122 @@ def __init__( ) def __del__(self): - self._counts_tdb_tdb.close() + self._matrix_tdb_tdb.close() self._gene_metadata_tdb.close() self._cell_metadata_tdb.close() - # TODO: - # Methods to implement - # search by gene - # search by cell metadata - # slice counts after search + def get_cell_metadata_columns(self) -> List[str]: + """Get column names from ``cell_metadata`` store. + + Returns: + List of available metadata columns. + """ + return qtd.get_schema_names_frame(self._cell_metadata_tdb) - def get_cell_metadata_columns(self): - columns = [] - for i in range(self._cell_metadata_tdb.schema.nattr): - columns.append(self._cell_metadata_tdb.schema.attr(i).name) + def get_cell_metadata_column(self, column_name: str) -> list: + """Access a column from the ``cell_metadata`` store. - return columns + Args: + column_name: + Name of the column or attribute. Usually one of the column names + from of :py:meth:`~get_cell_metadata_columns`. - def get_cell_metadata_column(self, column_name: str): - return self._cell_metadata_tdb.query(attrs=[column_name]).df[:] + Returns: + A list of values for this column. + """ + return qtd.get_a_column(self._cell_metadata_tdb, column_name=column_name) def get_cell_subset( self, subset: Union[slice, tiledb.QueryCondition], columns=None - ): - if columns is None: - columns = self.get_cell_metadata_columns() + ) -> pd.DataFrame: + """Slice the ``cell_metadata`` store. - query = self._cell_metadata_tdb.query(cond=subset, attrs=columns) - data = query.df[:] - result = data.dropna() - return result + Args: + subset: + A list of integer indices to subset the ``cell_metadata`` + store. + + Alternatively, may also provide a + :py:class:`tiledb.QueryCondition` to query the store. + + columns: + List of specific column names to access. - def get_gene_metadata_columns(self): - columns = [] - for i in range(self._gene_metadata_tdb.schema.nattr): - columns.append(self._gene_metadata_tdb.schema.attr(i).name) + Defaults to None, in which case all columns are extracted. + + Returns: + A pandas Dataframe of the subset. + """ + return qtd.subset_frame(self._cell_metadata_tdb, subset=subset, columns=columns) - return columns + def get_gene_metadata_columns(self) -> List[str]: + """Get annotation column names from ``gene_metadata`` store. + + Returns: + List of available annotations. + """ + return qtd.get_schema_names_frame(self._gene_metadata_tdb) def get_gene_metadata_column(self, column_name: str): - return self._gene_metadata_tdb.query(attrs=[column_name]).df[:] + """Access a column from the ``gene_metadata`` store. + + Args: + column_name: + Name of the column or attribute. Usually one of the column names + from of :py:meth:`~get_gene_metadata_columns`. + + Returns: + A list of values for this column. + """ + return qtd.get_a_column(self._gene_metadata_tdb, column_name=column_name) + + def get_gene_metadata_index(self): + """Get index of the ``gene_metadata`` store. This typically should store + all unique gene symbols. + + Returns: + List of unique symbols. + """ + return qtd.get_index(self._gene_metadata_tdb) + + def _get_indices_for_gene_list(self, query: list) -> List[int]: + _gene_index = self.get_gene_metadata_index() + return qtd._match_to_list(_gene_index, query=query) def get_gene_subset( - self, subset: Union[slice, tiledb.QueryCondition], columns=None + self, subset: Union[slice, List[str], tiledb.QueryCondition], columns=None ): - if columns is None: - columns = self.get_gene_metadata_columns() + """Slice the ``gene_metadata`` store. + + Args: + subset: + A list of integer indices to subset the ``gene_metadata`` + store. + + Alternatively, may provide a + :py:class:`tiledb.QueryCondition` to query the store. + + Alternatively, may provide a list of strings to match with + the index of ``gene_metadata`` store. + + columns: + List of specific column names to access. - query = self._gene_metadata_tdb.query(cond=subset, attrs=columns) - data = query.df[:] - result = data.dropna() - return result + Defaults to None, in which case all columns are extracted. + + Returns: + A pandas Dataframe of the subset. + """ + + if qtd._is_list_strings(subset): + subset = self._get_indices_for_gene_list(subset) + + return qtd.subset_frame(self._gene_metadata_tdb, subset=subset, columns=columns) def get_slice( self, cell_subset: Union[slice, tiledb.QueryCondition], - gene_subset: Union[slice, tiledb.QueryCondition], + gene_subset: Union[slice, List[str], tiledb.QueryCondition], ): _csubset = self.get_cell_subset(cell_subset) _cell_indices = _csubset.index.tolist() @@ -113,4 +178,50 @@ def get_slice( _gsubset = self.get_gene_subset(gene_subset) _gene_indices = _gsubset.index.tolist() - return self._counts_tdb_tdb.multi_index[_cell_indices, _gene_indices] + return self._matrix_tdb_tdb.multi_index[_cell_indices, _gene_indices] + + def __getitem__( + self, + args: Union[int, str, Sequence, tuple], + ): + """Subset a ``SummarizedExperiment``. + + Args: + args: + Integer indices, a boolean filter, or (if the current object is + named) names specifying the ranges to be extracted, see + :py:meth:`~biocutils.normalize_subscript.normalize_subscript`. + + Alternatively a tuple of length 1. The first entry specifies + the rows to retain based on their names or indices. + + Alternatively a tuple of length 2. The first entry specifies + the rows to retain, while the second entry specifies the + columns to retain, based on their names or indices. + + Raises: + ValueError: + If too many or too few slices provided. + + Returns: + Same type as caller with the sliced rows and columns. + """ + if isinstance(args, (str, int)): + return self.get_slice(args, slice(None)) + + if isinstance(args, tuple): + if len(args) == 0: + raise ValueError("At least one slicing argument must be provided.") + + if len(args) == 1: + return self.get_slice(args[0], slice(None)) + elif len(args) == 2: + return self.get_slice(args[0], args[1]) + else: + raise ValueError( + f"`{type(self).__name__}` only supports 2-dimensional slicing." + ) + + raise TypeError( + "args must be a sequence or a scalar integer or string or a tuple of atmost 2 values." + ) diff --git a/src/cellarr/__init__.py b/src/cellarr/__init__.py index e7c6f0c..2494da0 100644 --- a/src/cellarr/__init__.py +++ b/src/cellarr/__init__.py @@ -15,5 +15,5 @@ finally: del version, PackageNotFoundError -from .build_cellarrdataset import build_cellarrdataset +from .buildutils_cellarrdataset import build_cellarrdataset from .CellArrDataset import CellArrDataset diff --git a/src/cellarr/build_cellarrdataset.py b/src/cellarr/buildutils_cellarrdataset.py similarity index 95% rename from src/cellarr/build_cellarrdataset.py rename to src/cellarr/buildutils_cellarrdataset.py index b36d95d..ede5780 100644 --- a/src/cellarr/build_cellarrdataset.py +++ b/src/cellarr/buildutils_cellarrdataset.py @@ -7,7 +7,7 @@ - `gene_metadata`: A TileDB file containing gene metadata. - `cell_metadata`: A TileDB file containing cell metadata. -- A matrix TileDB file named as specified by the `layer_matrix_name` parameter. +- A matrix TileDB file named by the `layer_matrix_name` parameter. The TileDB matrix file is stored in a cell X gene orientation. This orientation is chosen because the fastest-changing dimension as new files are added to the @@ -62,8 +62,8 @@ import pandas as pd from . import utils_anndata as uad -from . import utils_tiledb_array as uta -from . import utils_tiledb_frame as utf +from . import buildutils_tiledb_array as uta +from . import buildutils_tiledb_frame as utf from .CellArrDataset import CellArrDataset __author__ = "Jayaram Kancherla" @@ -232,17 +232,16 @@ def build_cellarrdataset( gene_set = sorted(gene_set) - gene_metadata = pd.DataFrame({"genes": gene_set}, index=gene_set) + gene_metadata = pd.DataFrame({"cellarr_gene_index": gene_set}, index=gene_set) elif isinstance(gene_metadata, list): _gene_list = sorted(list(set(gene_metadata))) - gene_metadata = pd.DataFrame({"genes": _gene_list}, index=_gene_list) + gene_metadata = pd.DataFrame({"cellarr_gene_index": _gene_list}, index=_gene_list) elif isinstance(gene_metadata, dict): _gene_list = sorted(list(gene_metadata.keys())) - gene_metadata = pd.DataFrame({"genes": _gene_list}, index=_gene_list) + gene_metadata = pd.DataFrame({"cellarr_gene_index": _gene_list}, index=_gene_list) elif isinstance(gene_metadata, str): gene_metadata = pd.read_csv(gene_metadata, index=True, header=True) - - gene_metadata["genes_index"] = gene_metadata.index.tolist() + gene_metadata["cellarr_gene_index"] = gene_metadata.index.tolist() if not isinstance(gene_metadata, pd.DataFrame): raise TypeError("'gene_metadata' must be a pandas dataframe.") @@ -250,6 +249,8 @@ def build_cellarrdataset( if len(gene_metadata.index.unique()) != len(gene_metadata.index.tolist()): raise ValueError("'gene_metadata' must contain a unique index.") + gene_metadata.reset_index(drop=True, inplace=True) + if num_genes is None: num_genes = len(gene_metadata) @@ -367,7 +368,7 @@ def build_cellarrdataset( if optimize_tiledb: uta.optimize_tiledb_array(_counts_uri) - return CellArrDataset(dataset_path=output_path, counts_tdb_uri=layer_matrix_name) + return CellArrDataset(dataset_path=output_path, matrix_tdb_uri=layer_matrix_name) def generate_metadata_tiledb_frame( diff --git a/src/cellarr/utils_tiledb_array.py b/src/cellarr/buildutils_tiledb_array.py similarity index 100% rename from src/cellarr/utils_tiledb_array.py rename to src/cellarr/buildutils_tiledb_array.py diff --git a/src/cellarr/utils_tiledb_frame.py b/src/cellarr/buildutils_tiledb_frame.py similarity index 100% rename from src/cellarr/utils_tiledb_frame.py rename to src/cellarr/buildutils_tiledb_frame.py diff --git a/src/cellarr/queryutils_tiledb_frame.py b/src/cellarr/queryutils_tiledb_frame.py new file mode 100644 index 0000000..50af7fc --- /dev/null +++ b/src/cellarr/queryutils_tiledb_frame.py @@ -0,0 +1,78 @@ +from functools import lru_cache +from typing import List, Union +from warnings import warn + +import pandas as pd +import tiledb + +__author__ = "Jayaram Kancherla" +__copyright__ = "Jayaram Kancherla" +__license__ = "MIT" + + +@lru_cache +def get_schema_names_frame(tiledb_obj: tiledb.Array) -> List[str]: + columns = [] + for i in range(tiledb_obj.schema.nattr): + columns.append(tiledb_obj.schema.attr(i).name) + + return columns + + +def subset_frame( + tiledb_obj: tiledb.Array, + subset: Union[slice, tiledb.QueryCondition], + columns: Union[str, list] = None, +) -> pd.DataFrame: + + _avail_columns = get_schema_names_frame(tiledb_obj) + + if columns is None: + columns = _avail_columns + else: + _not_avail = [] + for col in columns: + if col not in _avail_columns: + _not_avail.append(col) + + if len(_not_avail) > 0: + raise ValueError(f"Columns '{', '.join(_not_avail)}' are not available.") + + if isinstance(columns, str): + warn( + "provided subset is string, its expected to be a 'query_condition'", + UserWarning, + ) + + query = tiledb_obj.query(cond=subset, attrs=columns) + data = query.df[:] + else: + data = query.df[subset, columns] + result = data.dropna() + return result + + +def get_a_column(tiledb_obj: tiledb.Array, column_name: str) -> list: + if column_name not in get_schema_names_frame(tiledb_obj): + raise ValueError(f"Column '{column_name}' does not exist.") + + return tiledb_obj.query(attrs=[column_name]).df[:] + + +@lru_cache +def get_index(tiledb_obj: tiledb.Array) -> list: + _index = tiledb_obj.unique_dim_values("__tiledb_rows") + return [x.decode() for x in _index] + + +def _match_to_list(x: list, query: list): + return sorted([x.index(x) for x in query]) + + +def _is_list_strings(x: list): + _ret = False + + if isinstance(x, (list, tuple)) and all(isinstance(y, str) for y in x): + _ret = True + + return _ret diff --git a/src/cellarr/utils_anndata.py b/src/cellarr/utils_anndata.py index 5921e08..0d57d2b 100644 --- a/src/cellarr/utils_anndata.py +++ b/src/cellarr/utils_anndata.py @@ -6,7 +6,7 @@ import numpy as np from scipy.sparse import coo_matrix, csr_array, csr_matrix -__author__ = "Jayaram Kancherla, Tony Kuo" +__author__ = "Jayaram Kancherla" __copyright__ = "Jayaram Kancherla" __license__ = "MIT" diff --git a/tests/test_query.py b/tests/test_query.py index 7f1ab4a..4363a5e 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -44,17 +44,11 @@ def test_query_cellarrdataset(): cfp = tiledb.open(f"{tempdir}/counts", "r") gfp = tiledb.open(f"{tempdir}/gene_metadata", "r") - genes = gfp.df[:] + cd = CellArrDataset(dataset_path=tempdir) gene_list = ["gene_1", "gene_95", "gene_50"] - gene_indices_tdb = sorted([genes.index.tolist().index(x) for x in gene_list]) + result1 = cd[0, gene_list] - adata1_gene_indices = sorted( - [adata1.var.index.tolist().index(x) for x in gene_list] - ) - adata2_gene_indices = sorted( - [adata2.var.index.tolist().index(x) for x in gene_list] - ) assert np.allclose( cfp.multi_index[0, gene_indices_tdb]["counts"],