From e8e7d46f30f4295af99cd093ea7b24b48a784afb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yunus=20Emre=20Geldeg=C3=BCl?= Date: Thu, 15 Aug 2024 10:53:43 +0300 Subject: [PATCH] feat: Update `tables_involved` property to include tables from UPDATE operation (#717) - Enhanced the `tables_involved` property of the `SQLQuery` model to correctly identify tables involved in `UPDATE` operation. - Added new test cases to validate the changes: - Test for `UPDATE` queries. - Test for complex `UPDATE` query involving subqueries and joins. - Ensured existing tests for `SELECT`, `JOIN`, and `AS` tokens continue to pass. --- project/tests/test_models.py | 26 ++++++++++++++++++++++++++ silk/models.py | 7 ++++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/project/tests/test_models.py b/project/tests/test_models.py index efdb0d0a..0d0bc9e9 100644 --- a/project/tests/test_models.py +++ b/project/tests/test_models.py @@ -422,6 +422,32 @@ def test_tables_involved_if_query_has_django_aliase_on_column_names(self): self.obj.query = query self.assertEqual(self.obj.tables_involved, ['bar', 'some_table;']) + def test_tables_involved_if_query_has_update_token(self): + + query = """UPDATE Book SET title = 'New Title' WHERE id = 1;""" + self.obj.query = query + self.assertEqual(self.obj.tables_involved, ['Book']) + + def test_tables_involved_in_complex_update_query(self): + + query = '''UPDATE Person p + SET p.name = (SELECT c.name FROM Company c WHERE c.id = p.company_id), + p.salary = p.salary * 1.1 + FROM Department d + WHERE p.department_id = d.id AND d.budget > 100000; + ''' + self.obj.query = query + self.assertEqual(self.obj.tables_involved, ['Person', 'Company', 'Department']) + + def test_tables_involved_in_update_with_subquery(self): + + query = '''UPDATE Employee e + SET e.bonus = (SELECT AVG(salary) FROM Employee WHERE department_id = e.department_id) + WHERE e.performance = 'excellent'; + ''' + self.obj.query = query + self.assertEqual(self.obj.tables_involved, ['Employee', 'Employee']) + def test_save_if_no_end_and_start_time(self): obj = SQLQueryFactory.create() diff --git a/silk/models.py b/silk/models.py index 075c5e9a..0c28cad0 100644 --- a/silk/models.py +++ b/silk/models.py @@ -303,7 +303,12 @@ def tables_involved(self): for idx, component in enumerate(components): # TODO: If django uses aliases on column names they will be falsely # identified as tables... - if component.lower() == 'from' or component.lower() == 'join' or component.lower() == 'as': + if ( + component.lower() == "from" + or component.lower() == "join" + or component.lower() == "as" + or component.lower() == "update" + ): try: _next = components[idx + 1] if not _next.startswith('('): # Subquery