diff --git a/anndata/base.py b/anndata/base.py index 309cd8b4d..df8c071c2 100644 --- a/anndata/base.py +++ b/anndata/base.py @@ -4,7 +4,10 @@ import warnings import logging as logg from enum import Enum -from collections import Mapping, Sequence, Sized +from collections import Mapping, Sequence, Sized, ChainMap +from functools import reduce +from typing import Union + import numpy as np from numpy import ma import pandas as pd @@ -844,7 +847,7 @@ def shape(self): return self.n_obs, self.n_vars @property - def X(self): + def X(self) -> Union[np.ndarray, sparse.spmatrix]: """Data matrix of shape `n_obs` × `n_vars` (`np.ndarray`, `sp.sparse.spmatrix`).""" if self.isbacked: if not self.file.isopen: self.file.open() @@ -1297,15 +1300,17 @@ def copy(self, filename=None): copyfile(self.filename, filename) return AnnData(filename=filename) - def concatenate(self, adatas, batch_key='batch', batch_categories=None, index_unique='-'): + def concatenate(self, *adatas, join='inner', batch_key='batch', batch_categories=None, index_unique='-'): """Concatenate along the observations axis after intersecting the variables names. The `.var`, `.varm`, and `.uns` attributes of the passed adatas are ignored. Parameters ---------- - adatas : :class:`~anndata.AnnData` or list of :class:`~anndata.AnnData` + adatas : :class:`~anndata.AnnData` AnnData matrices to concatenate with. + join: `str` (default: 'inner') + Use intersection (``'inner'``) or union (``'outer'``) of variables? batch_key : `str` (default: 'batch') Add the batch annotation to `.obs` using this key. batch_categories : list, optional (default: `range(len(adatas)+1)`) @@ -1332,7 +1337,7 @@ def concatenate(self, adatas, batch_key='batch', batch_categories=None, index_un >>> {'anno2': ['d3', 'd4']}, >>> {'var_names': ['b', 'c', 'd']}) >>> - >>> adata = adata1.concatenate([adata2, adata3]) + >>> adata = adata1.concatenate(adata2, adata3) >>> adata.X [[ 2. 3.] [ 5. 6.] @@ -1351,41 +1356,56 @@ def concatenate(self, adatas, batch_key='batch', batch_categories=None, index_un 0-2 NaN d3 2 1-2 NaN d4 2 """ - if isinstance(adatas, AnnData): adatas = [adatas] - joint_variables = self.var_names - for adata2 in adatas: - joint_variables = np.intersect1d( - joint_variables, adata2.var_names, assume_unique=True) + if len(adatas) == 0: + return self + elif len(adatas) == 1 and not isinstance(adatas[0], AnnData): + adatas = adatas[0] # backwards compatibility + all_adatas = (self,) + adatas + + mergers = dict(inner=set.intersection, outer=set.union) + var_names = pd.Index(reduce(mergers[join], (set(ad.var_names) for ad in all_adatas))) + if batch_categories is None: - categories = [str(i) for i in range(len(adatas)+1)] - elif len(batch_categories) == len(adatas)+1: + categories = [str(i) for i, _ in enumerate(all_adatas)] + elif len(batch_categories) == len(all_adatas): categories = batch_categories else: raise ValueError('Provide as many `batch_categories` as `adatas`.') - adatas_to_concat = [] - for i, ad in enumerate([self] + adatas): - ad.obs.index.values - ad = ad[:, joint_variables] - ad.obs[batch_key] = pd.Categorical( - ad.n_obs*[categories[i]], categories=categories) - ad.obs.index.values + + out_shape = (sum(a.n_obs for a in all_adatas), len(var_names)) + + any_sparse = any(issparse(a.X) for a in all_adatas) + mat_cls = sparse.csc_matrix if any_sparse else np.ndarray + X = mat_cls(out_shape, dtype=self.X.dtype) + var = pd.DataFrame(index=var_names) + + obs_i = 0 # start of next adata’s observations in X + out_obss = [] + for i, ad in enumerate(all_adatas): + vars_ad_in_res = var_names.isin(ad.var_names) + vars_res_in_ad = ad.var_names.isin(var_names) + + # X + X[obs_i:obs_i+ad.n_obs, vars_ad_in_res] = ad.X[:, vars_res_in_ad] + obs_i += ad.n_obs + + # obs + obs = ad.obs.copy() + obs[batch_key] = pd.Categorical(ad.n_obs * [categories[i]], categories) if index_unique is not None: if not is_string_dtype(ad.obs.index): - ad.obs.index = ad.obs.index.astype(str) - ad.obs.index = ad.obs.index.values + index_unique + categories[i] - adatas_to_concat.append(ad) - Xs = [ad.X for ad in adatas_to_concat] - if issparse(self.X): - from scipy.sparse import vstack - X = vstack(Xs) - else: - X = np.concatenate(Xs) - obs = pd.concat([ad.obs for ad in adatas_to_concat]) - obsm = np.concatenate([ad.obsm for ad in adatas_to_concat]) - var = adatas_to_concat[0].var - varm = adatas_to_concat[0].varm - uns = adatas_to_concat[0].uns - return AnnData(X, obs, var, uns, obsm, varm, filename=self.filename) + obs.index = obs.index.astype(str) + obs.index = obs.index.values + index_unique + categories[i] + out_obss.append(obs) + + # var + var.loc[vars_ad_in_res, ad.var.columns] = ad.var.loc[vars_res_in_ad, :] + + obs = pd.concat(out_obss) + uns = dict(ChainMap({}, *[ad.obs for ad in all_adatas])) + obsm = np.concatenate([ad.obsm for ad in all_adatas]) + varm = self.varm # TODO + return AnnData(X, obs, var, uns, obsm, None, filename=self.filename) def var_names_make_unique(self, join='-'): self.var.index = utils.make_index_unique(self.var.index, join) @@ -1423,11 +1443,11 @@ def _check_dimensions(self, key=None): if 'obsm' in key and len(self._obsm) != self._n_obs: raise ValueError('Observations annot. `obsm` must have number of ' 'rows of `X` ({}), but has {} rows.' - .format(self._n_obs, self._obs.shape[0])) + .format(self._n_obs, len(self._obsm))) if 'varm' in key and len(self._varm) != self._n_vars: raise ValueError('Variables annot. `varm` must have number of ' 'columns of `X` ({}), but has {} rows.' - .format(self._n_vars, self._var.shape[0])) + .format(self._n_vars, len(self._varm))) def write(self, filename=None, compression='gzip', compression_opts=None): """Write `.h5ad`-formatted hdf5 file and close a potential backing file. diff --git a/anndata/tests/base.py b/anndata/tests/base.py index bcb9bd34f..d6b078a23 100644 --- a/anndata/tests/base.py +++ b/anndata/tests/base.py @@ -194,12 +194,12 @@ def test_concatenate(): {'obs_names': ['s5', 's6'], 'anno2': ['d3', 'd4']}, {'var_names': ['b', 'c', 'd']}) - adata = adata1.concatenate([adata2, adata3]) + adata = adata1.concatenate(adata2, adata3) assert adata.n_vars == 2 assert adata.obs_keys() == ['anno1', 'anno2', 'batch'] - adata = adata1.concatenate([adata2, adata3], batch_key='batch1') + adata = adata1.concatenate(adata2, adata3, batch_key='batch1') assert adata.obs_keys() == ['anno1', 'anno2', 'batch1'] - adata = adata1.concatenate([adata2, adata3], batch_categories=['a1', 'a2', 'a3']) + adata = adata1.concatenate(adata2, adata3, batch_categories=['a1', 'a2', 'a3']) assert adata.obs['batch'].cat.categories.tolist() == ['a1', 'a2', 'a3'] @@ -217,10 +217,23 @@ def test_concatenate_sparse(): {'obs_names': ['s5', 's6'], 'anno2': ['d3', 'd4']}, {'var_names': ['b', 'c', 'd']}) - adata = adata1.concatenate([adata2, adata3]) + adata = adata1.concatenate(adata2, adata3) assert adata.n_vars == 2 +def test_concatenate_outer(): + adata1 = AnnData(np.array([[1, 2, 3], [4, 5, 6]]), + {'obs_names': ['s1', 's2'], + 'anno1': ['c1', 'c2']}, + {'var_names': ['a', 'b', 'c']}) + adata2 = AnnData(np.array([[1, 2, 3], [4, 5, 6], [7,8,9]]), + {'obs_names': ['s3', 's4', 's5'], + 'anno2': ['c3', 'c4', 'c5']}, + {'var_names': ['b', 'c', 'd']}) + adata = adata1.concatenate(adata2, join='outer') + assert adata.n_vars == 4 + assert adata.obs_keys() == ['anno1', 'anno2', 'batch'] + # TODO: remove logging and actually test values # from scanpy import logging as logg