diff --git a/src/multiassayexperiment/MultiAssayExperiment.py b/src/multiassayexperiment/MultiAssayExperiment.py index b9379ba..1fbe37e 100644 --- a/src/multiassayexperiment/MultiAssayExperiment.py +++ b/src/multiassayexperiment/MultiAssayExperiment.py @@ -450,12 +450,12 @@ def experiment_names(self, names: List[str]): ######>> experiment accessor <<###### ##################################### - def experiment(self, name: str, with_sample_data: bool = False) -> Any: + def experiment(self, name: Union[int, str], with_sample_data: bool = False) -> Any: """Get an experiment by name. Args: name: - Experiment name. + Name or index position of the experiment. with_sample_data: Whether to merge column data of the experiment with @@ -463,19 +463,40 @@ def experiment(self, name: str, with_sample_data: bool = False) -> Any: Defaults to False. + Raises: + AttributeError: + If the experiment name does not exist. + IndexError: + If index is greater than the number of experiments. + Returns: The experiment object. If ``with_sample_data`` is `True`, a copy of the experiment object is returned. """ - if name not in self._experiments: - raise ValueError(f"'{name}' is not a valid experiment name.") + _name = name + if isinstance(name, int): + if name < 0: + raise IndexError("Index cannot be negative.") + + if name > len(self.experiment_names): + raise IndexError("Index greater than the number of assays.") + + _name = self.experiment_names[name] + expt = self._experiments[_name] + elif isinstance(name, str): + if name not in self._experiments: + raise ValueError(f"'{name}' is not a valid experiment name.") - expt = self.experiments[name] + expt = self.experiments[name] + else: + raise TypeError( + f"'experiment' must be a string or integer, provided '{type(name)}'." + ) if with_sample_data is True: assay_splits = self.sample_map.split("assay", only_indices=True) - subset_map = self.sample_map[assay_splits[name],] + subset_map = self.sample_map[assay_splits[_name],] subset_map = subset_map.set_row_names(subset_map.get_column("colname")) expt_column_data = expt.column_data diff --git a/tests/test_with_coldata.py b/tests/test_with_coldata.py index 8d59ac5..ff5d917 100644 --- a/tests/test_with_coldata.py +++ b/tests/test_with_coldata.py @@ -106,3 +106,14 @@ def test_access_expt_with_column_data(): assert sce.shape == tsce.shape assert len(sce.column_data.columns) >= len(tsce.column_data.columns) + +def test_access_expt_with_int_index(): + assert mae is not None + + se = mae.experiment(0) + assert se.shape == tse2.shape + + sce = mae.experiment(1, with_sample_data=True) + assert sce.shape == tsce.shape + + assert len(sce.column_data.columns) >= len(tsce.column_data.columns) \ No newline at end of file