Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

single Numpy array option #503

Draft
wants to merge 36 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
08ef851
Change RowFeatureIndex and RowFeatureIndex tests to use a list of dic…
Nov 20, 2024
c9ff683
Update load_h5ad to append features in dict format to to the row feat…
Nov 20, 2024
5f86da4
Modify Single Cell Memmap Dataset unit tests to reflect changes
Nov 20, 2024
896bad0
remove conversion to np.array in get_row for now
Nov 20, 2024
eb17845
Convert values and col indices to np array so that we're not returnin…
Nov 21, 2024
1663903
Revert conversion to np array, and refactor num_vars_at_row to use in…
Nov 22, 2024
0497c98
Merge branch 'main' into savitha/scdl-performance-improvements
savitha-eng Nov 22, 2024
7a43706
Made changes requested in review.
Nov 25, 2024
da395b4
Merge branch 'savitha/scdl-performance-improvements' of github.com:NV…
Nov 25, 2024
9e11ab8
Integrate SCDL into Geneformer, rebased on the latest changes in main
Nov 26, 2024
37de5d1
Merge branch 'main' into savitha/integrate-scdl-geneformer-rebased
savitha-eng Nov 26, 2024
546f84e
Tests for Geneformer SingleCellDataset
Nov 26, 2024
0dcc56b
Merge branch 'savitha/integrate-scdl-geneformer-rebased' of github.co…
Nov 26, 2024
9d4c6a4
Data directory fixtures needed for pytest
Nov 26, 2024
eea6b42
Add bypass_tokenize_vocab to the arguments for this script
Nov 26, 2024
e642bc9
Changes to Inference tutorial notebook to support SCDL integrated Gen…
Dec 2, 2024
d755901
modify dataset dir creation
Dec 2, 2024
507e31b
all scdl integration changes
Dec 2, 2024
f6d9380
Merge branch 'main' into savitha/integrate-scdl-geneformer-rebased
savitha-eng Dec 2, 2024
9846894
Updated documentation, removed refs to sc_memmap, & made changes requ…
Dec 4, 2024
6afed04
Merge branch 'main' into savitha/integrate-scdl-geneformer-rebased
savitha-eng Dec 4, 2024
c63f8d9
single array
polinabinder1 Dec 5, 2024
c1491bd
better fix
polinabinder1 Dec 5, 2024
31bd648
remove notebook for tests
polinabinder1 Dec 5, 2024
84d53f2
more correct
polinabinder1 Dec 5, 2024
c4efa70
string formatting
polinabinder1 Dec 5, 2024
2e18cbb
Merge branch 'main' into savitha/integrate-scdl-geneformer-rebased
savitha-eng Dec 5, 2024
2971338
wandb version
polinabinder1 Dec 5, 2024
e780e66
Merge branch 'savitha/integrate-scdl-geneformer-rebased' into polinab…
polinabinder1 Dec 5, 2024
64f6ead
not using row feature index
polinabinder1 Dec 5, 2024
1a5a284
subclasses
polinabinder1 Dec 5, 2024
7fe0409
adding the length caching more
polinabinder1 Dec 6, 2024
d7940fd
cast as int
polinabinder1 Dec 6, 2024
73dddc9
profiling
polinabinder1 Dec 9, 2024
fed2c0c
full loading
polinabinder1 Dec 10, 2024
5e8c3bd
whole dataset
polinabinder1 Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
single array
  • Loading branch information
polinabinder1 committed Dec 5, 2024
commit c63f8d9834435f1d28776ba627745c8a32198cf7
Original file line number Diff line number Diff line change
@@ -107,7 +107,7 @@ def __init__( # noqa: D107
self._seed = seed
self.eos_token = eos_token

self.scdl = SingleCellMemMapDataset(str(data_path))
self.scdl = SingleCellMemMapDataset(str(data_path), feature_list=["feature_id"])

# - median dict
self.gene_medians = median_dict
Original file line number Diff line number Diff line change
@@ -28,6 +28,14 @@
__all__: Sequence[str] = ("RowFeatureIndex",)


def _are_dicts_equal(d1, d2):
# Check if the dictionaries have the same keys
if d1.keys() != d2.keys():
return False
# Check if the corresponding arrays are equal
return all(np.array_equal(d1[key], d2[key]) for key in d1)


class RowFeatureIndex:
"""Maintains a mapping between a row and its features.

@@ -47,13 +55,13 @@ class RowFeatureIndex:
_version: The version of the dataset
"""

def __init__(self) -> None:
def __init__(self, feature_list=[]) -> None:
"""Instantiates the index."""
self._cumulative_sum_index: np.array = np.array([-1])
self._feature_arr: list[dict[str, np.ndarray]] = []
self._num_genes_per_row: list[int] = []
self._version = importlib.metadata.version("bionemo.scdl")
self._labels: list[str] = []
self.all_same = feature_list == ["feature_id"]

def _get_dataset_id(self, row) -> int:
"""Gets the dataset id for a specified row index.
@@ -100,12 +108,20 @@ def append_features(
if isinstance(features, pd.DataFrame):
raise TypeError("Expected a dictionary, but received a Pandas DataFrame.")
csum = max(self._cumulative_sum_index[-1], 0)
self._cumulative_sum_index = np.append(self._cumulative_sum_index, csum + n_obs)
self._feature_arr.append(features)
self._num_genes_per_row.append(num_genes)
self._labels.append(label)

def lookup(self, row: int, select_features: Optional[list[str]] = None) -> Tuple[list[np.ndarray], str]:
if len(self._feature_arr) > 0 and _are_dicts_equal(features, self._feature_arr[-1]):
self._cumulative_sum_index[-1] = csum + n_obs
self._num_genes_per_row[-1] += num_genes
else:
self._cumulative_sum_index = np.append(self._cumulative_sum_index, csum + n_obs)
self._feature_arr.append(features)
self._num_genes_per_row.append(num_genes)
# self._labels.append(str(label))
if self.all_same:
self._feature_ids = np.array([self._feature_arr[0]["feature_ids"]]).astype(np.string_)
self._feature_arr = None

def lookup(self, row: int, select_features: Optional[list[str]] = None) -> Tuple[list[np.ndarray]]:
"""Find the features at a given row.

It is assumed that the row is
@@ -133,23 +149,28 @@ def lookup(self, row: int, select_features: Optional[list[str]] = None) -> Tuple
raise IndexError(
f"Row index {row} is larger than number of rows in FeatureIndex ({self._cumulative_sum_index[-1]})."
)
d_id = self._get_dataset_id(row)

# Retrieve the features for the identified value.
features_dict = self._feature_arr[d_id]

# If specific features are to be selected, filter the features.
if select_features is not None:
features = []
for feature in select_features:
if feature not in features_dict:
raise ValueError(f"Provided feature column {feature} in select_features not present in dataset.")
features.append(features_dict[feature])
if self.all_same:
return self._feature_ids
else:
features = [features_dict[f] for f in features_dict]

# Return the features for the identified range.
return features, self._labels[d_id]
d_id = self._get_dataset_id(row)

# Retrieve the features for the identified value.
features_dict = self._feature_arr[d_id]

# If specific features are to be selected, filter the features.
if select_features is not None:
features = []
for feature in select_features:
if feature not in features_dict:
raise ValueError(
f"Provided feature column {feature} in select_features not present in dataset."
)
features.append(features_dict[feature])
else:
features = [features_dict[f] for f in features_dict]

# Return the features for the identified range.
return features

def number_vars_at_row(self, row: int) -> int:
"""Return number of variables in a given row.
@@ -227,9 +248,9 @@ def concat(self, other_row_index: RowFeatureIndex, fail_on_empty_index: bool = T
raise ValueError("Error: Cannot append empty FeatureIndex.")
for i, feats in enumerate(list(other_row_index._feature_arr)):
c_span = other_row_index._cumulative_sum_index[i + 1]
label = other_row_index._labels[i]
# label = other_row_index._labels[i]
num_genes = other_row_index._num_genes_per_row[i]
self.append_features(c_span, feats, num_genes, label)
self.append_features(c_span, feats, num_genes)

return self

@@ -247,7 +268,7 @@ def save(self, datapath: str) -> None:
pq.write_table(table, f"{datapath}/dataframe_{dataframe_str_index}.parquet")

np.save(Path(datapath) / "cumulative_sum_index.npy", self._cumulative_sum_index)
np.save(Path(datapath) / "labels.npy", self._labels)
# np.save(Path(datapath) / "labels.npy", self._labels)
np.save(Path(datapath) / "version.npy", np.array(self._version))

@staticmethod
@@ -263,13 +284,14 @@ def load(datapath: str) -> RowFeatureIndex:
parquet_data_paths = sorted(Path(datapath).rglob("*.parquet"))
data_tables = [pq.read_table(csv_path) for csv_path in parquet_data_paths]
new_row_feat_index._feature_arr = [
{column: table[column].to_numpy() for column in table.column_names} for table in data_tables
{column: table[column].to_numpy().astype(np.string_) for column in table.column_names}
for table in data_tables
]
new_row_feat_index._num_genes_per_row = [
len(feats[next(iter(feats.keys()))]) for feats in new_row_feat_index._feature_arr
]

new_row_feat_index._cumulative_sum_index = np.load(Path(datapath) / "cumulative_sum_index.npy")
new_row_feat_index._labels = np.load(Path(datapath) / "labels.npy", allow_pickle=True)
# new_row_feat_index._labels = np.load(Path(datapath) / "labels.npy", allow_pickle=False)
new_row_feat_index._version = np.load(Path(datapath) / "version.npy").item()
return new_row_feat_index
Original file line number Diff line number Diff line change
@@ -239,6 +239,7 @@ def __init__(
mode: Mode = Mode.READ_APPEND,
paginated_load_cutoff: int = 10_000,
load_block_row_size: int = 1_000_000,
feature_list=None,
) -> None:
"""Instantiate the class.

@@ -251,6 +252,7 @@ def __init__(
mode: Whether to read or write from the data_path.
paginated_load_cutoff: MB size on disk at which to load the h5ad structure with paginated load.
load_block_row_size: Number of rows to load into memory with paginated load
feature_list: features to use
"""
self._version: str = importlib.metadata.version("bionemo.scdl")
self.data_path: str = data_path
@@ -268,7 +270,7 @@ def __init__(
# Stores the Feature Index, which tracks
# the original AnnData features (e.g., gene names)
# and allows us to store ragged arrays in our SCMMAP structure.
self._feature_index: RowFeatureIndex = RowFeatureIndex()
self._feature_index: RowFeatureIndex = RowFeatureIndex(feature_list)

# Variables for int packing / reduced precision
self.dtypes: Dict[FileNames, str] = {
@@ -356,7 +358,7 @@ def get_row(
columns = self.col_index[start:end]
ret = (values, columns)
if return_features:
return ret, self._feature_index.lookup(index, select_features=feature_vars)[0]
return ret, self._feature_index.lookup(index, select_features=feature_vars)
else:
return ret, None

Original file line number Diff line number Diff line change
@@ -102,15 +102,13 @@ def test_feature_index_internals_on_append(create_first_RowFeatureIndex):
assert sum(create_first_RowFeatureIndex.number_of_values()) == (12 * 3) + (8 * 5)
assert create_first_RowFeatureIndex.number_of_values()[1] == (8 * 5)
assert create_first_RowFeatureIndex.number_of_rows() == 20
feats, label = create_first_RowFeatureIndex.lookup(row=3, select_features=None)
feats = create_first_RowFeatureIndex.lookup(row=3, select_features=None)
assert np.all(feats[0] == one_feats["feature_name"])
assert np.all(feats[1] == one_feats["feature_int"])
assert label is None
feats, label = create_first_RowFeatureIndex.lookup(row=15, select_features=None)
feats = create_first_RowFeatureIndex.lookup(row=15, select_features=None)
assert np.all(feats[0] == two_feats["feature_name"])
assert np.all(feats[1] == two_feats["gene_name"])
assert np.all(feats[2] == two_feats["spare"])
assert label == "MY_DATAFRAME"


def test_concat_length(
@@ -154,15 +152,13 @@ def test_concat_lookup_results(
"spare": np.array([None, None, None, None, None]),
}
create_first_RowFeatureIndex.concat(create_second_RowFeatureIndex)
feats, label = create_first_RowFeatureIndex.lookup(row=3, select_features=None)
feats = create_first_RowFeatureIndex.lookup(row=3, select_features=None)
assert np.all(feats[0] == one_feats["feature_name"])
assert np.all(feats[1] == one_feats["feature_int"])
assert label is None
feats, label = create_first_RowFeatureIndex.lookup(row=15, select_features=None)
feats = create_first_RowFeatureIndex.lookup(row=15, select_features=None)
assert np.all(feats[0] == two_feats["feature_name"])
assert np.all(feats[1] == two_feats["gene_name"])
assert np.all(feats[2] == two_feats["spare"])
assert label == "MY_DATAFRAME"


def test_feature_lookup_empty():
@@ -197,7 +193,6 @@ def test_save_reload_row_feature_index_identical(
assert create_first_RowFeatureIndex.number_of_values() == index_reload.number_of_values()

for row in range(create_first_RowFeatureIndex.number_of_rows()):
features_one, labels_one = create_first_RowFeatureIndex.lookup(row=row, select_features=None)
features_reload, labels_reload = index_reload.lookup(row=row, select_features=None)
assert labels_one == labels_reload
assert np.all(np.array(features_one, dtype=object) == np.array(features_reload))
features_one = create_first_RowFeatureIndex.lookup(row=row, select_features=None)
features_reload = index_reload.lookup(row=row, select_features=None)
assert np.all(np.array(features_one).astype(np.string_) == np.array(features_reload).astype(np.string_))
Original file line number Diff line number Diff line change
@@ -36,10 +36,10 @@ def generate_dataset(tmp_path, test_directory) -> SingleCellMemMapDataset:
Returns:
A SingleCellMemMapDataset
"""
ds = SingleCellMemMapDataset(tmp_path / "scy", h5ad_path=test_directory / "adata_sample0.h5ad")
ds = SingleCellMemMapDataset(tmp_path / "scy", h5ad_path=test_directory / "adata_sample0.h5ad", feature_list=None)
ds.save()
del ds
reloaded = SingleCellMemMapDataset(tmp_path / "scy")
reloaded = SingleCellMemMapDataset(tmp_path / "scy", feature_list=None)
return reloaded


@@ -87,10 +87,10 @@ def _compare(dns: SingleCellMemMapDataset, dt: SingleCellMemMapDataset) -> bool:


def test_empty_dataset_save_and_reload(tmp_path):
ds = SingleCellMemMapDataset(data_path=tmp_path / "scy", num_rows=2, num_elements=10)
ds = SingleCellMemMapDataset(data_path=tmp_path / "scy", num_rows=2, num_elements=10, feature_list=None)
ds.save()
del ds
reloaded = SingleCellMemMapDataset(tmp_path / "scy")
reloaded = SingleCellMemMapDataset(tmp_path / "scy", feature_list=None)
assert reloaded.number_of_rows() == 0
assert reloaded.number_of_variables() == [0]
assert reloaded.number_of_values() == 0
@@ -102,11 +102,11 @@ def test_wrong_arguments_for_dataset(tmp_path):
with pytest.raises(
ValueError, match=r"An np.memmap path, an h5ad path, or the number of elements and rows is required"
):
SingleCellMemMapDataset(data_path=tmp_path / "scy")
SingleCellMemMapDataset(data_path=tmp_path / "scy", feature_list=None)


def test_load_h5ad(tmp_path, test_directory):
ds = SingleCellMemMapDataset(tmp_path / "scy", h5ad_path=test_directory / "adata_sample0.h5ad")
ds = SingleCellMemMapDataset(tmp_path / "scy", h5ad_path=test_directory / "adata_sample0.h5ad", feature_list=None)
assert ds.number_of_rows() == 8
assert ds.number_of_variables() == [10]
assert len(ds) == 8
@@ -117,7 +117,7 @@ def test_load_h5ad(tmp_path, test_directory):


def test_h5ad_no_file(tmp_path):
ds = SingleCellMemMapDataset(data_path=tmp_path / "scy", num_rows=2, num_elements=10)
ds = SingleCellMemMapDataset(data_path=tmp_path / "scy", num_rows=2, num_elements=10, feature_list=None)
with pytest.raises(FileNotFoundError, match=rf"Error: could not find h5ad path {tmp_path}/a"):
ds.load_h5ad(anndata_path=tmp_path / "a")

@@ -202,8 +202,8 @@ def test_SingleCellMemMapDataset_get_row_padded(generate_dataset):


def test_concat_SingleCellMemMapDatasets_same(tmp_path, test_directory):
ds = SingleCellMemMapDataset(tmp_path / "scy", h5ad_path=test_directory / "adata_sample0.h5ad")
dt = SingleCellMemMapDataset(tmp_path / "sct", h5ad_path=test_directory / "adata_sample0.h5ad")
ds = SingleCellMemMapDataset(tmp_path / "scy", h5ad_path=test_directory / "adata_sample0.h5ad", feature_list=None)
dt = SingleCellMemMapDataset(tmp_path / "sct", h5ad_path=test_directory / "adata_sample0.h5ad", feature_list=None)
dt.concat(ds)

assert dt.number_of_rows() == 2 * ds.number_of_rows()
@@ -212,8 +212,8 @@ def test_concat_SingleCellMemMapDatasets_same(tmp_path, test_directory):


def test_concat_SingleCellMemMapDatasets_diff(tmp_path, test_directory):
ds = SingleCellMemMapDataset(tmp_path / "scy", h5ad_path=test_directory / "adata_sample0.h5ad")
dt = SingleCellMemMapDataset(tmp_path / "sct", h5ad_path=test_directory / "adata_sample1.h5ad")
ds = SingleCellMemMapDataset(tmp_path / "scy", h5ad_path=test_directory / "adata_sample0.h5ad", feature_list=None)
dt = SingleCellMemMapDataset(tmp_path / "sct", h5ad_path=test_directory / "adata_sample1.h5ad", feature_list=None)

exp_number_of_rows = ds.number_of_rows() + dt.number_of_rows()
exp_n_val = ds.number_of_values() + dt.number_of_values()
@@ -225,26 +225,31 @@ def test_concat_SingleCellMemMapDatasets_diff(tmp_path, test_directory):


def test_concat_SingleCellMemMapDatasets_multi(tmp_path, compare_fn, test_directory):
ds = SingleCellMemMapDataset(tmp_path / "scy", h5ad_path=test_directory / "adata_sample0.h5ad")
dt = SingleCellMemMapDataset(tmp_path / "sct", h5ad_path=test_directory / "adata_sample1.h5ad")
dx = SingleCellMemMapDataset(tmp_path / "sccx", h5ad_path=test_directory / "adata_sample2.h5ad")
ds = SingleCellMemMapDataset(tmp_path / "scy", h5ad_path=test_directory / "adata_sample0.h5ad", feature_list=None)
dt = SingleCellMemMapDataset(tmp_path / "sct", h5ad_path=test_directory / "adata_sample1.h5ad", feature_list=None)
dx = SingleCellMemMapDataset(tmp_path / "sccx", h5ad_path=test_directory / "adata_sample2.h5ad", feature_list=None)
exp_n_obs = ds.number_of_rows() + dt.number_of_rows() + dx.number_of_rows()
dt.concat(ds)
dt.concat(dx)
assert dt.number_of_rows() == exp_n_obs
dns = SingleCellMemMapDataset(tmp_path / "scdns", h5ad_path=test_directory / "adata_sample1.h5ad")
dns = SingleCellMemMapDataset(
tmp_path / "scdns", h5ad_path=test_directory / "adata_sample1.h5ad", feature_list=None
)

dns.concat([ds, dx])
compare_fn(dns, dt)


def test_lazy_load_SingleCellMemMapDatasets_one_dataset(tmp_path, compare_fn, test_directory):
ds_regular = SingleCellMemMapDataset(tmp_path / "sc1", h5ad_path=test_directory / "adata_sample1.h5ad")
ds_regular = SingleCellMemMapDataset(
tmp_path / "sc1", h5ad_path=test_directory / "adata_sample1.h5ad", feature_list=None
)
ds_lazy = SingleCellMemMapDataset(
tmp_path / "sc2",
h5ad_path=test_directory / "adata_sample1.h5ad",
paginated_load_cutoff=0,
load_block_row_size=2,
feature_list=None,
)
compare_fn(ds_regular, ds_lazy)

@@ -256,5 +261,6 @@ def test_lazy_load_SingleCellMemMapDatasets_another_dataset(tmp_path, compare_fn
h5ad_path=test_directory / "adata_sample0.h5ad",
paginated_load_cutoff=0,
load_block_row_size=3,
feature_list=None,
)
compare_fn(ds_regular, ds_lazy)