Skip to content

Commit

Permalink
ability to execute precompiled sqlalchemy queries (#294)
Browse files Browse the repository at this point in the history
* ability to execute precompiled sqlalchemy queries

* global cache for compiled queries

* update formatting

* use only query as a key in compiled cache
  • Loading branch information
vlanse authored and jettify committed Jun 3, 2018
1 parent a6f3ee9 commit aec31dd
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 10 deletions.
6 changes: 6 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
Changes
-------

0.0.16 (2018-05-21)
^^^^^^^^^^^^^^^^^^^

* Added ability to execute precompiled sqlalchemy queries


0.0.15 (2018-05-20)
^^^^^^^^^^^^^^^^^^^

Expand Down
2 changes: 1 addition & 1 deletion aiomysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from .cursors import Cursor, SSCursor, DictCursor, SSDictCursor
from .pool import create_pool, Pool

__version__ = '0.0.15'
__version__ = '0.0.16'

__all__ = [

Expand Down
17 changes: 14 additions & 3 deletions aiomysql/sa/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@

class SAConnection:

def __init__(self, connection, engine):
def __init__(self, connection, engine, compiled_cache=None):
self._connection = connection
self._transaction = None
self._savepoint_seq = 0
self._weak_results = weakref.WeakSet()
self._engine = engine
self._dialect = engine.dialect
self._compiled_cache = compiled_cache

def execute(self, query, *multiparams, **params):
"""Executes a SQL query with optional parameters.
Expand Down Expand Up @@ -76,8 +77,18 @@ async def _execute(self, query, *multiparams, **params):
if isinstance(query, str):
await cursor.execute(query, dp or None)
elif isinstance(query, ClauseElement):
compiled = query.compile(dialect=self._dialect)
# parameters = compiled.params
if self._compiled_cache is not None:
key = query
compiled = self._compiled_cache.get(key)
if not compiled:
compiled = query.compile(dialect=self._dialect)
if dp and dp.keys() == compiled.params.keys() \
or not (dp or compiled.params):
# we only want queries with bound params in cache
self._compiled_cache[key] = compiled
else:
compiled = query.compile(dialect=self._dialect)

if not isinstance(query, DDLElement):
if dp and isinstance(dp, (list, tuple)):
if isinstance(query, UpdateBase):
Expand Down
16 changes: 10 additions & 6 deletions aiomysql/sa/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@


def create_engine(minsize=1, maxsize=10, loop=None,
dialect=_dialect, pool_recycle=-1, **kwargs):
dialect=_dialect, pool_recycle=-1, compiled_cache=None,
**kwargs):
"""A coroutine for Engine creation.
Returns Engine instance with embedded connection pool.
The pool has *minsize* opened connections to PostgreSQL server.
"""
coro = _create_engine(minsize=minsize, maxsize=maxsize, loop=loop,
dialect=dialect, pool_recycle=pool_recycle, **kwargs)
dialect=dialect, pool_recycle=pool_recycle,
compiled_cache=compiled_cache, **kwargs)
compatible_cursor_classes = [Cursor]
# Without provided kwarg, default is default cursor from Connection class
if kwargs.get('cursorclass', Cursor) not in compatible_cursor_classes:
Expand All @@ -38,7 +40,8 @@ def create_engine(minsize=1, maxsize=10, loop=None,


async def _create_engine(minsize=1, maxsize=10, loop=None,
dialect=_dialect, pool_recycle=-1, **kwargs):
dialect=_dialect, pool_recycle=-1,
compiled_cache=None, **kwargs):

if loop is None:
loop = asyncio.get_event_loop()
Expand All @@ -47,7 +50,7 @@ async def _create_engine(minsize=1, maxsize=10, loop=None,
pool_recycle=pool_recycle, **kwargs)
conn = await pool.acquire()
try:
return Engine(dialect, pool, **kwargs)
return Engine(dialect, pool, compiled_cache=compiled_cache, **kwargs)
finally:
pool.release(conn)

Expand All @@ -61,9 +64,10 @@ class Engine:
create_engine coroutine.
"""

def __init__(self, dialect, pool, **kwargs):
def __init__(self, dialect, pool, compiled_cache=None, **kwargs):
self._dialect = dialect
self._pool = pool
self._compiled_cache = compiled_cache
self._conn_kw = kwargs

@property
Expand Down Expand Up @@ -124,7 +128,7 @@ def acquire(self):

async def _acquire(self):
raw = await self._pool.acquire()
conn = SAConnection(raw, self)
conn = SAConnection(raw, self, compiled_cache=self._compiled_cache)
return conn

def release(self, conn):
Expand Down
138 changes: 138 additions & 0 deletions tests/sa/test_sa_compiled_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import asyncio
from aiomysql import sa
from sqlalchemy import bindparam

import os
import unittest

from sqlalchemy import MetaData, Table, Column, Integer, String

meta = MetaData()
tbl = Table('sa_tbl_cache_test', meta,
Column('id', Integer, nullable=False,
primary_key=True),
Column('val', String(255)))


class TestCompiledCache(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)
self.host = os.environ.get('MYSQL_HOST', 'localhost')
self.port = int(os.environ.get('MYSQL_PORT', 3306))
self.user = os.environ.get('MYSQL_USER', 'root')
self.db = os.environ.get('MYSQL_DB', 'test_pymysql')
self.password = os.environ.get('MYSQL_PASSWORD', '')
self.engine = self.loop.run_until_complete(self.make_engine())
self.loop.run_until_complete(self.start())

def tearDown(self):
self.engine.terminate()
self.loop.run_until_complete(self.engine.wait_closed())
self.loop.close()

async def make_engine(self, **kwargs):
return (await sa.create_engine(db=self.db,
user=self.user,
password=self.password,
host=self.host,
port=self.port,
loop=self.loop,
minsize=10,
**kwargs))

async def start(self):
async with self.engine.acquire() as conn:
tx = await conn.begin()
await conn.execute("DROP TABLE IF EXISTS "
"sa_tbl_cache_test")
await conn.execute("CREATE TABLE sa_tbl_cache_test"
"(id serial, val varchar(255))")
await conn.execute(tbl.insert().values(val='some_val_1'))
await conn.execute(tbl.insert().values(val='some_val_2'))
await conn.execute(tbl.insert().values(val='some_val_3'))
await tx.commit()

def test_cache(self):
async def go():
cache = dict()
engine = await self.make_engine(compiled_cache=cache)
async with engine.acquire() as conn:
# check select with params not added to cache
q = tbl.select().where(tbl.c.val == 'some_val_1')
cursor = await conn.execute(q)
row = await cursor.fetchone()
self.assertEqual('some_val_1', row.val)
self.assertEqual(0, len(cache))

# check select with bound params added to cache
select_by_val = tbl.select().where(
tbl.c.val == bindparam('value')
)
cursor = await conn.execute(
select_by_val, {'value': 'some_val_3'}
)
row = await cursor.fetchone()
self.assertEqual('some_val_3', row.val)
self.assertEqual(1, len(cache))

cursor = await conn.execute(
select_by_val, value='some_val_2'
)
row = await cursor.fetchone()
self.assertEqual('some_val_2', row.val)
self.assertEqual(1, len(cache))

select_all = tbl.select()
cursor = await conn.execute(select_all)
rows = await cursor.fetchall()
self.assertEqual(3, len(rows))
self.assertEqual(2, len(cache))

# check insert with bound params not added to cache
await conn.execute(tbl.insert().values(val='some_val_4'))
self.assertEqual(2, len(cache))

# check insert with bound params added to cache
q = tbl.insert().values(val=bindparam('value'))
await conn.execute(q, value='some_val_5')
self.assertEqual(3, len(cache))

await conn.execute(q, value='some_val_6')
self.assertEqual(3, len(cache))

await conn.execute(q, {'value': 'some_val_7'})
self.assertEqual(3, len(cache))

cursor = await conn.execute(select_all)
rows = await cursor.fetchall()
self.assertEqual(7, len(rows))
self.assertEqual(3, len(cache))

# check update with params not added to cache
q = tbl.update().where(
tbl.c.val == 'some_val_1'
).values(val='updated_val_1')
await conn.execute(q)
self.assertEqual(3, len(cache))
cursor = await conn.execute(
select_by_val, value='updated_val_1'
)
row = await cursor.fetchone()
self.assertEqual('updated_val_1', row.val)

# check update with bound params added to cache
q = tbl.update().where(
tbl.c.val == bindparam('value')
).values(val=bindparam('update'))
await conn.execute(
q, value='some_val_2', update='updated_val_2'
)
self.assertEqual(4, len(cache))
cursor = await conn.execute(
select_by_val, value='updated_val_2'
)
row = await cursor.fetchone()
self.assertEqual('updated_val_2', row.val)

self.loop.run_until_complete(go())

0 comments on commit aec31dd

Please sign in to comment.