diff --git a/src/cellarr/CellArrDataset.py b/src/cellarr/CellArrDataset.py index 524963e..2473e17 100644 --- a/src/cellarr/CellArrDataset.py +++ b/src/cellarr/CellArrDataset.py @@ -442,7 +442,10 @@ def __getitem__( #### @property def shape(self): - return (self._cell_metadata_tdb.shape[0], self._gene_annotation_tdb.shape[0]) + return ( + self._cell_metadata_tdb.nonempty_domain()[0][1] + 1, + self._gene_annotation_tdb.nonempty_domain()[0][1] + 1, + ) def __len__(self): return self.shape[0] diff --git a/src/cellarr/dataloader.py b/src/cellarr/dataloader.py index 4f99c63..c979bca 100644 --- a/src/cellarr/dataloader.py +++ b/src/cellarr/dataloader.py @@ -254,8 +254,8 @@ def __init__( ) self.matrix_shape = ( - self.cell_metadata_tdb.shape[0], - self.gene_annotation_tdb.shape[0], + self.cell_metadata_tdb.nonempty_domain()[0][1] + 1, + self.gene_annotation_tdb.nonempty_domain()[0][1] + 1, ) # limit to cells with labels