Skip to content

Commit

Permalink
Fix for #27 'relationship with secondary generates incorrect query'
Browse files Browse the repository at this point in the history
  • Loading branch information
mrevutskyi committed Aug 14, 2021
1 parent 0c9d4e7 commit f62f9ef
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 11 deletions.
5 changes: 5 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ Changelog
Changes in Flask-Restless-NG
============================

Next Version
-------------
- Fix for #27 'relationship with secondary generates incorrect query'


Version 2.2.4
-------------
- Do not log exceptions for user related errors (bad query, etc)
Expand Down
13 changes: 9 additions & 4 deletions flask_restless/views/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,20 +1044,25 @@ def _selectinload_included_relationships(
filters=None
) -> Query:

def is_unsafe_to_selectload(attribute):
def is_safe_to_selectload(attribute):
# SQLAlchemy does not build correct `selectinload` queries for models that have special select join
try:
inspected_relationship = inspect(attribute)
if inspected_relationship.property.secondary:
return False
if not isinstance(inspected_relationship.property.primaryjoin, BinaryExpression):
return True
return False
except Exception:
# we do not have enough information, assume it's not safe
return False

return True

join_paths = {path.split('.')[0] for path in include}

for path in join_paths:
attribute = getattr(self.model, path)
if is_unsafe_to_selectload(attribute):
if not is_safe_to_selectload(attribute):
continue
if not is_proxy(attribute) and not isinstance(attribute.impl, DynamicAttributeImpl):
query = query.options(selectinload(attribute))
Expand All @@ -1071,7 +1076,7 @@ def is_unsafe_to_selectload(attribute):

for path in relationship_columns:
attribute = getattr(self.model, path)
if is_unsafe_to_selectload(attribute):
if not is_safe_to_selectload(attribute):
continue
if path not in join_paths and not is_proxy(attribute) and not isinstance(attribute.impl, DynamicAttributeImpl):
options = selectinload(attribute)
Expand Down
63 changes: 62 additions & 1 deletion tests/integration/test_mysq_flask_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from flask_restless import APIManager

pytestmark = pytest.mark.integration

app = Flask(__name__)
app.testing = True
app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql+pymysql://db_user:password@localhost/flask_restless'
Expand Down Expand Up @@ -39,6 +41,39 @@ class Order(db.Model):
client_id = db.Column(db.Integer, db.ForeignKey('client.id'))


# Models for Many-to-Many test case

class Sheet(db.Model):
__tablename__ = "sheet"

id = db.Column(db.Integer, primary_key=True)

report = db.relationship(
"Report",
cascade="all, delete",
passive_deletes=True,
single_parent=True,
secondary="join(Report, Device, Report.device_id == Device.id)",
order_by="Device.parent_device_id",
)


class Report(db.Model):
__tablename__ = "report"

id = db.Column(db.Integer, primary_key=True)

device_id = db.Column(db.Integer, db.ForeignKey("device.id", ondelete="CASCADE"), nullable=False)
sheet_id = db.Column(db.Integer, db.ForeignKey("sheet.id", ondelete="CASCADE"), nullable=False)


class Device(db.Model):
__tablename__ = "device"

id = db.Column(db.Integer, primary_key=True)
parent_device_id = db.Column(db.Integer, db.ForeignKey("device.id", ondelete="CASCADE"), nullable=True)


@pytest.fixture(scope='module')
def api():

Expand All @@ -47,6 +82,10 @@ def api():
api_manager.create_api(Client, collection_name='clients', page_size=0)
api_manager.create_api(Order, collection_name='orders', page_size=0)

api_manager.create_api(Report, collection_name='reports', page_size=0)
api_manager.create_api(Device, collection_name='devices', page_size=0)
api_manager.create_api(Sheet, collection_name='sheets', page_size=0)

db.drop_all()
db.create_all()

Expand All @@ -61,9 +100,31 @@ def api():
yield app.test_client()


@pytest.mark.integration
def test_responses(api):
response = api.get('/clients/1?include=orders,starred_orders')
assert response.status_code == 200
document = response.json
assert len(document['included']) == 5


def test_selectin_for_many_to_many(api):
"""
Test case to catch https://github.com/mrevutskyi/flask-restless-ng/issues/27
"""
db.session.add_all([
Device(id=1),
Device(id=2, parent_device_id=1),
Sheet(id=1),
Sheet(id=2)
])
db.session.commit()
db.session.add_all([
Report(id=1, device_id=2, sheet_id=1),
Report(id=2, device_id=1, sheet_id=1),
Report(id=3, device_id=1, sheet_id=2),
])
db.session.commit()

response = api.get('/sheets/1')
assert response.status_code == 200
assert len(response.json['data']['relationships']['report']['data']) == 2
7 changes: 1 addition & 6 deletions tests/test_jsonapi/test_server_responsibilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,9 @@ def test_no_response_media_type_params(self):
.. _Server Responsibilities: https://jsonapi.org/format/#content-negotiation-servers
"""
data = {
'data': {
'type': 'person',
}
}
headers = {'Content-Type': f'{CONTENT_TYPE}; version=1'}
# flask 1.0.1 overrides headers when `json` parameter is used, so have to use json.dumps
response = self.app.post('/api/person', data=json.dumps(data), headers=headers)
response = self.app.post('/api/person', data=json.dumps({}), headers=headers)
check_sole_error(response, 415, ['Content-Type', 'media type parameters'])

def test_empty_accept_header(self):
Expand Down

0 comments on commit f62f9ef

Please sign in to comment.