diff --git a/mdaadb/__init__.py b/mdaadb/__init__.py index d0879eb..6ea9690 100644 --- a/mdaadb/__init__.py +++ b/mdaadb/__init__.py @@ -8,7 +8,7 @@ from .query import Query from .database import Database, Table -from .analysis import DBAnalysisRunner +from .analysis import DBAnalysisManager __version__ = version("mdaadb-kit") diff --git a/mdaadb/database.py b/mdaadb/database.py index 0a0a3f8..af910c1 100644 --- a/mdaadb/database.py +++ b/mdaadb/database.py @@ -4,8 +4,15 @@ import pathlib from dataclasses import dataclass from collections import namedtuple, UserDict -from typing import Optional, List, NamedTuple, Iterator, Any -from numpy.typing import ArrayLike +from typing import ( + Optional, + List, + Tuple, + NamedTuple, + Iterator, + Sequence, + Any, +) import pandas as pd from . import query @@ -17,147 +24,69 @@ def _namedtuple_factory(cursor, row): return Row(*row) -def touch(dbfile: pathlib.Path | str) -> None: - """Create a minimal database file. - - A legal database includes a 'Simulations' and 'Observables' table - at a minimum. 'Simulations' must contain 'simID', 'topology', and - 'trajectory' columns. 'Observables' must contain 'name' and 'progenitor' - columns. - - Parameters - ---------- - dbfile : path-like - Path to database file. - - Raises - ------ - ValueError - If database already exists. - - """ - if isinstance(dbfile, str): - dbfile = pathlib.Path(dbfile) - - if dbfile.exists(): - raise ValueError(f"{dbfile} already exists") - - with Database(dbfile) as db: - sim_schema = "Simulations (simID INT PRIMARY KEY, topology TEXT, trajectory TEXT)" - db.create_table(sim_schema) - obs_schema = "Observables (name TEXT, progenitor TEXT)" - db.create_table(obs_schema) - - class Database: - """ - - Parameters - ---------- - dbfile : path-like - Path to the database file. - Use ':memory:' for a temporary in-memory database. - - """ def __init__(self, dbfile: pathlib.Path | str): + """Database constructor. + + Parameters + ---------- + dbfile : path-like + Path to the database file. + Use ':memory:' for a temporary in-memory database. + """ if isinstance(dbfile, str): self.dbfile = pathlib.Path(dbfile) else: self.dbfile = dbfile - self.connection = None self.cursor = None - self.open() - def __contains__(self, table: Table) -> bool: - return table.db == self - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, exc_traceback): - self.close() - - def __iter__(self): - return iter(self.tables) - def open(self): + """Open a connection to `self.dbfile` if it is not already open.""" if self.connection is None: self.connection = sqlite3.connect(self.dbfile) self.connection.row_factory = _namedtuple_factory self.cursor = self.connection.cursor() def close(self): + """Close the connection to `self.dbfile` if it is open.""" if self.connection is not None: self.connection.close() self.connection = None self.cursor = None - @property - def schema(self) -> pd.DataFrame: - """Get the database schema as a `pandas.DataFrame`.""" - return self._schema.to_df() - - @property - def _schema(self) -> Table: - return Table("sqlite_schema", self) - - @property - def table_list(self) -> pd.DataFrame: - """Get a list of tables as a `pandas.DataFrame`.""" - return ( - self._table_list - .SELECT("*") - .WHERE("schema='main'") - .to_df() - ) - - @property - def _table_list(self) -> Table: - return Table("pragma_table_list", self) - - @property - def tables(self): - """... - - Returns - ------- - - """ - return (Table(name, self) for name in self._get_table_names()) - - def _get_table_names(self) -> List[str]: - """Helper function to get table names in this database.""" - return list(self.table_list["name"]) - - def get_table(self, tbl_name) -> Table: - """... + def get_table(self, table_name: str) -> Table: + """Get a table from this database by name. Parameters ---------- - tbl_name : str - Table name + table_name : str Returns ------- Table """ - if tbl_name not in self._get_table_names(): - raise ValueError("invalid table") - return Table(tbl_name, self) - - def create_table(self, tbl_schema: str, STRICT: bool = True) -> Table: + if table_name not in self._get_table_names(): + raise ValueError(f"'{table_name}' not in database") + return Table(table_name, self) + + def create_table( + self, + schema: Schema | str, + STRICT: Optional[bool] = True, + ) -> Table: """Create a table inside this database. Parameters ---------- - tbl_schema : str + schema : Schema or str Table schema that defines the table and column names. - STRICT : bool + STRICT : bool (default True) + Set the SQLite 'STRICT' flag for CREATE TABLE command. Note ---- @@ -169,75 +98,285 @@ def create_table(self, tbl_schema: str, STRICT: bool = True) -> Table: if self.connection is None: raise ValueError("Database connection is closed.") + if isinstance(schema, str): + schema = Schema.from_str(schema) + + table_name = schema.table_name + schema_str = schema.to_sql() + if STRICT: - tbl_schema += " STRICT" + schema_str += "\nSTRICT" with self.connection: - self.CREATE_TABLE(tbl_schema).execute() - - tbl_name = tbl_schema.split()[0] + self.CREATE_TABLE(schema_str).execute() - return self.get_table(tbl_name) + return self.get_table(table_name) - def insert_row_into_table(self, table: Table | str, row: tuple) -> None: - """... + def insert_row_into_table( + self, + table: Table | str, + row: Tuple, + columns: Optional[Sequence[str]] = None, + ) -> None: + """Insert a single row into a table within this database. Parameters ---------- table : Table or str - ... row : tuple - ... + columns : sequence of strings + Column names to insert rows into. + Must be defined if not inserting an element into every column. """ + if self.connection is None: + raise ValueError("Database connection is closed.") + if isinstance(table, str): table = Table(table, self) - table.insert_row(row) - def insert_array_into_table(self, table: Table | str, array: List[tuple]) -> None: - """... + if columns: + column_names = ", ".join(columns) + tbl_with_cols = f"{table.name} ({column_names})" + with self.connection: + ( + self.INSERT_INTO(tbl_with_cols) + .VALUES(str(row)) + .execute() + ) + return + + with self.connection: + ( + self.INSERT_INTO(table.name) + .VALUES(str(row)) + .execute() + ) + + def insert_rows_into_table( + self, + table: Table | str, + rows: List[Tuple], + columns: Optional[Sequence[str]] = None, + ) -> None: + """Insert multiple rows into a table within this database. Parameters ---------- table : Table or str - ... - array : ArrayLike - ... + rows : list of tuples + Sequence of rows to be inserted. + columns : sequence of strings + Column names to insert rows into. + Must be defined if not inserting an element into every column. + + """ + if self.connection is None: + raise ValueError("Database connection is closed.") + + if isinstance(table, str): + table = Table(table, self) + + if columns: + place_holders = "(" + ", ".join((len(columns) * "?").split()) + ")" + column_names = ", ".join(columns) + tbl_with_cols = f"{table.name} ({column_names})" + with self.connection: + ( + self.INSERT_INTO(tbl_with_cols) + .VALUES(place_holders) + .executemany(rows) + ) + + return + + place_holders = "(" + ", ".join((table.n_cols * "?").split()) + ")" + with self.connection: + ( + self.INSERT_INTO(table.name) + .VALUES(place_holders) + .executemany(rows) + ) + + def add_column_to_table( + self, + table: Table | str, + column_def: str, + ) -> None: + """Add a column to a table within this database. + + Parameters + ---------- + table : Table or str + column_def : str + Column definition with fields ' '. + must be a valid SQLite type. + + """ + if self.connection is None: + raise ValueError("Database connection is closed.") + + if isinstance(table, str): + table = Table(table, self) + + with self.connection: + ( + self.ALTER_TABLE(table.name) + .ADD_COLUMN(column_def) + .execute() + ) + + def update_column_in_table( + self, + table: Table | str, + column_name: str, + values: Sequence, + ids: Optional[Sequence[int]] = None, + ) -> None: + """Update a column of a table within this database. + + Parameters + ---------- + table : Table or str + column_name : str + values : sequence of values + ids : sequence of int """ + if self.connection is None: + raise ValueError("Database connection is closed.") + if isinstance(table, str): table = Table(table, self) - table.insert_array(array) + + if ids is None: + assert len(values) == len(table) + # better to use generator or list comprehension here ?? + data = ( + (val, id) + for val, id in zip(values, table._get_rowids()) + ) + else: + assert len(values) == len(ids) + data = ( + (val, id) + for val, id in zip(values, ids) + ) + + with self.connection: + ( + self.UPDATE(table.name) + .SET(f"{column_name} = ?") + .WHERE(f"{table.pk} = ?") + .executemany(data) + ) + + def insert_column_into_table( + self, + table: Table | str, + column_def: str, + values: Sequence, + ) -> None: + """Insert a column into a table within this database. + + Parameters + ---------- + table : Table or str + column_def : str + values : sequence of values + + """ + if self.connection is None: + raise ValueError("Database connection is closed.") + + if isinstance(table, str): + table = Table(table, self) + + col_name, col_type = column_def.strip().split() + if col_name in table._get_column_names(): + raise ValueError( + f"Column '{col_name}' is already in table '{table.name}'." + ) + + self.add_column_to_table(table, column_def) + self.update_column_in_table(table, col_name, values) + + @property + def schema(self) -> pd.DataFrame: + """Get the database schema as a `pandas.DataFrame`.""" + return self._schema.to_df() + + @property + def _schema(self) -> Table: + return Table("sqlite_schema", self) + + @property + def table_list(self) -> pd.DataFrame: + """Get a list of tables as a `pandas.DataFrame`.""" + return ( + self._table_list + .SELECT("*") + .WHERE("schema='main'") + .to_df() + ) + + @property + def _table_list(self) -> Table: + return Table("pragma_table_list", self) + + @property + def tables(self): + """An iterable of `Table` in this database + + TODO: implement this as a UserDict container? + + """ + return (Table(name, self) for name in self._get_table_names()) + + def _get_table_names(self) -> List[str]: + """Helper function to get table names in this database.""" + return list(self.table_list["name"]) + + def ALTER_TABLE(self, table_name: str) -> query.Query: + """SQLite ALTER TABLE statement entry point via database. + + Parameters + ---------- + table_name : str + + Returns + ------- + query.Query + + """ + return query.Query(db=self).ALTER_TABLE(table_name) def CREATE_TABLE(self, tbl_schema: str) -> query.Query: - """SQLite CREATE_TABLE Statment entry point for `query.Query`. + """SQLite CREATE TABLE statement entry point via database. Parameters ---------- tbl_schema : str - ... Returns ------- query.Query - ... """ return query.Query(db=self).CREATE_TABLE(tbl_schema) def DELETE(self) -> query.Query: - """SQLite DELETE Statment entry point for `query.Query`. + """SQLite DELETE statement entry point via database. Returns ------- query.Query - ... """ return query.Query(db=self).DELETE() def INSERT_INTO(self, table: str) -> query.Query: - """SQLite INSERT INTO Statment entry point for `query.Query`. + """SQLite INSERT INTO statement entry point via database. Parameters ---------- @@ -253,7 +392,7 @@ def INSERT_INTO(self, table: str) -> query.Query: return query.Query(db=self).INSERT_INTO(table) def PRAGMA(self, schema_name: str) -> query.Query: - """SQLite PRAGMA Statment entry point for `query.Query`. + """SQLite PRAGMA statement entry point via database. Parameters ---------- @@ -269,7 +408,7 @@ def PRAGMA(self, schema_name: str) -> query.Query: return query.Query(db=self).PRAGMA(schema_name) def SELECT(self, *columns: str) -> query.Query: - """SQLite SELECT Statment entry point for `query.Query`. + """SQLite SELECT statement entry point via database. Parameters ---------- @@ -285,7 +424,7 @@ def SELECT(self, *columns: str) -> query.Query: return query.Query(db=self).SELECT(*columns) def SELECT_COUNT(self, column: str) -> query.Query: - """SQLite SELECT COUNT Statment entry point for `query.Query`. + """SQLite SELECT COUNT statement entry point via database. Parameters ---------- @@ -300,15 +439,192 @@ def SELECT_COUNT(self, column: str) -> query.Query: """ return query.Query(db=self).SELECT_COUNT(column) + def SELECT_DISTINCT(self, *columns: str) -> query.Query: + """SQLite SELECT DISTINCT statement entry point via database. + + Parameters + ---------- + *columns: str + ... + + Returns + ------- + query.Query + ... + + """ + return query.Query(db=self).SELECT_DISTINCT(*columns) + + def UPDATE(self, table_name: str) -> query.Query: + """SQlite UPDATE statement entry point via database. + + Parameters + ---------- + table_name : str + + Returns + ------- + query.Query + + """ + return query.Query(db=self).UPDATE(table_name) + + def __contains__(self, table: Table | str) -> bool: + if isinstance(table, str): + table = Table(table, self) + return (table.name in self._get_table_names()) and (table.db == self) + + def __enter__(self): + self.open() + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.close() + + def __iter__(self): + return iter(self.tables) + @dataclass class Table: name: str db: Database + def get_row(self, id: int) -> Row: + """Get a row from this table by integer index. + + Parameters + ---------- + id : int + + Returns + ------- + Row + + Raises + ------ + ValueError + + """ + if id not in self._get_rowids(): + raise ValueError(f"{self.pk} {id} not found in '{self.name}' table") + + return Row(id, self) + + def get_column(self, name: str) -> Column: + """Get a column from this table by name. + + Parameters + ---------- + name : str + + Returns + ------- + Column + + Raises + ------ + ValueError + + """ + if name not in self._get_column_names(): + raise ValueError(f"{name} not a column in '{self.name}' table") + + return Column(name, self) + + def insert_row( + self, + row: Tuple, + columns: Optional[Sequence[str]] = None, + ) -> None: + """Insert a single row into this table. + + Parameters + ---------- + row : tuple + columns : sequence of strings + Column names to insert rows into. + Must be defined if not inserting an element into every column. + + """ + self.db.insert_row_into_table(self, row, columns=columns) + + def insert_rows( + self, + rows: List[Tuple], + columns: Optional[Sequence[str]] = None, + ) -> None: + """Insert multiple rows into this table. + + Parameters + ---------- + rows : list of tuples + Sequence of rows to be inserted. + columns : sequence of strings + Column names to insert rows into. + Must be defined if not inserting an element into every column. + + """ + self.db.insert_rows_into_table(self, rows, columns=columns) + + def add_column( + self, + column_def: str + ) -> None: + """Add a column to this table. + + Parameters + ---------- + column_def : str + Column definition with fields ' '. + must be a valid SQLite type. + + """ + self.db.add_column_to_table(self, column_def) + + def update_column( + self, + column: str, + values: Sequence, + ids: Optional[Sequence[int]] = None, + ) -> None: + """Update a column in this table. + + Parameters + ---------- + column : str + values : + ids : + + """ + self.db.update_column_in_table(self, column, values, ids=ids) + + def insert_column( + self, + column_name: str, + values: Sequence, + ) -> None: + """Insert a column into this table. + + Parameters + ---------- + column_name : str + values : sequence of values + + """ + self.db.insert_column_into_table(self, column_name, values) + + def to_sql(self) -> str: + """Return the SQLite query that generates this table.""" + return self.SELECT("*").to_sql() + + def to_df(self) -> pd.DataFrame: + """Return a `pandas.DataFrame` of this table.""" + return pd.read_sql(self.to_sql(), self.db.connection) + @property def info(self) -> pd.DataFrame: - """Get the PRAGMA info table as a `pandas.DataFrame`.""" + """The PRAGMA info table as a `pandas.DataFrame`.""" return ( self.db .PRAGMA(f"table_info('{self.name}')") @@ -321,7 +637,7 @@ def _info(self) -> Table: @property def schema(self) -> str: - """Get the schema/structure of this table.""" + """The schema/structure of this table.""" _schema = ( self.db._schema .SELECT("sql") @@ -349,64 +665,23 @@ def n_cols(self) -> int: .ncol ) - # @property - # def rows(self) -> Iterator[NamedTuple]: - # """An iterable of Rows contained in this table.""" - # return (Row(id, self) for id in self._get_rowids()) - @property - def columns(self) -> Iterator[Column]: - """An iterable of Columns contained in this table.""" - return (Column(name, self) for name in self._get_column_names()) + def rows(self) -> Iterator[Row]: + """An iterable of Rows contained in this table. - def get_row(self, id: int) -> NamedTuple: - """ - - Parameters - ---------- - id : int - INT PRIMARY KEY rowid - - Returns - ------- - Row - - Raises - ------ - ValueError + TODO: implement this as a UserDict container? """ - if id not in self._get_rowids(): - raise ValueError(f"{self.pk} {id} not found in '{self.name}' table") + return (Row(id, self) for id in self._get_rowids()) - return ( - self.SELECT("*") - .WHERE(f"{self.pk}={id}") - .execute() - .fetchone() - ) - - def get_column(self, name: str) -> Column: - """ - - Parameters - ---------- - name : str - Name of the column - - Returns - ------- - Column + @property + def columns(self) -> Iterator[Column]: + """An iterable of Columns contained in this table. - Raises - ------ - ValueError + TODO: implement this as a UserDict container? """ - if name not in self._get_column_names(): - raise ValueError(f"{name} not a column in '{self.name}' table") - - return Column(name, self) + return (Column(name, self) for name in self._get_column_names()) def _get_rowids(self) -> List[int]: """Helper function that returns a list of integer indices @@ -423,7 +698,7 @@ def _get_column_names(self) -> List[str]: @property def pk(self) -> str: - """Get the primary key of this table.""" + """The primary key of this table.""" result = ( self._info .SELECT("name") @@ -438,70 +713,46 @@ def pk(self) -> str: return result.name - def insert_row(self, row: tuple) -> None: - """ + def DELETE(self) -> query.Query: + """SQLite DELETE statement entry point via table. - Parameters - ---------- - row : tuple + Returns + ------- + query.Query """ - if self.db.connection is None: - raise ValueError("Database connection is closed.") - - with self.db.connection: - ( - self.INSERT_INTO(self.name) - .VALUES(str(row)) - .execute() - ) + return self.db.DELETE().FROM(self.name) - def insert_array(self, array: List[tuple]) -> None: - """ + def SELECT(self, *columns: str) -> query.Query: + """SQlite SELECT statement entry point via table. Parameters ---------- - array : ArrayLike - - """ - if self.db.connection is None: - raise ValueError("Database connection is closed.") - - n_cols = self.n_cols - values = "(" + ", ".join(["?" for _ in range(n_cols)]) + ")" - with self.db.connection: - ( - self.INSERT_INTO(self.name) - .VALUES(values) - .executemany(array) - ) - - def DELETE(self) -> query.Query: - """ + *columns : str Returns ------- query.Query """ - return self.db.DELETE().FROM(self.name) + return self.db.SELECT(*columns).FROM(self.name) - def INSERT_INTO(self, table: str) -> query.Query: - """ + def SELECT_COUNT(self, column: str) -> query.Query: + """SQLite SELECT COUNT statement entry point via table. Parameters ---------- - table : str + column : str Returns ------- query.Query """ - return self.db.INSERT_INTO(self.name) + return self.db.SELECT_COUNT(column).FROM(self.name) - def SELECT(self, *columns: str) -> query.Query: - """ + def SELECT_DISTINCT(self, *columns: str) -> query.Query: + """SQlite SELECT DISTINCT statement entry point via table. Parameters ---------- @@ -512,35 +763,81 @@ def SELECT(self, *columns: str) -> query.Query: query.Query """ - return self.db.SELECT(*columns).FROM(self.name) + return self.db.SELECT_DISTINCT(*columns).FROM(self.name) - def SELECT_COUNT(self, column: str) -> query.Query: - """ + def SET(self, *column_exprs) -> query.Query: + """SQLite SET statement entry point via table. Parameters ---------- - column : str + *column_exprs : str Returns ------- query.Query """ - return self.db.SELECT_COUNT(column).FROM(self.name) + return self.db.UPDATE(self.name).SET(*column_exprs) + + def __len__(self) -> int: + return self.n_rows + + # def __iter__(self) -> Iterator[NamedTuple]: + # return iter(self.rows) + + +class Schema: + def __init__(self, table_name: str, *column_defs: str): + self.table_name = table_name + self.column_defs = column_defs + + def __repr__(self): + return self.to_sql() + + def to_sql(self): + col_defs = ",\n ".join(self.column_defs) + statement = f"{self.table_name} (\n {col_defs}\n)" + return statement + + @classmethod + def from_str(cls, string): + table_name = string[:string.find("(")].strip() + column_defs = [ + s.strip() + for s in string[string.find("(")+1:string.rfind(")")].replace("\n", "").split(",") + ] + return cls(table_name, *column_defs) + + +@dataclass +class Row: + id: int + table: Table + + @property + def db(self) -> Database: + return self.table.db + + @property + def data(self): + return self._q.execute().fetchone() def to_sql(self) -> str: - """Returns the sqlite query that generates this table.""" - return self.SELECT("*").to_sql() + return self._q.to_sql() def to_df(self) -> pd.DataFrame: - """Returns a pandas DataFrame of this table.""" return pd.read_sql(self.to_sql(), self.db.connection) - def __len__(self) -> int: - return self.n_rows + @property + def _q(self) -> query.Query: + return ( + self.table + .SELECT("*") + .WHERE(f"{self.table.pk}={self.id}") + ) - # def __iter__(self) -> Iterator[NamedTuple]: - # return iter(self.rows) + def __getattr__(self, name): + return getattr(self.data, name) @dataclass @@ -550,12 +847,10 @@ class Column: @property def db(self) -> Database: - """""" return self.table.db @property def type_(self) -> str: - """""" return ( self.table._info .SELECT("type") @@ -567,23 +862,63 @@ def type_(self) -> str: @property def data(self) -> List[Any]: - """""" - _data = self.table.SELECT(self.name).execute().fetchall() - return [data[0] for data in _data] + return [row[0] for row in self._q.execute().fetchall()] def to_sql(self) -> str: - """""" - return ( - self.table - .SELECT(self.name) - .to_sql() - ) + return self._q.to_sql() def to_df(self) -> pd.DataFrame: - """""" return pd.read_sql(self.to_sql(), self.db.connection) + @property + def _q(self) -> query.Query: + return self.table.SELECT(self.name) + + +def touch(dbfile: pathlib.Path | str) -> None: + """Create a minimal database file. + + A legal database includes a 'Simulations' and 'Observables' table + at a minimum. + + Parameters + ---------- + dbfile : path-like + Path to database file. + + Raises + ------ + ValueError + If database already exists. + + """ + if isinstance(dbfile, str): + dbfile = pathlib.Path(dbfile) + + if dbfile.exists(): + raise ValueError(f"{dbfile} already exists") + + sims_schema = """ + Simulations ( + simID INT PRIMARY KEY, + topology TEXT, + trajectory TEXT + ) + """ + obsv_schema = """ + Observables ( + obsName TEXT, + notes TEXT, + creator TEXT, + timestamp DATETIME DEFAULT (strftime('%m-%d-%Y %H:%M', 'now', 'localtime')) + ) + """ + + with Database(dbfile) as db: + db.create_table(sims_schema) + db.create_table(obsv_schema, STRICT=False) + -class Tables(UserDict): - def __init__(self, db, *args, **kwargs): - self.db = db +# class Tables(UserDict): +# def __init__(self, db, *args, **kwargs): +# self.db = db diff --git a/mdaadb/query.py b/mdaadb/query.py index 36908f2..b143a52 100644 --- a/mdaadb/query.py +++ b/mdaadb/query.py @@ -13,9 +13,14 @@ def register_command(dependency=None): - """""" + """SQLite keyword decorator for `Query` methods. + When a method is called, the keyword is bundled with input fields + into a `Command` object, then registered into `Query.registered_commands`. + + """ def decorator(func): + @wraps(func) def wrapper(self, *args): keyword = wrapper.__name__.replace("_", " ") @@ -26,7 +31,6 @@ def wrapper(self, *args): self.registered_commands.append( Command(keyword, fields, dependency) ) - return func(self, *args) return wrapper @@ -46,18 +50,20 @@ def to_sql(self): class Query: - """ + """SQLite query object that implements the "builder" design pattern + to build a full query statement from individual SQLite commands.""" - Parameters - ---------- - db : `Database` or path-like - Database to execute queries on. If path-like, will - call `Database(db)`. Use ":memory:" to execute queries - on a temporary in memory database. + def __init__(self, db: database.Database | pathlib.Path | str) -> None: + """`Query` constructor. - """ + Parameters + ---------- + db : `Database` or path-like + Database to execute queries on. If path-like, will + call `Database(db)`. Use ":memory:" to execute queries + on a temporary in memory database. - def __init__(self, db): + """ if isinstance(db, pathlib.Path) or isinstance(db, str): self.db = database.Database(db) else: @@ -138,13 +144,44 @@ def commit(self): """""" self.db.connection.commit() + @register_command(dependency="ALTER TABLE") + def ADD_COLUMN(self, column_def: str) -> Self: + """Add ADD COLUMN command to the current Query. + + Parameters + ---------- + column_def : str + + Returns + ------- + self : `Query` + ... + + """ + return self + @register_command() - def CREATE_TABLE(self, tbl_schema: str) -> Self: + def ALTER_TABLE(self, table_name: str) -> Self: + """Add ALTER TABLE command to the current Query. + + Parameters + ---------- + table_name : str + + Returns + ------- + self : `Query` + + """ + return self + + @register_command() + def CREATE_TABLE(self, table_schema: str) -> Self: """Add CREATE TABLE command to the current Query. Parameters ---------- - tbl_schema : str + table_schema : str The table schema/structure. Returns @@ -218,7 +255,6 @@ def INSERT_INTO(self, table: str) -> Self: """ return self - @register_command() def LIMIT(self, limit: str) -> Self: """Add LIMIT command to the current Query. @@ -321,6 +357,56 @@ def SELECT_COUNT(self, column: str) -> Self: """ return self + @register_command(dependency="FROM") + def SELECT_DISTINCT(self, *columns: str) -> Self: + """Add SELECT DISTINCT command to the current Query. + + Parameters + ---------- + *columns : str + ... + + Returns + ------- + self : Query + ... + + """ + return self + + @register_command(dependency="UPDATE") + def SET(self, *column_exprs) -> Self: + """Add SET command to the current Query. + + Parameters + ---------- + *column_exprs : str + ... + + Returns + ------- + self : Query + ... + + """ + return self + + @register_command() + def UPDATE(self, table_name: str) -> Self: + """Add UPDATE command to the current Query. + + Parameters + ---------- + table_name : str + ... + + Returns + ------- + self : Query + ... + + """ + return self @register_command(dependency="INSERT INTO") def VALUES(self, values: str) -> Self: diff --git a/mdaadb/tests/test_database.py b/mdaadb/tests/test_database.py index e69de29..646bca0 100644 --- a/mdaadb/tests/test_database.py +++ b/mdaadb/tests/test_database.py @@ -0,0 +1,110 @@ +import pytest +from numpy.testing import assert_equal +from mdaadb.database import Database, Table, Schema + + +class TestMemoryDatabase: + + @pytest.fixture() + def memdb(self): + db = Database(":memory:") + tbl1 = db.create_table( + """ + Table1 ( + ID INT PRIMARY KEY, + text_col TEXT, + real_col REAL + ) + """ + ) + tbl2 = db.create_table( + """ + Table2 ( + ID INT PRIMARY KEY, + text_col TEXT, + real_col REAL + ) + """ + ) + rows = [(i, f"text{i}", i*10.0) for i in range(1,11)] + tbl1.insert_rows(rows) + for row in rows: + tbl2.insert_row(row) + + return db + + def test_get_table(self, memdb): + tbl1 = memdb.get_table("Table1") + tbl2 = memdb.get_table("Table2") + assert tbl1 in memdb + assert tbl2 in memdb + + def test_nrows(self, memdb): + assert memdb.get_table("Table1").n_rows == 10 + assert memdb.get_table("Table2").n_rows == 10 + + def test_ncols(self, memdb): + assert memdb.get_table("Table1").n_cols == 3 + assert memdb.get_table("Table2").n_cols == 3 + + def test_table_names(self, memdb): + actual = memdb._get_table_names() + desired = ["Table2", "Table1", "sqlite_schema"] + assert_equal(actual, desired) + + def test_insert_rows(self, memdb): + rows = [(i, f"text{i}", float(i*10)) for i in range(1,11)] + rows_from_db = [ + tuple(row) + for row in memdb.get_table("Table1").SELECT("*").execute().fetchall() + ] + + assert_equal(rows_from_db, rows) + + def test_insert_row_vs_rows(self, memdb): + tbl1 = memdb.get_table("Table1") + tbl2 = memdb.get_table("Table2") + result1 = tbl1.SELECT("*").execute().fetchall() + result2 = tbl2.SELECT("*").execute().fetchall() + + assert_equal(result1, result2) + + def test_add_column(self, memdb): + ... + + def test_update_column(self, memdb): + ... + + def test_insert_column(self, memdb): + ... + + def test_Rows(self, memdb): + tbl1 = memdb.get_table("Table1") + tbl2 = memdb.get_table("Table2") + + for i, row in enumerate(tbl1.rows): + assert row.ID == i + 1 + assert row.text_col == f"text{i+1}" + assert row.real_col == (i+1) * 10.0 + for i, row in enumerate(tbl2.rows): + assert row.ID == i + 1 + assert row.text_col == f"text{i+1}" + assert row.real_col == (i+1) * 10.0 + + def test_Columns(self, memdb): + tbl1 = memdb.get_table("Table1") + tbl2 = memdb.get_table("Table2") + + for name, type_, col in zip( + ["ID", "text_col", "real_col"], + ["INT", "TEXT", "REAL"], + tbl1.columns, + ): + assert col.type_ == tbl1.get_column(name).type_ + for name, type_, col in zip( + ["ID", "text_col", "real_col"], + ["INT", "TEXT", "REAL"], + tbl2.columns, + ): + assert col.type_ == tbl1.get_column(name).type_ +