Skip to content

Commit

Permalink
add analysis scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
edisj committed Jul 11, 2024
1 parent c5823e1 commit c1a34c6
Show file tree
Hide file tree
Showing 5 changed files with 288 additions and 11 deletions.
1 change: 1 addition & 0 deletions mdaadb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .query import Query
from .database import Database, Table
from .analysis import DBAnalysisRunner


__version__ = version("mdaadb-kit")
78 changes: 78 additions & 0 deletions mdaadb/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import List, NamedTuple
from pathlib import Path

from napalib.system.universe import NapAUniverse
import MDAnalysis as mda

from . import Database


def get_NapA_universe_by_simID(db: Database, simID: int) -> NapAUniverse:
row = db.get_table("Simulations").get_row(simID)
topology = row.topology
trajectory = row.trajectory
u = NapAUniverse(topology)
u.load_new(trajectory)

return u


def get_universe_by_simID(db: Database, simID: int) -> mda.Universe:
row = db.get_table("Simulations").get_row(simID)
topology = row.topology
trajectory = row.trajectory

return mda.Universe(topology, trajectory)


class DBAnalysisRunner:

def __init__(self, db: Database, Analysis):

self.db = db
self.Analysis = Analysis
self.name = self.Analysis.name
self._analysis = None

try:
self.observables = self.db.get_table("Observables")
except ValueError:
self.observables = self.db.create_table(
"Observables (name TEXT, progenitor TEXT)"
)
finally:
self.observables.insert_array([
(self.name, self.Analysis._path),
])

def __enter__(self):
self.db.open()
return self

def __exit__(self, *args):
self.db.close()

@property
def results(self):
if self._analysis is not None:
return self._analysis.results

def run_for_simID(self, simID: int, **kwargs) -> None:
""""""
universe = get_NapA_universe_by_simID(self.db, simID)
self._analysis = self.Analysis(universe, **kwargs)
self._analysis._simID = simID
self._analysis.run()

def save(self) -> None:
if not self.results:
raise ValueError("no results")

if self.name not in self.db._get_table_names():
self.db.create_table(self.Analysis.schema)

rows = self.results[self.Analysis.results_key]

table = self.db.get_table(self.name)
table.insert_array(rows)

53 changes: 42 additions & 11 deletions mdaadb/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,36 @@ def _namedtuple_factory(cursor, row):
return Row(*row)


class Tables(UserDict):
def __init__(self, db, *args, **kwargs):
self.db = db
def touch(dbfile: pathlib.Path | str) -> None:
"""Create a minimal database file.
A legal database includes a 'Simulations' and 'Observables' table
at a minimum. 'Simulations' must contain 'simID', 'topology', and
'trajectory' columns. 'Observables' must contain 'name' and 'progenitor'
columns.
Parameters
----------
dbfile : path-like
Path to database file.
Raises
------
ValueError
If database already exists.
"""
if isinstance(dbfile, str):
dbfile = pathlib.Path(dbfile)

if dbfile.exists():
raise ValueError(f"{dbfile} already exists")

with Database(dbfile) as db:
sim_schema = "Simulations (simID INT PRIMARY KEY, topology TEXT, trajectory TEXT)"
db.create_table(sim_schema)
obs_schema = "Observables (name TEXT, progenitor TEXT)"
db.create_table(obs_schema)


class Database:
Expand All @@ -39,11 +66,15 @@ def __init__(self, dbfile: pathlib.Path | str):
self.dbfile = pathlib.Path(dbfile)
else:
self.dbfile = dbfile

self.connection = None
self.cursor = None

self.open()

def __contains__(self, table: Table) -> bool:
return table.db == self

def __enter__(self):
return self

Expand All @@ -53,9 +84,6 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
def __iter__(self):
return iter(self.tables)

def __contains__(self, table: Table) -> bool:
return table.db == self

def open(self):
if self.connection is None:
self.connection = sqlite3.connect(self.dbfile)
Expand Down Expand Up @@ -166,7 +194,7 @@ def insert_row_into_table(self, table: Table | str, row: tuple) -> None:
table = Table(table, self)
table.insert_row(row)

def insert_array_into_table(self, table: Table | str, array: ArrayLike) -> None:
def insert_array_into_table(self, table: Table | str, array: List[tuple]) -> None:
"""...
Parameters
Expand Down Expand Up @@ -313,8 +341,7 @@ def n_rows(self) -> int:
def n_cols(self) -> int:
"""Total number of columns in this table."""
return (
self.db.
_table_list
self.db._table_list
.SELECT("ncol")
.WHERE(f"name='{self.name}'")
.execute()
Expand Down Expand Up @@ -411,7 +438,7 @@ def pk(self) -> str:

return result.name

def insert_row(self, row: tuple):
def insert_row(self, row: tuple) -> None:
"""
Parameters
Expand All @@ -429,7 +456,7 @@ def insert_row(self, row: tuple):
.execute()
)

def insert_array(self, array: ArrayLike) -> None:
def insert_array(self, array: List[tuple]) -> None:
"""
Parameters
Expand Down Expand Up @@ -556,3 +583,7 @@ def to_df(self) -> pd.DataFrame:
""""""
return pd.read_sql(self.to_sql(), self.db.connection)


class Tables(UserDict):
def __init__(self, db, *args, **kwargs):
self.db = db
85 changes: 85 additions & 0 deletions scripts/K305D156.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from collections import namedtuple
import pathlib

import numpy as np
from MDAnalysis.analysis.base import AnalysisBase
from MDAnalysis.analysis.distances import distance_array

from mdaadb.analysis import DBAnalysisRunner
from mdaadb import Database


class K305_D156(AnalysisBase):

name = "K305D156"
results_key = "distance"
Row = namedtuple("Row", ["simID", "frame", "time", "dA", "dB"])
schema = f"{name} (simID INT, frame INT, time REAL, dA REAL, dB REAL)"
_path = str(pathlib.Path(__file__).resolve())

def __init__(self, universe):
self._simID = None

super().__init__(universe.trajectory)
self.u = universe
self.n_frames = len(self.u.trajectory)

self.OD_A = self.u.select_atoms(
"resid 156 and name OD1 OD2 and segid A"
)
self.OD_B = self.u.select_atoms(
"resid 156 and name OD1 OD2 and segid B"
)
self.NZ_A = self.u.select_atoms(
"resid 305 and name NZ and segid A"
)
self.NZ_B = self.u.select_atoms(
"resid 305 and name NZ and segid B"
)

def _prepare(self):
self.results[self.results_key] = []

def _single_frame(self):
simID = self._simID
ts = self.u.trajectory.ts
frame = ts.frame
time = ts.time

d_A = np.min(
distance_array(
self.OD_A.positions,
self.NZ_A.positions,
box=self.u.dimensions
)
)
d_B = np.min(
distance_array(
self.OD_B.positions,
self.NZ_B.positions,
box=self.u.dimensions
)
)

row = self.Row(simID, frame, time, d_A, d_B)

self.results[self.results_key].append(row)


def main():

napa_dbfile = pathlib.Path("~/projects/napadb/napa.sqlite")
napa = Database(napa_dbfile)

analysis_runner = DBAnalysisRunner(napa, K305_D156)
with analysis_runner as analysis:
simulations = analysis.db.get_table("Simulations")
simids = simulations._get_rowids()

for simID in simids:
analysis.run_for_simID(simID)
analysis.save()


if __name__ == "__main__":
main()
82 changes: 82 additions & 0 deletions scripts/create_nabadb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import pathlib

from napalib.system.traj import trajectories
from mdaadb import Database, DBAnalysisRunner


def main():

def get_topology(traj):
return traj.topology

def get_trajectory(traj):
return traj.trajectory

def get_repeat(traj):
return int(traj.name().split("_")[2])

def get_name(traj):
repeat = get_repeat(traj)
name = traj.name().split("_")
name.remove(f"{repeat}")
name = "_".join(name)
return name

def get_conformation(traj):
if traj.is_inward:
return "inward"
if traj.is_outward:
return "outward"
if "_occ_" in traj.name():
return "occluded"

def get_temperature(traj):
if traj.is_310:
return 310
if traj.is_358:
return 358

def get_protonation(traj):
protonation = []
if traj.has_s1:
protonation.append("s1")
if traj.has_s2:
protonation.append("s2")
if traj.has_s4:
protonation.append("s4")

return ",".join(protonation)


def get_row(idx, traj):
row = (
idx,
get_topology(traj),
get_trajectory(traj),
get_name(traj),
get_repeat(traj),
get_conformation(traj),
get_temperature(traj),
get_protonation(traj),
)
return row

sims_schema = (
"Simulations (simID INT PRIMARY KEY, topology TEXT, trajectory TEXT, name TEXT, repeat INT, conformation TEXT, temperature INT, protonation TEXT)"
)
obs_schema = (
"Observables (name TEXT, progenitor TEXT)"
)

rows = [get_row(idx, traj) for (idx, traj) in enumerate(trajectories)]

dbfile = pathlib.Path("~/projects/napadb/napa.sqlite")
with Database(dbfile) as db:
db.create_table(sims_schema)
db.create_table(obs_schema)

db.get_table("Simulations").insert_array(rows)


if __name__ == "__main__":
main()

0 comments on commit c1a34c6

Please sign in to comment.