Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimized ToOne sqlalchemy aware field. #167

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion flask_potion/contrib/alchemy/fields.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from flask_potion.fields import Object
from werkzeug.utils import cached_property

from flask_potion.fields import Object, ToOne as GenericToOne
from flask_potion.utils import get_value, route_from


class InlineModel(Object):
Expand All @@ -15,3 +18,27 @@ def converter(self, instance):
if instance is not None:
instance = self.model(**instance)
return instance


class ToOne(GenericToOne):
"""
Same as flask_potion.fields.ToOne
except it will use the local id stored on the ForeignKey field to serialize the field.
This is an optimisation to avoid additional lookups to the database,
in order to prevent fetching the remote object, just to obtain its id,
that we already have.
Limitations:
- It works only if the foreign key is made of a single field.
- It works only if the serialization is using the ForeignKey as source of information to Identify the remote resource.
- `attribute` parameter is ignored.
"""
def output(self, key, obj):
column = getattr(obj.__class__, key)
local_columns = column.property.local_columns
assert len(local_columns) == 1
local_column = list(local_columns)[0]
key = local_column.key
return self.format(get_value(key, obj, self.default))

def formatter(self, item):
return self.formatter_key.format(item, is_local=True)
14 changes: 10 additions & 4 deletions flask_potion/natural_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@


class Key(Schema, ResourceBound):
is_local = False

def matcher_type(self):
type_ = self.response['type']
Expand Down Expand Up @@ -43,11 +44,16 @@ def schema(self):
"additionalProperties": False
}

def _id_uri(self, resource, id_):
return '{}/{}'.format(resource.route_prefix, id_)

def _item_uri(self, resource, item):
# return url_for('{}.instance'.format(self.resource.meta.id_attribute, item, None), get_value(self.resource.meta.id_attribute, item, None))
return '{}/{}'.format(resource.route_prefix, get_value(resource.manager.id_attribute, item, None))

def format(self, item):
def format(self, item, is_local=False):
if is_local:
return {'$ref': self._id_uri(self.resource, item)}
return {"$ref": self._item_uri(self.resource, item)}

def convert(self, value):
Expand All @@ -71,7 +77,7 @@ def rebind(self, resource):
def schema(self):
return self.resource.schema.fields[self.property].request

def format(self, item):
def format(self, item, is_local=False):
return self.resource.schema.fields[self.property].output(self.property, item)

@cached_property
Expand Down Expand Up @@ -101,7 +107,7 @@ def schema(self):
"additionalItems": False
}

def format(self, item):
def format(self, item, is_local=False):
return [self.resource.schema.fields[p].output(p, item) for p in self.properties]

@cached_property
Expand All @@ -123,7 +129,7 @@ def _on_bind(self, resource):
def schema(self):
return self.id_field.request

def format(self, item):
def format(self, item, is_local=False):
return self.id_field.output(self.resource.manager.id_attribute, item)

def convert(self, value):
Expand Down
53 changes: 52 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from pprint import pformat

from flask import json, Flask
from flask.testing import FlaskClient
from flask_testing import TestCase
import sqlalchemy


class ApiClient(FlaskClient):
Expand Down Expand Up @@ -49,4 +52,52 @@ def create_app(self):
return app

def pp(self, obj):
print(json.dumps(obj, sort_keys=True, indent=4, separators=(',', ': ')))
print(json.dumps(obj, sort_keys=True, indent=4, separators=(',', ': ')))


class DBQueryCounter:
"""
Use as a context manager to count the number of execute()'s performed
against the given sqlalchemy connection.

Usage:
with DBQueryCounter(db.session) as ctr:
db.session.execute("SELECT 1")
db.session.execute("SELECT 1")
ctr.assert_count(2)
"""

def __init__(self, session, reset=True):
self.session = session
self.reset = reset
self.statements = []

def __enter__(self):
if self.reset:
self.session.expire_all()
sqlalchemy.event.listen(
self.session.get_bind(), 'after_execute', self._callback
)
return self

def __exit__(self, *_):
sqlalchemy.event.remove(
self.session.get_bind(), 'after_execute', self._callback
)

def get_count(self):
return len(self.statements)

def _callback(self, conn, clause_element, multiparams, params, result):
self.statements.append((clause_element, multiparams, params))

def display_all(self):
for clause, multiparams, params in self.statements:
print(pformat(str(clause)), multiparams, params)
print('\n')
count = self.get_count()
return 'Counted: {count}'.format(count=count)

def assert_count(self, expected):
count = self.get_count()
assert count == expected, self.display_all()
62 changes: 62 additions & 0 deletions tests/contrib/alchemy/test_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from flask_sqlalchemy import SQLAlchemy

from flask_potion import Api, fields
from flask_potion.resource import ModelResource
from flask_potion.contrib.alchemy.fields import ToOne as SAToOne
from tests import BaseTestCase, DBQueryCounter


class SQLAlchemyToOneRemainNoPrefetchTestCase(BaseTestCase):
"""
"""

def setUp(self):
super(SQLAlchemyToOneRemainNoPrefetchTestCase, self).setUp()
self.app.config['SQLALCHEMY_ENGINE'] = 'sqlite://'
self.api = Api(self.app)
self.sa = sa = SQLAlchemy(
self.app, session_options={"autoflush": False})

class Type(sa.Model):
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(60), nullable=False)

class Machine(sa.Model):
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(60), nullable=False)

type_id = sa.Column(sa.Integer, sa.ForeignKey(Type.id))
type = sa.relationship(Type, foreign_keys=[type_id])

sa.create_all()

class MachineResource(ModelResource):
class Meta:
model = Machine

class Schema:
type = SAToOne('type')

class TypeResource(ModelResource):
class Meta:
model = Type

self.MachineResource = MachineResource
self.TypeResource = TypeResource

self.api.add_resource(MachineResource)
self.api.add_resource(TypeResource)

def test_relation_serialization_does_not_load_remote_object(self):
response = self.client.post('/type', data={"name": "aaa"})
aaa_uri = response.json["$uri"]
self.client.post(
'/machine', data={"name": "foo", "type": {"$ref": aaa_uri}})
with DBQueryCounter(self.sa.session) as counter:
response = self.client.get('/machine')
self.assert200(response)
self.assertJSONEqual(
[{'$uri': '/machine/1', 'type': {'$ref': aaa_uri}, 'name': 'foo'}],
response.json)
counter.assert_count(1)