-
-
Notifications
You must be signed in to change notification settings - Fork 297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sqla2 integration #420
base: master
Are you sure you want to change the base?
Sqla2 integration #420
Changes from all commits
4491da4
73f59e1
fc11118
7fdf348
3f6550b
2cd099f
c784d3b
d9ffb13
aa22f2b
c25d39b
33a2777
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,12 @@ | ||
version: 2 | ||
updates: | ||
- package-ecosystem: pip | ||
directory: "/" | ||
schedule: | ||
interval: daily | ||
time: "04:00" | ||
open-pull-requests-limit: 100 | ||
- package-ecosystem: pip | ||
directory: "/" | ||
schedule: | ||
interval: weekly | ||
open-pull-requests-limit: 100 | ||
- package-ecosystem: "github-actions" | ||
open-pull-requests-limit: 99 | ||
directory: "/" | ||
schedule: | ||
interval: weekly |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ all: clean test dists | |
|
||
.PHONY: test | ||
test: | ||
pytest | ||
pytest -v | ||
|
||
dists: | ||
python setup.py sdist | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,7 +44,6 @@ def __init__( | |
|
||
self.lock = threading.RLock() | ||
self.local = threading.local() | ||
self.connections = {} | ||
|
||
if len(parsed_url.query): | ||
query = parse_qs(parsed_url.query) | ||
|
@@ -54,9 +53,10 @@ def __init__( | |
schema = schema_qs.pop() | ||
|
||
self.schema = schema | ||
self.engine = create_engine(url, **engine_kwargs) | ||
self.is_postgres = self.engine.dialect.name == "postgresql" | ||
self.is_sqlite = self.engine.dialect.name == "sqlite" | ||
self._engine = create_engine(url, **engine_kwargs) | ||
self.is_postgres = self._engine.dialect.name == "postgresql" | ||
self.is_sqlite = self._engine.dialect.name == "sqlite" | ||
self.is_mysql = "mysql" in self._engine.dialect.dbapi.__name__ | ||
if on_connect_statements is None: | ||
on_connect_statements = [] | ||
|
||
|
@@ -72,7 +72,7 @@ def _run_on_connect(dbapi_con, con_record): | |
on_connect_statements.append("PRAGMA journal_mode=WAL") | ||
|
||
if len(on_connect_statements): | ||
event.listen(self.engine, "connect", _run_on_connect) | ||
event.listen(self._engine, "connect", _run_on_connect) | ||
|
||
self.types = Types(is_postgres=self.is_postgres) | ||
self.url = url | ||
|
@@ -81,39 +81,38 @@ def _run_on_connect(dbapi_con, con_record): | |
self._tables = {} | ||
|
||
@property | ||
def executable(self): | ||
def conn(self): | ||
"""Connection against which statements will be executed.""" | ||
with self.lock: | ||
tid = threading.get_ident() | ||
if tid not in self.connections: | ||
self.connections[tid] = self.engine.connect() | ||
return self.connections[tid] | ||
try: | ||
return self.local.conn | ||
except AttributeError: | ||
self.local.conn = self._engine.connect() | ||
self.local.conn.begin() | ||
return self.local.conn | ||
|
||
@property | ||
def op(self): | ||
"""Get an alembic operations context.""" | ||
ctx = MigrationContext.configure(self.executable) | ||
ctx = MigrationContext.configure(self.conn) | ||
return Operations(ctx) | ||
|
||
@property | ||
def inspect(self): | ||
"""Get a SQLAlchemy inspector.""" | ||
return inspect(self.executable) | ||
return inspect(self.conn) | ||
|
||
def has_table(self, name): | ||
return self.inspect.has_table(name, schema=self.schema) | ||
|
||
@property | ||
def metadata(self): | ||
"""Return a SQLAlchemy schema cache object.""" | ||
return MetaData(schema=self.schema, bind=self.executable) | ||
return MetaData(schema=self.schema) | ||
|
||
@property | ||
def in_transaction(self): | ||
"""Check if this database is in a transactional context.""" | ||
if not hasattr(self.local, "tx"): | ||
return False | ||
return len(self.local.tx) > 0 | ||
return self.conn.in_transaction() | ||
|
||
def _flush_tables(self): | ||
"""Clear the table metadata after transaction rollbacks.""" | ||
|
@@ -125,32 +124,22 @@ def begin(self): | |
|
||
No data will be written until the transaction has been committed. | ||
""" | ||
if not hasattr(self.local, "tx"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wouldn't this remove support for nested transactions? as in https://dataset.readthedocs.io/en/latest/quickstart.html#using-transactions |
||
self.local.tx = [] | ||
self.local.tx.append(self.executable.begin()) | ||
if not self.in_transaction: | ||
self.conn.begin() | ||
|
||
def commit(self): | ||
"""Commit the current transaction. | ||
|
||
Make all statements executed since the transaction was begun permanent. | ||
""" | ||
if hasattr(self.local, "tx") and self.local.tx: | ||
tx = self.local.tx.pop() | ||
tx.commit() | ||
# Removed in 2020-12, I'm a bit worried this means that some DDL | ||
# operations in transactions won't cause metadata to refresh any | ||
# more: | ||
# self._flush_tables() | ||
self.conn.commit() | ||
|
||
def rollback(self): | ||
"""Roll back the current transaction. | ||
|
||
Discard all statements executed since the transaction was begun. | ||
""" | ||
if hasattr(self.local, "tx") and self.local.tx: | ||
tx = self.local.tx.pop() | ||
tx.rollback() | ||
self._flush_tables() | ||
self.conn.rollback() | ||
|
||
def __enter__(self): | ||
"""Start a transaction.""" | ||
|
@@ -170,13 +159,10 @@ def __exit__(self, error_type, error_value, traceback): | |
|
||
def close(self): | ||
"""Close database connections. Makes this object unusable.""" | ||
with self.lock: | ||
for conn in self.connections.values(): | ||
conn.close() | ||
self.connections.clear() | ||
self.engine.dispose() | ||
self.local = threading.local() | ||
self._engine.dispose() | ||
self._tables = {} | ||
self.engine = None | ||
self._engine = None | ||
|
||
@property | ||
def tables(self): | ||
|
@@ -322,7 +308,7 @@ def query(self, query, *args, **kwargs): | |
_step = kwargs.pop("_step", QUERY_STEP) | ||
if _step is False or _step == 0: | ||
_step = None | ||
rp = self.executable.execute(query, *args, **kwargs) | ||
rp = self.conn.execute(query, *args, **kwargs) | ||
return ResultIter(rp, row_type=self.row_type, step=_step) | ||
|
||
def __repr__(self): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -116,7 +116,9 @@ def insert(self, row, ensure=None, types=None): | |
Returns the inserted row's primary key. | ||
""" | ||
row = self._sync_columns(row, ensure, types=types) | ||
res = self.db.executable.execute(self.table.insert(row)) | ||
res = self.db.conn.execute(self.table.insert(), row) | ||
# if not self.db.in_transaction: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't |
||
# self.db.conn.commit() | ||
if len(res.inserted_primary_key) > 0: | ||
return res.inserted_primary_key[0] | ||
return True | ||
|
@@ -181,7 +183,7 @@ def insert_many(self, rows, chunk_size=1000, ensure=None, types=None): | |
# Insert when chunk_size is fulfilled or this is the last row | ||
if len(chunk) == chunk_size or index == len(rows) - 1: | ||
chunk = pad_chunk_columns(chunk, columns) | ||
self.table.insert().execute(chunk) | ||
self.db.conn.execute(self.table.insert(), chunk) | ||
chunk = [] | ||
|
||
def update(self, row, keys, ensure=None, types=None, return_count=False): | ||
|
@@ -206,8 +208,8 @@ def update(self, row, keys, ensure=None, types=None, return_count=False): | |
clause = self._args_to_clause(args) | ||
if not len(row): | ||
return self.count(clause) | ||
stmt = self.table.update(whereclause=clause, values=row) | ||
rp = self.db.executable.execute(stmt) | ||
stmt = self.table.update().where(clause).values(row) | ||
rp = self.db.conn.execute(stmt) | ||
if rp.supports_sane_rowcount(): | ||
return rp.rowcount | ||
if return_count: | ||
|
@@ -241,11 +243,12 @@ def update_many(self, rows, keys, chunk_size=1000, ensure=None, types=None): | |
# Update when chunk_size is fulfilled or this is the last row | ||
if len(chunk) == chunk_size or index == len(rows) - 1: | ||
cl = [self.table.c[k] == bindparam("_%s" % k) for k in keys] | ||
stmt = self.table.update( | ||
whereclause=and_(True, *cl), | ||
values={col: bindparam(col, required=False) for col in columns}, | ||
stmt = ( | ||
self.table.update() | ||
.where(and_(True, *cl)) | ||
.values({col: bindparam(col, required=False) for col in columns}) | ||
) | ||
self.db.executable.execute(stmt, chunk) | ||
self.db.conn.execute(stmt, chunk) | ||
chunk = [] | ||
|
||
def upsert(self, row, keys, ensure=None, types=None): | ||
|
@@ -293,8 +296,8 @@ def delete(self, *clauses, **filters): | |
if not self.exists: | ||
return False | ||
clause = self._args_to_clause(filters, clauses=clauses) | ||
stmt = self.table.delete(whereclause=clause) | ||
rp = self.db.executable.execute(stmt) | ||
stmt = self.table.delete().where(clause) | ||
rp = self.db.conn.execute(stmt) | ||
return rp.rowcount > 0 | ||
|
||
def _reflect_table(self): | ||
|
@@ -303,7 +306,10 @@ def _reflect_table(self): | |
self._columns = None | ||
try: | ||
self._table = SQLATable( | ||
self.name, self.db.metadata, schema=self.db.schema, autoload=True | ||
self.name, | ||
self.db.metadata, | ||
schema=self.db.schema, | ||
autoload_with=self.db.conn, | ||
) | ||
except NoSuchTableError: | ||
self._table = None | ||
|
@@ -345,15 +351,15 @@ def _sync_table(self, columns): | |
for column in columns: | ||
if not column.name == self._primary_id: | ||
self._table.append_column(column) | ||
self._table.create(self.db.executable, checkfirst=True) | ||
self._table.create(self.db.conn, checkfirst=True) | ||
self._columns = None | ||
elif len(columns): | ||
with self.db.lock: | ||
self._reflect_table() | ||
self._threading_warn() | ||
for column in columns: | ||
if not self.has_column(column.name): | ||
self.db.op.add_column(self.name, column, self.db.schema) | ||
self.db.op.add_column(self.name, column, schema=self.db.schema) | ||
self._reflect_table() | ||
|
||
def _sync_columns(self, row, ensure, types=None): | ||
|
@@ -378,6 +384,7 @@ def _sync_columns(self, row, ensure, types=None): | |
_type = self.db.types.guess(value) | ||
sync_columns[name] = Column(name, _type) | ||
out[name] = value | ||
|
||
self._sync_table(sync_columns.values()) | ||
return out | ||
|
||
|
@@ -500,7 +507,7 @@ def drop_column(self, name): | |
table.drop_column('created_at') | ||
|
||
""" | ||
if self.db.engine.dialect.name == "sqlite": | ||
if self.db.is_sqlite: | ||
raise RuntimeError("SQLite does not support dropping columns.") | ||
name = self._get_column_name(name) | ||
with self.db.lock: | ||
|
@@ -509,7 +516,7 @@ def drop_column(self, name): | |
return | ||
|
||
self._threading_warn() | ||
self.db.op.drop_column(self.table.name, name, self.table.schema) | ||
self.db.op.drop_column(self.table.name, name, schema=self.table.schema) | ||
self._reflect_table() | ||
|
||
def drop(self): | ||
|
@@ -520,7 +527,7 @@ def drop(self): | |
with self.db.lock: | ||
if self.exists: | ||
self._threading_warn() | ||
self.table.drop(self.db.executable, checkfirst=True) | ||
self.table.drop(self.db.conn, checkfirst=True) | ||
self._table = None | ||
self._columns = None | ||
self.db._tables.pop(self.name, None) | ||
|
@@ -581,7 +588,7 @@ def create_index(self, columns, name=None, **kw): | |
kw["mysql_length"] = mysql_length | ||
|
||
idx = Index(name, *columns, **kw) | ||
idx.create(self.db.executable) | ||
idx.create(self.db.conn) | ||
|
||
def find(self, *_clauses, **kwargs): | ||
"""Perform a simple search on the table. | ||
|
@@ -625,14 +632,13 @@ def find(self, *_clauses, **kwargs): | |
|
||
order_by = self._args_to_order_by(order_by) | ||
args = self._args_to_clause(kwargs, clauses=_clauses) | ||
query = self.table.select(whereclause=args, limit=_limit, offset=_offset) | ||
query = self.table.select().where(args).limit(_limit).offset(_offset) | ||
if len(order_by): | ||
query = query.order_by(*order_by) | ||
|
||
conn = self.db.executable | ||
conn = self.db.conn | ||
if _streamed: | ||
conn = self.db.engine.connect() | ||
conn = conn.execution_options(stream_results=True) | ||
conn = self.db._engine.connect().execution_options(stream_results=True) | ||
|
||
return ResultIter(conn.execute(query), row_type=self.db.row_type, step=_step) | ||
|
||
|
@@ -666,9 +672,9 @@ def count(self, *_clauses, **kwargs): | |
return 0 | ||
|
||
args = self._args_to_clause(kwargs, clauses=_clauses) | ||
query = select([func.count()], whereclause=args) | ||
query = select(func.count()).where(args) | ||
query = query.select_from(self.table) | ||
rp = self.db.executable.execute(query) | ||
rp = self.db.conn.execute(query) | ||
return rp.fetchone()[0] | ||
|
||
def __len__(self): | ||
|
@@ -703,11 +709,11 @@ def distinct(self, *args, **_filter): | |
if not len(columns): | ||
return iter([]) | ||
|
||
q = expression.select( | ||
columns, | ||
distinct=True, | ||
whereclause=clause, | ||
order_by=[c.asc() for c in columns], | ||
q = ( | ||
expression.select(*columns) | ||
.where(clause) | ||
.group_by(*columns) | ||
.order_by(*(c.asc() for c in columns)) | ||
) | ||
return self.db.query(q) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems calling
conn.begin()
here means that there is only ever one transaction opened, even in presence of multipletable.insert()
calls.IMHO it would be better to have an explicit semantic where
db.conn
just returns the connection and transaction context is handled by the respective operations (insert, etc.). There should be a distinction between the autocommit case (user did not request a transaction) and the explicit commit case (user started the transaction).