diff --git a/flask_potion/contrib/alchemy/fields.py b/flask_potion/contrib/alchemy/fields.py index 1e02955..d7de7db 100644 --- a/flask_potion/contrib/alchemy/fields.py +++ b/flask_potion/contrib/alchemy/fields.py @@ -1,4 +1,7 @@ -from flask_potion.fields import Object +from werkzeug.utils import cached_property + +from flask_potion.fields import Object, ToOne as GenericToOne +from flask_potion.utils import get_value, route_from class InlineModel(Object): @@ -15,3 +18,27 @@ def converter(self, instance): if instance is not None: instance = self.model(**instance) return instance + + +class ToOne(GenericToOne): + """ + Same as flask_potion.fields.ToOne + except it will use the local id stored on the ForeignKey field to serialize the field. + This is an optimisation to avoid additional lookups to the database, + in order to prevent fetching the remote object, just to obtain its id, + that we already have. + Limitations: + - It works only if the foreign key is made of a single field. + - It works only if the serialization is using the ForeignKey as source of information to Identify the remote resource. + - `attribute` parameter is ignored. + """ + def output(self, key, obj): + column = getattr(obj.__class__, key) + local_columns = column.property.local_columns + assert len(local_columns) == 1 + local_column = list(local_columns)[0] + key = local_column.key + return self.format(get_value(key, obj, self.default)) + + def formatter(self, item): + return self.formatter_key.format(item, is_local=True) diff --git a/flask_potion/natural_keys.py b/flask_potion/natural_keys.py index 1556b03..d7e26cc 100644 --- a/flask_potion/natural_keys.py +++ b/flask_potion/natural_keys.py @@ -11,6 +11,7 @@ class Key(Schema, ResourceBound): + is_local = False def matcher_type(self): type_ = self.response['type'] @@ -43,11 +44,16 @@ def schema(self): "additionalProperties": False } + def _id_uri(self, resource, id_): + return '{}/{}'.format(resource.route_prefix, id_) + def _item_uri(self, resource, item): # return url_for('{}.instance'.format(self.resource.meta.id_attribute, item, None), get_value(self.resource.meta.id_attribute, item, None)) return '{}/{}'.format(resource.route_prefix, get_value(resource.manager.id_attribute, item, None)) - def format(self, item): + def format(self, item, is_local=False): + if is_local: + return {'$ref': self._id_uri(self.resource, item)} return {"$ref": self._item_uri(self.resource, item)} def convert(self, value): @@ -71,7 +77,7 @@ def rebind(self, resource): def schema(self): return self.resource.schema.fields[self.property].request - def format(self, item): + def format(self, item, is_local=False): return self.resource.schema.fields[self.property].output(self.property, item) @cached_property @@ -101,7 +107,7 @@ def schema(self): "additionalItems": False } - def format(self, item): + def format(self, item, is_local=False): return [self.resource.schema.fields[p].output(p, item) for p in self.properties] @cached_property @@ -123,7 +129,7 @@ def _on_bind(self, resource): def schema(self): return self.id_field.request - def format(self, item): + def format(self, item, is_local=False): return self.id_field.output(self.resource.manager.id_attribute, item) def convert(self, value): diff --git a/tests/__init__.py b/tests/__init__.py index fa27a4d..d08873f 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,6 +1,9 @@ +from pprint import pformat + from flask import json, Flask from flask.testing import FlaskClient from flask_testing import TestCase +import sqlalchemy class ApiClient(FlaskClient): @@ -49,4 +52,52 @@ def create_app(self): return app def pp(self, obj): - print(json.dumps(obj, sort_keys=True, indent=4, separators=(',', ': '))) \ No newline at end of file + print(json.dumps(obj, sort_keys=True, indent=4, separators=(',', ': '))) + + +class DBQueryCounter: + """ + Use as a context manager to count the number of execute()'s performed + against the given sqlalchemy connection. + + Usage: + with DBQueryCounter(db.session) as ctr: + db.session.execute("SELECT 1") + db.session.execute("SELECT 1") + ctr.assert_count(2) + """ + + def __init__(self, session, reset=True): + self.session = session + self.reset = reset + self.statements = [] + + def __enter__(self): + if self.reset: + self.session.expire_all() + sqlalchemy.event.listen( + self.session.get_bind(), 'after_execute', self._callback + ) + return self + + def __exit__(self, *_): + sqlalchemy.event.remove( + self.session.get_bind(), 'after_execute', self._callback + ) + + def get_count(self): + return len(self.statements) + + def _callback(self, conn, clause_element, multiparams, params, result): + self.statements.append((clause_element, multiparams, params)) + + def display_all(self): + for clause, multiparams, params in self.statements: + print(pformat(str(clause)), multiparams, params) + print('\n') + count = self.get_count() + return 'Counted: {count}'.format(count=count) + + def assert_count(self, expected): + count = self.get_count() + assert count == expected, self.display_all() diff --git a/tests/contrib/alchemy/test_fields.py b/tests/contrib/alchemy/test_fields.py new file mode 100644 index 0000000..af54b4d --- /dev/null +++ b/tests/contrib/alchemy/test_fields.py @@ -0,0 +1,62 @@ +from flask_sqlalchemy import SQLAlchemy + +from flask_potion import Api, fields +from flask_potion.resource import ModelResource +from flask_potion.contrib.alchemy.fields import ToOne as SAToOne +from tests import BaseTestCase, DBQueryCounter + + +class SQLAlchemyToOneRemainNoPrefetchTestCase(BaseTestCase): + """ + """ + + def setUp(self): + super(SQLAlchemyToOneRemainNoPrefetchTestCase, self).setUp() + self.app.config['SQLALCHEMY_ENGINE'] = 'sqlite://' + self.api = Api(self.app) + self.sa = sa = SQLAlchemy( + self.app, session_options={"autoflush": False}) + + class Type(sa.Model): + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String(60), nullable=False) + + class Machine(sa.Model): + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String(60), nullable=False) + + type_id = sa.Column(sa.Integer, sa.ForeignKey(Type.id)) + type = sa.relationship(Type, foreign_keys=[type_id]) + + sa.create_all() + + class MachineResource(ModelResource): + class Meta: + model = Machine + + class Schema: + type = SAToOne('type') + + class TypeResource(ModelResource): + class Meta: + model = Type + + self.MachineResource = MachineResource + self.TypeResource = TypeResource + + self.api.add_resource(MachineResource) + self.api.add_resource(TypeResource) + + def test_relation_serialization_does_not_load_remote_object(self): + response = self.client.post('/type', data={"name": "aaa"}) + aaa_uri = response.json["$uri"] + self.client.post( + '/machine', data={"name": "foo", "type": {"$ref": aaa_uri}}) + with DBQueryCounter(self.sa.session) as counter: + response = self.client.get('/machine') + self.assert200(response) + self.assertJSONEqual( + [{'$uri': '/machine/1', 'type': {'$ref': aaa_uri}, 'name': 'foo'}], + response.json) + counter.assert_count(1) +