diff --git a/mdaadb/analysis.py b/mdaadb/analysis.py index 9303a32..d39ea93 100644 --- a/mdaadb/analysis.py +++ b/mdaadb/analysis.py @@ -17,24 +17,34 @@ def __init__(self, Analysis, dbfile, hooks=None): Parameters ---------- - Analysis : - dbfile : - hooks : + Analysis : mda.analysis.base.AnalysisBase + dbfile : path-like + hooks : dict """ self.Analysis = Analysis self.db = Database(dbfile) + self._analysis = None + + self.hooks = { + "pre_run": None, + "post_run": None, + "get_universe": None, + "post_save": None, + } + if hooks is not None: + self.hooks.update(hooks) try: - self.analysis_name = self.Analysis.name + self._name = self.Analysis.name except AttributeError: - self.analysis_name = self.Analysis.__name__ + self._name = self.Analysis.__name__ try: - self.analysis_notes = self.Analysis.notes + self._notes = self.Analysis.notes except AttributeError: - self.analysis_notes = None - self.analysis_path = inspect.getfile(self.Analysis) + self._notes = None + self._path = inspect.getfile(self.Analysis) try: self.obsv = self.db.get_table("Observables") @@ -51,72 +61,79 @@ def __init__(self, Analysis, dbfile, hooks=None): STRICT=False ) - if self.analysis_name not in self.obsv.get_column("obsName").data: + if self._name not in self.obsv.get_column("obsName").data: self.obsv.insert_row( - (self.analysis_name, self.analysis_notes, self.analysis_path), + (self._name, self._notes, self._path), columns=["obsName, notes, creator"], ) - self._analysis = None - @property def results(self) -> Results: - """Analysis results.""" + """Analysis results""" if self._analysis is None: raise ValueError("Must call run() for results to exist.") return self._analysis.results - def _get_universe(self, simID: int, get_universe: Optional[Callable]): + def _get_universe(self, simID: int): - if get_universe is not None: + if self.hooks["get_universe"] is not None: #if self.hooks["get_universe"]: - return get_universe(self.db, simID) + return self.hooks["get_universe"](self.db, simID) row = self.db.get_table("Simulations").get_row(simID) - return mda.Universe(row.topology, row.trajecory) - - def run( - self, - simID: int, - get_universe: Optional[Callable] = None, - **kwargs: dict, - ) -> None: - """ + return mda.Universe(row.topology, row.trajectory) + + def run(self, simID: int, **kwargs: dict) -> None: + """Run the analysis for a simulation given by `simID`. + Parameters ---------- simID : int - get_universe : Callable[Database, int] **kwargs : dict additional keyword arguments to be passed to the Analysis class """ - u = self._get_universe(simID, get_universe) - self._analysis = self.Analysis(u, **kwargs) + univserse = self._get_universe(simID) + + self._analysis = self.Analysis(univserse, **kwargs) self._analysis._simID = simID + + if self.hooks["pre_run"] is not None: + self.hooks["pre_run"](simID, self.db) + self._analysis.run() + if self.hooks["post_run"] is not None: + self.hooks["post_run"](simID, self.db) + def save(self) -> None: """Save the results of the analysis to the database.""" + assert self._analysis is not None + if not self.results: raise ValueError("no results") - analysis_table = Table(self.analysis_name, self.db) - - if analysis_table not in self.db: - self.db.create_table(self.Analysis.schema) + try: + analysis_table = self.db.get_table(self._name) + except ValueError: + analysis_table = self.db.create_table(self.Analysis.schema) else: - simID = self._analysis._simID - if simID in analysis_table._get_rowids(): - raise ValueError( - f"{self.analysis_name} table already has data for simID {simID}" - ) + assert analysis_table.schema == self.Analysis.schema + + simID = self._analysis._simID + if simID in analysis_table._get_rowids(): + raise ValueError( + f"{self._name} table already has data for simID {simID}" + ) rows = self.results[self.Analysis.results_key] + analysis_table.insert_rows(rows) - self.db.get_table(self.analysis_name).insert_rows(rows) + if self.hooks["post_save"] is not None: + self.hooks["post_save"](simID, self.db) def __enter__(self): return self