From ab8381f2998229e9316a50fe67719f465dd12ead Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Thu, 26 Dec 2024 13:08:15 +0530 Subject: [PATCH] Filter default records for insert during select action --- cumulusci/tasks/bulkdata/step.py | 28 +++++++++++++++ cumulusci/tasks/bulkdata/tests/test_step.py | 40 +++++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/cumulusci/tasks/bulkdata/step.py b/cumulusci/tasks/bulkdata/step.py index 9dbbe40cd7..b95323ea43 100644 --- a/cumulusci/tasks/bulkdata/step.py +++ b/cumulusci/tasks/bulkdata/step.py @@ -16,6 +16,9 @@ from cumulusci.core.enums import StrEnum from cumulusci.core.exceptions import BulkDataException from cumulusci.core.utils import process_bool_arg +from cumulusci.tasks.bulkdata.extract_dataset_utils.hardcoded_default_declarations import ( + DEFAULT_DECLARATIONS, +) from cumulusci.tasks.bulkdata.select_utils import ( SelectOperationExecutor, SelectRecordRetrievalMode, @@ -485,6 +488,13 @@ def select_records(self, records): self.logger.info(f"Retrieved {len(select_query_records)} from org") query_records.extend(select_query_records) + + # Filter out default declarations + if not self.selection_filter and self.sobject in DEFAULT_DECLARATIONS: + records = filter_records( + self.fields, records, DEFAULT_DECLARATIONS[self.sobject].where + ) + # Post-process the query results ( selected_records, @@ -901,6 +911,12 @@ def select_records(self, records): query_records = self._execute_soql_query(select_query, query_fields) self.logger.info(f"Retrieved {len(query_records)} from org") + # Filter out default declarations + if not self.selection_filter and self.sobject in DEFAULT_DECLARATIONS: + records = filter_records( + self.fields, records, DEFAULT_DECLARATIONS[self.sobject].where + ) + # Post-process the query results for this batch ( selected_records, @@ -1218,3 +1234,15 @@ def assign_weights( weights[i] = HIGH_PRIORITY_VALUE return weights + + +# Parsing the where condition for default records +def filter_records(fields: List[str], records, where_condition: str): + filtered_records = [] + for record in records: + # Create a dictionary mapping fields to record values + record_dict = {field: value for field, value in zip(fields, record)} + # Use eval to dynamically evaluate the where condition + if eval(where_condition, {}, record_dict): + filtered_records.append(record) + return filtered_records diff --git a/cumulusci/tasks/bulkdata/tests/test_step.py b/cumulusci/tasks/bulkdata/tests/test_step.py index 3887b270f3..ae46f5e613 100644 --- a/cumulusci/tasks/bulkdata/tests/test_step.py +++ b/cumulusci/tasks/bulkdata/tests/test_step.py @@ -25,6 +25,7 @@ assign_weights, download_file, extract_flattened_headers, + filter_records, flatten_record, get_dml_operation, get_query_operation, @@ -3875,3 +3876,42 @@ def test_flatten_record(record, headers, expected): def test_assign_weights(priority_fields, fields, expected): result = assign_weights(priority_fields, fields) assert result == expected + + +def test_filter_records(): + fields = ["Name", "AccountNumber", "Industry"] + records = [ + ["Sample Account for Entitlements", "123", "Technology"], + ["Acme Corp", "456", "Retail"], + ["Test Company", "789", "Finance"], + ] + + # Test 1: Exclude specific record + where_condition = "Name != 'Sample Account for Entitlements'" + expected_output = [ + ["Acme Corp", "456", "Retail"], + ["Test Company", "789", "Finance"], + ] + assert filter_records(fields, records, where_condition) == expected_output + + # Test 2: Include only specific Industry + where_condition = "Industry == 'Retail'" + expected_output = [["Acme Corp", "456", "Retail"]] + assert filter_records(fields, records, where_condition) == expected_output + + # Test 3: No match + where_condition = "AccountNumber == '999'" + expected_output = [] + assert filter_records(fields, records, where_condition) == expected_output + + # Test 4: Multiple conditions + where_condition = "Name == 'Acme Corp' and Industry == 'Retail'" + expected_output = [["Acme Corp", "456", "Retail"]] + assert filter_records(fields, records, where_condition) == expected_output + + # Test 5: Empty input + fields = [] + records = [] + where_condition = "Name != 'Nonexistent'" + expected_output = [] + assert filter_records(fields, records, where_condition) == expected_output