Skip to content

Commit

Permalink
add __enter__ and __exit__ to Database
Browse files Browse the repository at this point in the history
  • Loading branch information
edisj committed Jul 10, 2024
1 parent cf05ce2 commit c5823e1
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 50 deletions.
103 changes: 54 additions & 49 deletions mdaadb/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,40 @@ class Database:
"""

def __init__(self, dbfile: pathlib.Path | str):
self.dbfile = pathlib.Path(dbfile) if isinstance(dbfile, str) else dbfile
self.connection = sqlite3.connect(self.dbfile)
self.connection.row_factory = _namedtuple_factory
self.cursor = self.connection.cursor()

if isinstance(dbfile, str):
self.dbfile = pathlib.Path(dbfile)
else:
self.dbfile = dbfile
self.connection = None
self.cursor = None

self.open()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
self.close()

def __iter__(self):
return iter(self.tables)

def __contains__(self, table: Table) -> bool:
return table.db == self

def open(self):
if self.connection is None:
self.connection = sqlite3.connect(self.dbfile)
self.connection.row_factory = _namedtuple_factory
self.cursor = self.connection.cursor()

def close(self):
if self.connection is not None:
self.connection.close()
self.connection = None
self.cursor = None

@property
def schema(self) -> pd.DataFrame:
"""Get the database schema as a `pandas.DataFrame`."""
Expand Down Expand Up @@ -114,6 +137,10 @@ def create_table(self, tbl_schema: str, STRICT: bool = True) -> Table:
`self.connection.commit()` is automatically called.
"""

if self.connection is None:
raise ValueError("Database connection is closed.")

if STRICT:
tbl_schema += " STRICT"

Expand Down Expand Up @@ -295,17 +322,17 @@ def n_cols(self) -> int:
.ncol
)

@property
def rows(self) -> Iterator[Row]:
"""An iterable of Rows contained in this table."""
return (Row(id, self) for id in self._get_rowids())
# @property
# def rows(self) -> Iterator[NamedTuple]:
# """An iterable of Rows contained in this table."""
# return (Row(id, self) for id in self._get_rowids())

@property
def columns(self) -> Iterator[Column]:
"""An iterable of Columns contained in this table."""
return (Column(name, self) for name in self._get_column_names())

def get_row(self, id: int) -> Row:
def get_row(self, id: int) -> NamedTuple:
"""
Parameters
Expand All @@ -323,8 +350,14 @@ def get_row(self, id: int) -> Row:
"""
if id not in self._get_rowids():
raise ValueError(f"rowid {id} not in rowids")
return Row(id, self)
raise ValueError(f"{self.pk} {id} not found in '{self.name}' table")

return (
self.SELECT("*")
.WHERE(f"{self.pk}={id}")
.execute()
.fetchone()
)

def get_column(self, name: str) -> Column:
"""
Expand All @@ -344,7 +377,8 @@ def get_column(self, name: str) -> Column:
"""
if name not in self._get_column_names():
raise ValueError("name not in column names")
raise ValueError(f"{name} not a column in '{self.name}' table")

return Column(name, self)

def _get_rowids(self) -> List[int]:
Expand Down Expand Up @@ -385,6 +419,9 @@ def insert_row(self, row: tuple):
row : tuple
"""
if self.db.connection is None:
raise ValueError("Database connection is closed.")

with self.db.connection:
(
self.INSERT_INTO(self.name)
Expand All @@ -400,6 +437,9 @@ def insert_array(self, array: ArrayLike) -> None:
array : ArrayLike
"""
if self.db.connection is None:
raise ValueError("Database connection is closed.")

n_cols = self.n_cols
values = "(" + ", ".join(["?" for _ in range(n_cols)]) + ")"
with self.db.connection:
Expand Down Expand Up @@ -472,43 +512,8 @@ def to_df(self) -> pd.DataFrame:
def __len__(self) -> int:
return self.n_rows

def __iter__(self) -> Iterator[Row]:
return iter(self.rows)


@dataclass
class Row:
id: int
table: Table

@property
def db(self) -> Database:
"""Returns the database that this row belongs to."""
return self.table.db

@property
def data(self):
""""""
return (
self.table
.SELECT("*")
.WHERE(f"{self.table.pk}={self.id}")
.execute()
.fetchone()
)

def to_sql(self) -> str:
"""Returns the sqlite query that generates this row."""
return (
self.table
.SELECT("*")
.WHERE(f"{self.table.pk}={self.id}")
.to_sql()
)

def to_df(self) -> pd.DataFrame:
"""Returns a pandas DataFrame of this row."""
return pd.read_sql(self.to_sql(), self.db.connection)
# def __iter__(self) -> Iterator[NamedTuple]:
# return iter(self.rows)


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion mdaadb/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def SELECT_COUNT(self, column: str) -> Self:
return self


@register_command(dependency="INSERT")
@register_command(dependency="INSERT INTO")
def VALUES(self, values: str) -> Self:
"""Add VALUES command to the current Query.
Expand Down

0 comments on commit c5823e1

Please sign in to comment.