Skip to content

Commit

Permalink
feat: Update tables_involved property to include tables from UPDATE…
Browse files Browse the repository at this point in the history
… operation (#717)

- Enhanced the `tables_involved` property of the `SQLQuery` model to correctly identify tables involved in `UPDATE` operation.
- Added new test cases to validate the changes:
  - Test for `UPDATE` queries.
  - Test for complex `UPDATE` query involving subqueries and joins.
  - Ensured existing tests for `SELECT`, `JOIN`, and `AS` tokens continue to pass.
  • Loading branch information
emregeldegul authored Aug 15, 2024
1 parent 0f22ce9 commit e8e7d46
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
26 changes: 26 additions & 0 deletions project/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,32 @@ def test_tables_involved_if_query_has_django_aliase_on_column_names(self):
self.obj.query = query
self.assertEqual(self.obj.tables_involved, ['bar', 'some_table;'])

def test_tables_involved_if_query_has_update_token(self):

query = """UPDATE Book SET title = 'New Title' WHERE id = 1;"""
self.obj.query = query
self.assertEqual(self.obj.tables_involved, ['Book'])

def test_tables_involved_in_complex_update_query(self):

query = '''UPDATE Person p
SET p.name = (SELECT c.name FROM Company c WHERE c.id = p.company_id),
p.salary = p.salary * 1.1
FROM Department d
WHERE p.department_id = d.id AND d.budget > 100000;
'''
self.obj.query = query
self.assertEqual(self.obj.tables_involved, ['Person', 'Company', 'Department'])

def test_tables_involved_in_update_with_subquery(self):

query = '''UPDATE Employee e
SET e.bonus = (SELECT AVG(salary) FROM Employee WHERE department_id = e.department_id)
WHERE e.performance = 'excellent';
'''
self.obj.query = query
self.assertEqual(self.obj.tables_involved, ['Employee', 'Employee'])

def test_save_if_no_end_and_start_time(self):

obj = SQLQueryFactory.create()
Expand Down
7 changes: 6 additions & 1 deletion silk/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,12 @@ def tables_involved(self):
for idx, component in enumerate(components):
# TODO: If django uses aliases on column names they will be falsely
# identified as tables...
if component.lower() == 'from' or component.lower() == 'join' or component.lower() == 'as':
if (
component.lower() == "from"
or component.lower() == "join"
or component.lower() == "as"
or component.lower() == "update"
):
try:
_next = components[idx + 1]
if not _next.startswith('('): # Subquery
Expand Down

0 comments on commit e8e7d46

Please sign in to comment.