From d67fc6b44b35daec8354f71b2156ebe4c22d85af Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Tue, 3 Dec 2024 15:12:13 +0530 Subject: [PATCH] Fix issue where zero threshold was selecting everything. Added tests as well --- cumulusci/tasks/bulkdata/select_utils.py | 4 +- cumulusci/tasks/bulkdata/tests/test_step.py | 189 +++++++++++++++++++- 2 files changed, 189 insertions(+), 4 deletions(-) diff --git a/cumulusci/tasks/bulkdata/select_utils.py b/cumulusci/tasks/bulkdata/select_utils.py index 2d2728dadb..fedc1398bb 100644 --- a/cumulusci/tasks/bulkdata/select_utils.py +++ b/cumulusci/tasks/bulkdata/select_utils.py @@ -397,7 +397,7 @@ def annoy_post_process( # Retrieve the corresponding record from the database record = query_record_data[neighbor_index] closest_record_id = record_to_id_map[tuple(record)] - if threshold and (neighbor_distances[idx] >= threshold): + if threshold is not None and (neighbor_distances[idx] >= threshold): selected_records.append(None) insertion_candidates.append(load_shaped_records[i]) else: @@ -445,7 +445,7 @@ def levenshtein_post_process( select_record, target_records, similarity_weights ) - if distance_threshold and match_distance > distance_threshold: + if distance_threshold is not None and match_distance > distance_threshold: # Append load record for insertion if distance exceeds threshold insertion_candidates.append(load_record) selected_records.append(None) diff --git a/cumulusci/tasks/bulkdata/tests/test_step.py b/cumulusci/tasks/bulkdata/tests/test_step.py index e94e91f226..3887b270f3 100644 --- a/cumulusci/tasks/bulkdata/tests/test_step.py +++ b/cumulusci/tasks/bulkdata/tests/test_step.py @@ -1232,7 +1232,9 @@ def test_process_insert_records_failure(self, download_mock): ) @mock.patch("cumulusci.tasks.bulkdata.step.download_file") - def test_select_records_similarity_strategy__insert_records(self, download_mock): + def test_select_records_similarity_strategy__insert_records__non_zero_threshold( + self, download_mock + ): # Set up mock context and BulkApiDmlOperation context = mock.Mock() # Add step with threshold @@ -1325,6 +1327,102 @@ def test_select_records_similarity_strategy__insert_records(self, download_mock) == 1 ) + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_similarity_strategy__insert_records__zero_threshold( + self, download_mock + ): + # Set up mock context and BulkApiDmlOperation + context = mock.Mock() + # Add step with threshold + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=context, + fields=["Name", "Email"], + selection_strategy=SelectStrategy.SIMILARITY, + threshold=0, + ) + + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + # Mock the downloaded CSV content with a single record + select_results = io.StringIO( + """[{"Id":"003000000000001", "Name":"Jawad", "Email":"mjawadtp@example.com"}]""" + ) + insert_results = io.StringIO( + "Id,Success,Created\n003000000000002,true,true\n003000000000003,true,true\n" + ) + download_mock.side_effect = [select_results, insert_results] + + # Mock the _wait_for_job method to simulate a successful job + step._wait_for_job = mock.Mock() + step._wait_for_job.return_value = DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 0, 0 + ) + + # Prepare input records + records = iter( + [ + ["Jawad", "mjawadtp@example.com"], + ["Aditya", "aditya@example.com"], + ["Tom", "cruise@example.com"], + ] + ) + + # Mock sub-operation for BulkApiDmlOperation + insert_step = mock.Mock(spec=BulkApiDmlOperation) + insert_step.start = mock.Mock() + insert_step.load_records = mock.Mock() + insert_step.end = mock.Mock() + insert_step.batch_ids = ["BATCH1"] + insert_step.bulk = mock.Mock() + insert_step.bulk.endpoint = "https://test" + insert_step.job_id = "JOB" + + with mock.patch( + "cumulusci.tasks.bulkdata.step.BulkApiDmlOperation", + return_value=insert_step, + ): + # Execute the select_records operation + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000002", success=True, error="", created=True + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000003", success=True, error="", created=True + ) + ) + == 1 + ) + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") def test_select_records_similarity_strategy__insert_records__no_select_records( self, download_mock @@ -2807,7 +2905,9 @@ def test_process_insert_records_failure(self): mock_rest_api_dml_operation.end.assert_not_called() @responses.activate - def test_select_records_similarity_strategy__insert_records(self): + def test_select_records_similarity_strategy__insert_records__non_zero_threshold( + self, + ): mock_describe_calls() task = _make_task( LoadData, @@ -2891,6 +2991,91 @@ def test_select_records_similarity_strategy__insert_records(self): == 1 ) + @responses.activate + def test_select_records_similarity_strategy__insert_records__zero_threshold(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + # Create step with threshold + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.UPSERT, + api_options={"batch_size": 10}, + context=task, + fields=["Name", "Email"], + selection_strategy=SelectStrategy.SIMILARITY, + threshold=0, + ) + + results_select_call = { + "records": [ + { + "Id": "003000000000001", + "Name": "Jawad", + "Email": "mjawadtp@example.com", + }, + ], + "done": True, + } + + results_insert_call = [ + {"id": "003000000000002", "success": True, "created": True}, + {"id": "003000000000003", "success": True, "created": True}, + ] + + step.sf.restful = mock.Mock( + side_effect=[results_select_call, results_insert_call] + ) + records = iter( + [ + ["Jawad", "mjawadtp@example.com"], + ["Aditya", "aditya@example.com"], + ["Tom Cruise", "tom@example.com"], + ] + ) + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000002", success=True, error="", created=True + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000003", success=True, error="", created=True + ) + ) + == 1 + ) + @responses.activate def test_insert_dml_operation__boolean_conversion(self): mock_describe_calls()