From c5823e154a860f8f9d132a1d2e38bed052cf70c7 Mon Sep 17 00:00:00 2001 From: edisj Date: Tue, 9 Jul 2024 19:33:46 -0700 Subject: [PATCH] add __enter__ and __exit__ to Database --- mdaadb/database.py | 103 ++++++++++++++++++++++++--------------------- mdaadb/query.py | 2 +- 2 files changed, 55 insertions(+), 50 deletions(-) diff --git a/mdaadb/database.py b/mdaadb/database.py index ee068e5..b84c6ec 100644 --- a/mdaadb/database.py +++ b/mdaadb/database.py @@ -34,10 +34,21 @@ class Database: """ def __init__(self, dbfile: pathlib.Path | str): - self.dbfile = pathlib.Path(dbfile) if isinstance(dbfile, str) else dbfile - self.connection = sqlite3.connect(self.dbfile) - self.connection.row_factory = _namedtuple_factory - self.cursor = self.connection.cursor() + + if isinstance(dbfile, str): + self.dbfile = pathlib.Path(dbfile) + else: + self.dbfile = dbfile + self.connection = None + self.cursor = None + + self.open() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.close() def __iter__(self): return iter(self.tables) @@ -45,6 +56,18 @@ def __iter__(self): def __contains__(self, table: Table) -> bool: return table.db == self + def open(self): + if self.connection is None: + self.connection = sqlite3.connect(self.dbfile) + self.connection.row_factory = _namedtuple_factory + self.cursor = self.connection.cursor() + + def close(self): + if self.connection is not None: + self.connection.close() + self.connection = None + self.cursor = None + @property def schema(self) -> pd.DataFrame: """Get the database schema as a `pandas.DataFrame`.""" @@ -114,6 +137,10 @@ def create_table(self, tbl_schema: str, STRICT: bool = True) -> Table: `self.connection.commit()` is automatically called. """ + + if self.connection is None: + raise ValueError("Database connection is closed.") + if STRICT: tbl_schema += " STRICT" @@ -295,17 +322,17 @@ def n_cols(self) -> int: .ncol ) - @property - def rows(self) -> Iterator[Row]: - """An iterable of Rows contained in this table.""" - return (Row(id, self) for id in self._get_rowids()) + # @property + # def rows(self) -> Iterator[NamedTuple]: + # """An iterable of Rows contained in this table.""" + # return (Row(id, self) for id in self._get_rowids()) @property def columns(self) -> Iterator[Column]: """An iterable of Columns contained in this table.""" return (Column(name, self) for name in self._get_column_names()) - def get_row(self, id: int) -> Row: + def get_row(self, id: int) -> NamedTuple: """ Parameters @@ -323,8 +350,14 @@ def get_row(self, id: int) -> Row: """ if id not in self._get_rowids(): - raise ValueError(f"rowid {id} not in rowids") - return Row(id, self) + raise ValueError(f"{self.pk} {id} not found in '{self.name}' table") + + return ( + self.SELECT("*") + .WHERE(f"{self.pk}={id}") + .execute() + .fetchone() + ) def get_column(self, name: str) -> Column: """ @@ -344,7 +377,8 @@ def get_column(self, name: str) -> Column: """ if name not in self._get_column_names(): - raise ValueError("name not in column names") + raise ValueError(f"{name} not a column in '{self.name}' table") + return Column(name, self) def _get_rowids(self) -> List[int]: @@ -385,6 +419,9 @@ def insert_row(self, row: tuple): row : tuple """ + if self.db.connection is None: + raise ValueError("Database connection is closed.") + with self.db.connection: ( self.INSERT_INTO(self.name) @@ -400,6 +437,9 @@ def insert_array(self, array: ArrayLike) -> None: array : ArrayLike """ + if self.db.connection is None: + raise ValueError("Database connection is closed.") + n_cols = self.n_cols values = "(" + ", ".join(["?" for _ in range(n_cols)]) + ")" with self.db.connection: @@ -472,43 +512,8 @@ def to_df(self) -> pd.DataFrame: def __len__(self) -> int: return self.n_rows - def __iter__(self) -> Iterator[Row]: - return iter(self.rows) - - -@dataclass -class Row: - id: int - table: Table - - @property - def db(self) -> Database: - """Returns the database that this row belongs to.""" - return self.table.db - - @property - def data(self): - """""" - return ( - self.table - .SELECT("*") - .WHERE(f"{self.table.pk}={self.id}") - .execute() - .fetchone() - ) - - def to_sql(self) -> str: - """Returns the sqlite query that generates this row.""" - return ( - self.table - .SELECT("*") - .WHERE(f"{self.table.pk}={self.id}") - .to_sql() - ) - - def to_df(self) -> pd.DataFrame: - """Returns a pandas DataFrame of this row.""" - return pd.read_sql(self.to_sql(), self.db.connection) + # def __iter__(self) -> Iterator[NamedTuple]: + # return iter(self.rows) @dataclass diff --git a/mdaadb/query.py b/mdaadb/query.py index ee6241a..36908f2 100644 --- a/mdaadb/query.py +++ b/mdaadb/query.py @@ -322,7 +322,7 @@ def SELECT_COUNT(self, column: str) -> Self: return self - @register_command(dependency="INSERT") + @register_command(dependency="INSERT INTO") def VALUES(self, values: str) -> Self: """Add VALUES command to the current Query.