Skip to content

Commit

Permalink
add analysis hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
edisj committed Jul 16, 2024
1 parent 68ff3e8 commit f3807d8
Showing 1 changed file with 55 additions and 38 deletions.
93 changes: 55 additions & 38 deletions mdaadb/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,34 @@ def __init__(self, Analysis, dbfile, hooks=None):
Parameters
----------
Analysis :
dbfile :
hooks :
Analysis : mda.analysis.base.AnalysisBase
dbfile : path-like
hooks : dict
"""

self.Analysis = Analysis
self.db = Database(dbfile)
self._analysis = None

self.hooks = {
"pre_run": None,
"post_run": None,
"get_universe": None,
"post_save": None,
}
if hooks is not None:
self.hooks.update(hooks)

try:
self.analysis_name = self.Analysis.name
self._name = self.Analysis.name
except AttributeError:
self.analysis_name = self.Analysis.__name__
self._name = self.Analysis.__name__
try:
self.analysis_notes = self.Analysis.notes
self._notes = self.Analysis.notes
except AttributeError:
self.analysis_notes = None
self.analysis_path = inspect.getfile(self.Analysis)
self._notes = None
self._path = inspect.getfile(self.Analysis)

try:
self.obsv = self.db.get_table("Observables")
Expand All @@ -51,72 +61,79 @@ def __init__(self, Analysis, dbfile, hooks=None):
STRICT=False
)

if self.analysis_name not in self.obsv.get_column("obsName").data:
if self._name not in self.obsv.get_column("obsName").data:
self.obsv.insert_row(
(self.analysis_name, self.analysis_notes, self.analysis_path),
(self._name, self._notes, self._path),
columns=["obsName, notes, creator"],
)

self._analysis = None

@property
def results(self) -> Results:
"""Analysis results."""
"""Analysis results"""

if self._analysis is None:
raise ValueError("Must call run() for results to exist.")
return self._analysis.results

def _get_universe(self, simID: int, get_universe: Optional[Callable]):
def _get_universe(self, simID: int):

if get_universe is not None:
if self.hooks["get_universe"] is not None:
#if self.hooks["get_universe"]:
return get_universe(self.db, simID)
return self.hooks["get_universe"](self.db, simID)

row = self.db.get_table("Simulations").get_row(simID)
return mda.Universe(row.topology, row.trajecory)

def run(
self,
simID: int,
get_universe: Optional[Callable] = None,
**kwargs: dict,
) -> None:
"""
return mda.Universe(row.topology, row.trajectory)

def run(self, simID: int, **kwargs: dict) -> None:
"""Run the analysis for a simulation given by `simID`.
Parameters
----------
simID : int
get_universe : Callable[Database, int]
**kwargs : dict
additional keyword arguments to be passed to the Analysis class
"""
u = self._get_universe(simID, get_universe)

self._analysis = self.Analysis(u, **kwargs)
univserse = self._get_universe(simID)

self._analysis = self.Analysis(univserse, **kwargs)
self._analysis._simID = simID

if self.hooks["pre_run"] is not None:
self.hooks["pre_run"](simID, self.db)

self._analysis.run()

if self.hooks["post_run"] is not None:
self.hooks["post_run"](simID, self.db)

def save(self) -> None:
"""Save the results of the analysis to the database."""

assert self._analysis is not None

if not self.results:
raise ValueError("no results")

analysis_table = Table(self.analysis_name, self.db)

if analysis_table not in self.db:
self.db.create_table(self.Analysis.schema)
try:
analysis_table = self.db.get_table(self._name)
except ValueError:
analysis_table = self.db.create_table(self.Analysis.schema)
else:
simID = self._analysis._simID
if simID in analysis_table._get_rowids():
raise ValueError(
f"{self.analysis_name} table already has data for simID {simID}"
)
assert analysis_table.schema == self.Analysis.schema

simID = self._analysis._simID
if simID in analysis_table._get_rowids():
raise ValueError(
f"{self._name} table already has data for simID {simID}"
)

rows = self.results[self.Analysis.results_key]
analysis_table.insert_rows(rows)

self.db.get_table(self.analysis_name).insert_rows(rows)
if self.hooks["post_save"] is not None:
self.hooks["post_save"](simID, self.db)

def __enter__(self):
return self
Expand Down

0 comments on commit f3807d8

Please sign in to comment.