Skip to content

Commit

Permalink
Merge pull request #12 from theislab/concat-outer
Browse files Browse the repository at this point in the history
Added outer join for concatenate function
  • Loading branch information
falexwolf authored Feb 15, 2018
2 parents 40a24fb + 47ad1a1 commit 5a02063
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 40 deletions.
92 changes: 56 additions & 36 deletions anndata/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import warnings
import logging as logg
from enum import Enum
from collections import Mapping, Sequence, Sized
from collections import Mapping, Sequence, Sized, ChainMap
from functools import reduce
from typing import Union

import numpy as np
from numpy import ma
import pandas as pd
Expand Down Expand Up @@ -844,7 +847,7 @@ def shape(self):
return self.n_obs, self.n_vars

@property
def X(self):
def X(self) -> Union[np.ndarray, sparse.spmatrix]:
"""Data matrix of shape `n_obs` × `n_vars` (`np.ndarray`, `sp.sparse.spmatrix`)."""
if self.isbacked:
if not self.file.isopen: self.file.open()
Expand Down Expand Up @@ -1297,15 +1300,17 @@ def copy(self, filename=None):
copyfile(self.filename, filename)
return AnnData(filename=filename)

def concatenate(self, adatas, batch_key='batch', batch_categories=None, index_unique='-'):
def concatenate(self, *adatas, join='inner', batch_key='batch', batch_categories=None, index_unique='-'):
"""Concatenate along the observations axis after intersecting the variables names.
The `.var`, `.varm`, and `.uns` attributes of the passed adatas are ignored.
Parameters
----------
adatas : :class:`~anndata.AnnData` or list of :class:`~anndata.AnnData`
adatas : :class:`~anndata.AnnData`
AnnData matrices to concatenate with.
join: `str` (default: 'inner')
Use intersection (``'inner'``) or union (``'outer'``) of variables?
batch_key : `str` (default: 'batch')
Add the batch annotation to `.obs` using this key.
batch_categories : list, optional (default: `range(len(adatas)+1)`)
Expand All @@ -1332,7 +1337,7 @@ def concatenate(self, adatas, batch_key='batch', batch_categories=None, index_un
>>> {'anno2': ['d3', 'd4']},
>>> {'var_names': ['b', 'c', 'd']})
>>>
>>> adata = adata1.concatenate([adata2, adata3])
>>> adata = adata1.concatenate(adata2, adata3)
>>> adata.X
[[ 2. 3.]
[ 5. 6.]
Expand All @@ -1351,41 +1356,56 @@ def concatenate(self, adatas, batch_key='batch', batch_categories=None, index_un
0-2 NaN d3 2
1-2 NaN d4 2
"""
if isinstance(adatas, AnnData): adatas = [adatas]
joint_variables = self.var_names
for adata2 in adatas:
joint_variables = np.intersect1d(
joint_variables, adata2.var_names, assume_unique=True)
if len(adatas) == 0:
return self
elif len(adatas) == 1 and not isinstance(adatas[0], AnnData):
adatas = adatas[0] # backwards compatibility
all_adatas = (self,) + adatas

mergers = dict(inner=set.intersection, outer=set.union)
var_names = pd.Index(reduce(mergers[join], (set(ad.var_names) for ad in all_adatas)))

if batch_categories is None:
categories = [str(i) for i in range(len(adatas)+1)]
elif len(batch_categories) == len(adatas)+1:
categories = [str(i) for i, _ in enumerate(all_adatas)]
elif len(batch_categories) == len(all_adatas):
categories = batch_categories
else:
raise ValueError('Provide as many `batch_categories` as `adatas`.')
adatas_to_concat = []
for i, ad in enumerate([self] + adatas):
ad.obs.index.values
ad = ad[:, joint_variables]
ad.obs[batch_key] = pd.Categorical(
ad.n_obs*[categories[i]], categories=categories)
ad.obs.index.values

out_shape = (sum(a.n_obs for a in all_adatas), len(var_names))

any_sparse = any(issparse(a.X) for a in all_adatas)
mat_cls = sparse.csc_matrix if any_sparse else np.ndarray
X = mat_cls(out_shape, dtype=self.X.dtype)
var = pd.DataFrame(index=var_names)

obs_i = 0 # start of next adata’s observations in X
out_obss = []
for i, ad in enumerate(all_adatas):
vars_ad_in_res = var_names.isin(ad.var_names)
vars_res_in_ad = ad.var_names.isin(var_names)

# X
X[obs_i:obs_i+ad.n_obs, vars_ad_in_res] = ad.X[:, vars_res_in_ad]
obs_i += ad.n_obs

# obs
obs = ad.obs.copy()
obs[batch_key] = pd.Categorical(ad.n_obs * [categories[i]], categories)
if index_unique is not None:
if not is_string_dtype(ad.obs.index):
ad.obs.index = ad.obs.index.astype(str)
ad.obs.index = ad.obs.index.values + index_unique + categories[i]
adatas_to_concat.append(ad)
Xs = [ad.X for ad in adatas_to_concat]
if issparse(self.X):
from scipy.sparse import vstack
X = vstack(Xs)
else:
X = np.concatenate(Xs)
obs = pd.concat([ad.obs for ad in adatas_to_concat])
obsm = np.concatenate([ad.obsm for ad in adatas_to_concat])
var = adatas_to_concat[0].var
varm = adatas_to_concat[0].varm
uns = adatas_to_concat[0].uns
return AnnData(X, obs, var, uns, obsm, varm, filename=self.filename)
obs.index = obs.index.astype(str)
obs.index = obs.index.values + index_unique + categories[i]
out_obss.append(obs)

# var
var.loc[vars_ad_in_res, ad.var.columns] = ad.var.loc[vars_res_in_ad, :]

obs = pd.concat(out_obss)
uns = dict(ChainMap({}, *[ad.obs for ad in all_adatas]))
obsm = np.concatenate([ad.obsm for ad in all_adatas])
varm = self.varm # TODO
return AnnData(X, obs, var, uns, obsm, None, filename=self.filename)

def var_names_make_unique(self, join='-'):
self.var.index = utils.make_index_unique(self.var.index, join)
Expand Down Expand Up @@ -1423,11 +1443,11 @@ def _check_dimensions(self, key=None):
if 'obsm' in key and len(self._obsm) != self._n_obs:
raise ValueError('Observations annot. `obsm` must have number of '
'rows of `X` ({}), but has {} rows.'
.format(self._n_obs, self._obs.shape[0]))
.format(self._n_obs, len(self._obsm)))
if 'varm' in key and len(self._varm) != self._n_vars:
raise ValueError('Variables annot. `varm` must have number of '
'columns of `X` ({}), but has {} rows.'
.format(self._n_vars, self._var.shape[0]))
.format(self._n_vars, len(self._varm)))

def write(self, filename=None, compression='gzip', compression_opts=None):
"""Write `.h5ad`-formatted hdf5 file and close a potential backing file.
Expand Down
21 changes: 17 additions & 4 deletions anndata/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,12 @@ def test_concatenate():
{'obs_names': ['s5', 's6'],
'anno2': ['d3', 'd4']},
{'var_names': ['b', 'c', 'd']})
adata = adata1.concatenate([adata2, adata3])
adata = adata1.concatenate(adata2, adata3)
assert adata.n_vars == 2
assert adata.obs_keys() == ['anno1', 'anno2', 'batch']
adata = adata1.concatenate([adata2, adata3], batch_key='batch1')
adata = adata1.concatenate(adata2, adata3, batch_key='batch1')
assert adata.obs_keys() == ['anno1', 'anno2', 'batch1']
adata = adata1.concatenate([adata2, adata3], batch_categories=['a1', 'a2', 'a3'])
adata = adata1.concatenate(adata2, adata3, batch_categories=['a1', 'a2', 'a3'])
assert adata.obs['batch'].cat.categories.tolist() == ['a1', 'a2', 'a3']


Expand All @@ -217,10 +217,23 @@ def test_concatenate_sparse():
{'obs_names': ['s5', 's6'],
'anno2': ['d3', 'd4']},
{'var_names': ['b', 'c', 'd']})
adata = adata1.concatenate([adata2, adata3])
adata = adata1.concatenate(adata2, adata3)
assert adata.n_vars == 2


def test_concatenate_outer():
adata1 = AnnData(np.array([[1, 2, 3], [4, 5, 6]]),
{'obs_names': ['s1', 's2'],
'anno1': ['c1', 'c2']},
{'var_names': ['a', 'b', 'c']})
adata2 = AnnData(np.array([[1, 2, 3], [4, 5, 6], [7,8,9]]),
{'obs_names': ['s3', 's4', 's5'],
'anno2': ['c3', 'c4', 'c5']},
{'var_names': ['b', 'c', 'd']})
adata = adata1.concatenate(adata2, join='outer')
assert adata.n_vars == 4
assert adata.obs_keys() == ['anno1', 'anno2', 'batch']

# TODO: remove logging and actually test values
# from scanpy import logging as logg

Expand Down

0 comments on commit 5a02063

Please sign in to comment.