From 1de93bd6a4adb8920472818c1fc24abe7c8cf164 Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Mon, 19 Aug 2024 11:10:28 +0530 Subject: [PATCH 01/34] Add select records to step with random selection --- cumulusci/tasks/bulkdata/load.py | 16 ++++--- cumulusci/tasks/bulkdata/step.py | 72 +++++++++++++++++++++++++++++++- 2 files changed, 82 insertions(+), 6 deletions(-) diff --git a/cumulusci/tasks/bulkdata/load.py b/cumulusci/tasks/bulkdata/load.py index 4ae0dcf31a..51c222ee55 100644 --- a/cumulusci/tasks/bulkdata/load.py +++ b/cumulusci/tasks/bulkdata/load.py @@ -289,7 +289,12 @@ def _execute_step( self, step, self._stream_queried_data(mapping, local_ids, query) ) step.start() - step.load_records(self._stream_queried_data(mapping, local_ids, query)) + if mapping.action == DataOperationType.SELECT: + step.select_records( + self._stream_queried_data(mapping, local_ids, query) + ) + else: + step.load_records(self._stream_queried_data(mapping, local_ids, query)) step.end() # Process Job Results @@ -481,10 +486,11 @@ def _process_job_results(self, mapping, step, local_ids): """Get the job results and process the results. If we're raising for row-level errors, do so; if we're inserting, store the new Ids.""" - is_insert_or_upsert = mapping.action in ( + is_insert_upsert_or_select = mapping.action in ( DataOperationType.INSERT, DataOperationType.UPSERT, DataOperationType.ETL_UPSERT, + DataOperationType.SELECT, ) conn = self.session.connection() @@ -500,7 +506,7 @@ def _process_job_results(self, mapping, step, local_ids): break # If we know we have no successful inserts, don't attempt to persist Ids. # Do, however, drain the generator to get error-checking behavior. - if is_insert_or_upsert and ( + if is_insert_upsert_or_select and ( step.job_result.records_processed - step.job_result.total_row_errors ): table = self.metadata.tables[self.ID_TABLE_NAME] @@ -516,7 +522,7 @@ def _process_job_results(self, mapping, step, local_ids): # person account Contact records so lookups to # person account Contact records get populated downstream as expected. if ( - is_insert_or_upsert + is_insert_upsert_or_select and mapping.sf_object == "Contact" and self._can_load_person_accounts(mapping) ): @@ -531,7 +537,7 @@ def _process_job_results(self, mapping, step, local_ids): ), ) - if is_insert_or_upsert: + if is_insert_upsert_or_select: self.session.commit() def _generate_results_id_map(self, step, local_ids): diff --git a/cumulusci/tasks/bulkdata/step.py b/cumulusci/tasks/bulkdata/step.py index edcb62afbb..5da84930d2 100644 --- a/cumulusci/tasks/bulkdata/step.py +++ b/cumulusci/tasks/bulkdata/step.py @@ -15,6 +15,9 @@ from cumulusci.core.enums import StrEnum from cumulusci.core.exceptions import BulkDataException from cumulusci.core.utils import process_bool_arg +from cumulusci.tasks.bulkdata.extract_dataset_utils.hardcoded_default_declarations import ( + DEFAULT_DECLARATIONS, +) from cumulusci.tasks.bulkdata.utils import iterate_in_chunks from cumulusci.utils.classutils import namedtuple_as_simple_dict from cumulusci.utils.xml import lxml_parse_string @@ -36,6 +39,7 @@ class DataOperationType(StrEnum): UPSERT = "upsert" ETL_UPSERT = "etl_upsert" SMART_UPSERT = "smart_upsert" # currently undocumented + SELECT = "select" class DataApi(StrEnum): @@ -320,6 +324,11 @@ def get_prev_record_values(self, records): """Get the previous records values in case of UPSERT and UPDATE to prepare for rollback""" pass + @abstractmethod + def select_records(self, records): + """Perform the requested DML operation on the supplied row iterator.""" + pass + @abstractmethod def load_records(self, records): """Perform the requested DML operation on the supplied row iterator.""" @@ -424,6 +433,9 @@ def load_records(self, records): self.context.logger.info(f"Uploading batch {count + 1}") self.batch_ids.append(self.bulk.post_batch(self.job_id, iter(csv_batch))) + def select_records(self, records): + return super().select_records(records) + def _batch(self, records, n, char_limit=10000000): """Given an iterator of records, yields batches of records serialized in .csv format. @@ -631,6 +643,64 @@ def load_records(self, records): row_errors, ) + def select_records(self, records): + """Executes a SOQL query to select records and adds them to results""" + self.results = [] + num_records = sum(1 for _ in records) + selected_records = self.random_selection(num_records) + self.results.extend(selected_records) + self.job_result = DataOperationJobResult( + DataOperationStatus.SUCCESS + if not len(selected_records) + else DataOperationStatus.JOB_FAILURE, + [], + len(self.results), + 0, + ) + + def random_selection(self, num_records): + try: + # Get the WHERE clause from DEFAULT_DECLARATIONS if available + declaration = DEFAULT_DECLARATIONS.get(self.sobject) + if declaration: + where_clause = declaration.where + else: + where_clause = None # Explicitly set to None if not found + # Construct the query with the WHERE clause (if it exists) + query = f"SELECT Id FROM {self.sobject}" + if where_clause: + query += f" WHERE {where_clause}" + query += f" LIMIT {num_records}" + query_results = self.sf.query(query) + + # Handle case where query returns 0 records + if not query_results["records"]: + error_message = ( + f"No records found for {self.sobject} in the target org." + ) + self.logger.error(error_message) + return [], error_message + + # Add 'success: True' to each record to emulate records have been inserted + selected_records = [ + {"success": True, "id": record["Id"]} + for record in query_results["records"] + ] + + # If fewer records than requested, repeat existing records to match num_records + if len(selected_records) < num_records: + original_records = selected_records.copy() + while len(selected_records) < num_records: + selected_records.extend(original_records) + selected_records = selected_records[:num_records] + + return selected_records + + except Exception as e: + error_message = f"Error executing SOQL query for {self.sobject}: {e}" + self.logger.error(error_message) + return [], error_message + def get_results(self): """Return a generator of DataOperationResult objects.""" @@ -646,7 +716,7 @@ def _convert(res): if self.operation == DataOperationType.INSERT: created = True - elif self.operation == DataOperationType.UPDATE: + elif self.operation in [DataOperationType.UPDATE, DataOperationType.SELECT]: created = False else: created = res.get("created") From 95d6414749268de50b4d2e6d2ffb62bbfa8d1700 Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Mon, 19 Aug 2024 17:01:56 +0530 Subject: [PATCH 02/34] Move random selection strategy outside DML Operation Class --- .../extract_dataset_utils/extract_yml.py | 9 +- cumulusci/tasks/bulkdata/load.py | 2 + cumulusci/tasks/bulkdata/step.py | 241 +++++++++++++----- 3 files changed, 189 insertions(+), 63 deletions(-) diff --git a/cumulusci/tasks/bulkdata/extract_dataset_utils/extract_yml.py b/cumulusci/tasks/bulkdata/extract_dataset_utils/extract_yml.py index 95d6b9ff97..9679da5a1a 100644 --- a/cumulusci/tasks/bulkdata/extract_dataset_utils/extract_yml.py +++ b/cumulusci/tasks/bulkdata/extract_dataset_utils/extract_yml.py @@ -5,7 +5,6 @@ from pydantic import Field, validator from cumulusci.core.enums import StrEnum -from cumulusci.tasks.bulkdata.step import DataApi from cumulusci.utils.yaml.model_parser import CCIDictModel, HashableBaseModel object_decl = re.compile(r"objects\((\w+)\)", re.IGNORECASE) @@ -25,6 +24,14 @@ class SFFieldGroupTypes(StrEnum): required = "required" +class DataApi(StrEnum): + """Enum defining requested Salesforce data API for an operation.""" + + BULK = "bulk" + REST = "rest" + SMART = "smart" + + class ExtractDeclaration(HashableBaseModel): where: T.Optional[str] = None fields_: T.Union[T.List[str], str] = Field(["FIELDS(ALL)"], alias="fields") diff --git a/cumulusci/tasks/bulkdata/load.py b/cumulusci/tasks/bulkdata/load.py index 51c222ee55..d6adf1395a 100644 --- a/cumulusci/tasks/bulkdata/load.py +++ b/cumulusci/tasks/bulkdata/load.py @@ -341,6 +341,8 @@ def configure_step(self, mapping): self.check_simple_upsert(mapping) api_options["update_key"] = mapping.update_key[0] action = DataOperationType.UPSERT + elif mapping.action == DataOperationType.SELECT: + action = DataOperationType.QUERY else: action = mapping.action diff --git a/cumulusci/tasks/bulkdata/step.py b/cumulusci/tasks/bulkdata/step.py index 5da84930d2..eb9d11023c 100644 --- a/cumulusci/tasks/bulkdata/step.py +++ b/cumulusci/tasks/bulkdata/step.py @@ -434,7 +434,60 @@ def load_records(self, records): self.batch_ids.append(self.bulk.post_batch(self.job_id, iter(csv_batch))) def select_records(self, records): - return super().select_records(records) + """Executes a SOQL query to select records and adds them to results""" + + self.select_results = [] # Store selected records + + # Count total number of records to fetch + total_num_records = sum(1 for _ in records) + + # Process in batches based on batch_size from api_options + for offset in range( + 0, total_num_records, self.api_options.get("batch_size", 500) + ): + # Calculate number of records to fetch in this batch + num_records = min( + self.api_options.get("batch_size", 500), total_num_records - offset + ) + + # Generate and execute SOQL query + query = random_generate_query(self.sobject, num_records) + self.batch_id = self.bulk.query(self.job_id, query) + self._wait_for_job(self.job_id) + + # Get and process query results + result_ids = self.bulk.get_query_batch_result_ids( + self.batch_id, job_id=self.job_id + ) + query_records = [] + for result_id in result_ids: + uri = f"{self.bulk.endpoint}/job/{self.job_id}/batch/{self.batch_id}/result/{result_id}" + with download_file(uri, self.bulk) as f: + reader = csv.reader(f) + self.headers = next(reader) + if "Records not found for this query" in self.headers: + break # Stop if no records found + for row in reader: + query_records.append([row[0]]) + + # Post-process the query results + selected_records, error_message = random_post_process( + query_records, num_records, self.sobject + ) + if error_message: + break # Stop if there's an error during post-processing + + self.select_results.extend(selected_records) + + # Update job result based on selection outcome + self.job_result = DataOperationJobResult( + DataOperationStatus.SUCCESS + if len(self.select_results) + else DataOperationStatus.JOB_FAILURE, + [error_message] if error_message else [], + len(self.select_results), + 0, + ) def _batch(self, records, n, char_limit=10000000): """Given an iterator of records, yields batches of @@ -484,6 +537,29 @@ def _serialize_csv_record(self, record): return serialized def get_results(self): + """ + Retrieves and processes the results of a Bulk API operation. + """ + + if self.operation is DataOperationType.QUERY: + yield from self._get_query_results() + else: + yield from self._get_batch_results() + + def _get_query_results(self): + """Handles results for QUERY (select) operations""" + for row in self.select_results: + success = process_bool_arg(row["success"]) + created = process_bool_arg(row["created"]) + yield DataOperationResult( + row["id"] if success else None, + success, + None, + created, + ) + + def _get_batch_results(self): + """Handles results for other DataOperationTypes (insert, update, etc.)""" for batch_id in self.batch_ids: try: results_url = ( @@ -493,24 +569,28 @@ def get_results(self): # to avoid the server dropping connections with download_file(results_url, self.bulk) as f: self.logger.info(f"Downloaded results for batch {batch_id}") + yield from self._parse_batch_results(f) - reader = csv.reader(f) - next(reader) # skip header - - for row in reader: - success = process_bool_arg(row[1]) - created = process_bool_arg(row[2]) - yield DataOperationResult( - row[0] if success else None, - success, - row[3] if not success else None, - created, - ) except Exception as e: raise BulkDataException( f"Failed to download results for batch {batch_id} ({str(e)})" ) + def _parse_batch_results(self, f): + """Parses batch results from the downloaded file""" + reader = csv.reader(f) + next(reader) # Skip header row + + for row in reader: + success = process_bool_arg(row[1]) + created = process_bool_arg(row[2]) + yield DataOperationResult( + row[0] if success else None, + success, + row[3] if not success else None, + created, + ) + class RestApiDmlOperation(BaseDmlOperation): """Operation class for all DML operations run using the REST API.""" @@ -645,62 +725,55 @@ def load_records(self, records): def select_records(self, records): """Executes a SOQL query to select records and adds them to results""" + + def convert(rec, fields): + """Helper function to convert record values to strings, handling None values""" + return [str(rec[f]) if rec[f] is not None else "" for f in fields] + self.results = [] - num_records = sum(1 for _ in records) - selected_records = self.random_selection(num_records) - self.results.extend(selected_records) + # Count the number of records to fetch + total_num_records = sum(1 for _ in records) + + # Process in batches + for offset in range(0, total_num_records, self.api_options.get("batch_size")): + num_records = min( + self.api_options.get("batch_size"), total_num_records - offset + ) + # Generate the SOQL query with and LIMIT + query = random_generate_query(self.sobject, num_records) + + # Execute the query and extract results + response = self.sf.query(query) + # Extract and convert 'Id' fields from the query results + query_records = list(convert(rec, ["Id"]) for rec in response["records"]) + # Handle pagination if there are more records within this batch + while not response["done"]: + response = self.sf.query_more( + response["nextRecordsUrl"], identifier_is_url=True + ) + query_records.extend( + list(convert(rec, ["Id"]) for rec in response["records"]) + ) + + # Post-process the query results for this batch + selected_records, error_message = random_post_process( + query_records, num_records, self.sobject + ) + if error_message: + break + # Add selected records from this batch to the overall results + self.results.extend(selected_records) + + # Update the job result based on the overall selection outcome self.job_result = DataOperationJobResult( DataOperationStatus.SUCCESS - if not len(selected_records) + if len(self.results) # Check the overall results length else DataOperationStatus.JOB_FAILURE, - [], + [error_message] if error_message else [], len(self.results), 0, ) - def random_selection(self, num_records): - try: - # Get the WHERE clause from DEFAULT_DECLARATIONS if available - declaration = DEFAULT_DECLARATIONS.get(self.sobject) - if declaration: - where_clause = declaration.where - else: - where_clause = None # Explicitly set to None if not found - # Construct the query with the WHERE clause (if it exists) - query = f"SELECT Id FROM {self.sobject}" - if where_clause: - query += f" WHERE {where_clause}" - query += f" LIMIT {num_records}" - query_results = self.sf.query(query) - - # Handle case where query returns 0 records - if not query_results["records"]: - error_message = ( - f"No records found for {self.sobject} in the target org." - ) - self.logger.error(error_message) - return [], error_message - - # Add 'success: True' to each record to emulate records have been inserted - selected_records = [ - {"success": True, "id": record["Id"]} - for record in query_results["records"] - ] - - # If fewer records than requested, repeat existing records to match num_records - if len(selected_records) < num_records: - original_records = selected_records.copy() - while len(selected_records) < num_records: - selected_records.extend(original_records) - selected_records = selected_records[:num_records] - - return selected_records - - except Exception as e: - error_message = f"Error executing SOQL query for {self.sobject}: {e}" - self.logger.error(error_message) - return [], error_message - def get_results(self): """Return a generator of DataOperationResult objects.""" @@ -716,7 +789,7 @@ def _convert(res): if self.operation == DataOperationType.INSERT: created = True - elif self.operation in [DataOperationType.UPDATE, DataOperationType.SELECT]: + elif self.operation == DataOperationType.UPDATE: created = False else: created = res.get("created") @@ -816,3 +889,47 @@ def get_dml_operation( context=context, fields=fields, ) + + +def random_generate_query(sobject: str, num_records: float) -> str: + """Generates the SOQL query for the random selection strategy""" + # Get the WHERE clause from DEFAULT_DECLARATIONS if available + declaration = DEFAULT_DECLARATIONS.get(sobject) + if declaration: + where_clause = declaration.where + else: + where_clause = None + # Construct the query with the WHERE clause (if it exists) + query = f"SELECT Id FROM {sobject}" + if where_clause: + query += f" WHERE {where_clause}" + query += f" LIMIT {num_records}" + + return query + + +def random_post_process(records, num_records: float, sobject: str): + """Processes the query results for the random selection strategy""" + try: + # Handle case where query returns 0 records + if not records: + error_message = f"No records found for {sobject} in the target org." + return [], error_message + + # Add 'success: True' to each record to emulate records have been inserted + selected_records = [ + {"id": record[0], "success": True, "created": False} for record in records + ] + + # If fewer records than requested, repeat existing records to match num_records + if len(selected_records) < num_records: + original_records = selected_records.copy() + while len(selected_records) < num_records: + selected_records.extend(original_records) + selected_records = selected_records[:num_records] + + return selected_records, None # Return selected records and None for error + + except Exception as e: + error_message = f"Error processing query results for {sobject}: {e}" + return [], error_message From 0fca3f4ddb4533552bb738fd4d2004191df4ce8a Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Mon, 19 Aug 2024 17:36:21 +0530 Subject: [PATCH 03/34] Separate select related functions into select_utils file --- cumulusci/tasks/bulkdata/select_utils.py | 54 ++++++++++++++ cumulusci/tasks/bulkdata/step.py | 95 ++++++++++-------------- datasets/mapping.yml | 1 + 3 files changed, 95 insertions(+), 55 deletions(-) create mode 100644 cumulusci/tasks/bulkdata/select_utils.py diff --git a/cumulusci/tasks/bulkdata/select_utils.py b/cumulusci/tasks/bulkdata/select_utils.py new file mode 100644 index 0000000000..b712970035 --- /dev/null +++ b/cumulusci/tasks/bulkdata/select_utils.py @@ -0,0 +1,54 @@ +from cumulusci.core.enums import StrEnum +from cumulusci.tasks.bulkdata.extract_dataset_utils.hardcoded_default_declarations import ( + DEFAULT_DECLARATIONS, +) + + +class SelectStrategy(StrEnum): + """Enum defining the different selection strategies requested.""" + + RANDOM = "random" + + +def random_generate_query(sobject: str, num_records: float): + """Generates the SOQL query for the random selection strategy""" + # Get the WHERE clause from DEFAULT_DECLARATIONS if available + declaration = DEFAULT_DECLARATIONS.get(sobject) + if declaration: + where_clause = declaration.where + else: + where_clause = None + # Construct the query with the WHERE clause (if it exists) + query = f"SELECT Id FROM {sobject}" + if where_clause: + query += f" WHERE {where_clause}" + query += f" LIMIT {num_records}" + + return query, ["Id"] + + +def random_post_process(records, num_records: float, sobject: str): + """Processes the query results for the random selection strategy""" + try: + # Handle case where query returns 0 records + if not records: + error_message = f"No records found for {sobject} in the target org." + return [], error_message + + # Add 'success: True' to each record to emulate records have been inserted + selected_records = [ + {"id": record[0], "success": True, "created": False} for record in records + ] + + # If fewer records than requested, repeat existing records to match num_records + if len(selected_records) < num_records: + original_records = selected_records.copy() + while len(selected_records) < num_records: + selected_records.extend(original_records) + selected_records = selected_records[:num_records] + + return selected_records, None # Return selected records and None for error + + except Exception as e: + error_message = f"Error processing query results for {sobject}: {e}" + return [], error_message diff --git a/cumulusci/tasks/bulkdata/step.py b/cumulusci/tasks/bulkdata/step.py index eb9d11023c..b0d4e31b44 100644 --- a/cumulusci/tasks/bulkdata/step.py +++ b/cumulusci/tasks/bulkdata/step.py @@ -15,8 +15,10 @@ from cumulusci.core.enums import StrEnum from cumulusci.core.exceptions import BulkDataException from cumulusci.core.utils import process_bool_arg -from cumulusci.tasks.bulkdata.extract_dataset_utils.hardcoded_default_declarations import ( - DEFAULT_DECLARATIONS, +from cumulusci.tasks.bulkdata.select_utils import ( + SelectStrategy, + random_generate_query, + random_post_process, ) from cumulusci.tasks.bulkdata.utils import iterate_in_chunks from cumulusci.utils.classutils import namedtuple_as_simple_dict @@ -347,7 +349,16 @@ def get_results(self): class BulkApiDmlOperation(BaseDmlOperation, BulkJobMixin): """Operation class for all DML operations run using the Bulk API.""" - def __init__(self, *, sobject, operation, api_options, context, fields): + def __init__( + self, + *, + sobject, + operation, + api_options, + context, + fields, + selection_strategy=SelectStrategy.RANDOM, + ): super().__init__( sobject=sobject, operation=operation, @@ -362,6 +373,10 @@ def __init__(self, *, sobject, operation, api_options, context, fields): self.csv_buff = io.StringIO(newline="") self.csv_writer = csv.writer(self.csv_buff, quoting=csv.QUOTE_ALL) + if selection_strategy is SelectStrategy.RANDOM: + self.select_generate_query = random_generate_query + self.select_post_process = random_post_process + def start(self): self.job_id = self.bulk.create_job( self.sobject, @@ -451,7 +466,7 @@ def select_records(self, records): ) # Generate and execute SOQL query - query = random_generate_query(self.sobject, num_records) + query, query_fields = self.select_generate_query(self.sobject, num_records) self.batch_id = self.bulk.query(self.job_id, query) self._wait_for_job(self.job_id) @@ -468,10 +483,10 @@ def select_records(self, records): if "Records not found for this query" in self.headers: break # Stop if no records found for row in reader: - query_records.append([row[0]]) + query_records.append([row[: len(query_fields)]]) # Post-process the query results - selected_records, error_message = random_post_process( + selected_records, error_message = self.select_post_process( query_records, num_records, self.sobject ) if error_message: @@ -595,7 +610,16 @@ def _parse_batch_results(self, f): class RestApiDmlOperation(BaseDmlOperation): """Operation class for all DML operations run using the REST API.""" - def __init__(self, *, sobject, operation, api_options, context, fields): + def __init__( + self, + *, + sobject, + operation, + api_options, + context, + fields, + selection_strategy=SelectStrategy.RANDOM, + ): super().__init__( sobject=sobject, operation=operation, @@ -617,6 +641,9 @@ def __init__(self, *, sobject, operation, api_options, context, fields): self.api_options["batch_size"] = min( self.api_options["batch_size"], MAX_REST_BATCH_SIZE ) + if selection_strategy is SelectStrategy.RANDOM: + self.select_generate_query = random_generate_query + self.select_post_process = random_post_process def _record_to_json(self, rec): result = dict(zip(self.fields, rec)) @@ -740,23 +767,25 @@ def convert(rec, fields): self.api_options.get("batch_size"), total_num_records - offset ) # Generate the SOQL query with and LIMIT - query = random_generate_query(self.sobject, num_records) + query, query_fields = self.select_generate_query(self.sobject, num_records) # Execute the query and extract results response = self.sf.query(query) # Extract and convert 'Id' fields from the query results - query_records = list(convert(rec, ["Id"]) for rec in response["records"]) + query_records = list( + convert(rec, query_fields) for rec in response["records"] + ) # Handle pagination if there are more records within this batch while not response["done"]: response = self.sf.query_more( response["nextRecordsUrl"], identifier_is_url=True ) query_records.extend( - list(convert(rec, ["Id"]) for rec in response["records"]) + list(convert(rec, query_fields) for rec in response["records"]) ) # Post-process the query results for this batch - selected_records, error_message = random_post_process( + selected_records, error_message = self.select_post_process( query_records, num_records, self.sobject ) if error_message: @@ -889,47 +918,3 @@ def get_dml_operation( context=context, fields=fields, ) - - -def random_generate_query(sobject: str, num_records: float) -> str: - """Generates the SOQL query for the random selection strategy""" - # Get the WHERE clause from DEFAULT_DECLARATIONS if available - declaration = DEFAULT_DECLARATIONS.get(sobject) - if declaration: - where_clause = declaration.where - else: - where_clause = None - # Construct the query with the WHERE clause (if it exists) - query = f"SELECT Id FROM {sobject}" - if where_clause: - query += f" WHERE {where_clause}" - query += f" LIMIT {num_records}" - - return query - - -def random_post_process(records, num_records: float, sobject: str): - """Processes the query results for the random selection strategy""" - try: - # Handle case where query returns 0 records - if not records: - error_message = f"No records found for {sobject} in the target org." - return [], error_message - - # Add 'success: True' to each record to emulate records have been inserted - selected_records = [ - {"id": record[0], "success": True, "created": False} for record in records - ] - - # If fewer records than requested, repeat existing records to match num_records - if len(selected_records) < num_records: - original_records = selected_records.copy() - while len(selected_records) < num_records: - selected_records.extend(original_records) - selected_records = selected_records[:num_records] - - return selected_records, None # Return selected records and None for error - - except Exception as e: - error_message = f"Error processing query results for {sobject}: {e}" - return [], error_message diff --git a/datasets/mapping.yml b/datasets/mapping.yml index ae7952b22c..838b8b4597 100644 --- a/datasets/mapping.yml +++ b/datasets/mapping.yml @@ -1,6 +1,7 @@ Account: sf_object: Account api: bulk + action: select fields: - Name - Description From 580a6e67961cfc70ef4a68e1af095353b9aabd88 Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Tue, 20 Aug 2024 00:21:20 +0530 Subject: [PATCH 04/34] Add test cases for select_records functionality --- cumulusci/tasks/bulkdata/select_utils.py | 39 +-- cumulusci/tasks/bulkdata/step.py | 3 +- .../tasks/bulkdata/tests/test_select_utils.py | 63 ++++ cumulusci/tasks/bulkdata/tests/test_step.py | 309 ++++++++++++++++++ cumulusci/tasks/bulkdata/tests/utils.py | 3 + datasets/mapping.yml | 1 - 6 files changed, 394 insertions(+), 24 deletions(-) create mode 100644 cumulusci/tasks/bulkdata/tests/test_select_utils.py diff --git a/cumulusci/tasks/bulkdata/select_utils.py b/cumulusci/tasks/bulkdata/select_utils.py index b712970035..3521fa3c8e 100644 --- a/cumulusci/tasks/bulkdata/select_utils.py +++ b/cumulusci/tasks/bulkdata/select_utils.py @@ -29,26 +29,21 @@ def random_generate_query(sobject: str, num_records: float): def random_post_process(records, num_records: float, sobject: str): """Processes the query results for the random selection strategy""" - try: - # Handle case where query returns 0 records - if not records: - error_message = f"No records found for {sobject} in the target org." - return [], error_message - - # Add 'success: True' to each record to emulate records have been inserted - selected_records = [ - {"id": record[0], "success": True, "created": False} for record in records - ] - - # If fewer records than requested, repeat existing records to match num_records - if len(selected_records) < num_records: - original_records = selected_records.copy() - while len(selected_records) < num_records: - selected_records.extend(original_records) - selected_records = selected_records[:num_records] - - return selected_records, None # Return selected records and None for error - - except Exception as e: - error_message = f"Error processing query results for {sobject}: {e}" + # Handle case where query returns 0 records + if not records: + error_message = f"No records found for {sobject} in the target org." return [], error_message + + # Add 'success: True' to each record to emulate records have been inserted + selected_records = [ + {"id": record[0], "success": True, "created": False} for record in records + ] + + # If fewer records than requested, repeat existing records to match num_records + if len(selected_records) < num_records: + original_records = selected_records.copy() + while len(selected_records) < num_records: + selected_records.extend(original_records) + selected_records = selected_records[:num_records] + + return selected_records, None # Return selected records and None for error diff --git a/cumulusci/tasks/bulkdata/step.py b/cumulusci/tasks/bulkdata/step.py index b0d4e31b44..1844c4caeb 100644 --- a/cumulusci/tasks/bulkdata/step.py +++ b/cumulusci/tasks/bulkdata/step.py @@ -388,7 +388,8 @@ def start(self): def end(self): self.bulk.close_job(self.job_id) - self.job_result = self._wait_for_job(self.job_id) + if not self.job_result: + self.job_result = self._wait_for_job(self.job_id) def get_prev_record_values(self, records): """Get the previous values of the records based on the update key diff --git a/cumulusci/tasks/bulkdata/tests/test_select_utils.py b/cumulusci/tasks/bulkdata/tests/test_select_utils.py new file mode 100644 index 0000000000..c649871217 --- /dev/null +++ b/cumulusci/tasks/bulkdata/tests/test_select_utils.py @@ -0,0 +1,63 @@ +from cumulusci.tasks.bulkdata.select_utils import ( + random_generate_query, + random_post_process, +) + + +# Test Cases for random_generate_query +def test_random_generate_query_with_default_record_declaration(): + sobject = "Account" # Assuming Account has a declaration in DEFAULT_DECLARATIONS + num_records = 5 + query, fields = random_generate_query(sobject, num_records) + + assert "WHERE" in query # Ensure WHERE clause is included + assert f"LIMIT {num_records}" in query + assert fields == ["Id"] + + +def test_random_generate_query_without_default_record_declaration(): + sobject = "Contact" # Assuming no declaration for this object + num_records = 3 + query, fields = random_generate_query(sobject, num_records) + + assert "WHERE" not in query # No WHERE clause should be present + assert f"LIMIT {num_records}" in query + assert fields == ["Id"] + + +# Test Cases for random_post_process +def test_random_post_process_with_records(): + records = [["001"], ["002"], ["003"]] + num_records = 3 + sobject = "Contact" + selected_records, error_message = random_post_process(records, num_records, sobject) + + assert error_message is None + assert len(selected_records) == num_records + assert all(record["success"] for record in selected_records) + assert all(record["created"] is False for record in selected_records) + assert all(record["id"] in ["001", "002", "003"] for record in selected_records) + + +def test_random_post_process_with_fewer_records(): + records = [["001"]] + num_records = 3 + sobject = "Opportunity" + selected_records, error_message = random_post_process(records, num_records, sobject) + + assert error_message is None + assert len(selected_records) == num_records + assert all(record["success"] for record in selected_records) + assert all(record["created"] is False for record in selected_records) + # Check if records are repeated to match num_records + assert selected_records.count({"id": "001", "success": True, "created": False}) == 3 + + +def test_random_post_process_with_no_records(): + records = [] + num_records = 2 + sobject = "Lead" + selected_records, error_message = random_post_process(records, num_records, sobject) + + assert selected_records == [] + assert error_message == f"No records found for {sobject} in the target org." diff --git a/cumulusci/tasks/bulkdata/tests/test_step.py b/cumulusci/tasks/bulkdata/tests/test_step.py index fc8cea7013..6459edb6d0 100644 --- a/cumulusci/tasks/bulkdata/tests/test_step.py +++ b/cumulusci/tasks/bulkdata/tests/test_step.py @@ -7,6 +7,7 @@ from cumulusci.core.exceptions import BulkDataException from cumulusci.tasks.bulkdata.load import LoadData +from cumulusci.tasks.bulkdata.select_utils import SelectStrategy from cumulusci.tasks.bulkdata.step import ( BulkApiDmlOperation, BulkApiQueryOperation, @@ -534,6 +535,104 @@ def test_get_prev_record_values(self): ) step.bulk.get_all_results_for_query_batch.assert_called_once_with("BATCH_ID") + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_random_strategy_success(self, download_mock): + # Set up mock context and BulkApiDmlOperation + context = mock.Mock() + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=context, + fields=["LastName"], + selection_strategy=SelectStrategy.RANDOM, + ) + + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + # Mock the downloaded CSV content with a single record + download_mock.return_value = io.StringIO( + """Id +003000000000001""" + ) + + # Mock the _wait_for_job method to simulate a successful job + step._wait_for_job = mock.Mock() + step._wait_for_job.return_value = DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 0, 0 + ) + + # Prepare input records + records = iter([["Test1"], ["Test2"], ["Test3"]]) + + # Execute the select_records operation + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id=["003000000000001"], success=True, error=None, created=False + ) + ) + == 3 + ) + + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_random_strategy_failure__no_records(self, download_mock): + # Set up mock context and BulkApiDmlOperation + context = mock.Mock() + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=context, + fields=["LastName"], + selection_strategy=SelectStrategy.RANDOM, + ) + + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + # Mock the downloaded CSV content indicating no records found + download_mock.return_value = io.StringIO("""Records not found for this query""") + + # Mock the _wait_for_job method to simulate a successful job + step._wait_for_job = mock.Mock() + step._wait_for_job.return_value = DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 0, 0 + ) + + # Prepare input records + records = iter([["Test1"], ["Test2"], ["Test3"]]) + + # Execute the select_records operation + step.start() + step.select_records(records) + step.end() + + # Get the job result and assert its properties for failure scenario + job_result = step.job_result + assert job_result.status == DataOperationStatus.JOB_FAILURE + assert ( + job_result.job_errors[0] + == "No records found for Contact in the target org." + ) + assert job_result.records_processed == 0 + assert job_result.total_row_errors == 0 + def test_batch(self): context = mock.Mock() @@ -879,6 +978,216 @@ def test_get_prev_record_values(self): ) assert set(relevant_fields) == set(expected_relevant_fields) + @responses.activate + def test_select_records_random_strategy_success(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.UPSERT, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=task, + fields=["LastName"], + ) + + results = { + "records": [ + {"Id": "003000000000001"}, + ], + "done": True, + } + step.sf.query = mock.Mock() + step.sf.query.return_value = results + records = iter([["Test1"], ["Test2"], ["Test3"]]) + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 3 + ) + + @responses.activate + def test_select_records_random_strategy_success__pagination(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.UPSERT, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=task, + fields=["LastName"], + ) + + results = { + "records": [ + {"Id": "003000000000001"}, + ], + "done": False, + "nextRecordsUrl": "https://example.com", + } + results_more = { + "records": [ + {"Id": "003000000000002"}, + {"Id": "003000000000003"}, + ], + "done": True, + } + step.sf.query = mock.Mock() + step.sf.query.return_value = results + step.sf.query_more = mock.Mock() + step.sf.query_more.return_value = results_more + records = iter([["Test1"], ["Test2"], ["Test3"]]) + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000002", success=True, error="", created=False + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000003", success=True, error="", created=False + ) + ) + == 1 + ) + + @responses.activate + def test_select_records_random_strategy_failure__no_records(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.UPSERT, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=task, + fields=["LastName"], + ) + + results = {"records": [], "done": True} + step.sf.query = mock.Mock() + step.sf.query.return_value = results + records = iter([["Test1"], ["Test2"], ["Test3"]]) + step.start() + step.select_records(records) + step.end() + + # Get the job result and assert its properties for failure scenario + job_result = step.job_result + assert job_result.status == DataOperationStatus.JOB_FAILURE + assert ( + job_result.job_errors[0] + == "No records found for Contact in the target org." + ) + assert job_result.records_processed == 0 + assert job_result.total_row_errors == 0 + @responses.activate def test_insert_dml_operation__boolean_conversion(self): mock_describe_calls() diff --git a/cumulusci/tasks/bulkdata/tests/utils.py b/cumulusci/tasks/bulkdata/tests/utils.py index 173f4c6122..c0db0f9515 100644 --- a/cumulusci/tasks/bulkdata/tests/utils.py +++ b/cumulusci/tasks/bulkdata/tests/utils.py @@ -98,6 +98,9 @@ def get_prev_record_values(self, records): def load_records(self, records): self.records.extend(records) + def select_records(self, records): + pass + def get_results(self): return iter(self.results) diff --git a/datasets/mapping.yml b/datasets/mapping.yml index 838b8b4597..ae7952b22c 100644 --- a/datasets/mapping.yml +++ b/datasets/mapping.yml @@ -1,7 +1,6 @@ Account: sf_object: Account api: bulk - action: select fields: - Name - Description From e230159ffefca688d881d9270a138ce7bda19b39 Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Tue, 20 Aug 2024 00:25:30 +0530 Subject: [PATCH 05/34] Undo load changes for select record functioanlity --- cumulusci/tasks/bulkdata/load.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/cumulusci/tasks/bulkdata/load.py b/cumulusci/tasks/bulkdata/load.py index d6adf1395a..4ae0dcf31a 100644 --- a/cumulusci/tasks/bulkdata/load.py +++ b/cumulusci/tasks/bulkdata/load.py @@ -289,12 +289,7 @@ def _execute_step( self, step, self._stream_queried_data(mapping, local_ids, query) ) step.start() - if mapping.action == DataOperationType.SELECT: - step.select_records( - self._stream_queried_data(mapping, local_ids, query) - ) - else: - step.load_records(self._stream_queried_data(mapping, local_ids, query)) + step.load_records(self._stream_queried_data(mapping, local_ids, query)) step.end() # Process Job Results @@ -341,8 +336,6 @@ def configure_step(self, mapping): self.check_simple_upsert(mapping) api_options["update_key"] = mapping.update_key[0] action = DataOperationType.UPSERT - elif mapping.action == DataOperationType.SELECT: - action = DataOperationType.QUERY else: action = mapping.action @@ -488,11 +481,10 @@ def _process_job_results(self, mapping, step, local_ids): """Get the job results and process the results. If we're raising for row-level errors, do so; if we're inserting, store the new Ids.""" - is_insert_upsert_or_select = mapping.action in ( + is_insert_or_upsert = mapping.action in ( DataOperationType.INSERT, DataOperationType.UPSERT, DataOperationType.ETL_UPSERT, - DataOperationType.SELECT, ) conn = self.session.connection() @@ -508,7 +500,7 @@ def _process_job_results(self, mapping, step, local_ids): break # If we know we have no successful inserts, don't attempt to persist Ids. # Do, however, drain the generator to get error-checking behavior. - if is_insert_upsert_or_select and ( + if is_insert_or_upsert and ( step.job_result.records_processed - step.job_result.total_row_errors ): table = self.metadata.tables[self.ID_TABLE_NAME] @@ -524,7 +516,7 @@ def _process_job_results(self, mapping, step, local_ids): # person account Contact records so lookups to # person account Contact records get populated downstream as expected. if ( - is_insert_upsert_or_select + is_insert_or_upsert and mapping.sf_object == "Contact" and self._can_load_person_accounts(mapping) ): @@ -539,7 +531,7 @@ def _process_job_results(self, mapping, step, local_ids): ), ) - if is_insert_upsert_or_select: + if is_insert_or_upsert: self.session.commit() def _generate_results_id_map(self, step, local_ids): From b15945203dcd185bb07a294d439fe0c3ed5669eb Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Mon, 19 Aug 2024 11:10:28 +0530 Subject: [PATCH 06/34] Core Logic for Selecting Records from Target Org --- .../extract_dataset_utils/extract_yml.py | 9 +- cumulusci/tasks/bulkdata/select_utils.py | 49 +++ cumulusci/tasks/bulkdata/step.py | 203 +++++++++++- .../tasks/bulkdata/tests/test_select_utils.py | 63 ++++ cumulusci/tasks/bulkdata/tests/test_step.py | 309 ++++++++++++++++++ cumulusci/tasks/bulkdata/tests/utils.py | 3 + 6 files changed, 620 insertions(+), 16 deletions(-) create mode 100644 cumulusci/tasks/bulkdata/select_utils.py create mode 100644 cumulusci/tasks/bulkdata/tests/test_select_utils.py diff --git a/cumulusci/tasks/bulkdata/extract_dataset_utils/extract_yml.py b/cumulusci/tasks/bulkdata/extract_dataset_utils/extract_yml.py index 95d6b9ff97..9679da5a1a 100644 --- a/cumulusci/tasks/bulkdata/extract_dataset_utils/extract_yml.py +++ b/cumulusci/tasks/bulkdata/extract_dataset_utils/extract_yml.py @@ -5,7 +5,6 @@ from pydantic import Field, validator from cumulusci.core.enums import StrEnum -from cumulusci.tasks.bulkdata.step import DataApi from cumulusci.utils.yaml.model_parser import CCIDictModel, HashableBaseModel object_decl = re.compile(r"objects\((\w+)\)", re.IGNORECASE) @@ -25,6 +24,14 @@ class SFFieldGroupTypes(StrEnum): required = "required" +class DataApi(StrEnum): + """Enum defining requested Salesforce data API for an operation.""" + + BULK = "bulk" + REST = "rest" + SMART = "smart" + + class ExtractDeclaration(HashableBaseModel): where: T.Optional[str] = None fields_: T.Union[T.List[str], str] = Field(["FIELDS(ALL)"], alias="fields") diff --git a/cumulusci/tasks/bulkdata/select_utils.py b/cumulusci/tasks/bulkdata/select_utils.py new file mode 100644 index 0000000000..3521fa3c8e --- /dev/null +++ b/cumulusci/tasks/bulkdata/select_utils.py @@ -0,0 +1,49 @@ +from cumulusci.core.enums import StrEnum +from cumulusci.tasks.bulkdata.extract_dataset_utils.hardcoded_default_declarations import ( + DEFAULT_DECLARATIONS, +) + + +class SelectStrategy(StrEnum): + """Enum defining the different selection strategies requested.""" + + RANDOM = "random" + + +def random_generate_query(sobject: str, num_records: float): + """Generates the SOQL query for the random selection strategy""" + # Get the WHERE clause from DEFAULT_DECLARATIONS if available + declaration = DEFAULT_DECLARATIONS.get(sobject) + if declaration: + where_clause = declaration.where + else: + where_clause = None + # Construct the query with the WHERE clause (if it exists) + query = f"SELECT Id FROM {sobject}" + if where_clause: + query += f" WHERE {where_clause}" + query += f" LIMIT {num_records}" + + return query, ["Id"] + + +def random_post_process(records, num_records: float, sobject: str): + """Processes the query results for the random selection strategy""" + # Handle case where query returns 0 records + if not records: + error_message = f"No records found for {sobject} in the target org." + return [], error_message + + # Add 'success: True' to each record to emulate records have been inserted + selected_records = [ + {"id": record[0], "success": True, "created": False} for record in records + ] + + # If fewer records than requested, repeat existing records to match num_records + if len(selected_records) < num_records: + original_records = selected_records.copy() + while len(selected_records) < num_records: + selected_records.extend(original_records) + selected_records = selected_records[:num_records] + + return selected_records, None # Return selected records and None for error diff --git a/cumulusci/tasks/bulkdata/step.py b/cumulusci/tasks/bulkdata/step.py index edcb62afbb..1844c4caeb 100644 --- a/cumulusci/tasks/bulkdata/step.py +++ b/cumulusci/tasks/bulkdata/step.py @@ -15,6 +15,11 @@ from cumulusci.core.enums import StrEnum from cumulusci.core.exceptions import BulkDataException from cumulusci.core.utils import process_bool_arg +from cumulusci.tasks.bulkdata.select_utils import ( + SelectStrategy, + random_generate_query, + random_post_process, +) from cumulusci.tasks.bulkdata.utils import iterate_in_chunks from cumulusci.utils.classutils import namedtuple_as_simple_dict from cumulusci.utils.xml import lxml_parse_string @@ -36,6 +41,7 @@ class DataOperationType(StrEnum): UPSERT = "upsert" ETL_UPSERT = "etl_upsert" SMART_UPSERT = "smart_upsert" # currently undocumented + SELECT = "select" class DataApi(StrEnum): @@ -320,6 +326,11 @@ def get_prev_record_values(self, records): """Get the previous records values in case of UPSERT and UPDATE to prepare for rollback""" pass + @abstractmethod + def select_records(self, records): + """Perform the requested DML operation on the supplied row iterator.""" + pass + @abstractmethod def load_records(self, records): """Perform the requested DML operation on the supplied row iterator.""" @@ -338,7 +349,16 @@ def get_results(self): class BulkApiDmlOperation(BaseDmlOperation, BulkJobMixin): """Operation class for all DML operations run using the Bulk API.""" - def __init__(self, *, sobject, operation, api_options, context, fields): + def __init__( + self, + *, + sobject, + operation, + api_options, + context, + fields, + selection_strategy=SelectStrategy.RANDOM, + ): super().__init__( sobject=sobject, operation=operation, @@ -353,6 +373,10 @@ def __init__(self, *, sobject, operation, api_options, context, fields): self.csv_buff = io.StringIO(newline="") self.csv_writer = csv.writer(self.csv_buff, quoting=csv.QUOTE_ALL) + if selection_strategy is SelectStrategy.RANDOM: + self.select_generate_query = random_generate_query + self.select_post_process = random_post_process + def start(self): self.job_id = self.bulk.create_job( self.sobject, @@ -364,7 +388,8 @@ def start(self): def end(self): self.bulk.close_job(self.job_id) - self.job_result = self._wait_for_job(self.job_id) + if not self.job_result: + self.job_result = self._wait_for_job(self.job_id) def get_prev_record_values(self, records): """Get the previous values of the records based on the update key @@ -424,6 +449,62 @@ def load_records(self, records): self.context.logger.info(f"Uploading batch {count + 1}") self.batch_ids.append(self.bulk.post_batch(self.job_id, iter(csv_batch))) + def select_records(self, records): + """Executes a SOQL query to select records and adds them to results""" + + self.select_results = [] # Store selected records + + # Count total number of records to fetch + total_num_records = sum(1 for _ in records) + + # Process in batches based on batch_size from api_options + for offset in range( + 0, total_num_records, self.api_options.get("batch_size", 500) + ): + # Calculate number of records to fetch in this batch + num_records = min( + self.api_options.get("batch_size", 500), total_num_records - offset + ) + + # Generate and execute SOQL query + query, query_fields = self.select_generate_query(self.sobject, num_records) + self.batch_id = self.bulk.query(self.job_id, query) + self._wait_for_job(self.job_id) + + # Get and process query results + result_ids = self.bulk.get_query_batch_result_ids( + self.batch_id, job_id=self.job_id + ) + query_records = [] + for result_id in result_ids: + uri = f"{self.bulk.endpoint}/job/{self.job_id}/batch/{self.batch_id}/result/{result_id}" + with download_file(uri, self.bulk) as f: + reader = csv.reader(f) + self.headers = next(reader) + if "Records not found for this query" in self.headers: + break # Stop if no records found + for row in reader: + query_records.append([row[: len(query_fields)]]) + + # Post-process the query results + selected_records, error_message = self.select_post_process( + query_records, num_records, self.sobject + ) + if error_message: + break # Stop if there's an error during post-processing + + self.select_results.extend(selected_records) + + # Update job result based on selection outcome + self.job_result = DataOperationJobResult( + DataOperationStatus.SUCCESS + if len(self.select_results) + else DataOperationStatus.JOB_FAILURE, + [error_message] if error_message else [], + len(self.select_results), + 0, + ) + def _batch(self, records, n, char_limit=10000000): """Given an iterator of records, yields batches of records serialized in .csv format. @@ -472,6 +553,29 @@ def _serialize_csv_record(self, record): return serialized def get_results(self): + """ + Retrieves and processes the results of a Bulk API operation. + """ + + if self.operation is DataOperationType.QUERY: + yield from self._get_query_results() + else: + yield from self._get_batch_results() + + def _get_query_results(self): + """Handles results for QUERY (select) operations""" + for row in self.select_results: + success = process_bool_arg(row["success"]) + created = process_bool_arg(row["created"]) + yield DataOperationResult( + row["id"] if success else None, + success, + None, + created, + ) + + def _get_batch_results(self): + """Handles results for other DataOperationTypes (insert, update, etc.)""" for batch_id in self.batch_ids: try: results_url = ( @@ -481,29 +585,42 @@ def get_results(self): # to avoid the server dropping connections with download_file(results_url, self.bulk) as f: self.logger.info(f"Downloaded results for batch {batch_id}") + yield from self._parse_batch_results(f) - reader = csv.reader(f) - next(reader) # skip header - - for row in reader: - success = process_bool_arg(row[1]) - created = process_bool_arg(row[2]) - yield DataOperationResult( - row[0] if success else None, - success, - row[3] if not success else None, - created, - ) except Exception as e: raise BulkDataException( f"Failed to download results for batch {batch_id} ({str(e)})" ) + def _parse_batch_results(self, f): + """Parses batch results from the downloaded file""" + reader = csv.reader(f) + next(reader) # Skip header row + + for row in reader: + success = process_bool_arg(row[1]) + created = process_bool_arg(row[2]) + yield DataOperationResult( + row[0] if success else None, + success, + row[3] if not success else None, + created, + ) + class RestApiDmlOperation(BaseDmlOperation): """Operation class for all DML operations run using the REST API.""" - def __init__(self, *, sobject, operation, api_options, context, fields): + def __init__( + self, + *, + sobject, + operation, + api_options, + context, + fields, + selection_strategy=SelectStrategy.RANDOM, + ): super().__init__( sobject=sobject, operation=operation, @@ -525,6 +642,9 @@ def __init__(self, *, sobject, operation, api_options, context, fields): self.api_options["batch_size"] = min( self.api_options["batch_size"], MAX_REST_BATCH_SIZE ) + if selection_strategy is SelectStrategy.RANDOM: + self.select_generate_query = random_generate_query + self.select_post_process = random_post_process def _record_to_json(self, rec): result = dict(zip(self.fields, rec)) @@ -631,6 +751,59 @@ def load_records(self, records): row_errors, ) + def select_records(self, records): + """Executes a SOQL query to select records and adds them to results""" + + def convert(rec, fields): + """Helper function to convert record values to strings, handling None values""" + return [str(rec[f]) if rec[f] is not None else "" for f in fields] + + self.results = [] + # Count the number of records to fetch + total_num_records = sum(1 for _ in records) + + # Process in batches + for offset in range(0, total_num_records, self.api_options.get("batch_size")): + num_records = min( + self.api_options.get("batch_size"), total_num_records - offset + ) + # Generate the SOQL query with and LIMIT + query, query_fields = self.select_generate_query(self.sobject, num_records) + + # Execute the query and extract results + response = self.sf.query(query) + # Extract and convert 'Id' fields from the query results + query_records = list( + convert(rec, query_fields) for rec in response["records"] + ) + # Handle pagination if there are more records within this batch + while not response["done"]: + response = self.sf.query_more( + response["nextRecordsUrl"], identifier_is_url=True + ) + query_records.extend( + list(convert(rec, query_fields) for rec in response["records"]) + ) + + # Post-process the query results for this batch + selected_records, error_message = self.select_post_process( + query_records, num_records, self.sobject + ) + if error_message: + break + # Add selected records from this batch to the overall results + self.results.extend(selected_records) + + # Update the job result based on the overall selection outcome + self.job_result = DataOperationJobResult( + DataOperationStatus.SUCCESS + if len(self.results) # Check the overall results length + else DataOperationStatus.JOB_FAILURE, + [error_message] if error_message else [], + len(self.results), + 0, + ) + def get_results(self): """Return a generator of DataOperationResult objects.""" diff --git a/cumulusci/tasks/bulkdata/tests/test_select_utils.py b/cumulusci/tasks/bulkdata/tests/test_select_utils.py new file mode 100644 index 0000000000..c649871217 --- /dev/null +++ b/cumulusci/tasks/bulkdata/tests/test_select_utils.py @@ -0,0 +1,63 @@ +from cumulusci.tasks.bulkdata.select_utils import ( + random_generate_query, + random_post_process, +) + + +# Test Cases for random_generate_query +def test_random_generate_query_with_default_record_declaration(): + sobject = "Account" # Assuming Account has a declaration in DEFAULT_DECLARATIONS + num_records = 5 + query, fields = random_generate_query(sobject, num_records) + + assert "WHERE" in query # Ensure WHERE clause is included + assert f"LIMIT {num_records}" in query + assert fields == ["Id"] + + +def test_random_generate_query_without_default_record_declaration(): + sobject = "Contact" # Assuming no declaration for this object + num_records = 3 + query, fields = random_generate_query(sobject, num_records) + + assert "WHERE" not in query # No WHERE clause should be present + assert f"LIMIT {num_records}" in query + assert fields == ["Id"] + + +# Test Cases for random_post_process +def test_random_post_process_with_records(): + records = [["001"], ["002"], ["003"]] + num_records = 3 + sobject = "Contact" + selected_records, error_message = random_post_process(records, num_records, sobject) + + assert error_message is None + assert len(selected_records) == num_records + assert all(record["success"] for record in selected_records) + assert all(record["created"] is False for record in selected_records) + assert all(record["id"] in ["001", "002", "003"] for record in selected_records) + + +def test_random_post_process_with_fewer_records(): + records = [["001"]] + num_records = 3 + sobject = "Opportunity" + selected_records, error_message = random_post_process(records, num_records, sobject) + + assert error_message is None + assert len(selected_records) == num_records + assert all(record["success"] for record in selected_records) + assert all(record["created"] is False for record in selected_records) + # Check if records are repeated to match num_records + assert selected_records.count({"id": "001", "success": True, "created": False}) == 3 + + +def test_random_post_process_with_no_records(): + records = [] + num_records = 2 + sobject = "Lead" + selected_records, error_message = random_post_process(records, num_records, sobject) + + assert selected_records == [] + assert error_message == f"No records found for {sobject} in the target org." diff --git a/cumulusci/tasks/bulkdata/tests/test_step.py b/cumulusci/tasks/bulkdata/tests/test_step.py index fc8cea7013..6459edb6d0 100644 --- a/cumulusci/tasks/bulkdata/tests/test_step.py +++ b/cumulusci/tasks/bulkdata/tests/test_step.py @@ -7,6 +7,7 @@ from cumulusci.core.exceptions import BulkDataException from cumulusci.tasks.bulkdata.load import LoadData +from cumulusci.tasks.bulkdata.select_utils import SelectStrategy from cumulusci.tasks.bulkdata.step import ( BulkApiDmlOperation, BulkApiQueryOperation, @@ -534,6 +535,104 @@ def test_get_prev_record_values(self): ) step.bulk.get_all_results_for_query_batch.assert_called_once_with("BATCH_ID") + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_random_strategy_success(self, download_mock): + # Set up mock context and BulkApiDmlOperation + context = mock.Mock() + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=context, + fields=["LastName"], + selection_strategy=SelectStrategy.RANDOM, + ) + + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + # Mock the downloaded CSV content with a single record + download_mock.return_value = io.StringIO( + """Id +003000000000001""" + ) + + # Mock the _wait_for_job method to simulate a successful job + step._wait_for_job = mock.Mock() + step._wait_for_job.return_value = DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 0, 0 + ) + + # Prepare input records + records = iter([["Test1"], ["Test2"], ["Test3"]]) + + # Execute the select_records operation + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id=["003000000000001"], success=True, error=None, created=False + ) + ) + == 3 + ) + + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_random_strategy_failure__no_records(self, download_mock): + # Set up mock context and BulkApiDmlOperation + context = mock.Mock() + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=context, + fields=["LastName"], + selection_strategy=SelectStrategy.RANDOM, + ) + + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + # Mock the downloaded CSV content indicating no records found + download_mock.return_value = io.StringIO("""Records not found for this query""") + + # Mock the _wait_for_job method to simulate a successful job + step._wait_for_job = mock.Mock() + step._wait_for_job.return_value = DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 0, 0 + ) + + # Prepare input records + records = iter([["Test1"], ["Test2"], ["Test3"]]) + + # Execute the select_records operation + step.start() + step.select_records(records) + step.end() + + # Get the job result and assert its properties for failure scenario + job_result = step.job_result + assert job_result.status == DataOperationStatus.JOB_FAILURE + assert ( + job_result.job_errors[0] + == "No records found for Contact in the target org." + ) + assert job_result.records_processed == 0 + assert job_result.total_row_errors == 0 + def test_batch(self): context = mock.Mock() @@ -879,6 +978,216 @@ def test_get_prev_record_values(self): ) assert set(relevant_fields) == set(expected_relevant_fields) + @responses.activate + def test_select_records_random_strategy_success(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.UPSERT, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=task, + fields=["LastName"], + ) + + results = { + "records": [ + {"Id": "003000000000001"}, + ], + "done": True, + } + step.sf.query = mock.Mock() + step.sf.query.return_value = results + records = iter([["Test1"], ["Test2"], ["Test3"]]) + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 3 + ) + + @responses.activate + def test_select_records_random_strategy_success__pagination(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.UPSERT, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=task, + fields=["LastName"], + ) + + results = { + "records": [ + {"Id": "003000000000001"}, + ], + "done": False, + "nextRecordsUrl": "https://example.com", + } + results_more = { + "records": [ + {"Id": "003000000000002"}, + {"Id": "003000000000003"}, + ], + "done": True, + } + step.sf.query = mock.Mock() + step.sf.query.return_value = results + step.sf.query_more = mock.Mock() + step.sf.query_more.return_value = results_more + records = iter([["Test1"], ["Test2"], ["Test3"]]) + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000002", success=True, error="", created=False + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000003", success=True, error="", created=False + ) + ) + == 1 + ) + + @responses.activate + def test_select_records_random_strategy_failure__no_records(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.UPSERT, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=task, + fields=["LastName"], + ) + + results = {"records": [], "done": True} + step.sf.query = mock.Mock() + step.sf.query.return_value = results + records = iter([["Test1"], ["Test2"], ["Test3"]]) + step.start() + step.select_records(records) + step.end() + + # Get the job result and assert its properties for failure scenario + job_result = step.job_result + assert job_result.status == DataOperationStatus.JOB_FAILURE + assert ( + job_result.job_errors[0] + == "No records found for Contact in the target org." + ) + assert job_result.records_processed == 0 + assert job_result.total_row_errors == 0 + @responses.activate def test_insert_dml_operation__boolean_conversion(self): mock_describe_calls() diff --git a/cumulusci/tasks/bulkdata/tests/utils.py b/cumulusci/tasks/bulkdata/tests/utils.py index 173f4c6122..c0db0f9515 100644 --- a/cumulusci/tasks/bulkdata/tests/utils.py +++ b/cumulusci/tasks/bulkdata/tests/utils.py @@ -98,6 +98,9 @@ def get_prev_record_values(self, records): def load_records(self, records): self.records.extend(records) + def select_records(self, records): + pass + def get_results(self): return iter(self.results) From b11abc30898d0d68167cf22a7f0981d6f31de436 Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Tue, 20 Aug 2024 14:12:30 +0530 Subject: [PATCH 07/34] Refactor select utility file and generalize arguments --- cumulusci/tasks/bulkdata/select_utils.py | 14 ++++++++++---- cumulusci/tasks/bulkdata/step.py | 17 +++++++++++------ .../tasks/bulkdata/tests/test_select_utils.py | 12 +++++++++--- 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/cumulusci/tasks/bulkdata/select_utils.py b/cumulusci/tasks/bulkdata/select_utils.py index 3521fa3c8e..48bac23578 100644 --- a/cumulusci/tasks/bulkdata/select_utils.py +++ b/cumulusci/tasks/bulkdata/select_utils.py @@ -1,3 +1,5 @@ +import typing as T + from cumulusci.core.enums import StrEnum from cumulusci.tasks.bulkdata.extract_dataset_utils.hardcoded_default_declarations import ( DEFAULT_DECLARATIONS, @@ -10,7 +12,9 @@ class SelectStrategy(StrEnum): RANDOM = "random" -def random_generate_query(sobject: str, num_records: float): +def random_generate_query( + sobject: str, num_records: float +) -> T.Tuple[str, T.List[str]]: """Generates the SOQL query for the random selection strategy""" # Get the WHERE clause from DEFAULT_DECLARATIONS if available declaration = DEFAULT_DECLARATIONS.get(sobject) @@ -27,16 +31,18 @@ def random_generate_query(sobject: str, num_records: float): return query, ["Id"] -def random_post_process(records, num_records: float, sobject: str): +def random_post_process( + load_records, query_records: list, num_records: float, sobject: str +) -> T.Tuple[T.List[dict], T.Union[str, None]]: """Processes the query results for the random selection strategy""" # Handle case where query returns 0 records - if not records: + if not query_records: error_message = f"No records found for {sobject} in the target org." return [], error_message # Add 'success: True' to each record to emulate records have been inserted selected_records = [ - {"id": record[0], "success": True, "created": False} for record in records + {"id": record[0], "success": True, "created": False} for record in query_records ] # If fewer records than requested, repeat existing records to match num_records diff --git a/cumulusci/tasks/bulkdata/step.py b/cumulusci/tasks/bulkdata/step.py index 1844c4caeb..1fe0cc80d4 100644 --- a/cumulusci/tasks/bulkdata/step.py +++ b/cumulusci/tasks/bulkdata/step.py @@ -7,6 +7,7 @@ import time from abc import ABCMeta, abstractmethod from contextlib import contextmanager +from itertools import tee from typing import Any, Dict, List, NamedTuple, Optional import requests @@ -454,8 +455,10 @@ def select_records(self, records): self.select_results = [] # Store selected records - # Count total number of records to fetch - total_num_records = sum(1 for _ in records) + # Create a copy of the generator using tee + records, records_copy = tee(records) + # Count total number of records to fetch using the copy + total_num_records = sum(1 for _ in records_copy) # Process in batches based on batch_size from api_options for offset in range( @@ -488,7 +491,7 @@ def select_records(self, records): # Post-process the query results selected_records, error_message = self.select_post_process( - query_records, num_records, self.sobject + records, query_records, num_records, self.sobject ) if error_message: break # Stop if there's an error during post-processing @@ -759,8 +762,10 @@ def convert(rec, fields): return [str(rec[f]) if rec[f] is not None else "" for f in fields] self.results = [] - # Count the number of records to fetch - total_num_records = sum(1 for _ in records) + # Create a copy of the generator using tee + records, records_copy = tee(records) + # Count total number of records to fetch using the copy + total_num_records = sum(1 for _ in records_copy) # Process in batches for offset in range(0, total_num_records, self.api_options.get("batch_size")): @@ -787,7 +792,7 @@ def convert(rec, fields): # Post-process the query results for this batch selected_records, error_message = self.select_post_process( - query_records, num_records, self.sobject + records, query_records, num_records, self.sobject ) if error_message: break diff --git a/cumulusci/tasks/bulkdata/tests/test_select_utils.py b/cumulusci/tasks/bulkdata/tests/test_select_utils.py index c649871217..43c39d63bd 100644 --- a/cumulusci/tasks/bulkdata/tests/test_select_utils.py +++ b/cumulusci/tasks/bulkdata/tests/test_select_utils.py @@ -30,7 +30,9 @@ def test_random_post_process_with_records(): records = [["001"], ["002"], ["003"]] num_records = 3 sobject = "Contact" - selected_records, error_message = random_post_process(records, num_records, sobject) + selected_records, error_message = random_post_process( + None, records, num_records, sobject + ) assert error_message is None assert len(selected_records) == num_records @@ -43,7 +45,9 @@ def test_random_post_process_with_fewer_records(): records = [["001"]] num_records = 3 sobject = "Opportunity" - selected_records, error_message = random_post_process(records, num_records, sobject) + selected_records, error_message = random_post_process( + None, records, num_records, sobject + ) assert error_message is None assert len(selected_records) == num_records @@ -57,7 +61,9 @@ def test_random_post_process_with_no_records(): records = [] num_records = 2 sobject = "Lead" - selected_records, error_message = random_post_process(records, num_records, sobject) + selected_records, error_message = random_post_process( + None, records, num_records, sobject + ) assert selected_records == [] assert error_message == f"No records found for {sobject} in the target org." From a8b1aedd12135f860720fbb19529cf0b251bdf9f Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Tue, 20 Aug 2024 16:21:19 +0530 Subject: [PATCH 08/34] Add fields argument to select query utility function --- cumulusci/tasks/bulkdata/select_utils.py | 2 +- cumulusci/tasks/bulkdata/step.py | 14 +++++++++----- .../tasks/bulkdata/tests/test_select_utils.py | 4 ++-- cumulusci/tasks/bulkdata/tests/test_step.py | 2 +- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/cumulusci/tasks/bulkdata/select_utils.py b/cumulusci/tasks/bulkdata/select_utils.py index 48bac23578..808d2f8a2a 100644 --- a/cumulusci/tasks/bulkdata/select_utils.py +++ b/cumulusci/tasks/bulkdata/select_utils.py @@ -13,7 +13,7 @@ class SelectStrategy(StrEnum): def random_generate_query( - sobject: str, num_records: float + sobject: str, fields: T.List[str], num_records: float ) -> T.Tuple[str, T.List[str]]: """Generates the SOQL query for the random selection strategy""" # Get the WHERE clause from DEFAULT_DECLARATIONS if available diff --git a/cumulusci/tasks/bulkdata/step.py b/cumulusci/tasks/bulkdata/step.py index 1fe0cc80d4..14f37db181 100644 --- a/cumulusci/tasks/bulkdata/step.py +++ b/cumulusci/tasks/bulkdata/step.py @@ -470,7 +470,9 @@ def select_records(self, records): ) # Generate and execute SOQL query - query, query_fields = self.select_generate_query(self.sobject, num_records) + query, query_fields = self.select_generate_query( + self.sobject, self.fields, num_records + ) self.batch_id = self.bulk.query(self.job_id, query) self._wait_for_job(self.job_id) @@ -487,7 +489,7 @@ def select_records(self, records): if "Records not found for this query" in self.headers: break # Stop if no records found for row in reader: - query_records.append([row[: len(query_fields)]]) + query_records.append(row[: len(query_fields)]) # Post-process the query results selected_records, error_message = self.select_post_process( @@ -571,9 +573,9 @@ def _get_query_results(self): success = process_bool_arg(row["success"]) created = process_bool_arg(row["created"]) yield DataOperationResult( - row["id"] if success else None, + row["id"] if success else "", success, - None, + "", created, ) @@ -773,7 +775,9 @@ def convert(rec, fields): self.api_options.get("batch_size"), total_num_records - offset ) # Generate the SOQL query with and LIMIT - query, query_fields = self.select_generate_query(self.sobject, num_records) + query, query_fields = self.select_generate_query( + self.sobject, self.fields, num_records + ) # Execute the query and extract results response = self.sf.query(query) diff --git a/cumulusci/tasks/bulkdata/tests/test_select_utils.py b/cumulusci/tasks/bulkdata/tests/test_select_utils.py index 43c39d63bd..29abc845e7 100644 --- a/cumulusci/tasks/bulkdata/tests/test_select_utils.py +++ b/cumulusci/tasks/bulkdata/tests/test_select_utils.py @@ -8,7 +8,7 @@ def test_random_generate_query_with_default_record_declaration(): sobject = "Account" # Assuming Account has a declaration in DEFAULT_DECLARATIONS num_records = 5 - query, fields = random_generate_query(sobject, num_records) + query, fields = random_generate_query(sobject, [], num_records) assert "WHERE" in query # Ensure WHERE clause is included assert f"LIMIT {num_records}" in query @@ -18,7 +18,7 @@ def test_random_generate_query_with_default_record_declaration(): def test_random_generate_query_without_default_record_declaration(): sobject = "Contact" # Assuming no declaration for this object num_records = 3 - query, fields = random_generate_query(sobject, num_records) + query, fields = random_generate_query(sobject, [], num_records) assert "WHERE" not in query # No WHERE clause should be present assert f"LIMIT {num_records}" in query diff --git a/cumulusci/tasks/bulkdata/tests/test_step.py b/cumulusci/tasks/bulkdata/tests/test_step.py index 6459edb6d0..8f6f34ad90 100644 --- a/cumulusci/tasks/bulkdata/tests/test_step.py +++ b/cumulusci/tasks/bulkdata/tests/test_step.py @@ -581,7 +581,7 @@ def test_select_records_random_strategy_success(self, download_mock): assert ( results.count( DataOperationResult( - id=["003000000000001"], success=True, error=None, created=False + id="003000000000001", success=True, error="", created=False ) ) == 3 From 720d484018861626ff0249301464c7d3dbfb516b Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Tue, 20 Aug 2024 16:52:20 +0530 Subject: [PATCH 09/34] Move DataApi import to utils --- .../bulkdata/extract_dataset_utils/extract_yml.py | 9 +-------- cumulusci/tasks/bulkdata/step.py | 10 +--------- cumulusci/tasks/bulkdata/utils.py | 9 +++++++++ 3 files changed, 11 insertions(+), 17 deletions(-) diff --git a/cumulusci/tasks/bulkdata/extract_dataset_utils/extract_yml.py b/cumulusci/tasks/bulkdata/extract_dataset_utils/extract_yml.py index 9679da5a1a..cec42d0bd9 100644 --- a/cumulusci/tasks/bulkdata/extract_dataset_utils/extract_yml.py +++ b/cumulusci/tasks/bulkdata/extract_dataset_utils/extract_yml.py @@ -5,6 +5,7 @@ from pydantic import Field, validator from cumulusci.core.enums import StrEnum +from cumulusci.tasks.bulkdata.utils import DataApi from cumulusci.utils.yaml.model_parser import CCIDictModel, HashableBaseModel object_decl = re.compile(r"objects\((\w+)\)", re.IGNORECASE) @@ -24,14 +25,6 @@ class SFFieldGroupTypes(StrEnum): required = "required" -class DataApi(StrEnum): - """Enum defining requested Salesforce data API for an operation.""" - - BULK = "bulk" - REST = "rest" - SMART = "smart" - - class ExtractDeclaration(HashableBaseModel): where: T.Optional[str] = None fields_: T.Union[T.List[str], str] = Field(["FIELDS(ALL)"], alias="fields") diff --git a/cumulusci/tasks/bulkdata/step.py b/cumulusci/tasks/bulkdata/step.py index 14f37db181..4770cab37e 100644 --- a/cumulusci/tasks/bulkdata/step.py +++ b/cumulusci/tasks/bulkdata/step.py @@ -21,7 +21,7 @@ random_generate_query, random_post_process, ) -from cumulusci.tasks.bulkdata.utils import iterate_in_chunks +from cumulusci.tasks.bulkdata.utils import DataApi, iterate_in_chunks from cumulusci.utils.classutils import namedtuple_as_simple_dict from cumulusci.utils.xml import lxml_parse_string @@ -45,14 +45,6 @@ class DataOperationType(StrEnum): SELECT = "select" -class DataApi(StrEnum): - """Enum defining requested Salesforce data API for an operation.""" - - BULK = "bulk" - REST = "rest" - SMART = "smart" - - class DataOperationStatus(StrEnum): """Enum defining outcome values for a data operation.""" diff --git a/cumulusci/tasks/bulkdata/utils.py b/cumulusci/tasks/bulkdata/utils.py index 082277fb16..b5c195a817 100644 --- a/cumulusci/tasks/bulkdata/utils.py +++ b/cumulusci/tasks/bulkdata/utils.py @@ -10,10 +10,19 @@ from sqlalchemy.engine.base import Connection from sqlalchemy.orm import Session, mapper +from cumulusci.core.enums import StrEnum from cumulusci.core.exceptions import BulkDataException from cumulusci.utils.iterators import iterate_in_chunks +class DataApi(StrEnum): + """Enum defining requested Salesforce data API for an operation.""" + + BULK = "bulk" + REST = "rest" + SMART = "smart" + + class SqlAlchemyMixin: logger: logging.Logger metadata: MetaData From 2ade0c0d9a36e96e262ad8c6584d76f5b4801e71 Mon Sep 17 00:00:00 2001 From: Jawadtp Date: Fri, 23 Aug 2024 15:11:49 +0530 Subject: [PATCH 10/34] Add similarity select algorithm --- cumulusci/tasks/bulkdata/select_utils.py | 109 +++++++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git a/cumulusci/tasks/bulkdata/select_utils.py b/cumulusci/tasks/bulkdata/select_utils.py index 808d2f8a2a..4c40600b53 100644 --- a/cumulusci/tasks/bulkdata/select_utils.py +++ b/cumulusci/tasks/bulkdata/select_utils.py @@ -10,6 +10,7 @@ class SelectStrategy(StrEnum): """Enum defining the different selection strategies requested.""" RANDOM = "random" + SIMILARITY = "similarity" def random_generate_query( @@ -53,3 +54,111 @@ def random_post_process( selected_records = selected_records[:num_records] return selected_records, None # Return selected records and None for error + + +def similarity_generate_query( + sobject: str, + fields: T.List[str], + num_records: float, +) -> T.Tuple[str, T.List[str]]: + """Generates the SOQL query for the random selection strategy""" + # Get the WHERE clause from DEFAULT_DECLARATIONS if available + declaration = DEFAULT_DECLARATIONS.get(sobject) + if declaration: + where_clause = declaration.where + else: + where_clause = None + # Construct the query with the WHERE clause (if it exists) + + fields.insert(0, "Id") + fields_to_query = ", ".join(field for field in fields if field) + + query = f"SELECT {fields_to_query} FROM {sobject}" + if where_clause: + query += f" WHERE {where_clause}" + + return query, fields + + +def similarity_post_process( + load_records: list, query_records: list, num_records: float, sobject: str +) -> T.Tuple[T.List[dict], T.Union[str, None]]: + """Processes the query results for the similarity selection strategy""" + # Handle case where query returns 0 records + if not query_records: + error_message = f"No records found for {sobject} in the target org." + return [], error_message + + closest_records = [] + + for record in load_records: + closest_record = find_closest_record(record, query_records) + closest_records.append( + {"id": closest_record[0], "success": True, "created": False} + ) + + return closest_records, None + + +def find_closest_record(load_record: list, query_records: list): + closest_distance = float("inf") + closest_record = query_records[0] + + for record in query_records: + distance = calculate_levenshtein_distance(load_record, record[1:]) + if distance < closest_distance: + closest_distance = distance + closest_record = record + + return closest_record + + +def levenshtein_distance(str1: str, str2: str): + """Calculate the Levenshtein distance between two strings""" + len_str1 = len(str1) + 1 + len_str2 = len(str2) + 1 + + dp = [[0 for _ in range(len_str2)] for _ in range(len_str1)] + + for i in range(len_str1): + dp[i][0] = i + for j in range(len_str2): + dp[0][j] = j + + for i in range(1, len_str1): + for j in range(1, len_str2): + cost = 0 if str1[i - 1] == str2[j - 1] else 1 + dp[i][j] = min( + dp[i - 1][j] + 1, # Deletion + dp[i][j - 1] + 1, # Insertion + dp[i - 1][j - 1] + cost, + ) # Substitution + + return dp[-1][-1] + + +def calculate_levenshtein_distance(record1: list, record2: list): + if len(record1) != len(record2): + raise ValueError("Records must have the same number of fields.") + + total_distance = 0 + total_fields = 0 + + for field1, field2 in zip(record1, record2): + + field1 = field1.lower() + field2 = field2.lower() + + if len(field1) == 0 and len(field2) == 0: + # If both fields are blank, distance is 0 + distance = 0 + else: + distance = levenshtein_distance(field1, field2) + if len(field1) == 0 or len(field2) == 0: + # If one field is blank, reduce the impact of the distance + distance = distance * 0.05 # Fixed value for blank vs non-blank + + total_distance += distance + total_fields += 1 + + return total_distance / total_fields if total_fields > 0 else 0 From 5f6da64d94953b25f14ab29af6a9439dbddb0566 Mon Sep 17 00:00:00 2001 From: Jawadtp Date: Fri, 23 Aug 2024 15:13:54 +0530 Subject: [PATCH 11/34] Add support for similarity algorithm in step.py --- cumulusci/tasks/bulkdata/step.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/cumulusci/tasks/bulkdata/step.py b/cumulusci/tasks/bulkdata/step.py index 4770cab37e..8d9f39638c 100644 --- a/cumulusci/tasks/bulkdata/step.py +++ b/cumulusci/tasks/bulkdata/step.py @@ -20,6 +20,8 @@ SelectStrategy, random_generate_query, random_post_process, + similarity_generate_query, + similarity_post_process, ) from cumulusci.tasks.bulkdata.utils import DataApi, iterate_in_chunks from cumulusci.utils.classutils import namedtuple_as_simple_dict @@ -369,6 +371,9 @@ def __init__( if selection_strategy is SelectStrategy.RANDOM: self.select_generate_query = random_generate_query self.select_post_process = random_post_process + elif selection_strategy is SelectStrategy.SIMILARITY: + self.select_generate_query = similarity_generate_query + self.select_post_process = similarity_post_process def start(self): self.job_id = self.bulk.create_job( @@ -616,7 +621,7 @@ def __init__( api_options, context, fields, - selection_strategy=SelectStrategy.RANDOM, + selection_strategy=SelectStrategy.SIMILARITY, ): super().__init__( sobject=sobject, @@ -642,6 +647,9 @@ def __init__( if selection_strategy is SelectStrategy.RANDOM: self.select_generate_query = random_generate_query self.select_post_process = random_post_process + elif selection_strategy is SelectStrategy.SIMILARITY: + self.select_generate_query = similarity_generate_query + self.select_post_process = similarity_post_process def _record_to_json(self, rec): result = dict(zip(self.fields, rec)) From 30246d99f54d71bf36527440efa9071540a3a07e Mon Sep 17 00:00:00 2001 From: Jawadtp Date: Fri, 23 Aug 2024 15:46:47 +0530 Subject: [PATCH 12/34] Add unit tests for changes made for similarity selection strategy --- .../tasks/bulkdata/tests/test_select_utils.py | 161 ++++++++ cumulusci/tasks/bulkdata/tests/test_step.py | 367 ++++++++++++++++++ 2 files changed, 528 insertions(+) diff --git a/cumulusci/tasks/bulkdata/tests/test_select_utils.py b/cumulusci/tasks/bulkdata/tests/test_select_utils.py index 29abc845e7..4d084d5391 100644 --- a/cumulusci/tasks/bulkdata/tests/test_select_utils.py +++ b/cumulusci/tasks/bulkdata/tests/test_select_utils.py @@ -1,6 +1,11 @@ from cumulusci.tasks.bulkdata.select_utils import ( + calculate_levenshtein_distance, + find_closest_record, + levenshtein_distance, random_generate_query, random_post_process, + similarity_generate_query, + similarity_post_process, ) @@ -67,3 +72,159 @@ def test_random_post_process_with_no_records(): assert selected_records == [] assert error_message == f"No records found for {sobject} in the target org." + + +# Test Cases for random_generate_query +def test_similarity_generate_query_with_default_record_declaration(): + sobject = "Account" # Assuming Account has a declaration in DEFAULT_DECLARATIONS + num_records = 5 + query, fields = similarity_generate_query(sobject, ["Name"], num_records) + + assert "WHERE" in query # Ensure WHERE clause is included + assert fields == ["Id", "Name"] + + +def test_similarity_generate_query_without_default_record_declaration(): + sobject = "Contact" # Assuming no declaration for this object + num_records = 3 + query, fields = similarity_generate_query(sobject, ["Name"], num_records) + + assert "WHERE" not in query # No WHERE clause should be present + assert fields == ["Id", "Name"] + + +def test_levenshtein_distance(): + assert levenshtein_distance("kitten", "kitten") == 0 # Identical strings + assert levenshtein_distance("kitten", "sitten") == 1 # One substitution + assert levenshtein_distance("kitten", "kitte") == 1 # One deletion + assert levenshtein_distance("kitten", "sittin") == 2 # Two substitutions + assert levenshtein_distance("kitten", "dog") == 6 # Completely different strings + assert levenshtein_distance("kitten", "") == 6 # One string is empty + assert levenshtein_distance("", "") == 0 # Both strings are empty + assert levenshtein_distance("Kitten", "kitten") == 1 # Case sensitivity + assert levenshtein_distance("kit ten", "kitten") == 1 # Strings with spaces + assert ( + levenshtein_distance("levenshtein", "meilenstein") == 4 + ) # Longer strings with multiple differences + + +def test_calculate_levenshtein_distance(): + # Identical records + record1 = ["Tom Cruise", "24", "Actor"] + record2 = ["Tom Cruise", "24", "Actor"] + assert calculate_levenshtein_distance(record1, record2) == 0 # Distance should be 0 + + # Records with one different field + record1 = ["Tom Cruise", "24", "Actor"] + record2 = ["Tom Hanks", "24", "Actor"] + assert calculate_levenshtein_distance(record1, record2) > 0 # Non-zero distance + + # One record has an empty field + record1 = ["Tom Cruise", "24", "Actor"] + record2 = ["Tom Cruise", "", "Actor"] + assert ( + calculate_levenshtein_distance(record1, record2) > 0 + ) # Distance should reflect the empty field + + # Completely empty records + record1 = ["", "", ""] + record2 = ["", "", ""] + assert calculate_levenshtein_distance(record1, record2) == 0 # Distance should be 0 + + +def test_find_closest_record(): + # Test case 1: Exact match + load_record = ["Tom Cruise", "62", "Actor"] + query_records = [ + [1, "Tom Hanks", "30", "Actor"], + [2, "Tom Cruise", "62", "Actor"], # Exact match + [3, "Jennifer Aniston", "30", "Actress"], + ] + assert find_closest_record(load_record, query_records) == [ + 2, + "Tom Cruise", + "62", + "Actor", + ] # Should return the exact match + + # Test case 2: Closest match with slight differences + load_record = ["Tom Cruise", "62", "Actor"] + query_records = [ + [1, "Tom Hanks", "62", "Actor"], + [2, "Tom Cruise", "63", "Actor"], # Slight difference + [3, "Jennifer Aniston", "30", "Actress"], + ] + assert find_closest_record(load_record, query_records) == [ + 2, + "Tom Cruise", + "63", + "Actor", + ] # Should return the closest match + + # Test case 3: All records are significantly different + load_record = ["Tom Cruise", "62", "Actor"] + query_records = [ + [1, "Brad Pitt", "30", "Producer"], + [2, "Leonardo DiCaprio", "40", "Director"], + [3, "Jennifer Aniston", "30", "Actress"], + ] + assert ( + find_closest_record(load_record, query_records) == query_records[0] + ) # Should return the first record as the closest (though none are close) + + # Test case 4: Closest match is the last in the list + load_record = ["Tom Cruise", "62", "Actor"] + query_records = [ + [1, "Johnny Depp", "50", "Actor"], + [2, "Brad Pitt", "30", "Producer"], + [3, "Tom Cruise", "62", "Actor"], # Exact match as the last record + ] + assert find_closest_record(load_record, query_records) == [ + 3, + "Tom Cruise", + "62", + "Actor", + ] # Should return the last record + + # Test case 5: Single record in query_records + load_record = ["Tom Cruise", "62", "Actor"] + query_records = [[1, "Johnny Depp", "50", "Actor"]] + assert find_closest_record(load_record, query_records) == [ + 1, + "Johnny Depp", + "50", + "Actor", + ] # Should return the only record available + + +def test_similarity_post_process_with_records(): + num_records = 1 + sobject = "Contact" + load_records = [["Tom Cruise", "62", "Actor"]] + query_records = [ + ["001", "Tom Hanks", "62", "Actor"], + ["002", "Tom Cruise", "63", "Actor"], # Slight difference + ["003", "Jennifer Aniston", "30", "Actress"], + ] + + selected_records, error_message = similarity_post_process( + load_records, query_records, num_records, sobject + ) + + assert error_message is None + assert len(selected_records) == num_records + assert all(record["success"] for record in selected_records) + assert all(record["created"] is False for record in selected_records) + assert all(record["id"] in ["002"] for record in selected_records) + + +def test_similarity_post_process_with_no_records(): + records = [] + num_records = 2 + sobject = "Lead" + selected_records, error_message = similarity_post_process( + None, records, num_records, sobject + ) + + assert selected_records == [] + assert error_message == f"No records found for {sobject} in the target org." diff --git a/cumulusci/tasks/bulkdata/tests/test_step.py b/cumulusci/tasks/bulkdata/tests/test_step.py index 8f6f34ad90..66f464006f 100644 --- a/cumulusci/tasks/bulkdata/tests/test_step.py +++ b/cumulusci/tasks/bulkdata/tests/test_step.py @@ -633,6 +633,137 @@ def test_select_records_random_strategy_failure__no_records(self, download_mock) assert job_result.records_processed == 0 assert job_result.total_row_errors == 0 + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_similarity_strategy_success(self, download_mock): + # Set up mock context and BulkApiDmlOperation + context = mock.Mock() + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=context, + fields=["Id", "Name", "Email"], + selection_strategy=SelectStrategy.SIMILARITY, + ) + + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + # Mock the downloaded CSV content with a single record + download_mock.return_value = io.StringIO( + """Id,Name,Email +003000000000001,Jawad,mjawadtp@example.com +003000000000002,Aditya,aditya@example.com +003000000000003,Tom,tom@example.com""" + ) + + # Mock the _wait_for_job method to simulate a successful job + step._wait_for_job = mock.Mock() + step._wait_for_job.return_value = DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 0, 0 + ) + + # Prepare input records + records = iter( + [ + ["Jawad", "mjawadtp@example.com"], + ["Aditya", "aditya@example.com"], + ["Tom", "cruise@example.com"], + ] + ) + + # Execute the select_records operation + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000002", success=True, error="", created=False + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000003", success=True, error="", created=False + ) + ) + == 1 + ) + + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_similarity_strategy_failure__no_records( + self, download_mock + ): + # Set up mock context and BulkApiDmlOperation + context = mock.Mock() + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=context, + fields=["Id", "Name", "Email"], + selection_strategy=SelectStrategy.SIMILARITY, + ) + + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + # Mock the downloaded CSV content indicating no records found + download_mock.return_value = io.StringIO("""Records not found for this query""") + + # Mock the _wait_for_job method to simulate a successful job + step._wait_for_job = mock.Mock() + step._wait_for_job.return_value = DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 0, 0 + ) + + # Prepare input records + records = iter( + [ + ["Jawad", "mjawadtp@example.com"], + ["Aditya", "aditya@example.com"], + ["Tom", "cruise@example.com"], + ] + ) + + # Execute the select_records operation + step.start() + step.select_records(records) + step.end() + + # Get the job result and assert its properties for failure scenario + job_result = step.job_result + assert job_result.status == DataOperationStatus.JOB_FAILURE + assert ( + job_result.job_errors[0] + == "No records found for Contact in the target org." + ) + assert job_result.records_processed == 0 + assert job_result.total_row_errors == 0 + def test_batch(self): context = mock.Mock() @@ -1014,6 +1145,7 @@ def test_select_records_random_strategy_success(self): api_options={"batch_size": 10, "update_key": "LastName"}, context=task, fields=["LastName"], + selection_strategy=SelectStrategy.RANDOM, ) results = { @@ -1078,6 +1210,7 @@ def test_select_records_random_strategy_success__pagination(self): api_options={"batch_size": 10, "update_key": "LastName"}, context=task, fields=["LastName"], + selection_strategy=SelectStrategy.RANDOM, ) results = { @@ -1168,6 +1301,7 @@ def test_select_records_random_strategy_failure__no_records(self): api_options={"batch_size": 10, "update_key": "LastName"}, context=task, fields=["LastName"], + selection_strategy=SelectStrategy.RANDOM, ) results = {"records": [], "done": True} @@ -1188,6 +1322,239 @@ def test_select_records_random_strategy_failure__no_records(self): assert job_result.records_processed == 0 assert job_result.total_row_errors == 0 + @responses.activate + def test_select_records_similarity_strategy_success(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.UPSERT, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=task, + fields=["Id", "Name", "Email"], + selection_strategy=SelectStrategy.SIMILARITY, + ) + + results = { + "records": [ + { + "Id": "003000000000001", + "Name": "Jawad", + "Email": "mjawadtp@example.com", + }, + { + "Id": "003000000000002", + "Name": "Aditya", + "Email": "aditya@example.com", + }, + { + "Id": "003000000000003", + "Name": "Tom Cruise", + "Email": "tomcruise@example.com", + }, + ], + "done": True, + } + step.sf.query = mock.Mock() + step.sf.query.return_value = results + records = iter( + [ + ["Id: 1", "Jawad", "mjawadtp@example.com"], + ["Id: 2", "Aditya", "aditya@example.com"], + ["Id: 2", "Tom", "tom@example.com"], + ] + ) + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 1 + ) + + @responses.activate + def test_select_records_random_similarity_success__pagination(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.UPSERT, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=task, + fields=["Id", "Name", "Email"], + selection_strategy=SelectStrategy.SIMILARITY, + ) + + results = { + "records": [ + { + "Id": "003000000000001", + "Name": "Jawad", + "Email": "mjawadtp@example.com", + }, + ], + "done": False, + "nextRecordsUrl": "https://example.com", + } + results_more = { + "records": [ + { + "Id": "003000000000002", + "Name": "Aditya", + "Email": "aditya@example.com", + }, + { + "Id": "003000000000003", + "Name": "Tom Cruise", + "Email": "tomcruise@example.com", + }, + ], + "done": True, + } + step.sf.query = mock.Mock() + step.sf.query.return_value = results + step.sf.query_more = mock.Mock() + step.sf.query_more.return_value = results_more + records = iter( + [ + ["Id: 1", "Jawad", "mjawadtp@example.com"], + ["Id: 2", "Aditya", "aditya@example.com"], + ["Id: 2", "Tom", "tom@example.com"], + ] + ) + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + + @responses.activate + def test_select_records_similarity_strategy_failure__no_records(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.UPSERT, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=task, + fields=["Name", "Email"], + selection_strategy=SelectStrategy.SIMILARITY, + ) + + results = {"records": [], "done": True} + step.sf.query = mock.Mock() + step.sf.query.return_value = results + records = iter( + [ + ["Id: 1", "Jawad", "mjawadtp@example.com"], + ["Id: 2", "Aditya", "aditya@example.com"], + ["Id: 2", "Tom", "tom@example.com"], + ] + ) + step.start() + step.select_records(records) + step.end() + + # Get the job result and assert its properties for failure scenario + job_result = step.job_result + assert job_result.status == DataOperationStatus.JOB_FAILURE + assert ( + job_result.job_errors[0] + == "No records found for Contact in the target org." + ) + assert job_result.records_processed == 0 + assert job_result.total_row_errors == 0 + @responses.activate def test_insert_dml_operation__boolean_conversion(self): mock_describe_calls() From a368803e5f4b85dcf59d08f3265de4433706f1ad Mon Sep 17 00:00:00 2001 From: Jawadtp Date: Tue, 27 Aug 2024 15:34:56 +0530 Subject: [PATCH 13/34] Add more assertions in tests and remote list typing for load_Records --- cumulusci/tasks/bulkdata/select_utils.py | 2 +- cumulusci/tasks/bulkdata/tests/test_step.py | 44 ++++++++++++++++++++- 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/cumulusci/tasks/bulkdata/select_utils.py b/cumulusci/tasks/bulkdata/select_utils.py index 4c40600b53..f40ae8d431 100644 --- a/cumulusci/tasks/bulkdata/select_utils.py +++ b/cumulusci/tasks/bulkdata/select_utils.py @@ -81,7 +81,7 @@ def similarity_generate_query( def similarity_post_process( - load_records: list, query_records: list, num_records: float, sobject: str + load_records, query_records: list, num_records: float, sobject: str ) -> T.Tuple[T.List[dict], T.Union[str, None]]: """Processes the query results for the similarity selection strategy""" # Handle case where query returns 0 records diff --git a/cumulusci/tasks/bulkdata/tests/test_step.py b/cumulusci/tasks/bulkdata/tests/test_step.py index 66f464006f..9fdee3adb0 100644 --- a/cumulusci/tasks/bulkdata/tests/test_step.py +++ b/cumulusci/tasks/bulkdata/tests/test_step.py @@ -1387,7 +1387,7 @@ def test_select_records_similarity_strategy_success(self): [ ["Id: 1", "Jawad", "mjawadtp@example.com"], ["Id: 2", "Aditya", "aditya@example.com"], - ["Id: 2", "Tom", "tom@example.com"], + ["Id: 3", "Tom Cruise", "tom@example.com"], ] ) step.start() @@ -1406,6 +1406,22 @@ def test_select_records_similarity_strategy_success(self): ) == 1 ) + assert ( + results.count( + DataOperationResult( + id="003000000000002", success=True, error="", created=False + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000003", success=True, error="", created=False + ) + ) + == 1 + ) @responses.activate def test_select_records_random_similarity_success__pagination(self): @@ -1480,7 +1496,7 @@ def test_select_records_random_similarity_success__pagination(self): [ ["Id: 1", "Jawad", "mjawadtp@example.com"], ["Id: 2", "Aditya", "aditya@example.com"], - ["Id: 2", "Tom", "tom@example.com"], + ["Id: 3", "Tom Cruise", "tom@example.com"], ] ) step.start() @@ -1491,6 +1507,30 @@ def test_select_records_random_similarity_success__pagination(self): results = list(step.get_results()) assert len(results) == 3 # Expect 3 results (matching the input records count) # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 1 + ) @responses.activate def test_select_records_similarity_strategy_failure__no_records(self): From 6fad397839d8b7c6fe4312d8693614004dc6d5a2 Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Tue, 3 Sep 2024 03:11:23 +0530 Subject: [PATCH 14/34] Adds selection_filter and RANDOM selection strategy --- cumulusci/tasks/bulkdata/load.py | 20 +- cumulusci/tasks/bulkdata/mapping_parser.py | 9 +- cumulusci/tasks/bulkdata/select_utils.py | 78 ++- cumulusci/tasks/bulkdata/step.py | 281 ++++++++--- .../tasks/bulkdata/tests/test_select_utils.py | 107 +++- cumulusci/tasks/bulkdata/tests/test_step.py | 455 ++++++++++++------ 6 files changed, 715 insertions(+), 235 deletions(-) diff --git a/cumulusci/tasks/bulkdata/load.py b/cumulusci/tasks/bulkdata/load.py index 4ae0dcf31a..9435dfd183 100644 --- a/cumulusci/tasks/bulkdata/load.py +++ b/cumulusci/tasks/bulkdata/load.py @@ -289,7 +289,12 @@ def _execute_step( self, step, self._stream_queried_data(mapping, local_ids, query) ) step.start() - step.load_records(self._stream_queried_data(mapping, local_ids, query)) + if mapping.action == DataOperationType.SELECT: + step.select_records( + self._stream_queried_data(mapping, local_ids, query) + ) + else: + step.load_records(self._stream_queried_data(mapping, local_ids, query)) step.end() # Process Job Results @@ -336,6 +341,8 @@ def configure_step(self, mapping): self.check_simple_upsert(mapping) api_options["update_key"] = mapping.update_key[0] action = DataOperationType.UPSERT + elif mapping.action == DataOperationType.SELECT: + action = DataOperationType.QUERY else: action = mapping.action @@ -349,6 +356,8 @@ def configure_step(self, mapping): fields=fields, api=mapping.api, volume=query.count(), + selection_strategy=mapping.selection_strategy, + selection_filter=mapping.selection_filter, ) return step, query @@ -481,10 +490,11 @@ def _process_job_results(self, mapping, step, local_ids): """Get the job results and process the results. If we're raising for row-level errors, do so; if we're inserting, store the new Ids.""" - is_insert_or_upsert = mapping.action in ( + is_insert_upsert_or_select = mapping.action in ( DataOperationType.INSERT, DataOperationType.UPSERT, DataOperationType.ETL_UPSERT, + DataOperationType.SELECT, ) conn = self.session.connection() @@ -500,7 +510,7 @@ def _process_job_results(self, mapping, step, local_ids): break # If we know we have no successful inserts, don't attempt to persist Ids. # Do, however, drain the generator to get error-checking behavior. - if is_insert_or_upsert and ( + if is_insert_upsert_or_select and ( step.job_result.records_processed - step.job_result.total_row_errors ): table = self.metadata.tables[self.ID_TABLE_NAME] @@ -516,7 +526,7 @@ def _process_job_results(self, mapping, step, local_ids): # person account Contact records so lookups to # person account Contact records get populated downstream as expected. if ( - is_insert_or_upsert + is_insert_upsert_or_select and mapping.sf_object == "Contact" and self._can_load_person_accounts(mapping) ): @@ -531,7 +541,7 @@ def _process_job_results(self, mapping, step, local_ids): ), ) - if is_insert_or_upsert: + if is_insert_upsert_or_select: self.session.commit() def _generate_results_id_map(self, step, local_ids): diff --git a/cumulusci/tasks/bulkdata/mapping_parser.py b/cumulusci/tasks/bulkdata/mapping_parser.py index bb59fc6647..e812ca7d16 100644 --- a/cumulusci/tasks/bulkdata/mapping_parser.py +++ b/cumulusci/tasks/bulkdata/mapping_parser.py @@ -15,6 +15,7 @@ from cumulusci.core.enums import StrEnum from cumulusci.core.exceptions import BulkDataException from cumulusci.tasks.bulkdata.dates import iso_to_date +from cumulusci.tasks.bulkdata.select_utils import SelectStrategy from cumulusci.tasks.bulkdata.step import DataApi, DataOperationType from cumulusci.utils import convert_to_snake_case from cumulusci.utils.yaml.model_parser import CCIDictModel @@ -84,7 +85,7 @@ class BulkMode(StrEnum): ENUM_VALUES = { v.value.lower(): v.value - for enum in [BulkMode, DataApi, DataOperationType] + for enum in [BulkMode, DataApi, DataOperationType, SelectStrategy] for v in enum.__members__.values() } @@ -107,9 +108,13 @@ class MappingStep(CCIDictModel): ] = None # default should come from task options anchor_date: Optional[Union[str, date]] = None soql_filter: Optional[str] = None # soql_filter property + selection_strategy: SelectStrategy = SelectStrategy.STANDARD # selection strategy + selection_filter: Optional[ + str + ] = None # filter to be added at the end of select query update_key: T.Union[str, T.Tuple[str, ...]] = () # only for upserts - @validator("bulk_mode", "api", "action", pre=True) + @validator("bulk_mode", "api", "action", "selection_strategy", pre=True) def case_normalize(cls, val): if isinstance(val, Enum): return val diff --git a/cumulusci/tasks/bulkdata/select_utils.py b/cumulusci/tasks/bulkdata/select_utils.py index f40ae8d431..976f852540 100644 --- a/cumulusci/tasks/bulkdata/select_utils.py +++ b/cumulusci/tasks/bulkdata/select_utils.py @@ -1,3 +1,4 @@ +import random import typing as T from cumulusci.core.enums import StrEnum @@ -9,14 +10,52 @@ class SelectStrategy(StrEnum): """Enum defining the different selection strategies requested.""" - RANDOM = "random" + STANDARD = "standard" SIMILARITY = "similarity" + RANDOM = "random" -def random_generate_query( - sobject: str, fields: T.List[str], num_records: float +class SelectOperationExecutor: + def __init__(self, strategy: SelectStrategy): + self.strategy = strategy + + def select_generate_query( + self, sobject: str, fields: T.List[str], num_records: int + ): + # For STANDARD strategy + if self.strategy == SelectStrategy.STANDARD: + return standard_generate_query(sobject=sobject, num_records=num_records) + # For SIMILARITY strategy + elif self.strategy == SelectStrategy.SIMILARITY: + return similarity_generate_query(sobject=sobject, fields=fields) + # For RANDOM strategy + elif self.strategy == SelectStrategy.RANDOM: + return standard_generate_query(sobject=sobject, num_records=num_records) + + def select_post_process( + self, load_records, query_records: list, num_records: int, sobject: str + ): + # For STANDARD strategy + if self.strategy == SelectStrategy.STANDARD: + return standard_post_process( + query_records=query_records, num_records=num_records, sobject=sobject + ) + # For SIMILARITY strategy + elif self.strategy == SelectStrategy.SIMILARITY: + return similarity_post_process( + load_records=load_records, query_records=query_records, sobject=sobject + ) + # For RANDOM strategy + elif self.strategy == SelectStrategy.RANDOM: + return random_post_process( + query_records=query_records, num_records=num_records, sobject=sobject + ) + + +def standard_generate_query( + sobject: str, num_records: int ) -> T.Tuple[str, T.List[str]]: - """Generates the SOQL query for the random selection strategy""" + """Generates the SOQL query for the standard (as well as random) selection strategy""" # Get the WHERE clause from DEFAULT_DECLARATIONS if available declaration = DEFAULT_DECLARATIONS.get(sobject) if declaration: @@ -32,10 +71,10 @@ def random_generate_query( return query, ["Id"] -def random_post_process( - load_records, query_records: list, num_records: float, sobject: str +def standard_post_process( + query_records: list, num_records: int, sobject: str ) -> T.Tuple[T.List[dict], T.Union[str, None]]: - """Processes the query results for the random selection strategy""" + """Processes the query results for the standard selection strategy""" # Handle case where query returns 0 records if not query_records: error_message = f"No records found for {sobject} in the target org." @@ -59,9 +98,8 @@ def random_post_process( def similarity_generate_query( sobject: str, fields: T.List[str], - num_records: float, ) -> T.Tuple[str, T.List[str]]: - """Generates the SOQL query for the random selection strategy""" + """Generates the SOQL query for the similarity selection strategy""" # Get the WHERE clause from DEFAULT_DECLARATIONS if available declaration = DEFAULT_DECLARATIONS.get(sobject) if declaration: @@ -81,7 +119,7 @@ def similarity_generate_query( def similarity_post_process( - load_records, query_records: list, num_records: float, sobject: str + load_records: list, query_records: list, sobject: str ) -> T.Tuple[T.List[dict], T.Union[str, None]]: """Processes the query results for the similarity selection strategy""" # Handle case where query returns 0 records @@ -100,6 +138,26 @@ def similarity_post_process( return closest_records, None +def random_post_process( + query_records: list, num_records: int, sobject: str +) -> T.Tuple[T.List[dict], T.Union[str, None]]: + """Processes the query results for the random selection strategy""" + + if not query_records: + error_message = f"No records found for {sobject} in the target org." + return [], error_message + + selected_records = [] + for _ in range(num_records): # Loop 'num_records' times + # Randomly select one record from query_records + random_record = random.choice(query_records) + selected_records.append( + {"id": random_record[0], "success": True, "created": False} + ) + + return selected_records, None + + def find_closest_record(load_record: list, query_records: list): closest_distance = float("inf") closest_record = query_records[0] diff --git a/cumulusci/tasks/bulkdata/step.py b/cumulusci/tasks/bulkdata/step.py index 8d9f39638c..fd25f0e19d 100644 --- a/cumulusci/tasks/bulkdata/step.py +++ b/cumulusci/tasks/bulkdata/step.py @@ -3,25 +3,23 @@ import json import os import pathlib +import re import tempfile import time from abc import ABCMeta, abstractmethod from contextlib import contextmanager from itertools import tee -from typing import Any, Dict, List, NamedTuple, Optional +from typing import Any, Dict, List, NamedTuple, Optional, Union import requests import salesforce_bulk from cumulusci.core.enums import StrEnum -from cumulusci.core.exceptions import BulkDataException +from cumulusci.core.exceptions import BulkDataException, SOQLQueryException from cumulusci.core.utils import process_bool_arg from cumulusci.tasks.bulkdata.select_utils import ( + SelectOperationExecutor, SelectStrategy, - random_generate_query, - random_post_process, - similarity_generate_query, - similarity_post_process, ) from cumulusci.tasks.bulkdata.utils import DataApi, iterate_in_chunks from cumulusci.utils.classutils import namedtuple_as_simple_dict @@ -352,7 +350,8 @@ def __init__( api_options, context, fields, - selection_strategy=SelectStrategy.RANDOM, + selection_strategy=SelectStrategy.STANDARD, + selection_filter=None, ): super().__init__( sobject=sobject, @@ -368,12 +367,8 @@ def __init__( self.csv_buff = io.StringIO(newline="") self.csv_writer = csv.writer(self.csv_buff, quoting=csv.QUOTE_ALL) - if selection_strategy is SelectStrategy.RANDOM: - self.select_generate_query = random_generate_query - self.select_post_process = random_post_process - elif selection_strategy is SelectStrategy.SIMILARITY: - self.select_generate_query = similarity_generate_query - self.select_post_process = similarity_post_process + self.select_operation_executor = SelectOperationExecutor(selection_strategy) + self.selection_filter = selection_filter def start(self): self.job_id = self.bulk.create_job( @@ -451,7 +446,7 @@ def select_records(self, records): """Executes a SOQL query to select records and adds them to results""" self.select_results = [] # Store selected records - + query_records = [] # Create a copy of the generator using tee records, records_copy = tee(records) # Count total number of records to fetch using the copy @@ -467,34 +462,55 @@ def select_records(self, records): ) # Generate and execute SOQL query - query, query_fields = self.select_generate_query( + ( + select_query, + query_fields, + ) = self.select_operation_executor.select_generate_query( self.sobject, self.fields, num_records ) - self.batch_id = self.bulk.query(self.job_id, query) - self._wait_for_job(self.job_id) + if self.selection_filter: + # Generate user filter query if selection_filter is present (offset clause not supported) + user_query = generate_user_filter_query( + self.selection_filter, self.sobject, ["Id"], num_records, None + ) + # Execute the user query using Bulk API + user_query_executor = get_query_operation( + sobject=self.sobject, + fields=["Id"], + api_options=self.api_options, + context=self, + query=user_query, + api=DataApi.BULK, + ) + user_query_executor.query() + user_query_records = user_query_executor.get_results() - # Get and process query results - result_ids = self.bulk.get_query_batch_result_ids( - self.batch_id, job_id=self.job_id - ) - query_records = [] - for result_id in result_ids: - uri = f"{self.bulk.endpoint}/job/{self.job_id}/batch/{self.batch_id}/result/{result_id}" - with download_file(uri, self.bulk) as f: - reader = csv.reader(f) - self.headers = next(reader) - if "Records not found for this query" in self.headers: - break # Stop if no records found - for row in reader: - query_records.append(row[: len(query_fields)]) - - # Post-process the query results - selected_records, error_message = self.select_post_process( - records, query_records, num_records, self.sobject + # Find intersection based on 'Id' + user_query_ids = set(record[0] for record in user_query_records) + + # Execute the main select query using Bulk API + select_query_records = self._execute_select_query( + select_query=select_query, query_fields=query_fields ) - if error_message: - break # Stop if there's an error during post-processing + # If user_query_ids exist, filter select_query_records based on the intersection of Ids + if self.selection_filter: + query_records.extend( + record + for record in select_query_records + if record[query_fields.index("Id")] in user_query_ids + ) + else: + query_records.extend(select_query_records) + + # Post-process the query results + ( + selected_records, + error_message, + ) = self.select_operation_executor.select_post_process( + records, query_records, num_records, self.sobject + ) + if not error_message: self.select_results.extend(selected_records) # Update job result based on selection outcome @@ -507,6 +523,25 @@ def select_records(self, records): 0, ) + def _execute_select_query(self, select_query: str, query_fields: List[str]): + """Executes the select Bulk API query and retrieves the results.""" + self.batch_id = self.bulk.query(self.job_id, select_query) + self._wait_for_job(self.job_id) + result_ids = self.bulk.get_query_batch_result_ids( + self.batch_id, job_id=self.job_id + ) + select_query_records = [] + for result_id in result_ids: + uri = f"{self.bulk.endpoint}/job/{self.job_id}/batch/{self.batch_id}/result/{result_id}" + with download_file(uri, self.bulk) as f: + reader = csv.reader(f) + self.headers = next(reader) + if "Records not found for this query" in self.headers: + break + for row in reader: + select_query_records.append(row[: len(query_fields)]) + return select_query_records + def _batch(self, records, n, char_limit=10000000): """Given an iterator of records, yields batches of records serialized in .csv format. @@ -622,6 +657,7 @@ def __init__( context, fields, selection_strategy=SelectStrategy.SIMILARITY, + selection_filter=None, ): super().__init__( sobject=sobject, @@ -644,12 +680,9 @@ def __init__( self.api_options["batch_size"] = min( self.api_options["batch_size"], MAX_REST_BATCH_SIZE ) - if selection_strategy is SelectStrategy.RANDOM: - self.select_generate_query = random_generate_query - self.select_post_process = random_post_process - elif selection_strategy is SelectStrategy.SIMILARITY: - self.select_generate_query = similarity_generate_query - self.select_post_process = similarity_post_process + + self.select_operation_executor = SelectOperationExecutor(selection_strategy) + self.selection_filter = selection_filter def _record_to_json(self, rec): result = dict(zip(self.fields, rec)) @@ -764,6 +797,7 @@ def convert(rec, fields): return [str(rec[f]) if rec[f] is not None else "" for f in fields] self.results = [] + query_records = [] # Create a copy of the generator using tee records, records_copy = tee(records) # Count total number of records to fetch using the copy @@ -774,32 +808,44 @@ def convert(rec, fields): num_records = min( self.api_options.get("batch_size"), total_num_records - offset ) - # Generate the SOQL query with and LIMIT - query, query_fields = self.select_generate_query( + + # Generate the SOQL query based on the selection strategy + ( + select_query, + query_fields, + ) = self.select_operation_executor.select_generate_query( self.sobject, self.fields, num_records ) - # Execute the query and extract results - response = self.sf.query(query) - # Extract and convert 'Id' fields from the query results - query_records = list( - convert(rec, query_fields) for rec in response["records"] - ) - # Handle pagination if there are more records within this batch - while not response["done"]: - response = self.sf.query_more( - response["nextRecordsUrl"], identifier_is_url=True + # If user given selection filter present, create composite request + if self.selection_filter: + user_query = generate_user_filter_query( + self.selection_filter, self.sobject, ["Id"], num_records, offset + ) + query_records.extend( + self._execute_composite_query( + select_query=select_query, + user_query=user_query, + query_fields=query_fields, + ) + ) + else: + # Handle the case where self.selection_query is None (and hence user_query is also None) + response = self.sf.restful( + requests.utils.requote_uri(f"query/?q={select_query}"), method="GET" ) query_records.extend( list(convert(rec, query_fields) for rec in response["records"]) ) - # Post-process the query results for this batch - selected_records, error_message = self.select_post_process( - records, query_records, num_records, self.sobject - ) - if error_message: - break + # Post-process the query results for this batch + ( + selected_records, + error_message, + ) = self.select_operation_executor.select_post_process( + records, query_records, total_num_records, self.sobject + ) + if not error_message: # Add selected records from this batch to the overall results self.results.extend(selected_records) @@ -813,6 +859,65 @@ def convert(rec, fields): 0, ) + def _execute_composite_query(self, select_query, user_query, query_fields): + """Executes a composite request with two queries and returns the intersected results.""" + + def convert(rec, fields): + """Helper function to convert record values to strings, handling None values""" + return [str(rec[f]) if rec[f] is not None else "" for f in fields] + + composite_request_json = { + "compositeRequest": [ + { + "method": "GET", + "url": requests.utils.requote_uri( + f"/services/data/v{self.sf.sf_version}/query/?q={select_query}" + ), + "referenceId": "select_query", + }, + { + "method": "GET", + "url": requests.utils.requote_uri( + f"/services/data/v{self.sf.sf_version}/query/?q={user_query}" + ), + "referenceId": "user_query", + }, + ] + } + response = self.sf.restful( + "composite", method="POST", json=composite_request_json + ) + + # Extract results based on referenceId + for sub_response in response["compositeResponse"]: + if ( + sub_response["referenceId"] == "select_query" + and sub_response["httpStatusCode"] == 200 + ): + select_query_records = list( + convert(rec, query_fields) + for rec in sub_response["body"]["records"] + ) + elif ( + sub_response["referenceId"] == "user_query" + and sub_response["httpStatusCode"] == 200 + ): + user_query_records = list( + convert(rec, ["Id"]) for rec in sub_response["body"]["records"] + ) + else: + raise SOQLQueryException( + f"{sub_response['body'][0]['errorCode']}: {sub_response['body'][0]['message']}" + ) + # Find intersection based on 'Id' + user_query_ids = set(record[0] for record in user_query_records) + + return [ + record + for record in select_query_records + if record[query_fields.index("Id")] in user_query_ids + ] + def get_results(self): """Return a generator of DataOperationResult objects.""" @@ -894,6 +999,8 @@ def get_dml_operation( context: Any, volume: int, api: Optional[DataApi] = DataApi.SMART, + selection_strategy: SelectStrategy = SelectStrategy.STANDARD, + selection_filter: Union[str, None] = None, ) -> BaseDmlOperation: """Create an appropriate DmlOperation instance for the given parameters, selecting between REST and Bulk APIs based upon volume (Bulk used at volumes over 2000 records, @@ -927,4 +1034,56 @@ def get_dml_operation( api_options=api_options, context=context, fields=fields, + selection_strategy=selection_strategy, + selection_filter=selection_filter, ) + + +def generate_user_filter_query( + filter_clause: str, + sobject: str, + fields: list, + limit_clause: Union[int, None] = None, + offset_clause: Union[int, None] = None, +) -> str: + """ + Generates a SOQL query with the provided filter, object, fields, limit, and offset clauses. + Handles cases where the filter clause already contains LIMIT or OFFSET, and avoids multiple spaces. + """ + + # Extract existing LIMIT and OFFSET from filter_clause if present + existing_limit_match = re.search(r"LIMIT\s+(\d+)", filter_clause, re.IGNORECASE) + existing_offset_match = re.search(r"OFFSET\s+(\d+)", filter_clause, re.IGNORECASE) + + if existing_limit_match: + existing_limit = int(existing_limit_match.group(1)) + if limit_clause is not None: # Only apply limit_clause if it's provided + limit_clause = min(existing_limit, limit_clause) + else: + limit_clause = existing_limit + + if existing_offset_match: + existing_offset = int(existing_offset_match.group(1)) + if offset_clause is not None: + offset_clause = existing_offset + offset_clause + else: + offset_clause = existing_offset + + # Remove existing LIMIT and OFFSET from filter_clause, handling potential extra spaces + filter_clause = re.sub( + r"\s+OFFSET\s+\d+\s*", " ", filter_clause, flags=re.IGNORECASE + ).strip() + filter_clause = re.sub( + r"\s+LIMIT\s+\d+\s*", " ", filter_clause, flags=re.IGNORECASE + ).strip() + + # Construct the SOQL query + fields_str = ", ".join(fields) + query = f"SELECT {fields_str} FROM {sobject} {filter_clause}" + + if limit_clause is not None: + query += f" LIMIT {limit_clause}" + if offset_clause is not None: + query += f" OFFSET {offset_clause}" + + return query diff --git a/cumulusci/tasks/bulkdata/tests/test_select_utils.py b/cumulusci/tasks/bulkdata/tests/test_select_utils.py index 4d084d5391..0ae97acb46 100644 --- a/cumulusci/tasks/bulkdata/tests/test_select_utils.py +++ b/cumulusci/tasks/bulkdata/tests/test_select_utils.py @@ -1,19 +1,47 @@ from cumulusci.tasks.bulkdata.select_utils import ( + SelectOperationExecutor, + SelectStrategy, calculate_levenshtein_distance, find_closest_record, levenshtein_distance, - random_generate_query, - random_post_process, - similarity_generate_query, - similarity_post_process, ) -# Test Cases for random_generate_query +# Test Cases for standard_generate_query +def test_standard_generate_query_with_default_record_declaration(): + select_operator = SelectOperationExecutor(SelectStrategy.STANDARD) + sobject = "Account" # Assuming Account has a declaration in DEFAULT_DECLARATIONS + num_records = 5 + query, fields = select_operator.select_generate_query( + sobject=sobject, fields=[], num_records=num_records + ) + + assert "WHERE" in query # Ensure WHERE clause is included + assert f"LIMIT {num_records}" in query + assert fields == ["Id"] + + +def test_standard_generate_query_without_default_record_declaration(): + select_operator = SelectOperationExecutor(SelectStrategy.STANDARD) + sobject = "Contact" # Assuming no declaration for this object + num_records = 3 + query, fields = select_operator.select_generate_query( + sobject=sobject, fields=[], num_records=num_records + ) + + assert "WHERE" not in query # No WHERE clause should be present + assert f"LIMIT {num_records}" in query + assert fields == ["Id"] + + +# Test Cases for random generate query def test_random_generate_query_with_default_record_declaration(): + select_operator = SelectOperationExecutor(SelectStrategy.RANDOM) sobject = "Account" # Assuming Account has a declaration in DEFAULT_DECLARATIONS num_records = 5 - query, fields = random_generate_query(sobject, [], num_records) + query, fields = select_operator.select_generate_query( + sobject=sobject, fields=[], num_records=num_records + ) assert "WHERE" in query # Ensure WHERE clause is included assert f"LIMIT {num_records}" in query @@ -21,21 +49,25 @@ def test_random_generate_query_with_default_record_declaration(): def test_random_generate_query_without_default_record_declaration(): + select_operator = SelectOperationExecutor(SelectStrategy.RANDOM) sobject = "Contact" # Assuming no declaration for this object num_records = 3 - query, fields = random_generate_query(sobject, [], num_records) + query, fields = select_operator.select_generate_query( + sobject=sobject, fields=[], num_records=num_records + ) assert "WHERE" not in query # No WHERE clause should be present assert f"LIMIT {num_records}" in query assert fields == ["Id"] -# Test Cases for random_post_process -def test_random_post_process_with_records(): +# Test Cases for standard_post_process +def test_standard_post_process_with_records(): + select_operator = SelectOperationExecutor(SelectStrategy.STANDARD) records = [["001"], ["002"], ["003"]] num_records = 3 sobject = "Contact" - selected_records, error_message = random_post_process( + selected_records, error_message = select_operator.select_post_process( None, records, num_records, sobject ) @@ -46,11 +78,12 @@ def test_random_post_process_with_records(): assert all(record["id"] in ["001", "002", "003"] for record in selected_records) -def test_random_post_process_with_fewer_records(): +def test_standard_post_process_with_fewer_records(): + select_operator = SelectOperationExecutor(SelectStrategy.STANDARD) records = [["001"]] num_records = 3 sobject = "Opportunity" - selected_records, error_message = random_post_process( + selected_records, error_message = select_operator.select_post_process( None, records, num_records, sobject ) @@ -62,11 +95,41 @@ def test_random_post_process_with_fewer_records(): assert selected_records.count({"id": "001", "success": True, "created": False}) == 3 +def test_standard_post_process_with_no_records(): + select_operator = SelectOperationExecutor(SelectStrategy.STANDARD) + records = [] + num_records = 2 + sobject = "Lead" + selected_records, error_message = select_operator.select_post_process( + None, records, num_records, sobject + ) + + assert selected_records == [] + assert error_message == f"No records found for {sobject} in the target org." + + +# Test cases for Random Post Process +def test_random_post_process_with_records(): + select_operator = SelectOperationExecutor(SelectStrategy.RANDOM) + records = [["001"], ["002"], ["003"]] + num_records = 3 + sobject = "Contact" + selected_records, error_message = select_operator.select_post_process( + None, records, num_records, sobject + ) + + assert error_message is None + assert len(selected_records) == num_records + assert all(record["success"] for record in selected_records) + assert all(record["created"] is False for record in selected_records) + + def test_random_post_process_with_no_records(): + select_operator = SelectOperationExecutor(SelectStrategy.RANDOM) records = [] num_records = 2 sobject = "Lead" - selected_records, error_message = random_post_process( + selected_records, error_message = select_operator.select_post_process( None, records, num_records, sobject ) @@ -74,20 +137,26 @@ def test_random_post_process_with_no_records(): assert error_message == f"No records found for {sobject} in the target org." -# Test Cases for random_generate_query +# Test Cases for Similarity Generate Query def test_similarity_generate_query_with_default_record_declaration(): + select_operator = SelectOperationExecutor(SelectStrategy.SIMILARITY) sobject = "Account" # Assuming Account has a declaration in DEFAULT_DECLARATIONS num_records = 5 - query, fields = similarity_generate_query(sobject, ["Name"], num_records) + query, fields = select_operator.select_generate_query( + sobject, ["Name"], num_records + ) assert "WHERE" in query # Ensure WHERE clause is included assert fields == ["Id", "Name"] def test_similarity_generate_query_without_default_record_declaration(): + select_operator = SelectOperationExecutor(SelectStrategy.SIMILARITY) sobject = "Contact" # Assuming no declaration for this object num_records = 3 - query, fields = similarity_generate_query(sobject, ["Name"], num_records) + query, fields = select_operator.select_generate_query( + sobject, ["Name"], num_records + ) assert "WHERE" not in query # No WHERE clause should be present assert fields == ["Id", "Name"] @@ -198,6 +267,7 @@ def test_find_closest_record(): def test_similarity_post_process_with_records(): + select_operator = SelectOperationExecutor(SelectStrategy.SIMILARITY) num_records = 1 sobject = "Contact" load_records = [["Tom Cruise", "62", "Actor"]] @@ -207,7 +277,7 @@ def test_similarity_post_process_with_records(): ["003", "Jennifer Aniston", "30", "Actress"], ] - selected_records, error_message = similarity_post_process( + selected_records, error_message = select_operator.select_post_process( load_records, query_records, num_records, sobject ) @@ -219,10 +289,11 @@ def test_similarity_post_process_with_records(): def test_similarity_post_process_with_no_records(): + select_operator = SelectOperationExecutor(SelectStrategy.SIMILARITY) records = [] num_records = 2 sobject = "Lead" - selected_records, error_message = similarity_post_process( + selected_records, error_message = select_operator.select_post_process( None, records, num_records, sobject ) diff --git a/cumulusci/tasks/bulkdata/tests/test_step.py b/cumulusci/tasks/bulkdata/tests/test_step.py index 9fdee3adb0..b2904ae9c5 100644 --- a/cumulusci/tasks/bulkdata/tests/test_step.py +++ b/cumulusci/tasks/bulkdata/tests/test_step.py @@ -5,7 +5,7 @@ import pytest import responses -from cumulusci.core.exceptions import BulkDataException +from cumulusci.core.exceptions import BulkDataException, SOQLQueryException from cumulusci.tasks.bulkdata.load import LoadData from cumulusci.tasks.bulkdata.select_utils import SelectStrategy from cumulusci.tasks.bulkdata.step import ( @@ -20,6 +20,7 @@ RestApiDmlOperation, RestApiQueryOperation, download_file, + generate_user_filter_query, get_dml_operation, get_query_operation, ) @@ -536,7 +537,7 @@ def test_get_prev_record_values(self): step.bulk.get_all_results_for_query_batch.assert_called_once_with("BATCH_ID") @mock.patch("cumulusci.tasks.bulkdata.step.download_file") - def test_select_records_random_strategy_success(self, download_mock): + def test_select_records_standard_strategy_success(self, download_mock): # Set up mock context and BulkApiDmlOperation context = mock.Mock() step = BulkApiDmlOperation( @@ -545,7 +546,7 @@ def test_select_records_random_strategy_success(self, download_mock): api_options={"batch_size": 10, "update_key": "LastName"}, context=context, fields=["LastName"], - selection_strategy=SelectStrategy.RANDOM, + selection_strategy=SelectStrategy.STANDARD, ) # Mock Bulk API responses @@ -588,7 +589,7 @@ def test_select_records_random_strategy_success(self, download_mock): ) @mock.patch("cumulusci.tasks.bulkdata.step.download_file") - def test_select_records_random_strategy_failure__no_records(self, download_mock): + def test_select_records_standard_strategy_failure__no_records(self, download_mock): # Set up mock context and BulkApiDmlOperation context = mock.Mock() step = BulkApiDmlOperation( @@ -597,7 +598,7 @@ def test_select_records_random_strategy_failure__no_records(self, download_mock) api_options={"batch_size": 10, "update_key": "LastName"}, context=context, fields=["LastName"], - selection_strategy=SelectStrategy.RANDOM, + selection_strategy=SelectStrategy.STANDARD, ) # Mock Bulk API responses @@ -633,6 +634,118 @@ def test_select_records_random_strategy_failure__no_records(self, download_mock) assert job_result.records_processed == 0 assert job_result.total_row_errors == 0 + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_user_selection_filter_success(self, download_mock): + # Set up mock context and BulkApiDmlOperation + context = mock.Mock() + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=context, + fields=["LastName"], + selection_strategy=SelectStrategy.STANDARD, + selection_filter='WHERE LastName in ("Sample Name")', + ) + + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + # Mock the downloaded CSV content with a single record + download_mock.return_value = io.StringIO( + """Id +003000000000001 +003000000000002 +003000000000003""" + ) + # Mock the query operation + with mock.patch( + "cumulusci.tasks.bulkdata.step.get_query_operation" + ) as query_operation_mock: + query_operation_mock.return_value = mock.Mock() + query_operation_mock.return_value.query = mock.Mock() + query_operation_mock.return_value.get_results = mock.Mock() + query_operation_mock.return_value.get_results.return_value = [ + ["003000000000001"] + ] + + # Mock the _wait_for_job method to simulate a successful job + step._wait_for_job = mock.Mock() + step._wait_for_job.return_value = DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 0, 0 + ) + + # Prepare input records + records = iter([["Test1"], ["Test2"], ["Test3"]]) + + # Execute the select_records operation + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert ( + len(results) == 3 + ) # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 3 + ) + + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_user_selection_filter_failure(self, download_mock): + # Set up mock context and BulkApiDmlOperation + context = mock.Mock() + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=context, + fields=["LastName"], + selection_strategy=SelectStrategy.STANDARD, + selection_filter='WHERE LastName in ("Sample Name")', + ) + + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + # Mock the downloaded CSV content with a single record + download_mock.return_value = io.StringIO( + """Id +003000000000001 +003000000000002 +003000000000003""" + ) + # Mock the query operation + with mock.patch( + "cumulusci.tasks.bulkdata.step.get_query_operation" + ) as query_operation_mock: + query_operation_mock.return_value = mock.Mock() + query_operation_mock.return_value.query = mock.Mock() + query_operation_mock.return_value.query.side_effect = BulkDataException( + "MALFORMED QUERY" + ) + + # Prepare input records + records = iter([["Test1"], ["Test2"], ["Test3"]]) + + # Execute the select_records operation + step.start() + with pytest.raises(BulkDataException): + step.select_records(records) + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") def test_select_records_similarity_strategy_success(self, download_mock): # Set up mock context and BulkApiDmlOperation @@ -1110,7 +1223,7 @@ def test_get_prev_record_values(self): assert set(relevant_fields) == set(expected_relevant_fields) @responses.activate - def test_select_records_random_strategy_success(self): + def test_select_records_standard_strategy_success(self): mock_describe_calls() task = _make_task( LoadData, @@ -1145,7 +1258,7 @@ def test_select_records_random_strategy_success(self): api_options={"batch_size": 10, "update_key": "LastName"}, context=task, fields=["LastName"], - selection_strategy=SelectStrategy.RANDOM, + selection_strategy=SelectStrategy.STANDARD, ) results = { @@ -1154,8 +1267,8 @@ def test_select_records_random_strategy_success(self): ], "done": True, } - step.sf.query = mock.Mock() - step.sf.query.return_value = results + step.sf.restful = mock.Mock() + step.sf.restful.return_value = results records = iter([["Test1"], ["Test2"], ["Test3"]]) step.start() step.select_records(records) @@ -1175,7 +1288,7 @@ def test_select_records_random_strategy_success(self): ) @responses.activate - def test_select_records_random_strategy_success__pagination(self): + def test_select_records_standard_strategy_failure__no_records(self): mock_describe_calls() task = _make_task( LoadData, @@ -1210,63 +1323,29 @@ def test_select_records_random_strategy_success__pagination(self): api_options={"batch_size": 10, "update_key": "LastName"}, context=task, fields=["LastName"], - selection_strategy=SelectStrategy.RANDOM, + selection_strategy=SelectStrategy.STANDARD, ) - results = { - "records": [ - {"Id": "003000000000001"}, - ], - "done": False, - "nextRecordsUrl": "https://example.com", - } - results_more = { - "records": [ - {"Id": "003000000000002"}, - {"Id": "003000000000003"}, - ], - "done": True, - } - step.sf.query = mock.Mock() - step.sf.query.return_value = results - step.sf.query_more = mock.Mock() - step.sf.query_more.return_value = results_more + results = {"records": [], "done": True} + step.sf.restful = mock.Mock() + step.sf.restful.return_value = results records = iter([["Test1"], ["Test2"], ["Test3"]]) step.start() step.select_records(records) step.end() - # Get the results and assert their properties - results = list(step.get_results()) - assert len(results) == 3 # Expect 3 results (matching the input records count) - # Assert that all results have the expected ID, success, and created values - assert ( - results.count( - DataOperationResult( - id="003000000000001", success=True, error="", created=False - ) - ) - == 1 - ) - assert ( - results.count( - DataOperationResult( - id="003000000000002", success=True, error="", created=False - ) - ) - == 1 - ) + # Get the job result and assert its properties for failure scenario + job_result = step.job_result + assert job_result.status == DataOperationStatus.JOB_FAILURE assert ( - results.count( - DataOperationResult( - id="003000000000003", success=True, error="", created=False - ) - ) - == 1 + job_result.job_errors[0] + == "No records found for Contact in the target org." ) + assert job_result.records_processed == 0 + assert job_result.total_row_errors == 0 @responses.activate - def test_select_records_random_strategy_failure__no_records(self): + def test_select_records_user_selection_filter_success(self): mock_describe_calls() task = _make_task( LoadData, @@ -1301,29 +1380,56 @@ def test_select_records_random_strategy_failure__no_records(self): api_options={"batch_size": 10, "update_key": "LastName"}, context=task, fields=["LastName"], - selection_strategy=SelectStrategy.RANDOM, + selection_strategy=SelectStrategy.STANDARD, + selection_filter='WHERE LastName IN ("Sample Name")', ) - results = {"records": [], "done": True} - step.sf.query = mock.Mock() - step.sf.query.return_value = results + results = { + "compositeResponse": [ + { + "body": { + "records": [ + {"Id": "003000000000001"}, + {"Id": "003000000000002"}, + {"Id": "003000000000003"}, + ] + }, + "referenceId": "select_query", + "httpStatusCode": 200, + }, + { + "body": { + "records": [ + {"Id": "003000000000001"}, + ] + }, + "referenceId": "user_query", + "httpStatusCode": 200, + }, + ] + } + step.sf.restful = mock.Mock() + step.sf.restful.return_value = results records = iter([["Test1"], ["Test2"], ["Test3"]]) step.start() step.select_records(records) step.end() - # Get the job result and assert its properties for failure scenario - job_result = step.job_result - assert job_result.status == DataOperationStatus.JOB_FAILURE + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values assert ( - job_result.job_errors[0] - == "No records found for Contact in the target org." + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 3 ) - assert job_result.records_processed == 0 - assert job_result.total_row_errors == 0 @responses.activate - def test_select_records_similarity_strategy_success(self): + def test_select_records_user_selection_filter_failure(self): mock_describe_calls() task = _make_task( LoadData, @@ -1357,74 +1463,46 @@ def test_select_records_similarity_strategy_success(self): operation=DataOperationType.UPSERT, api_options={"batch_size": 10, "update_key": "LastName"}, context=task, - fields=["Id", "Name", "Email"], - selection_strategy=SelectStrategy.SIMILARITY, + fields=["LastName"], + selection_strategy=SelectStrategy.STANDARD, + selection_filter="MALFORMED FILTER", # Applying malformed filter ) results = { - "records": [ + "compositeResponse": [ { - "Id": "003000000000001", - "Name": "Jawad", - "Email": "mjawadtp@example.com", - }, - { - "Id": "003000000000002", - "Name": "Aditya", - "Email": "aditya@example.com", + "body": { + "records": [ + {"Id": "003000000000001"}, + {"Id": "003000000000002"}, + {"Id": "003000000000003"}, + ] + }, + "referenceId": "select_query", + "httpStatusCode": 200, }, { - "Id": "003000000000003", - "Name": "Tom Cruise", - "Email": "tomcruise@example.com", + "body": [ + { + "message": "Error in MALFORMED FILTER", + "errorCode": "MALFORMED QUERY", + } + ], + "referenceId": "user_query", + "httpStatusCode": 400, }, - ], - "done": True, - } - step.sf.query = mock.Mock() - step.sf.query.return_value = results - records = iter( - [ - ["Id: 1", "Jawad", "mjawadtp@example.com"], - ["Id: 2", "Aditya", "aditya@example.com"], - ["Id: 3", "Tom Cruise", "tom@example.com"], ] - ) + } + step.sf.restful = mock.Mock() + step.sf.restful.return_value = results + records = iter([["Test1"], ["Test2"], ["Test3"]]) step.start() - step.select_records(records) - step.end() - - # Get the results and assert their properties - results = list(step.get_results()) - assert len(results) == 3 # Expect 3 results (matching the input records count) - # Assert that all results have the expected ID, success, and created values - assert ( - results.count( - DataOperationResult( - id="003000000000001", success=True, error="", created=False - ) - ) - == 1 - ) - assert ( - results.count( - DataOperationResult( - id="003000000000002", success=True, error="", created=False - ) - ) - == 1 - ) - assert ( - results.count( - DataOperationResult( - id="003000000000003", success=True, error="", created=False - ) - ) - == 1 - ) + with pytest.raises(SOQLQueryException) as e: + step.select_records(records) + assert "MALFORMED QUERY" in str(e.value) @responses.activate - def test_select_records_random_similarity_success__pagination(self): + def test_select_records_similarity_strategy_success(self): mock_describe_calls() task = _make_task( LoadData, @@ -1469,12 +1547,6 @@ def test_select_records_random_similarity_success__pagination(self): "Name": "Jawad", "Email": "mjawadtp@example.com", }, - ], - "done": False, - "nextRecordsUrl": "https://example.com", - } - results_more = { - "records": [ { "Id": "003000000000002", "Name": "Aditya", @@ -1488,10 +1560,8 @@ def test_select_records_random_similarity_success__pagination(self): ], "done": True, } - step.sf.query = mock.Mock() - step.sf.query.return_value = results - step.sf.query_more = mock.Mock() - step.sf.query_more.return_value = results_more + step.sf.restful = mock.Mock() + step.sf.restful.return_value = results records = iter( [ ["Id: 1", "Jawad", "mjawadtp@example.com"], @@ -1518,7 +1588,7 @@ def test_select_records_random_similarity_success__pagination(self): assert ( results.count( DataOperationResult( - id="003000000000001", success=True, error="", created=False + id="003000000000002", success=True, error="", created=False ) ) == 1 @@ -1526,7 +1596,7 @@ def test_select_records_random_similarity_success__pagination(self): assert ( results.count( DataOperationResult( - id="003000000000001", success=True, error="", created=False + id="003000000000003", success=True, error="", created=False ) ) == 1 @@ -1572,8 +1642,8 @@ def test_select_records_similarity_strategy_failure__no_records(self): ) results = {"records": [], "done": True} - step.sf.query = mock.Mock() - step.sf.query.return_value = results + step.sf.restful = mock.Mock() + step.sf.restful.return_value = results records = iter( [ ["Id: 1", "Jawad", "mjawadtp@example.com"], @@ -2071,6 +2141,8 @@ def test_get_dml_operation(self, rest_dml, bulk_dml): context=context, api=DataApi.BULK, volume=1, + selection_strategy=SelectStrategy.SIMILARITY, + selection_filter=None, ) assert op == bulk_dml.return_value @@ -2080,6 +2152,8 @@ def test_get_dml_operation(self, rest_dml, bulk_dml): fields=["Name"], api_options={}, context=context, + selection_strategy=SelectStrategy.SIMILARITY, + selection_filter=None, ) op = get_dml_operation( @@ -2090,6 +2164,8 @@ def test_get_dml_operation(self, rest_dml, bulk_dml): context=context, api=DataApi.REST, volume=1, + selection_strategy=SelectStrategy.SIMILARITY, + selection_filter=None, ) assert op == rest_dml.return_value @@ -2099,6 +2175,8 @@ def test_get_dml_operation(self, rest_dml, bulk_dml): fields=["Name"], api_options={}, context=context, + selection_strategy=SelectStrategy.SIMILARITY, + selection_filter=None, ) @mock.patch("cumulusci.tasks.bulkdata.step.BulkApiDmlOperation") @@ -2261,3 +2339,102 @@ def test_cleanup_date_strings__upsert_update(self, operation): "Name": "Bill", "attributes": {"type": "Test__c"}, }, json_out + + +import pytest + + +def test_generate_user_filter_query_basic(): + """Tests basic query generation without existing LIMIT or OFFSET.""" + filter_clause = "WHERE Name = 'John'" + sobject = "Account" + fields = ["Id", "Name"] + limit_clause = 10 + offset_clause = 5 + + expected_query = ( + "SELECT Id, Name FROM Account WHERE Name = 'John' LIMIT 10 OFFSET 5" + ) + assert ( + generate_user_filter_query( + filter_clause, sobject, fields, limit_clause, offset_clause + ) + == expected_query + ) + + +def test_generate_user_filter_query_existing_limit(): + """Tests handling of existing LIMIT in the filter clause.""" + filter_clause = "WHERE Name = 'John' LIMIT 20" + sobject = "Contact" + fields = ["Id", "FirstName"] + limit_clause = 5 # Should override the existing LIMIT + offset_clause = None + + expected_query = "SELECT Id, FirstName FROM Contact WHERE Name = 'John' LIMIT 5" + assert ( + generate_user_filter_query( + filter_clause, sobject, fields, limit_clause, offset_clause + ) + == expected_query + ) + + +def test_generate_user_filter_query_existing_offset(): + """Tests handling of existing OFFSET in the filter clause.""" + filter_clause = "WHERE Name = 'John' OFFSET 15" + sobject = "Opportunity" + fields = ["Id", "Name"] + limit_clause = None + offset_clause = 10 # Should add to the existing OFFSET + + expected_query = "SELECT Id, Name FROM Opportunity WHERE Name = 'John' OFFSET 25" + assert ( + generate_user_filter_query( + filter_clause, sobject, fields, limit_clause, offset_clause + ) + == expected_query + ) + + +def test_generate_user_filter_query_no_limit_or_offset(): + """Tests when no limit or offset is provided or present in the filter.""" + filter_clause = "WHERE Name = 'John' LIMIT 5 OFFSET 20" + sobject = "Lead" + fields = ["Id", "Name", "Email"] + limit_clause = None + offset_clause = None + + expected_query = ( + "SELECT Id, Name, Email FROM Lead WHERE Name = 'John' LIMIT 5 OFFSET 20" + ) + print( + generate_user_filter_query( + filter_clause, sobject, fields, limit_clause, offset_clause + ) + ) + assert ( + generate_user_filter_query( + filter_clause, sobject, fields, limit_clause, offset_clause + ) + == expected_query + ) + + +def test_generate_user_filter_query_case_insensitivity(): + """Tests case-insensitivity for LIMIT and OFFSET.""" + filter_clause = "where name = 'John' offset 5 limit 20" + sobject = "Task" + fields = ["Id", "Subject"] + limit_clause = 15 + offset_clause = 20 + + expected_query = ( + "SELECT Id, Subject FROM Task where name = 'John' LIMIT 15 OFFSET 25" + ) + assert ( + generate_user_filter_query( + filter_clause, sobject, fields, limit_clause, offset_clause + ) + == expected_query + ) From 196247a27a93e69a5356c9ad9a57468e3e59f751 Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Tue, 3 Sep 2024 11:57:37 +0530 Subject: [PATCH 15/34] Add limit and offset to queries under batch processing --- cumulusci/tasks/bulkdata/select_utils.py | 25 ++++++---- cumulusci/tasks/bulkdata/step.py | 46 +++++++++++++------ .../tasks/bulkdata/tests/test_select_utils.py | 46 ++++++++++++------- 3 files changed, 79 insertions(+), 38 deletions(-) diff --git a/cumulusci/tasks/bulkdata/select_utils.py b/cumulusci/tasks/bulkdata/select_utils.py index 976f852540..6b9623eb59 100644 --- a/cumulusci/tasks/bulkdata/select_utils.py +++ b/cumulusci/tasks/bulkdata/select_utils.py @@ -20,17 +20,23 @@ def __init__(self, strategy: SelectStrategy): self.strategy = strategy def select_generate_query( - self, sobject: str, fields: T.List[str], num_records: int + self, + sobject: str, + fields: T.List[str], + limit: T.Union[int, None], + offset: T.Union[int, None], ): # For STANDARD strategy if self.strategy == SelectStrategy.STANDARD: - return standard_generate_query(sobject=sobject, num_records=num_records) + return standard_generate_query(sobject=sobject, limit=limit, offset=offset) # For SIMILARITY strategy elif self.strategy == SelectStrategy.SIMILARITY: - return similarity_generate_query(sobject=sobject, fields=fields) + return similarity_generate_query( + sobject=sobject, fields=fields, limit=limit, offset=offset + ) # For RANDOM strategy elif self.strategy == SelectStrategy.RANDOM: - return standard_generate_query(sobject=sobject, num_records=num_records) + return standard_generate_query(sobject=sobject, limit=limit, offset=offset) def select_post_process( self, load_records, query_records: list, num_records: int, sobject: str @@ -53,7 +59,7 @@ def select_post_process( def standard_generate_query( - sobject: str, num_records: int + sobject: str, limit: T.Union[int, None], offset: T.Union[int, None] ) -> T.Tuple[str, T.List[str]]: """Generates the SOQL query for the standard (as well as random) selection strategy""" # Get the WHERE clause from DEFAULT_DECLARATIONS if available @@ -66,8 +72,8 @@ def standard_generate_query( query = f"SELECT Id FROM {sobject}" if where_clause: query += f" WHERE {where_clause}" - query += f" LIMIT {num_records}" - + query += f" LIMIT {limit}" if limit else "" + query += f" OFFSET {offset}" if offset else "" return query, ["Id"] @@ -98,6 +104,8 @@ def standard_post_process( def similarity_generate_query( sobject: str, fields: T.List[str], + limit: T.Union[int, None], + offset: T.Union[int, None], ) -> T.Tuple[str, T.List[str]]: """Generates the SOQL query for the similarity selection strategy""" # Get the WHERE clause from DEFAULT_DECLARATIONS if available @@ -114,7 +122,8 @@ def similarity_generate_query( query = f"SELECT {fields_to_query} FROM {sobject}" if where_clause: query += f" WHERE {where_clause}" - + query += f" LIMIT {limit}" if limit else "" + query += f" OFFSET {offset}" if offset else "" return query, fields diff --git a/cumulusci/tasks/bulkdata/step.py b/cumulusci/tasks/bulkdata/step.py index fd25f0e19d..61f23a5808 100644 --- a/cumulusci/tasks/bulkdata/step.py +++ b/cumulusci/tasks/bulkdata/step.py @@ -462,16 +462,21 @@ def select_records(self, records): ) # Generate and execute SOQL query + # (not passing offset as it is not supported in Bulk) ( select_query, query_fields, ) = self.select_operation_executor.select_generate_query( - self.sobject, self.fields, num_records + sobject=self.sobject, fields=self.fields, limit=num_records, offset=None ) if self.selection_filter: # Generate user filter query if selection_filter is present (offset clause not supported) user_query = generate_user_filter_query( - self.selection_filter, self.sobject, ["Id"], num_records, None + filter_clause=self.selection_filter, + sobject=self.sobject, + fields=["Id"], + limit_clause=num_records, + offset_clause=None, ) # Execute the user query using Bulk API user_query_executor = get_query_operation( @@ -508,19 +513,22 @@ def select_records(self, records): selected_records, error_message, ) = self.select_operation_executor.select_post_process( - records, query_records, num_records, self.sobject + load_records=records, + query_records=query_records, + num_records=num_records, + sobject=self.sobject, ) if not error_message: self.select_results.extend(selected_records) # Update job result based on selection outcome self.job_result = DataOperationJobResult( - DataOperationStatus.SUCCESS + status=DataOperationStatus.SUCCESS if len(self.select_results) else DataOperationStatus.JOB_FAILURE, - [error_message] if error_message else [], - len(self.select_results), - 0, + job_errors=[error_message] if error_message else [], + records_processed=len(self.select_results), + total_row_errors=0, ) def _execute_select_query(self, select_query: str, query_fields: List[str]): @@ -814,13 +822,20 @@ def convert(rec, fields): select_query, query_fields, ) = self.select_operation_executor.select_generate_query( - self.sobject, self.fields, num_records + sobject=self.sobject, + fields=self.fields, + limit=num_records, + offset=offset, ) # If user given selection filter present, create composite request if self.selection_filter: user_query = generate_user_filter_query( - self.selection_filter, self.sobject, ["Id"], num_records, offset + filter_clause=self.selection_filter, + sobject=self.sobject, + fields=["Id"], + limit_clause=num_records, + offset_clause=offset, ) query_records.extend( self._execute_composite_query( @@ -843,7 +858,10 @@ def convert(rec, fields): selected_records, error_message, ) = self.select_operation_executor.select_post_process( - records, query_records, total_num_records, self.sobject + load_records=records, + query_records=query_records, + num_records=total_num_records, + sobject=self.sobject, ) if not error_message: # Add selected records from this batch to the overall results @@ -851,12 +869,12 @@ def convert(rec, fields): # Update the job result based on the overall selection outcome self.job_result = DataOperationJobResult( - DataOperationStatus.SUCCESS + status=DataOperationStatus.SUCCESS if len(self.results) # Check the overall results length else DataOperationStatus.JOB_FAILURE, - [error_message] if error_message else [], - len(self.results), - 0, + job_errors=[error_message] if error_message else [], + records_processed=len(self.results), + total_row_errors=0, ) def _execute_composite_query(self, select_query, user_query, query_fields): diff --git a/cumulusci/tasks/bulkdata/tests/test_select_utils.py b/cumulusci/tasks/bulkdata/tests/test_select_utils.py index 0ae97acb46..755af5d009 100644 --- a/cumulusci/tasks/bulkdata/tests/test_select_utils.py +++ b/cumulusci/tasks/bulkdata/tests/test_select_utils.py @@ -11,26 +11,30 @@ def test_standard_generate_query_with_default_record_declaration(): select_operator = SelectOperationExecutor(SelectStrategy.STANDARD) sobject = "Account" # Assuming Account has a declaration in DEFAULT_DECLARATIONS - num_records = 5 + limit = 5 + offset = 2 query, fields = select_operator.select_generate_query( - sobject=sobject, fields=[], num_records=num_records + sobject=sobject, fields=[], limit=limit, offset=offset ) assert "WHERE" in query # Ensure WHERE clause is included - assert f"LIMIT {num_records}" in query + assert f"LIMIT {limit}" in query + assert f"OFFSET {offset}" in query assert fields == ["Id"] def test_standard_generate_query_without_default_record_declaration(): select_operator = SelectOperationExecutor(SelectStrategy.STANDARD) sobject = "Contact" # Assuming no declaration for this object - num_records = 3 + limit = 3 + offset = None query, fields = select_operator.select_generate_query( - sobject=sobject, fields=[], num_records=num_records + sobject=sobject, fields=[], limit=limit, offset=offset ) assert "WHERE" not in query # No WHERE clause should be present - assert f"LIMIT {num_records}" in query + assert f"LIMIT {limit}" in query + assert "OFFSET" not in query assert fields == ["Id"] @@ -38,26 +42,30 @@ def test_standard_generate_query_without_default_record_declaration(): def test_random_generate_query_with_default_record_declaration(): select_operator = SelectOperationExecutor(SelectStrategy.RANDOM) sobject = "Account" # Assuming Account has a declaration in DEFAULT_DECLARATIONS - num_records = 5 + limit = 5 + offset = 2 query, fields = select_operator.select_generate_query( - sobject=sobject, fields=[], num_records=num_records + sobject=sobject, fields=[], limit=limit, offset=offset ) assert "WHERE" in query # Ensure WHERE clause is included - assert f"LIMIT {num_records}" in query + assert f"LIMIT {limit}" in query + assert f"OFFSET {offset}" in query assert fields == ["Id"] def test_random_generate_query_without_default_record_declaration(): select_operator = SelectOperationExecutor(SelectStrategy.RANDOM) sobject = "Contact" # Assuming no declaration for this object - num_records = 3 + limit = 3 + offset = None query, fields = select_operator.select_generate_query( - sobject=sobject, fields=[], num_records=num_records + sobject=sobject, fields=[], limit=limit, offset=offset ) assert "WHERE" not in query # No WHERE clause should be present - assert f"LIMIT {num_records}" in query + assert f"LIMIT {limit}" in query + assert "OFFSET" not in query assert fields == ["Id"] @@ -141,25 +149,31 @@ def test_random_post_process_with_no_records(): def test_similarity_generate_query_with_default_record_declaration(): select_operator = SelectOperationExecutor(SelectStrategy.SIMILARITY) sobject = "Account" # Assuming Account has a declaration in DEFAULT_DECLARATIONS - num_records = 5 + limit = 5 + offset = 2 query, fields = select_operator.select_generate_query( - sobject, ["Name"], num_records + sobject, ["Name"], limit, offset ) assert "WHERE" in query # Ensure WHERE clause is included assert fields == ["Id", "Name"] + assert f"LIMIT {limit}" in query + assert f"OFFSET {offset}" in query def test_similarity_generate_query_without_default_record_declaration(): select_operator = SelectOperationExecutor(SelectStrategy.SIMILARITY) sobject = "Contact" # Assuming no declaration for this object - num_records = 3 + limit = 3 + offset = None query, fields = select_operator.select_generate_query( - sobject, ["Name"], num_records + sobject, ["Name"], limit, offset ) assert "WHERE" not in query # No WHERE clause should be present assert fields == ["Id", "Name"] + assert f"LIMIT {limit}" in query + assert "OFFSET" not in query def test_levenshtein_distance(): From 6eca45517557b87f3008483bf512d452d148c0ce Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Tue, 3 Sep 2024 12:37:53 +0530 Subject: [PATCH 16/34] Add failure scenario for calculate_levenshtein_distance --- cumulusci/tasks/bulkdata/load.py | 1 + .../tasks/bulkdata/tests/test_select_utils.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/cumulusci/tasks/bulkdata/load.py b/cumulusci/tasks/bulkdata/load.py index 9435dfd183..d416fa1f63 100644 --- a/cumulusci/tasks/bulkdata/load.py +++ b/cumulusci/tasks/bulkdata/load.py @@ -342,6 +342,7 @@ def configure_step(self, mapping): api_options["update_key"] = mapping.update_key[0] action = DataOperationType.UPSERT elif mapping.action == DataOperationType.SELECT: + # Bulk process expects DataOpertionType to be QUERY action = DataOperationType.QUERY else: action = mapping.action diff --git a/cumulusci/tasks/bulkdata/tests/test_select_utils.py b/cumulusci/tasks/bulkdata/tests/test_select_utils.py index 755af5d009..fe037a0177 100644 --- a/cumulusci/tasks/bulkdata/tests/test_select_utils.py +++ b/cumulusci/tasks/bulkdata/tests/test_select_utils.py @@ -1,3 +1,5 @@ +import pytest + from cumulusci.tasks.bulkdata.select_utils import ( SelectOperationExecutor, SelectStrategy, @@ -215,6 +217,20 @@ def test_calculate_levenshtein_distance(): assert calculate_levenshtein_distance(record1, record2) == 0 # Distance should be 0 +def test_calculate_levenshtein_distance_error(): + # Identical records + record1 = ["Tom Cruise", "24", "Actor"] + record2 = [ + "Tom Cruise", + "24", + "Actor", + "SomethingElse", + ] # Record Length does not match + with pytest.raises(ValueError) as e: + calculate_levenshtein_distance(record1, record2) + assert "Records must have the same number of fields" in str(e.value) + + def test_find_closest_record(): # Test case 1: Exact match load_record = ["Tom Cruise", "62", "Actor"] From ebd5f08503e856501daf9c34c581be195250c2ba Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Wed, 4 Sep 2024 11:58:51 +0530 Subject: [PATCH 17/34] Modify functionality to return records in order of the user query to support ORDER BY operation --- cumulusci/tasks/bulkdata/step.py | 27 +++- cumulusci/tasks/bulkdata/tests/test_step.py | 145 ++++++++++++++++++++ 2 files changed, 166 insertions(+), 6 deletions(-) diff --git a/cumulusci/tasks/bulkdata/step.py b/cumulusci/tasks/bulkdata/step.py index 61f23a5808..7dd025c345 100644 --- a/cumulusci/tasks/bulkdata/step.py +++ b/cumulusci/tasks/bulkdata/step.py @@ -491,7 +491,11 @@ def select_records(self, records): user_query_records = user_query_executor.get_results() # Find intersection based on 'Id' - user_query_ids = set(record[0] for record in user_query_records) + user_query_ids = ( + list(record[0] for record in user_query_records) + if user_query_records + else [] + ) # Execute the main select query using Bulk API select_query_records = self._execute_select_query( @@ -500,10 +504,16 @@ def select_records(self, records): # If user_query_ids exist, filter select_query_records based on the intersection of Ids if self.selection_filter: + # Create a dictionary to map IDs to their corresponding records + id_to_record_map = { + record[query_fields.index("Id")]: record + for record in select_query_records + } + # Extend query_records in the order of user_query_ids query_records.extend( record - for record in select_query_records - if record[query_fields.index("Id")] in user_query_ids + for id in user_query_ids + if (record := id_to_record_map.get(id)) is not None ) else: query_records.extend(select_query_records) @@ -928,12 +938,17 @@ def convert(rec, fields): f"{sub_response['body'][0]['errorCode']}: {sub_response['body'][0]['message']}" ) # Find intersection based on 'Id' - user_query_ids = set(record[0] for record in user_query_records) + user_query_ids = list(record[0] for record in user_query_records) + # Create a dictionary to map IDs to their corresponding records + id_to_record_map = { + record[query_fields.index("Id")]: record for record in select_query_records + } + # Extend query_records in the order of user_query_ids return [ record - for record in select_query_records - if record[query_fields.index("Id")] in user_query_ids + for id in user_query_ids + if (record := id_to_record_map.get(id)) is not None ] def get_results(self): diff --git a/cumulusci/tasks/bulkdata/tests/test_step.py b/cumulusci/tasks/bulkdata/tests/test_step.py index b2904ae9c5..046c6d3a5a 100644 --- a/cumulusci/tasks/bulkdata/tests/test_step.py +++ b/cumulusci/tasks/bulkdata/tests/test_step.py @@ -701,6 +701,70 @@ def test_select_records_user_selection_filter_success(self, download_mock): == 3 ) + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_user_selection_filter_order_success(self, download_mock): + # Set up mock context and BulkApiDmlOperation + context = mock.Mock() + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=context, + fields=["LastName"], + selection_strategy=SelectStrategy.STANDARD, + selection_filter="ORDER BY CreatedDate", + ) + + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + # Mock the downloaded CSV content with a single record + download_mock.return_value = io.StringIO( + """Id +003000000000001 +003000000000002 +003000000000003""" + ) + # Mock the query operation + with mock.patch( + "cumulusci.tasks.bulkdata.step.get_query_operation" + ) as query_operation_mock: + query_operation_mock.return_value = mock.Mock() + query_operation_mock.return_value.query = mock.Mock() + query_operation_mock.return_value.get_results = mock.Mock() + query_operation_mock.return_value.get_results.return_value = [ + ["003000000000003"], + ["003000000000001"], + ["003000000000002"], + ] + + # Mock the _wait_for_job method to simulate a successful job + step._wait_for_job = mock.Mock() + step._wait_for_job.return_value = DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 0, 0 + ) + + # Prepare input records + records = iter([["Test1"], ["Test2"], ["Test3"]]) + + # Execute the select_records operation + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert ( + len(results) == 3 + ) # Expect 3 results (matching the input records count) + # Assert that all results are in the order given by user query + assert results[0].id == "003000000000003" + assert results[1].id == "003000000000001" + assert results[2].id == "003000000000002" + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") def test_select_records_user_selection_filter_failure(self, download_mock): # Set up mock context and BulkApiDmlOperation @@ -1428,6 +1492,87 @@ def test_select_records_user_selection_filter_success(self): == 3 ) + @responses.activate + def test_select_records_user_selection_filter_order_success(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.UPSERT, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=task, + fields=["LastName"], + selection_strategy=SelectStrategy.STANDARD, + selection_filter="ORDER BY CreatedDate", + ) + + results = { + "compositeResponse": [ + { + "body": { + "records": [ + {"Id": "003000000000001"}, + {"Id": "003000000000002"}, + {"Id": "003000000000003"}, + ] + }, + "referenceId": "select_query", + "httpStatusCode": 200, + }, + { + "body": { + "records": [ + {"Id": "003000000000003"}, + {"Id": "003000000000001"}, + {"Id": "003000000000002"}, + ] + }, + "referenceId": "user_query", + "httpStatusCode": 200, + }, + ] + } + step.sf.restful = mock.Mock() + step.sf.restful.return_value = results + records = iter([["Test1"], ["Test2"], ["Test3"]]) + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results are in the order of user_query + assert results[0].id == "003000000000003" + assert results[1].id == "003000000000001" + assert results[2].id == "003000000000002" + @responses.activate def test_select_records_user_selection_filter_failure(self): mock_describe_calls() From 8c2bb3adb04ee2a0710b08c38b2ebc67ae997492 Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Fri, 6 Sep 2024 15:39:22 +0530 Subject: [PATCH 18/34] Add documentation for 'select' functionality --- docs/data.md | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/docs/data.md b/docs/data.md index 063e3f33f5..c81ba44c90 100644 --- a/docs/data.md +++ b/docs/data.md @@ -250,6 +250,43 @@ Insert Accounts: Whenever `update_key` is supplied, the action must be `upsert` and vice versa. +### Selects + +The "select" functionality enhances the mapping process by enabling direct record selection from the target Salesforce org for lookups. This is achieved by specifying the `select` action in the mapping file, particularly useful when dealing with objects dependent on non-insertable Salesforce objects. + +```yaml +Select Accounts: + sf_object: Account + action: select + selection_strategy: standard + selection_filter: WHERE Name IN ('Bluth Company', 'Camacho PLC') + fields: + - Name + - AccountNumber +Insert Contacts: + sf_object: Contact + action: insert + fields: + - LastName + lookups: + AccountId: + table: Account +``` + +The `Select Accounts` section in this YAML demonstrates how to fetch specific records from your Salesforce org. These selected Account records will then be referenced by the subsequent `Insert Contacts` section via lookups, ensuring that new Contacts are linked to the pre-existing Accounts chosen in the `select` step rather than relying on any newly inserted Account records. + +#### Selection Strategy + +The `selection_strategy` dictates how these records are chosen: + +- `standard`: This strategy fetches records from the org in the same order as they appear, respecting any filtering applied via `selection_filter`. +- `similarity`: This strategy is employed when you want to find records in the org that closely resemble those defined in your SQL file. +- `random`: As the name suggests, this strategy randomly selects records from the org. + +#### Selection Filter + +The `selection_filter` acts as a versatile SOQL clause, providing fine-grained control over record selection. It allows filtering with `WHERE`, sorting with `ORDER BY`, limiting with `LIMIT`, and potentially utilizing other SOQL capabilities, ensuring you select the precise records needed for your chosen `selection_strategy`. + ### Database Mapping CumulusCI's definition format includes considerable flexibility for use From bb72bfb6c227d713867165ad49ef7832793f6151 Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Mon, 30 Sep 2024 16:20:22 +0530 Subject: [PATCH 19/34] Fixes issue for improper batching and intersection --- cumulusci/tasks/bulkdata/select_utils.py | 19 +- cumulusci/tasks/bulkdata/step.py | 221 +++++++++++--------- cumulusci/tasks/bulkdata/tests/test_step.py | 15 +- 3 files changed, 152 insertions(+), 103 deletions(-) diff --git a/cumulusci/tasks/bulkdata/select_utils.py b/cumulusci/tasks/bulkdata/select_utils.py index 6b9623eb59..315e2ae349 100644 --- a/cumulusci/tasks/bulkdata/select_utils.py +++ b/cumulusci/tasks/bulkdata/select_utils.py @@ -15,9 +15,22 @@ class SelectStrategy(StrEnum): RANDOM = "random" +class SelectRecordRetrievalMode(StrEnum): + """Enum defining whether you need all records or match the + number of records of the local sql file""" + + ALL = "all" + MATCH = "match" + + class SelectOperationExecutor: def __init__(self, strategy: SelectStrategy): self.strategy = strategy + self.retrieval_mode = ( + SelectRecordRetrievalMode.ALL + if strategy == SelectStrategy.SIMILARITY + else SelectRecordRetrievalMode.MATCH + ) def select_generate_query( self, @@ -96,7 +109,7 @@ def standard_post_process( original_records = selected_records.copy() while len(selected_records) < num_records: selected_records.extend(original_records) - selected_records = selected_records[:num_records] + selected_records = selected_records[:num_records] return selected_records, None # Return selected records and None for error @@ -115,8 +128,8 @@ def similarity_generate_query( else: where_clause = None # Construct the query with the WHERE clause (if it exists) - - fields.insert(0, "Id") + if "Id" not in fields: + fields.insert(0, "Id") fields_to_query = ", ".join(field for field in fields if field) query = f"SELECT {fields_to_query} FROM {sobject}" diff --git a/cumulusci/tasks/bulkdata/step.py b/cumulusci/tasks/bulkdata/step.py index 7dd025c345..5f100aa88d 100644 --- a/cumulusci/tasks/bulkdata/step.py +++ b/cumulusci/tasks/bulkdata/step.py @@ -19,6 +19,7 @@ from cumulusci.core.utils import process_bool_arg from cumulusci.tasks.bulkdata.select_utils import ( SelectOperationExecutor, + SelectRecordRetrievalMode, SelectStrategy, ) from cumulusci.tasks.bulkdata.utils import DataApi, iterate_in_chunks @@ -452,71 +453,66 @@ def select_records(self, records): # Count total number of records to fetch using the copy total_num_records = sum(1 for _ in records_copy) - # Process in batches based on batch_size from api_options - for offset in range( - 0, total_num_records, self.api_options.get("batch_size", 500) - ): - # Calculate number of records to fetch in this batch - num_records = min( - self.api_options.get("batch_size", 500), total_num_records - offset + # Since OFFSET is not supported in bulk, we can run only over 1 api_batch_size + # Generate and execute SOQL query + # (not passing offset as it is not supported in Bulk) + ( + select_query, + query_fields, + ) = self.select_operation_executor.select_generate_query( + sobject=self.sobject, + fields=self.fields, + limit=self.api_options.get("batch_size", 500), + offset=None, + ) + if self.selection_filter: + # Generate user filter query if selection_filter is present (offset clause not supported) + user_query = generate_user_filter_query( + filter_clause=self.selection_filter, + sobject=self.sobject, + fields=["Id"], + limit_clause=self.api_options.get("batch_size", 500), + offset_clause=None, ) - - # Generate and execute SOQL query - # (not passing offset as it is not supported in Bulk) - ( - select_query, - query_fields, - ) = self.select_operation_executor.select_generate_query( - sobject=self.sobject, fields=self.fields, limit=num_records, offset=None + # Execute the user query using Bulk API + user_query_executor = get_query_operation( + sobject=self.sobject, + fields=["Id"], + api_options=self.api_options, + context=self, + query=user_query, + api=DataApi.BULK, ) - if self.selection_filter: - # Generate user filter query if selection_filter is present (offset clause not supported) - user_query = generate_user_filter_query( - filter_clause=self.selection_filter, - sobject=self.sobject, - fields=["Id"], - limit_clause=num_records, - offset_clause=None, - ) - # Execute the user query using Bulk API - user_query_executor = get_query_operation( - sobject=self.sobject, - fields=["Id"], - api_options=self.api_options, - context=self, - query=user_query, - api=DataApi.BULK, - ) - user_query_executor.query() - user_query_records = user_query_executor.get_results() - - # Find intersection based on 'Id' - user_query_ids = ( - list(record[0] for record in user_query_records) - if user_query_records - else [] - ) - - # Execute the main select query using Bulk API - select_query_records = self._execute_select_query( - select_query=select_query, query_fields=query_fields + user_query_executor.query() + user_query_records = user_query_executor.get_results() + + # Find intersection based on 'Id' + user_query_ids = ( + list(record[0] for record in user_query_records) + if user_query_records + else [] ) - # If user_query_ids exist, filter select_query_records based on the intersection of Ids - if self.selection_filter: - # Create a dictionary to map IDs to their corresponding records - id_to_record_map = { - record[query_fields.index("Id")]: record - for record in select_query_records - } - # Extend query_records in the order of user_query_ids - query_records.extend( - record - for id in user_query_ids - if (record := id_to_record_map.get(id)) is not None - ) - else: - query_records.extend(select_query_records) + # Execute the main select query using Bulk API + select_query_records = self._execute_select_query( + select_query=select_query, query_fields=query_fields + ) + + # If user_query_ids exist, filter select_query_records based on the intersection of Ids + if self.selection_filter: + # Create a dictionary to map IDs to their corresponding records + id_to_record_map = { + record[query_fields.index("Id")]: record + for record in select_query_records + } + # Extend query_records in the order of user_query_ids + query_records.extend( + record + for id in user_query_ids + if (record := id_to_record_map.get(id)) is not None + ) + else: + query_records.extend(select_query_records) # Post-process the query results ( @@ -525,7 +521,7 @@ def select_records(self, records): ) = self.select_operation_executor.select_post_process( load_records=records, query_records=query_records, - num_records=num_records, + num_records=total_num_records, sobject=self.sobject, ) if not error_message: @@ -674,7 +670,7 @@ def __init__( api_options, context, fields, - selection_strategy=SelectStrategy.SIMILARITY, + selection_strategy=SelectStrategy.STANDARD, selection_filter=None, ): super().__init__( @@ -816,17 +812,25 @@ def convert(rec, fields): self.results = [] query_records = [] + user_query_records = [] # Create a copy of the generator using tee records, records_copy = tee(records) # Count total number of records to fetch using the copy total_num_records = sum(1 for _ in records_copy) + # Set offset + offset = 0 - # Process in batches - for offset in range(0, total_num_records, self.api_options.get("batch_size")): - num_records = min( - self.api_options.get("batch_size"), total_num_records - offset - ) + # Define condition + def condition(retrieval_mode, offset, total_num_records): + if retrieval_mode == SelectRecordRetrievalMode.ALL: + return True + elif retrieval_mode == SelectRecordRetrievalMode.MATCH: + return offset < total_num_records + # Process in batches + while condition( + self.select_operation_executor.retrieval_mode, offset, total_num_records + ): # Generate the SOQL query based on the selection strategy ( select_query, @@ -834,34 +838,74 @@ def convert(rec, fields): ) = self.select_operation_executor.select_generate_query( sobject=self.sobject, fields=self.fields, - limit=num_records, + limit=self.api_options.get("batch_size"), offset=offset, ) # If user given selection filter present, create composite request if self.selection_filter: + # Generate user query user_query = generate_user_filter_query( filter_clause=self.selection_filter, sobject=self.sobject, fields=["Id"], - limit_clause=num_records, + limit_clause=self.api_options.get("batch_size"), offset_clause=offset, ) - query_records.extend( - self._execute_composite_query( - select_query=select_query, - user_query=user_query, - query_fields=query_fields, - ) + # Execute composite query + ( + current_user_query_records, + current_query_records, + ) = self._execute_composite_query( + select_query=select_query, + user_query=user_query, + query_fields=query_fields, ) + # Break if org has no more records + if ( + len(current_user_query_records) == 0 + and len(current_query_records) == 0 + ): + break + + # Extend to each + user_query_records.extend(current_user_query_records) + query_records.extend(current_query_records) + else: # Handle the case where self.selection_query is None (and hence user_query is also None) response = self.sf.restful( requests.utils.requote_uri(f"query/?q={select_query}"), method="GET" ) - query_records.extend( - list(convert(rec, query_fields) for rec in response["records"]) + current_query_records = list( + convert(rec, query_fields) for rec in response["records"] ) + # Break if nothing is returned + if len(current_query_records) == 0: + break + # Extend the query records + query_records.extend(current_query_records) + + # Update offset + offset += self.api_options.get("batch_size") + + # Find intersection if filter given + if self.selection_filter: + # Find intersection based on 'Id' + user_query_ids = list(record[0] for record in user_query_records) + # Create a dictionary to map IDs to their corresponding records + id_to_record_map = { + record[query_fields.index("Id")]: record for record in query_records + } + + # Extend insersection_query_records in the order of user_query_ids + insersection_query_records = [ + record + for id in user_query_ids + if (record := id_to_record_map.get(id)) is not None + ] + else: + insersection_query_records = query_records # Post-process the query results for this batch ( @@ -869,7 +913,7 @@ def convert(rec, fields): error_message, ) = self.select_operation_executor.select_post_process( load_records=records, - query_records=query_records, + query_records=insersection_query_records, num_records=total_num_records, sobject=self.sobject, ) @@ -888,7 +932,7 @@ def convert(rec, fields): ) def _execute_composite_query(self, select_query, user_query, query_fields): - """Executes a composite request with two queries and returns the intersected results.""" + """Executes a composite request with two queries and returns the results.""" def convert(rec, fields): """Helper function to convert record values to strings, handling None values""" @@ -937,19 +981,8 @@ def convert(rec, fields): raise SOQLQueryException( f"{sub_response['body'][0]['errorCode']}: {sub_response['body'][0]['message']}" ) - # Find intersection based on 'Id' - user_query_ids = list(record[0] for record in user_query_records) - # Create a dictionary to map IDs to their corresponding records - id_to_record_map = { - record[query_fields.index("Id")]: record for record in select_query_records - } - # Extend query_records in the order of user_query_ids - return [ - record - for id in user_query_ids - if (record := id_to_record_map.get(id)) is not None - ] + return user_query_records, select_query_records def get_results(self): """Return a generator of DataOperationResult objects.""" @@ -1076,8 +1109,8 @@ def generate_user_filter_query( filter_clause: str, sobject: str, fields: list, - limit_clause: Union[int, None] = None, - offset_clause: Union[int, None] = None, + limit_clause: Union[float, None] = None, + offset_clause: Union[float, None] = None, ) -> str: """ Generates a SOQL query with the provided filter, object, fields, limit, and offset clauses. diff --git a/cumulusci/tasks/bulkdata/tests/test_step.py b/cumulusci/tasks/bulkdata/tests/test_step.py index 046c6d3a5a..c182a92996 100644 --- a/cumulusci/tasks/bulkdata/tests/test_step.py +++ b/cumulusci/tasks/bulkdata/tests/test_step.py @@ -1685,7 +1685,7 @@ def test_select_records_similarity_strategy_success(self): selection_strategy=SelectStrategy.SIMILARITY, ) - results = { + results_first_call = { "records": [ { "Id": "003000000000001", @@ -1705,13 +1705,16 @@ def test_select_records_similarity_strategy_success(self): ], "done": True, } - step.sf.restful = mock.Mock() - step.sf.restful.return_value = results + + # First call returns `results_first_call`, second call returns an empty list + step.sf.restful = mock.Mock( + side_effect=[results_first_call, {"records": [], "done": True}] + ) records = iter( [ - ["Id: 1", "Jawad", "mjawadtp@example.com"], - ["Id: 2", "Aditya", "aditya@example.com"], - ["Id: 3", "Tom Cruise", "tom@example.com"], + ["Jawad", "mjawadtp@example.com"], + ["Aditya", "aditya@example.com"], + ["Tom Cruise", "tom@example.com"], ] ) step.start() From ada400f9d9cfd57193a33f9b5e45b57ee41f9879 Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Tue, 1 Oct 2024 12:45:29 +0530 Subject: [PATCH 20/34] Override user filter with our filters Also solve issue where offset if greater than 2000, was causing an issue --- cumulusci/tasks/bulkdata/load.py | 19 +- cumulusci/tasks/bulkdata/select_utils.py | 108 ++++++++--- cumulusci/tasks/bulkdata/step.py | 235 +++++------------------ 3 files changed, 154 insertions(+), 208 deletions(-) diff --git a/cumulusci/tasks/bulkdata/load.py b/cumulusci/tasks/bulkdata/load.py index d416fa1f63..f83199050a 100644 --- a/cumulusci/tasks/bulkdata/load.py +++ b/cumulusci/tasks/bulkdata/load.py @@ -313,6 +313,7 @@ def configure_step(self, mapping): """Create a step appropriate to the action""" bulk_mode = mapping.bulk_mode or self.bulk_mode or "Parallel" api_options = {"batch_size": mapping.batch_size, "bulk_mode": bulk_mode} + num_records_in_target = None fields = mapping.get_load_field_list() @@ -344,11 +345,27 @@ def configure_step(self, mapping): elif mapping.action == DataOperationType.SELECT: # Bulk process expects DataOpertionType to be QUERY action = DataOperationType.QUERY + # Determine number of records in the target org + record_count_response = self.sf.restful( + f"limits/recordCount?sObjects={mapping.sf_object}" + ) + sobject_map = { + entry["name"]: entry["count"] + for entry in record_count_response["sObjects"] + } + num_records_in_target = sobject_map.get(mapping.sf_object, None) else: action = mapping.action query = self._query_db(mapping) + # Set volume + volume = ( + num_records_in_target + if num_records_in_target is not None + else query.count() + ) + step = get_dml_operation( sobject=mapping.sf_object, operation=action, @@ -356,7 +373,7 @@ def configure_step(self, mapping): context=self, fields=fields, api=mapping.api, - volume=query.count(), + volume=volume, selection_strategy=mapping.selection_strategy, selection_filter=mapping.selection_filter, ) diff --git a/cumulusci/tasks/bulkdata/select_utils.py b/cumulusci/tasks/bulkdata/select_utils.py index 315e2ae349..daa993d045 100644 --- a/cumulusci/tasks/bulkdata/select_utils.py +++ b/cumulusci/tasks/bulkdata/select_utils.py @@ -1,4 +1,5 @@ import random +import re import typing as T from cumulusci.core.enums import StrEnum @@ -36,20 +37,29 @@ def select_generate_query( self, sobject: str, fields: T.List[str], + user_filter: str, limit: T.Union[int, None], offset: T.Union[int, None], ): # For STANDARD strategy if self.strategy == SelectStrategy.STANDARD: - return standard_generate_query(sobject=sobject, limit=limit, offset=offset) + return standard_generate_query( + sobject=sobject, user_filter=user_filter, limit=limit, offset=offset + ) # For SIMILARITY strategy elif self.strategy == SelectStrategy.SIMILARITY: return similarity_generate_query( - sobject=sobject, fields=fields, limit=limit, offset=offset + sobject=sobject, + fields=fields, + user_filter=user_filter, + limit=limit, + offset=offset, ) # For RANDOM strategy elif self.strategy == SelectStrategy.RANDOM: - return standard_generate_query(sobject=sobject, limit=limit, offset=offset) + return standard_generate_query( + sobject=sobject, user_filter=user_filter, limit=limit, offset=offset + ) def select_post_process( self, load_records, query_records: list, num_records: int, sobject: str @@ -72,21 +82,26 @@ def select_post_process( def standard_generate_query( - sobject: str, limit: T.Union[int, None], offset: T.Union[int, None] + sobject: str, + user_filter: str, + limit: T.Union[int, None], + offset: T.Union[int, None], ) -> T.Tuple[str, T.List[str]]: """Generates the SOQL query for the standard (as well as random) selection strategy""" - # Get the WHERE clause from DEFAULT_DECLARATIONS if available - declaration = DEFAULT_DECLARATIONS.get(sobject) - if declaration: - where_clause = declaration.where - else: - where_clause = None - # Construct the query with the WHERE clause (if it exists) + query = f"SELECT Id FROM {sobject}" - if where_clause: - query += f" WHERE {where_clause}" - query += f" LIMIT {limit}" if limit else "" - query += f" OFFSET {offset}" if offset else "" + # If user specifies user_filter + if user_filter: + query += add_limit_offset_to_user_filter( + filter_clause=user_filter, limit_clause=limit, offset_clause=offset + ) + else: + # Get the WHERE clause from DEFAULT_DECLARATIONS if available + declaration = DEFAULT_DECLARATIONS.get(sobject) + if declaration: + query += f" WHERE {declaration.where}" + query += f" LIMIT {limit}" if limit else "" + query += f" OFFSET {offset}" if offset else "" return query, ["Id"] @@ -117,26 +132,29 @@ def standard_post_process( def similarity_generate_query( sobject: str, fields: T.List[str], + user_filter: str, limit: T.Union[int, None], offset: T.Union[int, None], ) -> T.Tuple[str, T.List[str]]: """Generates the SOQL query for the similarity selection strategy""" - # Get the WHERE clause from DEFAULT_DECLARATIONS if available - declaration = DEFAULT_DECLARATIONS.get(sobject) - if declaration: - where_clause = declaration.where - else: - where_clause = None # Construct the query with the WHERE clause (if it exists) if "Id" not in fields: fields.insert(0, "Id") fields_to_query = ", ".join(field for field in fields if field) query = f"SELECT {fields_to_query} FROM {sobject}" - if where_clause: - query += f" WHERE {where_clause}" - query += f" LIMIT {limit}" if limit else "" - query += f" OFFSET {offset}" if offset else "" + + if user_filter: + query += add_limit_offset_to_user_filter( + filter_clause=user_filter, limit_clause=limit, offset_clause=offset + ) + else: + # Get the WHERE clause from DEFAULT_DECLARATIONS if available + declaration = DEFAULT_DECLARATIONS.get(sobject) + if declaration: + query += f" WHERE {declaration.where}" + query += f" LIMIT {limit}" if limit else "" + query += f" OFFSET {offset}" if offset else "" return query, fields @@ -242,3 +260,43 @@ def calculate_levenshtein_distance(record1: list, record2: list): total_fields += 1 return total_distance / total_fields if total_fields > 0 else 0 + + +def add_limit_offset_to_user_filter( + filter_clause: str, + limit_clause: T.Union[float, None] = None, + offset_clause: T.Union[float, None] = None, +) -> str: + + # Extract existing LIMIT and OFFSET from filter_clause if present + existing_limit_match = re.search(r"LIMIT\s+(\d+)", filter_clause, re.IGNORECASE) + existing_offset_match = re.search(r"OFFSET\s+(\d+)", filter_clause, re.IGNORECASE) + + if existing_limit_match: + existing_limit = int(existing_limit_match.group(1)) + if limit_clause is not None: # Only apply limit_clause if it's provided + limit_clause = min(existing_limit, limit_clause) + else: + limit_clause = existing_limit + + if existing_offset_match: + existing_offset = int(existing_offset_match.group(1)) + if offset_clause is not None: + offset_clause = existing_offset + offset_clause + else: + offset_clause = existing_offset + + # Remove existing LIMIT and OFFSET from filter_clause, handling potential extra spaces + filter_clause = re.sub( + r"\s+OFFSET\s+\d+\s*", " ", filter_clause, flags=re.IGNORECASE + ).strip() + filter_clause = re.sub( + r"\s+LIMIT\s+\d+\s*", " ", filter_clause, flags=re.IGNORECASE + ).strip() + + if limit_clause is not None: + filter_clause += f" LIMIT {limit_clause}" + if offset_clause is not None: + filter_clause += f" OFFSET {offset_clause}" + + return f" {filter_clause}" diff --git a/cumulusci/tasks/bulkdata/step.py b/cumulusci/tasks/bulkdata/step.py index 5f100aa88d..3f3fbaf0f3 100644 --- a/cumulusci/tasks/bulkdata/step.py +++ b/cumulusci/tasks/bulkdata/step.py @@ -3,7 +3,6 @@ import json import os import pathlib -import re import tempfile import time from abc import ABCMeta, abstractmethod @@ -453,7 +452,18 @@ def select_records(self, records): # Count total number of records to fetch using the copy total_num_records = sum(1 for _ in records_copy) - # Since OFFSET is not supported in bulk, we can run only over 1 api_batch_size + # Set LIMIT condition + if ( + self.select_operation_executor.retrieval_mode + == SelectRecordRetrievalMode.ALL + ): + limit_clause = None + elif ( + self.select_operation_executor.retrieval_mode + == SelectRecordRetrievalMode.MATCH + ): + limit_clause = total_num_records + # Generate and execute SOQL query # (not passing offset as it is not supported in Bulk) ( @@ -462,58 +472,17 @@ def select_records(self, records): ) = self.select_operation_executor.select_generate_query( sobject=self.sobject, fields=self.fields, - limit=self.api_options.get("batch_size", 500), + user_filter=self.selection_filter if self.selection_filter else None, + limit=limit_clause, offset=None, ) - if self.selection_filter: - # Generate user filter query if selection_filter is present (offset clause not supported) - user_query = generate_user_filter_query( - filter_clause=self.selection_filter, - sobject=self.sobject, - fields=["Id"], - limit_clause=self.api_options.get("batch_size", 500), - offset_clause=None, - ) - # Execute the user query using Bulk API - user_query_executor = get_query_operation( - sobject=self.sobject, - fields=["Id"], - api_options=self.api_options, - context=self, - query=user_query, - api=DataApi.BULK, - ) - user_query_executor.query() - user_query_records = user_query_executor.get_results() - - # Find intersection based on 'Id' - user_query_ids = ( - list(record[0] for record in user_query_records) - if user_query_records - else [] - ) # Execute the main select query using Bulk API select_query_records = self._execute_select_query( select_query=select_query, query_fields=query_fields ) - # If user_query_ids exist, filter select_query_records based on the intersection of Ids - if self.selection_filter: - # Create a dictionary to map IDs to their corresponding records - id_to_record_map = { - record[query_fields.index("Id")]: record - for record in select_query_records - } - # Extend query_records in the order of user_query_ids - query_records.extend( - record - for id in user_query_ids - if (record := id_to_record_map.get(id)) is not None - ) - else: - query_records.extend(select_query_records) - + query_records.extend(select_query_records) # Post-process the query results ( selected_records, @@ -812,100 +781,52 @@ def convert(rec, fields): self.results = [] query_records = [] - user_query_records = [] # Create a copy of the generator using tee records, records_copy = tee(records) # Count total number of records to fetch using the copy total_num_records = sum(1 for _ in records_copy) - # Set offset - offset = 0 - - # Define condition - def condition(retrieval_mode, offset, total_num_records): - if retrieval_mode == SelectRecordRetrievalMode.ALL: - return True - elif retrieval_mode == SelectRecordRetrievalMode.MATCH: - return offset < total_num_records - - # Process in batches - while condition( - self.select_operation_executor.retrieval_mode, offset, total_num_records - ): - # Generate the SOQL query based on the selection strategy - ( - select_query, - query_fields, - ) = self.select_operation_executor.select_generate_query( - sobject=self.sobject, - fields=self.fields, - limit=self.api_options.get("batch_size"), - offset=offset, - ) - # If user given selection filter present, create composite request - if self.selection_filter: - # Generate user query - user_query = generate_user_filter_query( - filter_clause=self.selection_filter, - sobject=self.sobject, - fields=["Id"], - limit_clause=self.api_options.get("batch_size"), - offset_clause=offset, - ) - # Execute composite query - ( - current_user_query_records, - current_query_records, - ) = self._execute_composite_query( - select_query=select_query, - user_query=user_query, - query_fields=query_fields, - ) - # Break if org has no more records - if ( - len(current_user_query_records) == 0 - and len(current_query_records) == 0 - ): - break + # Set LIMIT condition + if ( + self.select_operation_executor.retrieval_mode + == SelectRecordRetrievalMode.ALL + ): + limit_clause = None + elif ( + self.select_operation_executor.retrieval_mode + == SelectRecordRetrievalMode.MATCH + ): + limit_clause = total_num_records - # Extend to each - user_query_records.extend(current_user_query_records) - query_records.extend(current_query_records) + # Generate the SOQL query based on the selection strategy + ( + select_query, + query_fields, + ) = self.select_operation_executor.select_generate_query( + sobject=self.sobject, + fields=self.fields, + user_filter=self.selection_filter if self.selection_filter else None, + limit=limit_clause, + offset=None, + ) - else: - # Handle the case where self.selection_query is None (and hence user_query is also None) - response = self.sf.restful( - requests.utils.requote_uri(f"query/?q={select_query}"), method="GET" + # Handle the case where self.selection_query is None (and hence user_query is also None) + response = self.sf.restful( + requests.utils.requote_uri(f"query/?q={select_query}"), method="GET" + ) + query_records.extend( + list(convert(rec, query_fields) for rec in response["records"]) + ) + while True: + if not response["done"]: + response = self.sf.query_more( + response["nextRecordsUrl"], identifier_is_url=True ) - current_query_records = list( - convert(rec, query_fields) for rec in response["records"] + query_records.extend( + list(convert(rec, query_fields) for rec in response["records"]) ) - # Break if nothing is returned - if len(current_query_records) == 0: - break - # Extend the query records - query_records.extend(current_query_records) - - # Update offset - offset += self.api_options.get("batch_size") - - # Find intersection if filter given - if self.selection_filter: - # Find intersection based on 'Id' - user_query_ids = list(record[0] for record in user_query_records) - # Create a dictionary to map IDs to their corresponding records - id_to_record_map = { - record[query_fields.index("Id")]: record for record in query_records - } - - # Extend insersection_query_records in the order of user_query_ids - insersection_query_records = [ - record - for id in user_query_ids - if (record := id_to_record_map.get(id)) is not None - ] - else: - insersection_query_records = query_records + else: + break # Post-process the query results for this batch ( @@ -913,7 +834,7 @@ def condition(retrieval_mode, offset, total_num_records): error_message, ) = self.select_operation_executor.select_post_process( load_records=records, - query_records=insersection_query_records, + query_records=query_records, num_records=total_num_records, sobject=self.sobject, ) @@ -1103,53 +1024,3 @@ def get_dml_operation( selection_strategy=selection_strategy, selection_filter=selection_filter, ) - - -def generate_user_filter_query( - filter_clause: str, - sobject: str, - fields: list, - limit_clause: Union[float, None] = None, - offset_clause: Union[float, None] = None, -) -> str: - """ - Generates a SOQL query with the provided filter, object, fields, limit, and offset clauses. - Handles cases where the filter clause already contains LIMIT or OFFSET, and avoids multiple spaces. - """ - - # Extract existing LIMIT and OFFSET from filter_clause if present - existing_limit_match = re.search(r"LIMIT\s+(\d+)", filter_clause, re.IGNORECASE) - existing_offset_match = re.search(r"OFFSET\s+(\d+)", filter_clause, re.IGNORECASE) - - if existing_limit_match: - existing_limit = int(existing_limit_match.group(1)) - if limit_clause is not None: # Only apply limit_clause if it's provided - limit_clause = min(existing_limit, limit_clause) - else: - limit_clause = existing_limit - - if existing_offset_match: - existing_offset = int(existing_offset_match.group(1)) - if offset_clause is not None: - offset_clause = existing_offset + offset_clause - else: - offset_clause = existing_offset - - # Remove existing LIMIT and OFFSET from filter_clause, handling potential extra spaces - filter_clause = re.sub( - r"\s+OFFSET\s+\d+\s*", " ", filter_clause, flags=re.IGNORECASE - ).strip() - filter_clause = re.sub( - r"\s+LIMIT\s+\d+\s*", " ", filter_clause, flags=re.IGNORECASE - ).strip() - - # Construct the SOQL query - fields_str = ", ".join(fields) - query = f"SELECT {fields_str} FROM {sobject} {filter_clause}" - - if limit_clause is not None: - query += f" LIMIT {limit_clause}" - if offset_clause is not None: - query += f" OFFSET {offset_clause}" - - return query From 40097c17ea86b4229685f1218cc6c39603fc2b70 Mon Sep 17 00:00:00 2001 From: Jawadtp Date: Tue, 5 Nov 2024 19:46:39 +0530 Subject: [PATCH 21/34] Add ANN algorithm for large number of records for similarity strategy --- .pre-commit-config.yaml | 6 +- cumulusci/tasks/bulkdata/select_utils.py | 203 ++++++++++++++++++++++- 2 files changed, 205 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 62af507949..b1a928eafd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ default_language_version: python: python3 repos: - repo: https://github.com/ambv/black - rev: 22.3.0 + rev: 24.10.0 hooks: - id: black - repo: https://github.com/pre-commit/pre-commit-hooks @@ -18,12 +18,12 @@ repos: - id: rst-linter exclude: "docs" - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort args: ["--profile", "black", "--filter-files"] - repo: https://github.com/pre-commit/mirrors-prettier - rev: v2.5.1 + rev: v4.0.0-alpha.8 hooks: - id: prettier - repo: local diff --git a/cumulusci/tasks/bulkdata/select_utils.py b/cumulusci/tasks/bulkdata/select_utils.py index daa993d045..741ed17056 100644 --- a/cumulusci/tasks/bulkdata/select_utils.py +++ b/cumulusci/tasks/bulkdata/select_utils.py @@ -2,6 +2,12 @@ import re import typing as T +import numpy as np +import pandas as pd +from annoy import AnnoyIndex +from sklearn.feature_extraction.text import HashingVectorizer +from sklearn.preprocessing import StandardScaler + from cumulusci.core.enums import StrEnum from cumulusci.tasks.bulkdata.extract_dataset_utils.hardcoded_default_declarations import ( DEFAULT_DECLARATIONS, @@ -159,7 +165,7 @@ def similarity_generate_query( def similarity_post_process( - load_records: list, query_records: list, sobject: str + load_records, query_records: list, sobject: str ) -> T.Tuple[T.List[dict], T.Union[str, None]]: """Processes the query results for the similarity selection strategy""" # Handle case where query returns 0 records @@ -167,6 +173,107 @@ def similarity_post_process( error_message = f"No records found for {sobject} in the target org." return [], error_message + load_records = list(load_records) + load_record_count, query_record_count = len(load_records), len(query_records) + + complexity_constant = load_record_count * query_record_count + + print(complexity_constant) + + closest_records = [] + + if complexity_constant < 1000: + closest_records = annoy_post_process(load_records, query_records) + else: + closest_records = levenshtein_post_process(load_records, query_records) + + print(closest_records) + + return closest_records + + +def annoy_post_process( + load_records: list, query_records: list +) -> T.Tuple[T.List[dict], T.Union[str, None]]: + """Processes the query results for the similarity selection strategy using Annoy algorithm for large number of records""" + + query_records = replace_empty_strings_with_missing(query_records) + load_records = replace_empty_strings_with_missing(load_records) + + print("Query records: ") + print(query_records) + + print("Load records: ") + print(load_records) + + print("\n\n\n\n") + + hash_features = 100 + num_trees = 10 + + query_record_ids = [record[0] for record in query_records] + query_record_data = [record[1:] for record in query_records] + + record_to_id_map = { + tuple(query_record_data[i]): query_record_ids[i] + for i in range(len(query_records)) + } + + final_load_vectors, final_query_vectors = vectorize_records( + load_records, query_record_data, hash_features=hash_features + ) + + # Create Annoy index for nearest neighbor search + vector_dimension = final_query_vectors.shape[1] + annoy_index = AnnoyIndex(vector_dimension, "euclidean") + + for i in range(len(final_query_vectors)): + annoy_index.add_item(i, final_query_vectors[i]) + + # Build the index + annoy_index.build(num_trees) + + # Find nearest neighbors for each query vector + n_neighbors = 1 + + closest_records = [] + + for i, load_vector in enumerate(final_load_vectors): + # Get nearest neighbors' indices and distances + nearest_neighbors = annoy_index.get_nns_by_vector( + load_vector, n_neighbors, include_distances=True + ) + neighbor_indices = nearest_neighbors[0] # Indices of nearest neighbors + distances = nearest_neighbors[1] # Distances to nearest neighbors + + load_record = load_records[i] # Get the query record for the current index + print(f"Load record {i + 1}: {load_record}\n") # Print the query record + + # Print the nearest neighbors for the current query + print(f"Nearest neighbors for load record {i + 1}:") + + for j, neighbor_index in enumerate(neighbor_indices): + # Retrieve the corresponding record from the database + record = query_record_data[neighbor_index] + distance = distances[j] + + # Print the record and its distance + print(f" Neighbor {j + 1}: {record}, Distance: {distance:.6f}") + closest_record_id = record_to_id_map[tuple(record)] + print("Record id:" + closest_record_id) + closest_records.append( + {"id": closest_record_id, "success": True, "created": False} + ) + + print("\n") # Add a newline for better readability between query results + + return closest_records, None + + +def levenshtein_post_process( + load_records: list, query_records: list +) -> T.Tuple[T.List[dict], T.Union[str, None]]: + """Processes the query results for the similarity selection strategy using Levenshtein algorithm for small number of records""" closest_records = [] for record in load_records: @@ -300,3 +407,97 @@ def add_limit_offset_to_user_filter( filter_clause += f" OFFSET {offset_clause}" return f" {filter_clause}" + + +def determine_field_types(df): + numerical_features = [] + boolean_features = [] + categorical_features = [] + + for col in df.columns: + # Check if the column can be converted to numeric + try: + # Attempt to convert to numeric + df[col] = pd.to_numeric(df[col], errors="raise") + numerical_features.append(col) + except ValueError: + # Check for boolean values + if df[col].str.lower().isin(["true", "false"]).all(): + # Map to actual boolean values + df[col] = df[col].str.lower().map({"true": True, "false": False}) + boolean_features.append(col) + else: + categorical_features.append(col) + + return numerical_features, boolean_features, categorical_features + + +def vectorize_records(db_records, query_records, hash_features): + # Convert database records and query records to DataFrames + df_db = pd.DataFrame(db_records) + df_query = pd.DataFrame(query_records) + + # Dynamically determine field types + numerical_features, boolean_features, categorical_features = determine_field_types( + df_db + ) + + # Fit StandardScaler on the numerical features of the database records + scaler = StandardScaler() + if numerical_features: + df_db[numerical_features] = scaler.fit_transform(df_db[numerical_features]) + df_query[numerical_features] = scaler.transform(df_query[numerical_features]) + + # Use HashingVectorizer to transform the categorical features + hashing_vectorizer = HashingVectorizer( + n_features=hash_features, alternate_sign=False + ) + + # For db_records + hashed_categorical_data_db = [] + for col in categorical_features: + hashed_db = hashing_vectorizer.fit_transform(df_db[col]).toarray() + hashed_categorical_data_db.append(hashed_db) + + # For query_records + hashed_categorical_data_query = [] + for col in categorical_features: + hashed_query = hashing_vectorizer.transform(df_query[col]).toarray() + hashed_categorical_data_query.append(hashed_query) + + # Combine all feature types into a single vector for the database records + db_vectors = [] + if numerical_features: + db_vectors.append(df_db[numerical_features].values) + if boolean_features: + db_vectors.append( + df_db[boolean_features].astype(int).values + ) # Convert boolean to int + if hashed_categorical_data_db: + db_vectors.append(np.hstack(hashed_categorical_data_db)) + + # Concatenate database vectors + final_db_vectors = np.hstack(db_vectors) + + # Combine all feature types into a single vector for the query records + query_vectors = [] + if numerical_features: + query_vectors.append(df_query[numerical_features].values) + if boolean_features: + query_vectors.append( + df_query[boolean_features].astype(int).values + ) # Convert boolean to int + if hashed_categorical_data_query: + query_vectors.append(np.hstack(hashed_categorical_data_query)) + + # Concatenate query vectors + final_query_vectors = np.hstack(query_vectors) + + return final_db_vectors, final_query_vectors + + +def replace_empty_strings_with_missing(records): + return [ + [(field if field != "" else "missing") for field in record] + for record in records + ] From 83d45db946a512bbaabd094e9199d2df9d870bd0 Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Thu, 7 Nov 2024 11:23:37 +0530 Subject: [PATCH 22/34] Reference parent level record during similarity matching --- cumulusci/tasks/bulkdata/load.py | 117 +++++++++++++- cumulusci/tasks/bulkdata/mapping_parser.py | 18 ++- .../tasks/bulkdata/query_transformers.py | 60 ++++++++ cumulusci/tasks/bulkdata/select_utils.py | 87 ++++++----- cumulusci/tasks/bulkdata/step.py | 143 ++++++++++++++---- 5 files changed, 350 insertions(+), 75 deletions(-) diff --git a/cumulusci/tasks/bulkdata/load.py b/cumulusci/tasks/bulkdata/load.py index f83199050a..d4050c0aca 100644 --- a/cumulusci/tasks/bulkdata/load.py +++ b/cumulusci/tasks/bulkdata/load.py @@ -27,6 +27,7 @@ AddMappingFiltersToQuery, AddPersonAccountsToQuery, AddRecordTypesToQuery, + DynamicLookupQueryExtender, ) from cumulusci.tasks.bulkdata.step import ( DEFAULT_BULK_BATCH_SIZE, @@ -314,6 +315,7 @@ def configure_step(self, mapping): bulk_mode = mapping.bulk_mode or self.bulk_mode or "Parallel" api_options = {"batch_size": mapping.batch_size, "bulk_mode": bulk_mode} num_records_in_target = None + content_type = None fields = mapping.get_load_field_list() @@ -343,6 +345,8 @@ def configure_step(self, mapping): api_options["update_key"] = mapping.update_key[0] action = DataOperationType.UPSERT elif mapping.action == DataOperationType.SELECT: + # Set content type to json + content_type = "JSON" # Bulk process expects DataOpertionType to be QUERY action = DataOperationType.QUERY # Determine number of records in the target org @@ -354,6 +358,97 @@ def configure_step(self, mapping): for entry in record_count_response["sObjects"] } num_records_in_target = sobject_map.get(mapping.sf_object, None) + + # Check for similarity selection strategy and modify fields accordingly + if mapping.selection_strategy == "similarity": + # Describe the object to determine polymorphic lookups + describe_result = self.sf.restful( + f"sobjects/{mapping.sf_object}/describe" + ) + polymorphic_fields = { + field["name"]: field + for field in describe_result["fields"] + if field["type"] == "reference" + } + + # Loop through each lookup to get the corresponding fields + for name, lookup in mapping.lookups.items(): + if name in fields: + # Get the index of the lookup field before removing it + insert_index = fields.index(name) + # Remove the lookup field from fields + fields.remove(name) + + # Check if this lookup field is polymorphic + if ( + name in polymorphic_fields + and len(polymorphic_fields[name]["referenceTo"]) > 1 + ): + # Convert to list if string + if not isinstance(lookup.table, list): + lookup.table = [lookup.table] + # Polymorphic field handling + polymorphic_references = lookup.table + relationship_name = polymorphic_fields[name][ + "relationshipName" + ] + + # Loop through each polymorphic type (e.g., Contact, Lead) + for ref_type in polymorphic_references: + # Find the mapping step for this polymorphic type + lookup_mapping_step = next( + ( + step + for step in self.mapping.values() + if step.sf_object == ref_type + ), + None, + ) + + if lookup_mapping_step: + lookup_fields = ( + lookup_mapping_step.get_load_field_list() + ) + # Insert fields in the format {relationship_name}.{ref_type}.{lookup_field} + for field in lookup_fields: + fields.insert( + insert_index, + f"{relationship_name}.{lookup_mapping_step.sf_object}.{field}", + ) + insert_index += 1 + + else: + # Non-polymorphic field handling + lookup_table = lookup.table + + if isinstance(lookup_table, list): + lookup_table = lookup_table[0] + + # Get the mapping step for the non-polymorphic reference + lookup_mapping_step = next( + ( + step + for step in self.mapping.values() + if step.sf_object == lookup_table + ), + None, + ) + + if lookup_mapping_step: + relationship_name = polymorphic_fields[name][ + "relationshipName" + ] + lookup_fields = ( + lookup_mapping_step.get_load_field_list() + ) + + # Insert the new fields at the same position as the removed lookup field + for field in lookup_fields: + fields.insert( + insert_index, f"{relationship_name}.{field}" + ) + insert_index += 1 + else: action = mapping.action @@ -376,6 +471,7 @@ def configure_step(self, mapping): volume=volume, selection_strategy=mapping.selection_strategy, selection_filter=mapping.selection_filter, + content_type=content_type, ) return step, query @@ -406,6 +502,9 @@ def _stream_queried_data(self, mapping, local_ids, query): pkey = row[0] row = list(row[1:]) + statics + # Replace None values in row with empty strings + row = [value if value is not None else "" for value in row] + if mapping.anchor_date and (date_context[0] or date_context[1]): row = adjust_relative_dates( mapping, date_context, row, DataOperationType.INSERT @@ -475,9 +574,21 @@ def _query_db(self, mapping): AddMappingFiltersToQuery, AddUpsertsToQuery, ] - transformers = [ - AddLookupsToQuery(mapping, self.metadata, model, self._old_format) - ] + transformers = [] + if ( + mapping.action == DataOperationType.SELECT + and mapping.selection_strategy == "similarity" + ): + transformers.append( + DynamicLookupQueryExtender( + mapping, self.mapping, self.metadata, model, self._old_format + ) + ) + else: + transformers.append( + AddLookupsToQuery(mapping, self.metadata, model, self._old_format) + ) + transformers.extend([cls(mapping, self.metadata, model) for cls in classes]) if mapping.sf_object == "Contact" and self._can_load_person_accounts(mapping): diff --git a/cumulusci/tasks/bulkdata/mapping_parser.py b/cumulusci/tasks/bulkdata/mapping_parser.py index e812ca7d16..c9009f82fc 100644 --- a/cumulusci/tasks/bulkdata/mapping_parser.py +++ b/cumulusci/tasks/bulkdata/mapping_parser.py @@ -103,15 +103,15 @@ class MappingStep(CCIDictModel): batch_size: int = None oid_as_pk: bool = False # this one should be discussed and probably deprecated record_type: Optional[str] = None # should be discussed and probably deprecated - bulk_mode: Optional[ - Literal["Serial", "Parallel"] - ] = None # default should come from task options + bulk_mode: Optional[Literal["Serial", "Parallel"]] = ( + None # default should come from task options + ) anchor_date: Optional[Union[str, date]] = None soql_filter: Optional[str] = None # soql_filter property selection_strategy: SelectStrategy = SelectStrategy.STANDARD # selection strategy - selection_filter: Optional[ - str - ] = None # filter to be added at the end of select query + selection_filter: Optional[str] = ( + None # filter to be added at the end of select query + ) update_key: T.Union[str, T.Tuple[str, ...]] = () # only for upserts @validator("bulk_mode", "api", "action", "selection_strategy", pre=True) @@ -678,7 +678,9 @@ def _infer_and_validate_lookups(mapping: Dict, sf: Salesforce): if len(target_objects) == 1: # This is a non-polymorphic lookup. target_index = list(sf_objects.values()).index(target_objects[0]) - if target_index > idx or target_index == idx: + if ( + target_index > idx or target_index == idx + ) and m.action != DataOperationType.SELECT: # This is a non-polymorphic after step. lookup.after = list(mapping.keys())[idx] else: @@ -730,7 +732,7 @@ def validate_and_inject_mapping( if drop_missing: # Drop any steps with sObjects that are not present. - for (include, step_name) in zip(should_continue, list(mapping.keys())): + for include, step_name in zip(should_continue, list(mapping.keys())): if not include: del mapping[step_name] diff --git a/cumulusci/tasks/bulkdata/query_transformers.py b/cumulusci/tasks/bulkdata/query_transformers.py index aef23f5dc3..eda7a2cabe 100644 --- a/cumulusci/tasks/bulkdata/query_transformers.py +++ b/cumulusci/tasks/bulkdata/query_transformers.py @@ -86,6 +86,66 @@ def join_for_lookup(lookup): return [join_for_lookup(lookup) for lookup in self.lookups] +class DynamicLookupQueryExtender(LoadQueryExtender): + """Dynamically adds columns and joins for all fields in lookup tables, handling polymorphic lookups""" + + def __init__( + self, mapping, all_mappings, metadata, model, _old_format: bool + ) -> None: + super().__init__(mapping, metadata, model) + self._old_format = _old_format + self.all_mappings = all_mappings + self.lookups = [ + lookup for lookup in self.mapping.lookups.values() if not lookup.after + ] + + @cached_property + def columns_to_add(self): + """Add all relevant fields from lookup tables directly without CASE, with support for polymorphic lookups.""" + columns = [] + for lookup in self.lookups: + tables = lookup.table if isinstance(lookup.table, list) else [lookup.table] + lookup.aliased_table = [ + aliased(self.metadata.tables[table]) for table in tables + ] + + for aliased_table, table_name in zip(lookup.aliased_table, tables): + # Find the mapping step for this polymorphic type + lookup_mapping_step = next( + ( + step + for step in self.all_mappings.values() + if step.table == table_name + ), + None, + ) + if lookup_mapping_step: + load_fields = lookup_mapping_step.get_load_field_list() + for field in load_fields: + matching_column = next( + (col for col in aliased_table.columns if col.name == field) + ) + columns.append( + matching_column.label(f"{aliased_table.name}_{field}") + ) + return columns + + @cached_property + def outerjoins_to_add(self): + """Add outer joins for each lookup table directly, including handling for polymorphic lookups.""" + + def join_for_lookup(lookup, aliased_table): + key_field = lookup.get_lookup_key_field(self.model) + value_column = getattr(self.model, key_field) + return (aliased_table, aliased_table.columns.id == value_column) + + joins = [] + for lookup in self.lookups: + for aliased_table in lookup.aliased_table: + joins.append(join_for_lookup(lookup, aliased_table)) + return joins + + class AddRecordTypesToQuery(LoadQueryExtender): """Adds columns, joins and filters relatinng to recordtypes""" diff --git a/cumulusci/tasks/bulkdata/select_utils.py b/cumulusci/tasks/bulkdata/select_utils.py index 741ed17056..d1092504f4 100644 --- a/cumulusci/tasks/bulkdata/select_utils.py +++ b/cumulusci/tasks/bulkdata/select_utils.py @@ -142,14 +142,54 @@ def similarity_generate_query( limit: T.Union[int, None], offset: T.Union[int, None], ) -> T.Tuple[str, T.List[str]]: - """Generates the SOQL query for the similarity selection strategy""" - # Construct the query with the WHERE clause (if it exists) - if "Id" not in fields: - fields.insert(0, "Id") - fields_to_query = ", ".join(field for field in fields if field) - + """Generates the SOQL query for the similarity selection strategy, with support for TYPEOF on polymorphic fields.""" + + # Pre-process the new fields format to create a nested dict structure for TYPEOF clauses + nested_fields = {} + regular_fields = [] + + for field in fields: + components = field.split(".") + if len(components) >= 3: + # Handle polymorphic fields (format: {relationship_name}.{ref_obj}.{ref_field}) + relationship, ref_obj, ref_field = ( + components[0], + components[1], + components[2], + ) + if relationship not in nested_fields: + nested_fields[relationship] = {} + if ref_obj not in nested_fields[relationship]: + nested_fields[relationship][ref_obj] = [] + nested_fields[relationship][ref_obj].append(ref_field) + else: + # Handle regular fields (format: {field}) + regular_fields.append(field) + + # Construct the query fields + query_fields = [] + + # Build TYPEOF clauses for polymorphic fields + for relationship, references in nested_fields.items(): + type_clauses = [] + for ref_obj, ref_fields in references.items(): + fields_clause = ", ".join(ref_fields) + type_clauses.append(f"WHEN {ref_obj} THEN {fields_clause}") + type_clause = f"TYPEOF {relationship} {' '.join(type_clauses)} END" + query_fields.append(type_clause) + + # Add regular fields to the query + query_fields.extend(regular_fields) + + # Ensure "Id" is included in the fields list for identification + if "Id" not in query_fields: + query_fields.insert(0, "Id") + + # Build the main SOQL query + fields_to_query = ", ".join(query_fields) query = f"SELECT {fields_to_query} FROM {sobject}" + # Add the user-defined filter clause or default clause if user_filter: query += add_limit_offset_to_user_filter( filter_clause=user_filter, limit_clause=limit, offset_clause=offset @@ -161,7 +201,12 @@ def similarity_generate_query( query += f" WHERE {declaration.where}" query += f" LIMIT {limit}" if limit else "" query += f" OFFSET {offset}" if offset else "" - return query, fields + + # Return the original input fields with "Id" added if needed + if "Id" not in fields: + fields.insert(0, "Id") + + return query, fields # Return the original input fields with "Id" def similarity_post_process( @@ -178,8 +223,6 @@ def similarity_post_process( complexity_constant = load_record_count * query_record_count - print(complexity_constant) - closest_records = [] if complexity_constant < 1000: @@ -187,8 +230,6 @@ def similarity_post_process( else: closest_records = levenshtein_post_process(load_records, query_records) - print(closest_records) - return closest_records @@ -200,14 +241,6 @@ def annoy_post_process( query_records = replace_empty_strings_with_missing(query_records) load_records = replace_empty_strings_with_missing(load_records) - print("Query records: ") - print(query_records) - - print("Load records: ") - print(load_records) - - print("\n\n\n\n") - hash_features = 100 num_trees = 10 @@ -244,29 +277,15 @@ def annoy_post_process( load_vector, n_neighbors, include_distances=True ) neighbor_indices = nearest_neighbors[0] # Indices of nearest neighbors - distances = nearest_neighbors[1] # Distances to nearest neighbors - load_record = load_records[i] # Get the query record for the current index - print(f"Load record {i + 1}: {load_record}\n") # Print the query record - - # Print the nearest neighbors for the current query - print(f"Nearest neighbors for load record {i + 1}:") - - for j, neighbor_index in enumerate(neighbor_indices): + for neighbor_index in neighbor_indices: # Retrieve the corresponding record from the database record = query_record_data[neighbor_index] - distance = distances[j] - - # Print the record and its distance - print(f" Neighbor {j + 1}: {record}, Distance: {distance:.6f}") closest_record_id = record_to_id_map[tuple(record)] - print("Record id:" + closest_record_id) closest_records.append( {"id": closest_record_id, "success": True, "created": False} ) - print("\n") # Add a newline for better readability between query results - return closest_records, None diff --git a/cumulusci/tasks/bulkdata/step.py b/cumulusci/tasks/bulkdata/step.py index 3f3fbaf0f3..b664b48ffc 100644 --- a/cumulusci/tasks/bulkdata/step.py +++ b/cumulusci/tasks/bulkdata/step.py @@ -352,6 +352,7 @@ def __init__( fields, selection_strategy=SelectStrategy.STANDARD, selection_filter=None, + content_type=None, ): super().__init__( sobject=sobject, @@ -369,12 +370,13 @@ def __init__( self.select_operation_executor = SelectOperationExecutor(selection_strategy) self.selection_filter = selection_filter + self.content_type = content_type if content_type else "CSV" def start(self): self.job_id = self.bulk.create_job( self.sobject, self.operation.value, - contentType="CSV", + contentType=self.content_type, concurrency=self.api_options.get("bulk_mode", "Parallel"), external_id_name=self.api_options.get("update_key"), ) @@ -498,31 +500,39 @@ def select_records(self, records): # Update job result based on selection outcome self.job_result = DataOperationJobResult( - status=DataOperationStatus.SUCCESS - if len(self.select_results) - else DataOperationStatus.JOB_FAILURE, + status=( + DataOperationStatus.SUCCESS + if len(self.select_results) + else DataOperationStatus.JOB_FAILURE + ), job_errors=[error_message] if error_message else [], records_processed=len(self.select_results), total_row_errors=0, ) def _execute_select_query(self, select_query: str, query_fields: List[str]): - """Executes the select Bulk API query and retrieves the results.""" + """Executes the select Bulk API query, retrieves results in JSON, and converts to CSV format if needed.""" self.batch_id = self.bulk.query(self.job_id, select_query) - self._wait_for_job(self.job_id) + self.bulk.wait_for_batch(self.job_id, self.batch_id) result_ids = self.bulk.get_query_batch_result_ids( self.batch_id, job_id=self.job_id ) select_query_records = [] + for result_id in result_ids: - uri = f"{self.bulk.endpoint}/job/{self.job_id}/batch/{self.batch_id}/result/{result_id}" + # Modify URI to request JSON format + uri = f"{self.bulk.endpoint}/job/{self.job_id}/batch/{self.batch_id}/result/{result_id}?format=json" + # Download JSON data with download_file(uri, self.bulk) as f: - reader = csv.reader(f) - self.headers = next(reader) - if "Records not found for this query" in self.headers: - break - for row in reader: - select_query_records.append(row[: len(query_fields)]) + data = json.load(f) + # Get headers from fields, expanding nested structures for TYPEOF results + self.headers = query_fields + + # Convert each record to a flat row + for record in data: + flat_record = flatten_record(record, self.headers) + select_query_records.append(flat_record) + return select_query_records def _batch(self, records, n, char_limit=10000000): @@ -641,6 +651,7 @@ def __init__( fields, selection_strategy=SelectStrategy.STANDARD, selection_filter=None, + content_type=None, ): super().__init__( sobject=sobject, @@ -655,7 +666,9 @@ def __init__( field["name"]: field for field in getattr(context.sf, sobject).describe()["fields"] } - self.boolean_fields = [f for f in fields if describe[f]["type"] == "boolean"] + self.boolean_fields = [ + f for f in fields if "." not in f and describe[f]["type"] == "boolean" + ] self.api_options = api_options.copy() self.api_options["batch_size"] = ( self.api_options.get("batch_size") or DEFAULT_REST_BATCH_SIZE @@ -666,6 +679,7 @@ def __init__( self.select_operation_executor = SelectOperationExecutor(selection_strategy) self.selection_filter = selection_filter + self.content_type = content_type def _record_to_json(self, rec): result = dict(zip(self.fields, rec)) @@ -764,9 +778,11 @@ def load_records(self, records): row_errors = len([res for res in self.results if not res["success"]]) self.job_result = DataOperationJobResult( - DataOperationStatus.SUCCESS - if not row_errors - else DataOperationStatus.ROW_FAILURE, + ( + DataOperationStatus.SUCCESS + if not row_errors + else DataOperationStatus.ROW_FAILURE + ), [], len(self.results), row_errors, @@ -775,10 +791,6 @@ def load_records(self, records): def select_records(self, records): """Executes a SOQL query to select records and adds them to results""" - def convert(rec, fields): - """Helper function to convert record values to strings, handling None values""" - return [str(rec[f]) if rec[f] is not None else "" for f in fields] - self.results = [] query_records = [] # Create a copy of the generator using tee @@ -814,17 +826,18 @@ def convert(rec, fields): response = self.sf.restful( requests.utils.requote_uri(f"query/?q={select_query}"), method="GET" ) - query_records.extend( - list(convert(rec, query_fields) for rec in response["records"]) - ) + # Convert each record to a flat row + for record in response["records"]: + flat_record = flatten_record(record, query_fields) + query_records.append(flat_record) while True: if not response["done"]: response = self.sf.query_more( response["nextRecordsUrl"], identifier_is_url=True ) - query_records.extend( - list(convert(rec, query_fields) for rec in response["records"]) - ) + for record in response["records"]: + flat_record = flatten_record(record, query_fields) + query_records.append(flat_record) else: break @@ -844,9 +857,11 @@ def convert(rec, fields): # Update the job result based on the overall selection outcome self.job_result = DataOperationJobResult( - status=DataOperationStatus.SUCCESS - if len(self.results) # Check the overall results length - else DataOperationStatus.JOB_FAILURE, + status=( + DataOperationStatus.SUCCESS + if len(self.results) # Check the overall results length + else DataOperationStatus.JOB_FAILURE + ), job_errors=[error_message] if error_message else [], records_processed=len(self.results), total_row_errors=0, @@ -988,6 +1003,7 @@ def get_dml_operation( api: Optional[DataApi] = DataApi.SMART, selection_strategy: SelectStrategy = SelectStrategy.STANDARD, selection_filter: Union[str, None] = None, + content_type: Union[str, None] = None, ) -> BaseDmlOperation: """Create an appropriate DmlOperation instance for the given parameters, selecting between REST and Bulk APIs based upon volume (Bulk used at volumes over 2000 records, @@ -1023,4 +1039,71 @@ def get_dml_operation( fields=fields, selection_strategy=selection_strategy, selection_filter=selection_filter, + content_type=content_type, ) + + +def extract_flattened_headers(query_fields): + """Extract headers from query fields, including handling of TYPEOF fields.""" + headers = [] + + for field in query_fields: + if isinstance(field, dict): + # Handle TYPEOF / polymorphic fields + for lookup, references in field.items(): + # Assuming each reference is a list of dictionaries + for ref_type in references: + for ref_obj, ref_fields in ref_type.items(): + for nested_field in ref_fields: + headers.append( + f"{lookup}.{ref_obj}.{nested_field}" + ) # Flatten the structure + else: + # Regular fields + headers.append(field) + + return headers + + +def flatten_record(record, headers): + """Flatten each record to match headers, handling nested fields.""" + flat_record = [] + + for field in headers: + components = field.split(".") + value = "" + + # Handle lookup fields with two or three components + if len(components) >= 2: + lookup_field = components[0] + lookup = record.get(lookup_field, None) + + # Check if lookup field exists in the record + if lookup is None: + value = "" + else: + if len(components) == 2: + # Handle fields with two components: {lookup}.{ref_field} + ref_field = components[1] + value = lookup.get(ref_field, "") + elif len(components) == 3: + # Handle fields with three components: {lookup}.{ref_obj}.{ref_field} + ref_obj, ref_field = components[1], components[2] + # Check if the type matches the specified ref_obj + if lookup.get("attributes", {}).get("type") == ref_obj: + value = lookup.get(ref_field, "") + else: + value = "" + + else: + # Regular fields or non-polymorphic fields + value = record.get(field, "") + + # Set None values to empty string + if value is None: + value = "" + + # Append the resolved value to the flattened record + flat_record.append(value) + + return flat_record From a64e43866e65f9e6f8d84ad6c972ed858d8a0310 Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Thu, 7 Nov 2024 12:50:30 +0530 Subject: [PATCH 23/34] Fix for test import failure --- cumulusci/tasks/bulkdata/step.py | 2 + cumulusci/tasks/bulkdata/tests/test_step.py | 190 ++++++++++---------- 2 files changed, 96 insertions(+), 96 deletions(-) diff --git a/cumulusci/tasks/bulkdata/step.py b/cumulusci/tasks/bulkdata/step.py index b664b48ffc..cb86bda6fa 100644 --- a/cumulusci/tasks/bulkdata/step.py +++ b/cumulusci/tasks/bulkdata/step.py @@ -1102,6 +1102,8 @@ def flatten_record(record, headers): # Set None values to empty string if value is None: value = "" + elif not isinstance(value, str): + value = str(value) # Append the resolved value to the flattened record flat_record.append(value) diff --git a/cumulusci/tasks/bulkdata/tests/test_step.py b/cumulusci/tasks/bulkdata/tests/test_step.py index c182a92996..da13a9a8eb 100644 --- a/cumulusci/tasks/bulkdata/tests/test_step.py +++ b/cumulusci/tasks/bulkdata/tests/test_step.py @@ -20,7 +20,6 @@ RestApiDmlOperation, RestApiQueryOperation, download_file, - generate_user_filter_query, get_dml_operation, get_query_operation, ) @@ -2491,98 +2490,97 @@ def test_cleanup_date_strings__upsert_update(self, operation): import pytest - -def test_generate_user_filter_query_basic(): - """Tests basic query generation without existing LIMIT or OFFSET.""" - filter_clause = "WHERE Name = 'John'" - sobject = "Account" - fields = ["Id", "Name"] - limit_clause = 10 - offset_clause = 5 - - expected_query = ( - "SELECT Id, Name FROM Account WHERE Name = 'John' LIMIT 10 OFFSET 5" - ) - assert ( - generate_user_filter_query( - filter_clause, sobject, fields, limit_clause, offset_clause - ) - == expected_query - ) - - -def test_generate_user_filter_query_existing_limit(): - """Tests handling of existing LIMIT in the filter clause.""" - filter_clause = "WHERE Name = 'John' LIMIT 20" - sobject = "Contact" - fields = ["Id", "FirstName"] - limit_clause = 5 # Should override the existing LIMIT - offset_clause = None - - expected_query = "SELECT Id, FirstName FROM Contact WHERE Name = 'John' LIMIT 5" - assert ( - generate_user_filter_query( - filter_clause, sobject, fields, limit_clause, offset_clause - ) - == expected_query - ) - - -def test_generate_user_filter_query_existing_offset(): - """Tests handling of existing OFFSET in the filter clause.""" - filter_clause = "WHERE Name = 'John' OFFSET 15" - sobject = "Opportunity" - fields = ["Id", "Name"] - limit_clause = None - offset_clause = 10 # Should add to the existing OFFSET - - expected_query = "SELECT Id, Name FROM Opportunity WHERE Name = 'John' OFFSET 25" - assert ( - generate_user_filter_query( - filter_clause, sobject, fields, limit_clause, offset_clause - ) - == expected_query - ) - - -def test_generate_user_filter_query_no_limit_or_offset(): - """Tests when no limit or offset is provided or present in the filter.""" - filter_clause = "WHERE Name = 'John' LIMIT 5 OFFSET 20" - sobject = "Lead" - fields = ["Id", "Name", "Email"] - limit_clause = None - offset_clause = None - - expected_query = ( - "SELECT Id, Name, Email FROM Lead WHERE Name = 'John' LIMIT 5 OFFSET 20" - ) - print( - generate_user_filter_query( - filter_clause, sobject, fields, limit_clause, offset_clause - ) - ) - assert ( - generate_user_filter_query( - filter_clause, sobject, fields, limit_clause, offset_clause - ) - == expected_query - ) - - -def test_generate_user_filter_query_case_insensitivity(): - """Tests case-insensitivity for LIMIT and OFFSET.""" - filter_clause = "where name = 'John' offset 5 limit 20" - sobject = "Task" - fields = ["Id", "Subject"] - limit_clause = 15 - offset_clause = 20 - - expected_query = ( - "SELECT Id, Subject FROM Task where name = 'John' LIMIT 15 OFFSET 25" - ) - assert ( - generate_user_filter_query( - filter_clause, sobject, fields, limit_clause, offset_clause - ) - == expected_query - ) +# def test_generate_user_filter_query_basic(): +# """Tests basic query generation without existing LIMIT or OFFSET.""" +# filter_clause = "WHERE Name = 'John'" +# sobject = "Account" +# fields = ["Id", "Name"] +# limit_clause = 10 +# offset_clause = 5 + +# expected_query = ( +# "SELECT Id, Name FROM Account WHERE Name = 'John' LIMIT 10 OFFSET 5" +# ) +# assert ( +# generate_user_filter_query( +# filter_clause, sobject, fields, limit_clause, offset_clause +# ) +# == expected_query +# ) + + +# def test_generate_user_filter_query_existing_limit(): +# """Tests handling of existing LIMIT in the filter clause.""" +# filter_clause = "WHERE Name = 'John' LIMIT 20" +# sobject = "Contact" +# fields = ["Id", "FirstName"] +# limit_clause = 5 # Should override the existing LIMIT +# offset_clause = None + +# expected_query = "SELECT Id, FirstName FROM Contact WHERE Name = 'John' LIMIT 5" +# assert ( +# generate_user_filter_query( +# filter_clause, sobject, fields, limit_clause, offset_clause +# ) +# == expected_query +# ) + + +# def test_generate_user_filter_query_existing_offset(): +# """Tests handling of existing OFFSET in the filter clause.""" +# filter_clause = "WHERE Name = 'John' OFFSET 15" +# sobject = "Opportunity" +# fields = ["Id", "Name"] +# limit_clause = None +# offset_clause = 10 # Should add to the existing OFFSET + +# expected_query = "SELECT Id, Name FROM Opportunity WHERE Name = 'John' OFFSET 25" +# assert ( +# generate_user_filter_query( +# filter_clause, sobject, fields, limit_clause, offset_clause +# ) +# == expected_query +# ) + + +# def test_generate_user_filter_query_no_limit_or_offset(): +# """Tests when no limit or offset is provided or present in the filter.""" +# filter_clause = "WHERE Name = 'John' LIMIT 5 OFFSET 20" +# sobject = "Lead" +# fields = ["Id", "Name", "Email"] +# limit_clause = None +# offset_clause = None + +# expected_query = ( +# "SELECT Id, Name, Email FROM Lead WHERE Name = 'John' LIMIT 5 OFFSET 20" +# ) +# print( +# generate_user_filter_query( +# filter_clause, sobject, fields, limit_clause, offset_clause +# ) +# ) +# assert ( +# generate_user_filter_query( +# filter_clause, sobject, fields, limit_clause, offset_clause +# ) +# == expected_query +# ) + + +# def test_generate_user_filter_query_case_insensitivity(): +# """Tests case-insensitivity for LIMIT and OFFSET.""" +# filter_clause = "where name = 'John' offset 5 limit 20" +# sobject = "Task" +# fields = ["Id", "Subject"] +# limit_clause = 15 +# offset_clause = 20 + +# expected_query = ( +# "SELECT Id, Subject FROM Task where name = 'John' LIMIT 15 OFFSET 25" +# ) +# assert ( +# generate_user_filter_query( +# filter_clause, sobject, fields, limit_clause, offset_clause +# ) +# == expected_query +# ) From f86f94da5c102a95c3ba235207285dd70c287a98 Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Thu, 7 Nov 2024 14:37:43 +0530 Subject: [PATCH 24/34] Fix for no records if parent sobject not found --- cumulusci/tasks/bulkdata/select_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cumulusci/tasks/bulkdata/select_utils.py b/cumulusci/tasks/bulkdata/select_utils.py index d1092504f4..0e842a09e1 100644 --- a/cumulusci/tasks/bulkdata/select_utils.py +++ b/cumulusci/tasks/bulkdata/select_utils.py @@ -175,7 +175,7 @@ def similarity_generate_query( for ref_obj, ref_fields in references.items(): fields_clause = ", ".join(ref_fields) type_clauses.append(f"WHEN {ref_obj} THEN {fields_clause}") - type_clause = f"TYPEOF {relationship} {' '.join(type_clauses)} END" + type_clause = f"TYPEOF {relationship} {' '.join(type_clauses)} ELSE Id END" query_fields.append(type_clause) # Add regular fields to the query @@ -206,7 +206,7 @@ def similarity_generate_query( if "Id" not in fields: fields.insert(0, "Id") - return query, fields # Return the original input fields with "Id" + return query, fields def similarity_post_process( @@ -226,9 +226,9 @@ def similarity_post_process( closest_records = [] if complexity_constant < 1000: - closest_records = annoy_post_process(load_records, query_records) - else: closest_records = levenshtein_post_process(load_records, query_records) + else: + closest_records = annoy_post_process(load_records, query_records) return closest_records From 7fae06e54f76d0508d13c8fb4ffbcf81f8cb0c26 Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Fri, 8 Nov 2024 03:57:08 +0530 Subject: [PATCH 25/34] Functionality to prioritize user specified fields --- cumulusci/tasks/bulkdata/load.py | 9 +- cumulusci/tasks/bulkdata/mapping_parser.py | 44 +++--- .../tasks/bulkdata/query_transformers.py | 6 +- cumulusci/tasks/bulkdata/select_utils.py | 134 +++++++++++++----- cumulusci/tasks/bulkdata/step.py | 33 +++++ cumulusci/tasks/bulkdata/utils.py | 14 ++ 6 files changed, 180 insertions(+), 60 deletions(-) diff --git a/cumulusci/tasks/bulkdata/load.py b/cumulusci/tasks/bulkdata/load.py index d4050c0aca..ced885744b 100644 --- a/cumulusci/tasks/bulkdata/load.py +++ b/cumulusci/tasks/bulkdata/load.py @@ -360,7 +360,7 @@ def configure_step(self, mapping): num_records_in_target = sobject_map.get(mapping.sf_object, None) # Check for similarity selection strategy and modify fields accordingly - if mapping.selection_strategy == "similarity": + if mapping.select_options.strategy == "similarity": # Describe the object to determine polymorphic lookups describe_result = self.sf.restful( f"sobjects/{mapping.sf_object}/describe" @@ -469,8 +469,9 @@ def configure_step(self, mapping): fields=fields, api=mapping.api, volume=volume, - selection_strategy=mapping.selection_strategy, - selection_filter=mapping.selection_filter, + selection_strategy=mapping.select_options.strategy, + selection_filter=mapping.select_options.filter, + selection_priority_fields=mapping.select_options.priority_fields, content_type=content_type, ) return step, query @@ -577,7 +578,7 @@ def _query_db(self, mapping): transformers = [] if ( mapping.action == DataOperationType.SELECT - and mapping.selection_strategy == "similarity" + and mapping.select_options.strategy == "similarity" ): transformers.append( DynamicLookupQueryExtender( diff --git a/cumulusci/tasks/bulkdata/mapping_parser.py b/cumulusci/tasks/bulkdata/mapping_parser.py index c9009f82fc..6bad4f7bdd 100644 --- a/cumulusci/tasks/bulkdata/mapping_parser.py +++ b/cumulusci/tasks/bulkdata/mapping_parser.py @@ -8,34 +8,21 @@ from typing import IO, Any, Callable, Dict, List, Mapping, Optional, Tuple, Union from pydantic import Field, ValidationError, root_validator, validator -from requests.structures import CaseInsensitiveDict as RequestsCaseInsensitiveDict from simple_salesforce import Salesforce from typing_extensions import Literal from cumulusci.core.enums import StrEnum from cumulusci.core.exceptions import BulkDataException from cumulusci.tasks.bulkdata.dates import iso_to_date -from cumulusci.tasks.bulkdata.select_utils import SelectStrategy +from cumulusci.tasks.bulkdata.select_utils import SelectOptions, SelectStrategy from cumulusci.tasks.bulkdata.step import DataApi, DataOperationType +from cumulusci.tasks.bulkdata.utils import CaseInsensitiveDict from cumulusci.utils import convert_to_snake_case from cumulusci.utils.yaml.model_parser import CCIDictModel logger = getLogger(__name__) -class CaseInsensitiveDict(RequestsCaseInsensitiveDict): - def __init__(self, *args, **kwargs): - self._canonical_keys = {} - super().__init__(*args, **kwargs) - - def canonical_key(self, name): - return self._canonical_keys[name.lower()] - - def __setitem__(self, key, value): - super().__setitem__(key, value) - self._canonical_keys[key.lower()] = key - - class MappingLookup(CCIDictModel): "Lookup relationship between two tables." table: Union[str, List[str]] # Support for polymorphic lookups @@ -85,7 +72,7 @@ class BulkMode(StrEnum): ENUM_VALUES = { v.value.lower(): v.value - for enum in [BulkMode, DataApi, DataOperationType, SelectStrategy] + for enum in [BulkMode, DataApi, DataOperationType] for v in enum.__members__.values() } @@ -108,13 +95,12 @@ class MappingStep(CCIDictModel): ) anchor_date: Optional[Union[str, date]] = None soql_filter: Optional[str] = None # soql_filter property - selection_strategy: SelectStrategy = SelectStrategy.STANDARD # selection strategy - selection_filter: Optional[str] = ( - None # filter to be added at the end of select query + select_options: Optional[SelectOptions] = Field( + default_factory=lambda: SelectOptions(strategy=SelectStrategy.STANDARD) ) update_key: T.Union[str, T.Tuple[str, ...]] = () # only for upserts - @validator("bulk_mode", "api", "action", "selection_strategy", pre=True) + @validator("bulk_mode", "api", "action", pre=True) def case_normalize(cls, val): if isinstance(val, Enum): return val @@ -134,6 +120,24 @@ def split_update_key(cls, val): ), "`update_key` should be a field name or list of field names." assert False, "Should be unreachable" # pragma: no cover + @root_validator + def validate_priority_fields(cls, values): + select_options = values.get("select_options") + fields_ = values.get("fields_", {}) + + if select_options and select_options.priority_fields: + priority_field_names = set(select_options.priority_fields.keys()) + field_names = set(fields_.keys()) + + # Check if all priority fields are present in the fields + missing_fields = priority_field_names - field_names + if missing_fields: + raise ValueError( + f"Priority fields {missing_fields} are not present in 'fields'" + ) + + return values + def get_oid_as_pk(self): """Returns True if using Salesforce Ids as primary keys.""" return "Id" in self.fields diff --git a/cumulusci/tasks/bulkdata/query_transformers.py b/cumulusci/tasks/bulkdata/query_transformers.py index eda7a2cabe..b4daa4bd93 100644 --- a/cumulusci/tasks/bulkdata/query_transformers.py +++ b/cumulusci/tasks/bulkdata/query_transformers.py @@ -123,7 +123,11 @@ def columns_to_add(self): load_fields = lookup_mapping_step.get_load_field_list() for field in load_fields: matching_column = next( - (col for col in aliased_table.columns if col.name == field) + ( + col + for col in aliased_table.columns + if col.name == lookup_mapping_step.fields[field] + ) ) columns.append( matching_column.label(f"{aliased_table.name}_{field}") diff --git a/cumulusci/tasks/bulkdata/select_utils.py b/cumulusci/tasks/bulkdata/select_utils.py index 0e842a09e1..6de6adf652 100644 --- a/cumulusci/tasks/bulkdata/select_utils.py +++ b/cumulusci/tasks/bulkdata/select_utils.py @@ -1,10 +1,12 @@ import random import re import typing as T +from enum import Enum import numpy as np import pandas as pd from annoy import AnnoyIndex +from pydantic import Field, validator from sklearn.feature_extraction.text import HashingVectorizer from sklearn.preprocessing import StandardScaler @@ -12,6 +14,8 @@ from cumulusci.tasks.bulkdata.extract_dataset_utils.hardcoded_default_declarations import ( DEFAULT_DECLARATIONS, ) +from cumulusci.tasks.bulkdata.utils import CaseInsensitiveDict +from cumulusci.utils.yaml.model_parser import CCIDictModel class SelectStrategy(StrEnum): @@ -30,6 +34,35 @@ class SelectRecordRetrievalMode(StrEnum): MATCH = "match" +ENUM_VALUES = { + v.value.lower(): v.value + for enum in [SelectStrategy] + for v in enum.__members__.values() +} + + +class SelectOptions(CCIDictModel): + filter: T.Optional[str] = None # Optional filter for selection + strategy: SelectStrategy = SelectStrategy.STANDARD # Strategy for selection + priority_fields: T.Dict[str, str] = Field({}) + + @validator("strategy", pre=True) + def validate_strategy(cls, value): + if isinstance(value, Enum): + return value + if value is not None: + return ENUM_VALUES.get(value.lower()) + raise ValueError(f"Invalid strategy value: {value}") + + @validator("priority_fields", pre=True) + def standardize_fields_to_dict(cls, values): + if values is None: + values = {} + if type(values) is list: + values = {elem: elem for elem in values} + return CaseInsensitiveDict(values) + + class SelectOperationExecutor: def __init__(self, strategy: SelectStrategy): self.strategy = strategy @@ -68,7 +101,12 @@ def select_generate_query( ) def select_post_process( - self, load_records, query_records: list, num_records: int, sobject: str + self, + load_records, + query_records: list, + num_records: int, + sobject: str, + weights: list, ): # For STANDARD strategy if self.strategy == SelectStrategy.STANDARD: @@ -78,7 +116,10 @@ def select_post_process( # For SIMILARITY strategy elif self.strategy == SelectStrategy.SIMILARITY: return similarity_post_process( - load_records=load_records, query_records=query_records, sobject=sobject + load_records=load_records, + query_records=query_records, + sobject=sobject, + weights=weights, ) # For RANDOM strategy elif self.strategy == SelectStrategy.RANDOM: @@ -210,7 +251,7 @@ def similarity_generate_query( def similarity_post_process( - load_records, query_records: list, sobject: str + load_records, query_records: list, sobject: str, weights: list ) -> T.Tuple[T.List[dict], T.Union[str, None]]: """Processes the query results for the similarity selection strategy""" # Handle case where query returns 0 records @@ -226,15 +267,15 @@ def similarity_post_process( closest_records = [] if complexity_constant < 1000: - closest_records = levenshtein_post_process(load_records, query_records) + closest_records = levenshtein_post_process(load_records, query_records, weights) else: - closest_records = annoy_post_process(load_records, query_records) + closest_records = annoy_post_process(load_records, query_records, weights) return closest_records def annoy_post_process( - load_records: list, query_records: list + load_records: list, query_records: list, weights: list ) -> T.Tuple[T.List[dict], T.Union[str, None]]: """Processes the query results for the similarity selection strategy using Annoy algorithm for large number of records""" @@ -253,7 +294,7 @@ def annoy_post_process( } final_load_vectors, final_query_vectors = vectorize_records( - load_records, query_record_data, hash_features=hash_features + load_records, query_record_data, hash_features=hash_features, weights=weights ) # Create Annoy index for nearest neighbor search @@ -290,13 +331,13 @@ def annoy_post_process( def levenshtein_post_process( - load_records: list, query_records: list + load_records: list, query_records: list, weights: list ) -> T.Tuple[T.List[dict], T.Union[str, None]]: """Processes the query results for the similarity selection strategy using Levenshtein algorithm for small number of records""" closest_records = [] for record in load_records: - closest_record = find_closest_record(record, query_records) + closest_record = find_closest_record(record, query_records, weights) closest_records.append( {"id": closest_record[0], "success": True, "created": False} ) @@ -324,12 +365,12 @@ def random_post_process( return selected_records, None -def find_closest_record(load_record: list, query_records: list): +def find_closest_record(load_record: list, query_records: list, weights: list): closest_distance = float("inf") closest_record = query_records[0] for record in query_records: - distance = calculate_levenshtein_distance(load_record, record[1:]) + distance = calculate_levenshtein_distance(load_record, record[1:], weights) if distance < closest_distance: closest_distance = distance closest_record = record @@ -361,15 +402,16 @@ def levenshtein_distance(str1: str, str2: str): return dp[-1][-1] -def calculate_levenshtein_distance(record1: list, record2: list): +def calculate_levenshtein_distance(record1: list, record2: list, weights: list): if len(record1) != len(record2): raise ValueError("Records must have the same number of fields.") + elif len(record1) != len(weights): + raise ValueError("Records must be same size as fields (weights).") total_distance = 0 total_fields = 0 - for field1, field2 in zip(record1, record2): - + for field1, field2, weight in zip(record1, record2, weights): field1 = field1.lower() field2 = field2.lower() @@ -382,7 +424,8 @@ def calculate_levenshtein_distance(record1: list, record2: list): # If one field is blank, reduce the impact of the distance distance = distance * 0.05 # Fixed value for blank vs non-blank - total_distance += distance + # Multiply the distance by the corresponding weight + total_distance += distance * weight total_fields += 1 return total_distance / total_fields if total_fields > 0 else 0 @@ -428,38 +471,57 @@ def add_limit_offset_to_user_filter( return f" {filter_clause}" -def determine_field_types(df): +def determine_field_types(df, weights): numerical_features = [] boolean_features = [] categorical_features = [] - for col in df.columns: + numerical_weights = [] + boolean_weights = [] + categorical_weights = [] + + for col, weight in zip(df.columns, weights): # Check if the column can be converted to numeric try: # Attempt to convert to numeric df[col] = pd.to_numeric(df[col], errors="raise") numerical_features.append(col) + numerical_weights.append(weight) except ValueError: # Check for boolean values if df[col].str.lower().isin(["true", "false"]).all(): # Map to actual boolean values df[col] = df[col].str.lower().map({"true": True, "false": False}) boolean_features.append(col) + boolean_weights.append(weight) else: categorical_features.append(col) - - return numerical_features, boolean_features, categorical_features + categorical_weights.append(weight) + + return ( + numerical_features, + boolean_features, + categorical_features, + numerical_weights, + boolean_weights, + categorical_weights, + ) -def vectorize_records(db_records, query_records, hash_features): +def vectorize_records(db_records, query_records, hash_features, weights): # Convert database records and query records to DataFrames df_db = pd.DataFrame(db_records) df_query = pd.DataFrame(query_records) - # Dynamically determine field types - numerical_features, boolean_features, categorical_features = determine_field_types( - df_db - ) + # Determine field types and corresponding weights + ( + numerical_features, + boolean_features, + categorical_features, + numerical_weights, + boolean_weights, + categorical_weights, + ) = determine_field_types(df_db, weights) # Fit StandardScaler on the numerical features of the database records scaler = StandardScaler() @@ -474,24 +536,26 @@ def vectorize_records(db_records, query_records, hash_features): # For db_records hashed_categorical_data_db = [] - for col in categorical_features: + for idx, col in enumerate(categorical_features): hashed_db = hashing_vectorizer.fit_transform(df_db[col]).toarray() - hashed_categorical_data_db.append(hashed_db) + # Apply weight to the hashed vector for this categorical feature + hashed_db_weighted = hashed_db * categorical_weights[idx] + hashed_categorical_data_db.append(hashed_db_weighted) # For query_records hashed_categorical_data_query = [] - for col in categorical_features: + for idx, col in enumerate(categorical_features): hashed_query = hashing_vectorizer.transform(df_query[col]).toarray() - hashed_categorical_data_query.append(hashed_query) + # Apply weight to the hashed vector for this categorical feature + hashed_query_weighted = hashed_query * categorical_weights[idx] + hashed_categorical_data_query.append(hashed_query_weighted) # Combine all feature types into a single vector for the database records db_vectors = [] if numerical_features: - db_vectors.append(df_db[numerical_features].values) + db_vectors.append(df_db[numerical_features].values * numerical_weights) if boolean_features: - db_vectors.append( - df_db[boolean_features].astype(int).values - ) # Convert boolean to int + db_vectors.append(df_db[boolean_features].astype(int).values * boolean_weights) if hashed_categorical_data_db: db_vectors.append(np.hstack(hashed_categorical_data_db)) @@ -501,11 +565,11 @@ def vectorize_records(db_records, query_records, hash_features): # Combine all feature types into a single vector for the query records query_vectors = [] if numerical_features: - query_vectors.append(df_query[numerical_features].values) + query_vectors.append(df_query[numerical_features].values * numerical_weights) if boolean_features: query_vectors.append( - df_query[boolean_features].astype(int).values - ) # Convert boolean to int + df_query[boolean_features].astype(int).values * boolean_weights + ) if hashed_categorical_data_query: query_vectors.append(np.hstack(hashed_categorical_data_query)) diff --git a/cumulusci/tasks/bulkdata/step.py b/cumulusci/tasks/bulkdata/step.py index cb86bda6fa..ba0243c033 100644 --- a/cumulusci/tasks/bulkdata/step.py +++ b/cumulusci/tasks/bulkdata/step.py @@ -28,6 +28,8 @@ DEFAULT_BULK_BATCH_SIZE = 10_000 DEFAULT_REST_BATCH_SIZE = 200 MAX_REST_BATCH_SIZE = 200 +HIGH_PRIORITY_VALUE = 3 +LOW_PRIORITY_VALUE = 0.5 csv.field_size_limit(2**27) # 128 MB @@ -352,6 +354,7 @@ def __init__( fields, selection_strategy=SelectStrategy.STANDARD, selection_filter=None, + selection_priority_fields=None, content_type=None, ): super().__init__( @@ -370,6 +373,9 @@ def __init__( self.select_operation_executor = SelectOperationExecutor(selection_strategy) self.selection_filter = selection_filter + self.weights = assign_weights( + priority_fields=selection_priority_fields, fields=fields + ) self.content_type = content_type if content_type else "CSV" def start(self): @@ -494,6 +500,7 @@ def select_records(self, records): query_records=query_records, num_records=total_num_records, sobject=self.sobject, + weights=self.weights, ) if not error_message: self.select_results.extend(selected_records) @@ -651,6 +658,7 @@ def __init__( fields, selection_strategy=SelectStrategy.STANDARD, selection_filter=None, + selection_priority_fields=None, content_type=None, ): super().__init__( @@ -679,6 +687,9 @@ def __init__( self.select_operation_executor = SelectOperationExecutor(selection_strategy) self.selection_filter = selection_filter + self.weights = assign_weights( + priority_fields=selection_priority_fields, fields=fields + ) self.content_type = content_type def _record_to_json(self, rec): @@ -850,6 +861,7 @@ def select_records(self, records): query_records=query_records, num_records=total_num_records, sobject=self.sobject, + weights=self.weights, ) if not error_message: # Add selected records from this batch to the overall results @@ -1003,6 +1015,7 @@ def get_dml_operation( api: Optional[DataApi] = DataApi.SMART, selection_strategy: SelectStrategy = SelectStrategy.STANDARD, selection_filter: Union[str, None] = None, + selection_priority_fields: Union[dict, None] = None, content_type: Union[str, None] = None, ) -> BaseDmlOperation: """Create an appropriate DmlOperation instance for the given parameters, selecting @@ -1039,6 +1052,7 @@ def get_dml_operation( fields=fields, selection_strategy=selection_strategy, selection_filter=selection_filter, + selection_priority_fields=selection_priority_fields, content_type=content_type, ) @@ -1109,3 +1123,22 @@ def flatten_record(record, headers): flat_record.append(value) return flat_record + + +def assign_weights( + priority_fields: Union[Dict[str, str], None], fields: List[str] +) -> list: + # If priority_fields is None or an empty dictionary, set all weights to 1 + if not priority_fields: + return [1] * len(fields) + + # Initialize the weight list with LOW_PRIORITY_VALUE + weights = [LOW_PRIORITY_VALUE] * len(fields) + + # Iterate over the fields and assign weights based on priority_fields + for i, field in enumerate(fields): + if field in priority_fields: + # Set weight to HIGH_PRIORITY_VALUE if field is in priority_fields + weights[i] = HIGH_PRIORITY_VALUE + + return weights diff --git a/cumulusci/tasks/bulkdata/utils.py b/cumulusci/tasks/bulkdata/utils.py index b5c195a817..cee6a4ab66 100644 --- a/cumulusci/tasks/bulkdata/utils.py +++ b/cumulusci/tasks/bulkdata/utils.py @@ -5,6 +5,7 @@ from contextlib import contextmanager, nullcontext from pathlib import Path +from requests.structures import CaseInsensitiveDict as RequestsCaseInsensitiveDict from simple_salesforce import Salesforce from sqlalchemy import Boolean, Column, MetaData, Table, Unicode, inspect from sqlalchemy.engine.base import Connection @@ -23,6 +24,19 @@ class DataApi(StrEnum): SMART = "smart" +class CaseInsensitiveDict(RequestsCaseInsensitiveDict): + def __init__(self, *args, **kwargs): + self._canonical_keys = {} + super().__init__(*args, **kwargs) + + def canonical_key(self, name): + return self._canonical_keys[name.lower()] + + def __setitem__(self, key, value): + super().__setitem__(key, value) + self._canonical_keys[key.lower()] = key + + class SqlAlchemyMixin: logger: logging.Logger metadata: MetaData From 0ac200032597aa06079f261b3122b0cf2bef46a8 Mon Sep 17 00:00:00 2001 From: Jawadtp Date: Fri, 8 Nov 2024 19:01:54 +0530 Subject: [PATCH 26/34] Add tests for annoy_post_process --- .../tasks/bulkdata/tests/test_select_utils.py | 339 +++++++++++++----- 1 file changed, 244 insertions(+), 95 deletions(-) diff --git a/cumulusci/tasks/bulkdata/tests/test_select_utils.py b/cumulusci/tasks/bulkdata/tests/test_select_utils.py index fe037a0177..26768d4ea1 100644 --- a/cumulusci/tasks/bulkdata/tests/test_select_utils.py +++ b/cumulusci/tasks/bulkdata/tests/test_select_utils.py @@ -1,11 +1,16 @@ +import pandas as pd import pytest from cumulusci.tasks.bulkdata.select_utils import ( SelectOperationExecutor, SelectStrategy, + annoy_post_process, calculate_levenshtein_distance, + determine_field_types, find_closest_record, levenshtein_distance, + replace_empty_strings_with_missing, + vectorize_records, ) @@ -193,107 +198,56 @@ def test_levenshtein_distance(): ) # Longer strings with multiple differences -def test_calculate_levenshtein_distance(): - # Identical records - record1 = ["Tom Cruise", "24", "Actor"] - record2 = ["Tom Cruise", "24", "Actor"] - assert calculate_levenshtein_distance(record1, record2) == 0 # Distance should be 0 - - # Records with one different field - record1 = ["Tom Cruise", "24", "Actor"] - record2 = ["Tom Hanks", "24", "Actor"] - assert calculate_levenshtein_distance(record1, record2) > 0 # Non-zero distance - - # One record has an empty field - record1 = ["Tom Cruise", "24", "Actor"] - record2 = ["Tom Cruise", "", "Actor"] - assert ( - calculate_levenshtein_distance(record1, record2) > 0 - ) # Distance should reflect the empty field - - # Completely empty records - record1 = ["", "", ""] - record2 = ["", "", ""] - assert calculate_levenshtein_distance(record1, record2) == 0 # Distance should be 0 - - -def test_calculate_levenshtein_distance_error(): - # Identical records - record1 = ["Tom Cruise", "24", "Actor"] - record2 = [ - "Tom Cruise", - "24", - "Actor", - "SomethingElse", - ] # Record Length does not match - with pytest.raises(ValueError) as e: - calculate_levenshtein_distance(record1, record2) - assert "Records must have the same number of fields" in str(e.value) - - -def test_find_closest_record(): - # Test case 1: Exact match - load_record = ["Tom Cruise", "62", "Actor"] - query_records = [ - [1, "Tom Hanks", "30", "Actor"], - [2, "Tom Cruise", "62", "Actor"], # Exact match - [3, "Jennifer Aniston", "30", "Actress"], - ] - assert find_closest_record(load_record, query_records) == [ - 2, - "Tom Cruise", - "62", - "Actor", - ] # Should return the exact match - - # Test case 2: Closest match with slight differences - load_record = ["Tom Cruise", "62", "Actor"] +def test_find_closest_record_different_weights(): + load_record = ["hello", "world"] query_records = [ - [1, "Tom Hanks", "62", "Actor"], - [2, "Tom Cruise", "63", "Actor"], # Slight difference - [3, "Jennifer Aniston", "30", "Actress"], + ["record1", "hello", "word"], # Levenshtein distance = 1 + ["record2", "hullo", "word"], # Levenshtein distance = 1 + ["record3", "hello", "word"], # Levenshtein distance = 1 ] - assert find_closest_record(load_record, query_records) == [ - 2, - "Tom Cruise", - "63", - "Actor", - ] # Should return the closest match - - # Test case 3: All records are significantly different - load_record = ["Tom Cruise", "62", "Actor"] + weights = [2.0, 0.5] + + # With different weights, the first field will have more impact + closest_record = find_closest_record(load_record, query_records, weights) + assert closest_record == [ + "record1", + "hello", + "word", + ], "The closest record should be 'record1'." + + +def test_find_closest_record_basic(): + load_record = ["hello", "world"] query_records = [ - [1, "Brad Pitt", "30", "Producer"], - [2, "Leonardo DiCaprio", "40", "Director"], - [3, "Jennifer Aniston", "30", "Actress"], + ["record1", "hello", "word"], # Levenshtein distance = 1 + ["record2", "hullo", "word"], # Levenshtein distance = 1 + ["record3", "hello", "word"], # Levenshtein distance = 1 ] - assert ( - find_closest_record(load_record, query_records) == query_records[0] - ) # Should return the first record as the closest (though none are close) + weights = [1.0, 1.0] + + closest_record = find_closest_record(load_record, query_records, weights) + assert closest_record == [ + "record1", + "hello", + "word", + ], "The closest record should be 'record1'." - # Test case 4: Closest match is the last in the list - load_record = ["Tom Cruise", "62", "Actor"] + +def test_find_closest_record_multiple_matches(): + load_record = ["cat", "dog"] query_records = [ - [1, "Johnny Depp", "50", "Actor"], - [2, "Brad Pitt", "30", "Producer"], - [3, "Tom Cruise", "62", "Actor"], # Exact match as the last record + ["record1", "bat", "dog"], # Levenshtein distance = 1 + ["record2", "cat", "dog"], # Levenshtein distance = 0 + ["record3", "dog", "cat"], # Levenshtein distance = 3 ] - assert find_closest_record(load_record, query_records) == [ - 3, - "Tom Cruise", - "62", - "Actor", - ] # Should return the last record - - # Test case 5: Single record in query_records - load_record = ["Tom Cruise", "62", "Actor"] - query_records = [[1, "Johnny Depp", "50", "Actor"]] - assert find_closest_record(load_record, query_records) == [ - 1, - "Johnny Depp", - "50", - "Actor", - ] # Should return the only record available + weights = [1.0, 1.0] + + closest_record = find_closest_record(load_record, query_records, weights) + assert closest_record == [ + "record2", + "cat", + "dog", + ], "The closest record should be 'record2'." def test_similarity_post_process_with_records(): @@ -307,10 +261,16 @@ def test_similarity_post_process_with_records(): ["003", "Jennifer Aniston", "30", "Actress"], ] + weights = [1.0, 1.0, 1.0] # Adjust weights to match your data structure + selected_records, error_message = select_operator.select_post_process( - load_records, query_records, num_records, sobject + load_records, query_records, num_records, sobject, weights ) + # selected_records, error_message = select_operator.select_post_process( + # load_records, query_records, num_records, sobject + # ) + assert error_message is None assert len(selected_records) == num_records assert all(record["success"] for record in selected_records) @@ -329,3 +289,192 @@ def test_similarity_post_process_with_no_records(): assert selected_records == [] assert error_message == f"No records found for {sobject} in the target org." + + +def test_calculate_levenshtein_distance_basic(): + record1 = ["hello", "world"] + record2 = ["hullo", "word"] + weights = [1.0, 1.0] + + # Expected distance based on simple Levenshtein distances + # Levenshtein("hello", "hullo") = 1, Levenshtein("world", "word") = 1 + expected_distance = (1 * 1.0 + 1 * 1.0) / 2 # Averaged over two fields + + result = calculate_levenshtein_distance(record1, record2, weights) + assert result == pytest.approx( + expected_distance + ), "Basic distance calculation failed." + + +def test_calculate_levenshtein_distance_weighted(): + record1 = ["cat", "dog"] + record2 = ["bat", "fog"] + weights = [2.0, 0.5] + + # Levenshtein("cat", "bat") = 1, Levenshtein("dog", "fog") = 1 + expected_distance = (1 * 2.0 + 1 * 0.5) / 2 # Weighted average over two fields + + result = calculate_levenshtein_distance(record1, record2, weights) + assert result == pytest.approx( + expected_distance + ), "Weighted distance calculation failed." + + +def test_replace_empty_strings_with_missing(): + # Case 1: Normal case with some empty strings + records = [ + ["Alice", "", "New York"], + ["Bob", "Engineer", ""], + ["", "Teacher", "Chicago"], + ] + expected = [ + ["Alice", "missing", "New York"], + ["Bob", "Engineer", "missing"], + ["missing", "Teacher", "Chicago"], + ] + assert replace_empty_strings_with_missing(records) == expected + + # Case 2: No empty strings, so the output should be the same as input + records = [["Alice", "Manager", "New York"], ["Bob", "Engineer", "San Francisco"]] + expected = [["Alice", "Manager", "New York"], ["Bob", "Engineer", "San Francisco"]] + assert replace_empty_strings_with_missing(records) == expected + + # Case 3: List with all empty strings + records = [["", "", ""], ["", "", ""]] + expected = [["missing", "missing", "missing"], ["missing", "missing", "missing"]] + assert replace_empty_strings_with_missing(records) == expected + + # Case 4: Empty list (should return an empty list) + records = [] + expected = [] + assert replace_empty_strings_with_missing(records) == expected + + # Case 5: List with some empty sublists + records = [[], ["Alice", ""], []] + expected = [[], ["Alice", "missing"], []] + assert replace_empty_strings_with_missing(records) == expected + + +def test_all_numeric_columns(): + df = pd.DataFrame({"A": [1, 2, 3], "B": [4.5, 5.5, 6.5]}) + weights = [0.1, 0.2] + expected_output = ( + ["A", "B"], # numerical_features + [], # boolean_features + [], # categorical_features + [0.1, 0.2], # numerical_weights + [], # boolean_weights + [], # categorical_weights + ) + assert determine_field_types(df, weights) == expected_output + + +def test_all_boolean_columns(): + df = pd.DataFrame({"A": ["true", "false", "true"], "B": ["false", "true", "false"]}) + weights = [0.3, 0.4] + expected_output = ( + [], # numerical_features + ["A", "B"], # boolean_features + [], # categorical_features + [], # numerical_weights + [0.3, 0.4], # boolean_weights + [], # categorical_weights + ) + assert determine_field_types(df, weights) == expected_output + + +def test_all_categorical_columns(): + df = pd.DataFrame( + {"A": ["apple", "banana", "cherry"], "B": ["dog", "cat", "mouse"]} + ) + weights = [0.5, 0.6] + expected_output = ( + [], # numerical_features + [], # boolean_features + ["A", "B"], # categorical_features + [], # numerical_weights + [], # boolean_weights + [0.5, 0.6], # categorical_weights + ) + assert determine_field_types(df, weights) == expected_output + + +def test_mixed_types(): + df = pd.DataFrame( + { + "A": [1, 2, 3], + "B": ["true", "false", "true"], + "C": ["apple", "banana", "cherry"], + } + ) + weights = [0.7, 0.8, 0.9] + expected_output = ( + ["A"], # numerical_features + ["B"], # boolean_features + ["C"], # categorical_features + [0.7], # numerical_weights + [0.8], # boolean_weights + [0.9], # categorical_weights + ) + assert determine_field_types(df, weights) == expected_output + + +def test_vectorize_records_mixed_numerical_categorical(): + # Test data with mixed types: numerical and categorical only + db_records = [["1.0", "apple"], ["2.0", "banana"]] + query_records = [["1.5", "apple"], ["2.5", "cherry"]] + weights = [1.0, 1.0] # Equal weights for numerical and categorical columns + hash_features = 4 # Number of hashing vectorizer features for categorical columns + + final_db_vectors, final_query_vectors = vectorize_records( + db_records, query_records, hash_features, weights + ) + + # Check the shape of the output vectors + assert final_db_vectors.shape[0] == len(db_records), "DB vectors row count mismatch" + assert final_query_vectors.shape[0] == len( + query_records + ), "Query vectors row count mismatch" + + # Expected dimensions: numerical (1) + categorical hashed features (4) + expected_feature_count = 1 + hash_features + assert ( + final_db_vectors.shape[1] == expected_feature_count + ), "DB vectors column count mismatch" + assert ( + final_query_vectors.shape[1] == expected_feature_count + ), "Query vectors column count mismatch" + + +def test_annoy_post_process(): + # Test data + load_records = [["Alice", "Engineer"], ["Bob", "Doctor"]] + query_records = [["q1", "Alice", "Engineer"], ["q2", "Charlie", "Artist"]] + weights = [1.0, 1.0, 1.0] # Example weights + + closest_records, error = annoy_post_process(load_records, query_records, weights) + + # Assert the closest records + assert ( + len(closest_records) == 2 + ) # We expect two results (one for each query record) + assert ( + closest_records[0]["id"] == "q1" + ) # The first query record should match the first load record + + # No errors expected + assert error is None + + +def test_single_record_match_annoy_post_process(): + # Mock data where only the first query record matches the first load record + load_records = [["Alice", "Engineer"], ["Bob", "Doctor"]] + query_records = [["q1", "Alice", "Engineer"]] + weights = [1.0, 1.0, 1.0] + + closest_records, error = annoy_post_process(load_records, query_records, weights) + + # Both the load records should be matched with the only query record we have + assert len(closest_records) == 2 + assert closest_records[0]["id"] == "q1" + assert error is None From 730ba6ca3d38b04c8cf80f49700cf0883dc1dccb Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Mon, 11 Nov 2024 20:16:30 +0530 Subject: [PATCH 27/34] Add tests for parent level similarity and priority fields --- cumulusci/core/tests/test_datasets_e2e.py | 4 + ...generate_load_mapping_from_declarations.py | 17 + cumulusci/tasks/bulkdata/load.py | 167 +-- cumulusci/tasks/bulkdata/mapping_parser.py | 5 +- .../tasks/bulkdata/query_transformers.py | 32 +- cumulusci/tasks/bulkdata/select_utils.py | 20 +- cumulusci/tasks/bulkdata/step.py | 55 +- .../tasks/bulkdata/tests/mapping_select.yml | 20 + .../tests/mapping_select_invalid_strategy.yml | 20 + ...mapping_select_missing_priority_fields.yml | 22 + .../mapping_select_no_priority_fields.yml | 18 + cumulusci/tasks/bulkdata/tests/test_load.py | 114 ++ .../bulkdata/tests/test_mapping_parser.py | 36 + .../tests/test_query_db_joins_lookups.sql | 16 +- .../test_query_db_joins_lookups_select.yml | 48 + .../tasks/bulkdata/tests/test_select_utils.py | 209 +++- cumulusci/tasks/bulkdata/tests/test_step.py | 1091 ++++++++++++----- 17 files changed, 1442 insertions(+), 452 deletions(-) create mode 100644 cumulusci/tasks/bulkdata/tests/mapping_select.yml create mode 100644 cumulusci/tasks/bulkdata/tests/mapping_select_invalid_strategy.yml create mode 100644 cumulusci/tasks/bulkdata/tests/mapping_select_missing_priority_fields.yml create mode 100644 cumulusci/tasks/bulkdata/tests/mapping_select_no_priority_fields.yml create mode 100644 cumulusci/tasks/bulkdata/tests/test_query_db_joins_lookups_select.yml diff --git a/cumulusci/core/tests/test_datasets_e2e.py b/cumulusci/core/tests/test_datasets_e2e.py index c5140d3609..387ad696ad 100644 --- a/cumulusci/core/tests/test_datasets_e2e.py +++ b/cumulusci/core/tests/test_datasets_e2e.py @@ -304,6 +304,7 @@ def write_yaml(filename: str, json: Any): "after": "Insert Account", } }, + "select_options": {}, }, "Insert Event": { "sf_object": "Event", @@ -316,16 +317,19 @@ def write_yaml(filename: str, json: Any): "after": "Insert Lead", } }, + "select_options": {}, }, "Insert Account": { "sf_object": "Account", "table": "Account", "fields": ["Name"], + "select_options": {}, }, "Insert Lead": { "sf_object": "Lead", "table": "Lead", "fields": ["Company", "LastName"], + "select_options": {}, }, } assert tuple(actual.items()) == tuple(expected.items()), actual.items() diff --git a/cumulusci/tasks/bulkdata/generate_mapping_utils/tests/test_generate_load_mapping_from_declarations.py b/cumulusci/tasks/bulkdata/generate_mapping_utils/tests/test_generate_load_mapping_from_declarations.py index 7dbaefc740..69dd0e361d 100644 --- a/cumulusci/tasks/bulkdata/generate_mapping_utils/tests/test_generate_load_mapping_from_declarations.py +++ b/cumulusci/tasks/bulkdata/generate_mapping_utils/tests/test_generate_load_mapping_from_declarations.py @@ -41,6 +41,7 @@ def test_simple_generate_mapping_from_declarations(self, org_config): "sf_object": "Account", "table": "Account", "fields": ["Name", "Description"], + "select_options": {}, } } @@ -74,11 +75,13 @@ def test_generate_mapping_from_both_kinds_of_declarations(self, org_config): "sf_object": "Contact", "table": "Contact", "fields": ["FirstName", "LastName"], + "select_options": {}, }, "Insert Account": { "sf_object": "Account", "table": "Account", "fields": ["Name", "Description"], + "select_options": {}, }, }.items() ) @@ -111,6 +114,7 @@ def test_generate_load_mapping_from_declarations__lookups(self, org_config): "sf_object": "Account", "table": "Account", "fields": ["Name", "Description"], + "select_options": {}, }, "Insert Contact": { "sf_object": "Contact", @@ -119,6 +123,7 @@ def test_generate_load_mapping_from_declarations__lookups(self, org_config): "lookups": { "AccountId": {"table": ["Account"], "key_field": "AccountId"} }, + "select_options": {}, }, } @@ -157,6 +162,7 @@ def test_generate_load_mapping_from_declarations__polymorphic_lookups( "sf_object": "Account", "table": "Account", "fields": ["Name", "Description"], + "select_options": {}, }, "Insert Contact": { "sf_object": "Contact", @@ -165,11 +171,13 @@ def test_generate_load_mapping_from_declarations__polymorphic_lookups( "lookups": { "AccountId": {"table": ["Account"], "key_field": "AccountId"} }, + "select_options": {}, }, "Insert Lead": { "sf_object": "Lead", "table": "Lead", "fields": ["LastName", "Company"], + "select_options": {}, }, "Insert Event": { "sf_object": "Event", @@ -178,6 +186,7 @@ def test_generate_load_mapping_from_declarations__polymorphic_lookups( "lookups": { "WhoId": {"table": ["Contact", "Lead"], "key_field": "WhoId"} }, + "select_options": {}, }, } @@ -221,6 +230,7 @@ def test_generate_load_mapping_from_declarations__circular_lookups( }, "sf_object": "Account", "table": "Account", + "select_options": {}, }, "Insert Contact": { "sf_object": "Contact", @@ -229,6 +239,7 @@ def test_generate_load_mapping_from_declarations__circular_lookups( "lookups": { "AccountId": {"table": ["Account"], "key_field": "AccountId"} }, + "select_options": {}, }, }, mf @@ -252,11 +263,13 @@ def test_generate_load_mapping__with_load_declarations(self, org_config): "sf_object": "Account", "api": DataApi.REST, "table": "Account", + "select_options": {}, }, "Insert Contact": { "sf_object": "Contact", "api": DataApi.BULK, "table": "Contact", + "select_options": {}, }, }, mf @@ -288,6 +301,7 @@ def test_generate_load_mapping__with_upserts(self, org_config): "Insert Account": { "sf_object": "Account", "table": "Account", + "select_options": {}, }, "Upsert Account Name": { "sf_object": "Account", @@ -295,6 +309,7 @@ def test_generate_load_mapping__with_upserts(self, org_config): "action": DataOperationType.UPSERT, "update_key": ("Name",), "fields": ["Name"], + "select_options": {}, }, "Etl_Upsert Account AccountNumber_Name": { "sf_object": "Account", @@ -302,10 +317,12 @@ def test_generate_load_mapping__with_upserts(self, org_config): "action": DataOperationType.ETL_UPSERT, "update_key": ("AccountNumber", "Name"), "fields": ["AccountNumber", "Name"], + "select_options": {}, }, "Insert Contact": { "sf_object": "Contact", "table": "Contact", + "select_options": {}, }, }, mf diff --git a/cumulusci/tasks/bulkdata/load.py b/cumulusci/tasks/bulkdata/load.py index ced885744b..9a2f08ee90 100644 --- a/cumulusci/tasks/bulkdata/load.py +++ b/cumulusci/tasks/bulkdata/load.py @@ -310,6 +310,90 @@ def _execute_step( return step.job_result + def process_lookup_fields(self, mapping, fields, polymorphic_fields): + """Modify fields and priority fields based on lookup and polymorphic checks.""" + for name, lookup in mapping.lookups.items(): + if name in fields: + # Get the index of the lookup field before removing it + insert_index = fields.index(name) + # Remove the lookup field from fields + fields.remove(name) + + # Do the same for priority fields + lookup_in_priority_fields = False + if name in mapping.select_options.priority_fields: + # Set flag to True + lookup_in_priority_fields = True + # Remove the lookup field from priority fields + del mapping.select_options.priority_fields[name] + + # Check if this lookup field is polymorphic + if ( + name in polymorphic_fields + and len(polymorphic_fields[name]["referenceTo"]) > 1 + ): + # Convert to list if string + if not isinstance(lookup.table, list): + lookup.table = [lookup.table] + # Polymorphic field handling + polymorphic_references = lookup.table + relationship_name = polymorphic_fields[name]["relationshipName"] + + # Loop through each polymorphic type (e.g., Contact, Lead) + for ref_type in polymorphic_references: + # Find the mapping step for this polymorphic type + lookup_mapping_step = next( + ( + step + for step in self.mapping.values() + if step.table == ref_type + ), + None, + ) + if lookup_mapping_step: + lookup_fields = lookup_mapping_step.get_load_field_list() + # Insert fields in the format {relationship_name}.{ref_type}.{lookup_field} + for field in lookup_fields: + fields.insert( + insert_index, + f"{relationship_name}.{lookup_mapping_step.sf_object}.{field}", + ) + insert_index += 1 + if lookup_in_priority_fields: + mapping.select_options.priority_fields[ + f"{relationship_name}.{lookup_mapping_step.sf_object}.{field}" + ] = f"{relationship_name}.{lookup_mapping_step.sf_object}.{field}" + + else: + # Non-polymorphic field handling + lookup_table = lookup.table + + if isinstance(lookup_table, list): + lookup_table = lookup_table[0] + + # Get the mapping step for the non-polymorphic reference + lookup_mapping_step = next( + ( + step + for step in self.mapping.values() + if step.table == lookup_table + ), + None, + ) + + if lookup_mapping_step: + relationship_name = polymorphic_fields[name]["relationshipName"] + lookup_fields = lookup_mapping_step.get_load_field_list() + + # Insert the new fields at the same position as the removed lookup field + for field in lookup_fields: + fields.insert(insert_index, f"{relationship_name}.{field}") + insert_index += 1 + if lookup_in_priority_fields: + mapping.select_options.priority_fields[ + f"{relationship_name}.{field}" + ] = f"{relationship_name}.{field}" + def configure_step(self, mapping): """Create a step appropriate to the action""" bulk_mode = mapping.bulk_mode or self.bulk_mode or "Parallel" @@ -370,85 +454,7 @@ def configure_step(self, mapping): for field in describe_result["fields"] if field["type"] == "reference" } - - # Loop through each lookup to get the corresponding fields - for name, lookup in mapping.lookups.items(): - if name in fields: - # Get the index of the lookup field before removing it - insert_index = fields.index(name) - # Remove the lookup field from fields - fields.remove(name) - - # Check if this lookup field is polymorphic - if ( - name in polymorphic_fields - and len(polymorphic_fields[name]["referenceTo"]) > 1 - ): - # Convert to list if string - if not isinstance(lookup.table, list): - lookup.table = [lookup.table] - # Polymorphic field handling - polymorphic_references = lookup.table - relationship_name = polymorphic_fields[name][ - "relationshipName" - ] - - # Loop through each polymorphic type (e.g., Contact, Lead) - for ref_type in polymorphic_references: - # Find the mapping step for this polymorphic type - lookup_mapping_step = next( - ( - step - for step in self.mapping.values() - if step.sf_object == ref_type - ), - None, - ) - - if lookup_mapping_step: - lookup_fields = ( - lookup_mapping_step.get_load_field_list() - ) - # Insert fields in the format {relationship_name}.{ref_type}.{lookup_field} - for field in lookup_fields: - fields.insert( - insert_index, - f"{relationship_name}.{lookup_mapping_step.sf_object}.{field}", - ) - insert_index += 1 - - else: - # Non-polymorphic field handling - lookup_table = lookup.table - - if isinstance(lookup_table, list): - lookup_table = lookup_table[0] - - # Get the mapping step for the non-polymorphic reference - lookup_mapping_step = next( - ( - step - for step in self.mapping.values() - if step.sf_object == lookup_table - ), - None, - ) - - if lookup_mapping_step: - relationship_name = polymorphic_fields[name][ - "relationshipName" - ] - lookup_fields = ( - lookup_mapping_step.get_load_field_list() - ) - - # Insert the new fields at the same position as the removed lookup field - for field in lookup_fields: - fields.insert( - insert_index, f"{relationship_name}.{field}" - ) - insert_index += 1 - + self.process_lookup_fields(mapping, fields, polymorphic_fields) else: action = mapping.action @@ -503,9 +509,6 @@ def _stream_queried_data(self, mapping, local_ids, query): pkey = row[0] row = list(row[1:]) + statics - # Replace None values in row with empty strings - row = [value if value is not None else "" for value in row] - if mapping.anchor_date and (date_context[0] or date_context[1]): row = adjust_relative_dates( mapping, date_context, row, DataOperationType.INSERT diff --git a/cumulusci/tasks/bulkdata/mapping_parser.py b/cumulusci/tasks/bulkdata/mapping_parser.py index 6bad4f7bdd..1593dc97a1 100644 --- a/cumulusci/tasks/bulkdata/mapping_parser.py +++ b/cumulusci/tasks/bulkdata/mapping_parser.py @@ -124,16 +124,19 @@ def split_update_key(cls, val): def validate_priority_fields(cls, values): select_options = values.get("select_options") fields_ = values.get("fields_", {}) + lookups = values.get("lookups", {}) if select_options and select_options.priority_fields: priority_field_names = set(select_options.priority_fields.keys()) field_names = set(fields_.keys()) + lookup_names = set(lookups.keys()) # Check if all priority fields are present in the fields missing_fields = priority_field_names - field_names + missing_fields = missing_fields - lookup_names if missing_fields: raise ValueError( - f"Priority fields {missing_fields} are not present in 'fields'" + f"Priority fields {missing_fields} are not present in 'fields' or 'lookups'" ) return values diff --git a/cumulusci/tasks/bulkdata/query_transformers.py b/cumulusci/tasks/bulkdata/query_transformers.py index b4daa4bd93..f99689618e 100644 --- a/cumulusci/tasks/bulkdata/query_transformers.py +++ b/cumulusci/tasks/bulkdata/query_transformers.py @@ -3,6 +3,7 @@ from sqlalchemy import String, and_, func, text from sqlalchemy.orm import Query, aliased +from sqlalchemy.sql import literal_column from cumulusci.core.exceptions import BulkDataException @@ -106,7 +107,10 @@ def columns_to_add(self): for lookup in self.lookups: tables = lookup.table if isinstance(lookup.table, list) else [lookup.table] lookup.aliased_table = [ - aliased(self.metadata.tables[table]) for table in tables + aliased( + self.metadata.tables[table], name=f"{lookup.name}_{table}_alias" + ) + for table in tables ] for aliased_table, table_name in zip(lookup.aliased_table, tables): @@ -122,16 +126,24 @@ def columns_to_add(self): if lookup_mapping_step: load_fields = lookup_mapping_step.get_load_field_list() for field in load_fields: - matching_column = next( - ( - col - for col in aliased_table.columns - if col.name == lookup_mapping_step.fields[field] + if field in lookup_mapping_step.fields: + matching_column = next( + ( + col + for col in aliased_table.columns + if col.name == lookup_mapping_step.fields[field] + ) + ) + columns.append( + matching_column.label(f"{aliased_table.name}_{field}") + ) + else: + # Append an empty string if the field is not present + columns.append( + literal_column("''").label( + f"{aliased_table.name}_{field}" + ) ) - ) - columns.append( - matching_column.label(f"{aliased_table.name}_{field}") - ) return columns @cached_property diff --git a/cumulusci/tasks/bulkdata/select_utils.py b/cumulusci/tasks/bulkdata/select_utils.py index 6de6adf652..f5800f9b38 100644 --- a/cumulusci/tasks/bulkdata/select_utils.py +++ b/cumulusci/tasks/bulkdata/select_utils.py @@ -50,8 +50,12 @@ class SelectOptions(CCIDictModel): def validate_strategy(cls, value): if isinstance(value, Enum): return value - if value is not None: - return ENUM_VALUES.get(value.lower()) + + if value: + matched_strategy = ENUM_VALUES.get(value.lower()) + if matched_strategy: + return matched_strategy + raise ValueError(f"Invalid strategy value: {value}") @validator("priority_fields", pre=True) @@ -260,6 +264,10 @@ def similarity_post_process( return [], error_message load_records = list(load_records) + # Replace None values in each row with empty strings + for idx, row in enumerate(load_records): + row = [value if value is not None else "" for value in row] + load_records[idx] = row load_record_count, query_record_count = len(load_records), len(query_records) complexity_constant = load_record_count * query_record_count @@ -514,6 +522,7 @@ def vectorize_records(db_records, query_records, hash_features, weights): df_query = pd.DataFrame(query_records) # Determine field types and corresponding weights + # Modifies boolean columns to True or False ( numerical_features, boolean_features, @@ -523,6 +532,13 @@ def vectorize_records(db_records, query_records, hash_features, weights): categorical_weights, ) = determine_field_types(df_db, weights) + # Modify query dataframe boolean columns to True or False + for col in df_query.columns: + if df_query[col].str.lower().isin(["true", "false"]).all(): + df_query[col] = ( + df_query[col].str.lower().map({"true": True, "false": False}) + ) + # Fit StandardScaler on the numerical features of the database records scaler = StandardScaler() if numerical_features: diff --git a/cumulusci/tasks/bulkdata/step.py b/cumulusci/tasks/bulkdata/step.py index ba0243c033..3e60ef91c0 100644 --- a/cumulusci/tasks/bulkdata/step.py +++ b/cumulusci/tasks/bulkdata/step.py @@ -14,7 +14,7 @@ import salesforce_bulk from cumulusci.core.enums import StrEnum -from cumulusci.core.exceptions import BulkDataException, SOQLQueryException +from cumulusci.core.exceptions import BulkDataException from cumulusci.core.utils import process_bool_arg from cumulusci.tasks.bulkdata.select_utils import ( SelectOperationExecutor, @@ -879,59 +879,6 @@ def select_records(self, records): total_row_errors=0, ) - def _execute_composite_query(self, select_query, user_query, query_fields): - """Executes a composite request with two queries and returns the results.""" - - def convert(rec, fields): - """Helper function to convert record values to strings, handling None values""" - return [str(rec[f]) if rec[f] is not None else "" for f in fields] - - composite_request_json = { - "compositeRequest": [ - { - "method": "GET", - "url": requests.utils.requote_uri( - f"/services/data/v{self.sf.sf_version}/query/?q={select_query}" - ), - "referenceId": "select_query", - }, - { - "method": "GET", - "url": requests.utils.requote_uri( - f"/services/data/v{self.sf.sf_version}/query/?q={user_query}" - ), - "referenceId": "user_query", - }, - ] - } - response = self.sf.restful( - "composite", method="POST", json=composite_request_json - ) - - # Extract results based on referenceId - for sub_response in response["compositeResponse"]: - if ( - sub_response["referenceId"] == "select_query" - and sub_response["httpStatusCode"] == 200 - ): - select_query_records = list( - convert(rec, query_fields) - for rec in sub_response["body"]["records"] - ) - elif ( - sub_response["referenceId"] == "user_query" - and sub_response["httpStatusCode"] == 200 - ): - user_query_records = list( - convert(rec, ["Id"]) for rec in sub_response["body"]["records"] - ) - else: - raise SOQLQueryException( - f"{sub_response['body'][0]['errorCode']}: {sub_response['body'][0]['message']}" - ) - - return user_query_records, select_query_records - def get_results(self): """Return a generator of DataOperationResult objects.""" diff --git a/cumulusci/tasks/bulkdata/tests/mapping_select.yml b/cumulusci/tasks/bulkdata/tests/mapping_select.yml new file mode 100644 index 0000000000..e549d7a474 --- /dev/null +++ b/cumulusci/tasks/bulkdata/tests/mapping_select.yml @@ -0,0 +1,20 @@ +# Select Mapping File for load +Select Accounts: + api: bulk + action: select + sf_object: Account + table: accounts + select_options: + strategy: similarity + filter: WHEN Name in ('Sample Account') + priority_fields: + Name: name + AccountNumber: account_number + fields: + Name: name + AccountNumber: account_number + Description: description + lookups: + ParentId: + key_field: parent_id + table: accounts diff --git a/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_strategy.yml b/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_strategy.yml new file mode 100644 index 0000000000..6ab196fda6 --- /dev/null +++ b/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_strategy.yml @@ -0,0 +1,20 @@ +# Select Mapping File for load +Select Accounts: + api: bulk + action: select + sf_object: Account + table: accounts + select_options: + strategy: invalid_strategy + filter: WHEN Name in ('Sample Account') + priority_fields: + Name: name + AccountNumber: account_number + fields: + Name: name + AccountNumber: account_number + Description: description + lookups: + ParentId: + key_field: parent_id + table: accounts diff --git a/cumulusci/tasks/bulkdata/tests/mapping_select_missing_priority_fields.yml b/cumulusci/tasks/bulkdata/tests/mapping_select_missing_priority_fields.yml new file mode 100644 index 0000000000..34011945ad --- /dev/null +++ b/cumulusci/tasks/bulkdata/tests/mapping_select_missing_priority_fields.yml @@ -0,0 +1,22 @@ +# Select Mapping File for load +Select Accounts: + api: bulk + action: select + sf_object: Account + table: accounts + select_options: + strategy: similarity + filter: WHEN Name in ('Sample Account') + priority_fields: + - Name + - AccountNumber + - ParentId + - Email + fields: + - Name + - AccountNumber + - Description + lookups: + ParentId: + key_field: parent_id + table: accounts diff --git a/cumulusci/tasks/bulkdata/tests/mapping_select_no_priority_fields.yml b/cumulusci/tasks/bulkdata/tests/mapping_select_no_priority_fields.yml new file mode 100644 index 0000000000..1559848b48 --- /dev/null +++ b/cumulusci/tasks/bulkdata/tests/mapping_select_no_priority_fields.yml @@ -0,0 +1,18 @@ +# Select Mapping File for load +Select Accounts: + api: bulk + action: select + sf_object: Account + table: accounts + select_options: + strategy: similarity + filter: WHEN Name in ('Sample Account') + priority_fields: + fields: + - Name + - AccountNumber + - Description + lookups: + ParentId: + key_field: parent_id + table: accounts diff --git a/cumulusci/tasks/bulkdata/tests/test_load.py b/cumulusci/tasks/bulkdata/tests/test_load.py index 6649ff202e..9fb6ea1d87 100644 --- a/cumulusci/tasks/bulkdata/tests/test_load.py +++ b/cumulusci/tasks/bulkdata/tests/test_load.py @@ -806,6 +806,111 @@ def test_stream_queried_data__skips_empty_rows(self): ["001000000006", "001000000008"], ] == records + def test_process_lookup_fields_polymorphic(self): + task = _make_task( + LoadData, + { + "options": { + "sql_path": Path(__file__).parent + / "test_query_db_joins_lookups.sql", + "mapping": Path(__file__).parent + / "test_query_db_joins_lookups_select.yml", + } + }, + ) + polymorphic_fields = { + "WhoId": { + "name": "WhoId", + "referenceTo": ["Contact", "Lead"], + "relationshipName": "Who", + }, + "WhatId": { + "name": "WhatId", + "referenceTo": ["Account"], + "relationshipName": "What", + }, + } + + expected_fields = [ + "Subject", + "Who.Contact.FirstName", + "Who.Contact.LastName", + "Who.Contact.AccountId", + "Who.Lead.LastName", + ] + expected_priority_fields_keys = { + "Who.Contact.FirstName", + "Who.Contact.LastName", + "Who.Contact.AccountId", + "Who.Lead.LastName", + } + with mock.patch( + "cumulusci.tasks.bulkdata.load.validate_and_inject_mapping" + ), mock.patch.object(task, "sf", create=True): + task._init_mapping() + with task._init_db(): + task._old_format = mock.Mock(return_value=False) + mapping = task.mapping["Select Event"] + fields = mapping.get_load_field_list() + task.process_lookup_fields( + mapping=mapping, fields=fields, polymorphic_fields=polymorphic_fields + ) + assert fields == expected_fields + assert ( + set(mapping.select_options.priority_fields.keys()) + == expected_priority_fields_keys + ) + + def test_process_lookup_fields_non_polymorphic(self): + task = _make_task( + LoadData, + { + "options": { + "sql_path": Path(__file__).parent + / "test_query_db_joins_lookups.sql", + "mapping": Path(__file__).parent + / "test_query_db_joins_lookups_select.yml", + } + }, + ) + non_polymorphic_fields = { + "AccountId": { + "name": "AccountId", + "referenceTo": ["Account"], + "relationshipName": "Account", + } + } + + expected_fields = [ + "FirstName", + "LastName", + "Account.Name", + "Account.AccountNumber", + ] + expected_priority_fields_keys = { + "FirstName", + "Account.Name", + "Account.AccountNumber", + } + with mock.patch( + "cumulusci.tasks.bulkdata.load.validate_and_inject_mapping" + ), mock.patch.object(task, "sf", create=True): + task._init_mapping() + with task._init_db(): + task._old_format = mock.Mock(return_value=False) + mapping = task.mapping["Select Contact"] + fields = mapping.get_load_field_list() + task.process_lookup_fields( + mapping=mapping, + fields=fields, + polymorphic_fields=non_polymorphic_fields, + ) + assert fields == expected_fields + assert ( + set(mapping.select_options.priority_fields.keys()) + == expected_priority_fields_keys + ) + @responses.activate def test_stream_queried_data__adjusts_relative_dates(self): mock_describe_calls() @@ -878,6 +983,15 @@ def test_query_db__joins_self_lookups(self): old_format=True, ) + def test_query_db__joins_select_lookups(self): + """SQL File in New Format (Select)""" + _validate_query_for_mapping_step( + sql_path=Path(__file__).parent / "test_query_db_joins_lookups.sql", + mapping=Path(__file__).parent / "test_query_db_joins_lookups_select.yml", + mapping_step_name="Select Event", + expected='''SELECT events.id AS events_id, events."subject" AS "events_subject", "whoid_contacts_alias"."firstname" AS "whoid_contacts_alias_firstname", "whoid_contacts_alias"."lastname" AS "whoid_contacts_alias_lastname", '' AS "whoid_contacts_alias_accountid", "whoid_leads_alias"."lastname" AS "whoid_leads_alias_lastname" from events LEFT OUTER JOIN contacts AS "whoid_contacts_alias" ON "whoid_contacts_alias".id=events."whoid" LEFT OUTER JOIN leads AS "whoid_leads_alias" ON "whoid_leads_alias".id=events."whoid" ORDER BY events."whoid"''', + ) + def test_query_db__joins_polymorphic_lookups(self): """SQL File in New Format (Polymorphic)""" _validate_query_for_mapping_step( diff --git a/cumulusci/tasks/bulkdata/tests/test_mapping_parser.py b/cumulusci/tasks/bulkdata/tests/test_mapping_parser.py index c1419f300b..ae9fe91686 100644 --- a/cumulusci/tasks/bulkdata/tests/test_mapping_parser.py +++ b/cumulusci/tasks/bulkdata/tests/test_mapping_parser.py @@ -17,6 +17,7 @@ parse_from_yaml, validate_and_inject_mapping, ) +from cumulusci.tasks.bulkdata.select_utils import SelectStrategy from cumulusci.tasks.bulkdata.step import DataApi, DataOperationType from cumulusci.tests.util import DummyOrgConfig, mock_describe_calls @@ -213,6 +214,41 @@ def test_get_relative_date_e2e(self): date.today(), ) + def test_select_options__success(self): + base_path = Path(__file__).parent / "mapping_select.yml" + result = parse_from_yaml(base_path) + + step = result["Select Accounts"] + select_options = step.select_options + assert select_options + assert select_options.strategy == SelectStrategy.SIMILARITY + assert select_options.filter == "WHEN Name in ('Sample Account')" + assert select_options.priority_fields + + def test_select_options__invalid_strategy(self): + base_path = Path(__file__).parent / "mapping_select_invalid_strategy.yml" + with pytest.raises(ValueError) as e: + parse_from_yaml(base_path) + assert "Invalid strategy value: invalid_strategy" in str(e.value) + + def test_select_options__missing_priority_fields(self): + base_path = Path(__file__).parent / "mapping_select_missing_priority_fields.yml" + with pytest.raises(ValueError) as e: + parse_from_yaml(base_path) + print(str(e.value)) + assert ( + "Priority fields {'Email'} are not present in 'fields' or 'lookups'" + in str(e.value) + ) + + def test_select_options__no_priority_fields(self): + base_path = Path(__file__).parent / "mapping_select_no_priority_fields.yml" + result = parse_from_yaml(base_path) + + step = result["Select Accounts"] + select_options = step.select_options + assert select_options.priority_fields == {} + # Start of FLS/Namespace Injection Unit Tests def test_is_injectable(self): diff --git a/cumulusci/tasks/bulkdata/tests/test_query_db_joins_lookups.sql b/cumulusci/tasks/bulkdata/tests/test_query_db_joins_lookups.sql index 113e5cebe5..ed7f0e694a 100644 --- a/cumulusci/tasks/bulkdata/tests/test_query_db_joins_lookups.sql +++ b/cumulusci/tasks/bulkdata/tests/test_query_db_joins_lookups.sql @@ -1,13 +1,23 @@ BEGIN TRANSACTION; +CREATE TABLE "accounts" ( + id VARCHAR(255) NOT NULL, + "Name" VARCHAR(255), + "AccountNumber" VARCHAR(255), + PRIMARY KEY (id) +); +INSERT INTO "accounts" VALUES("Account-1",'Bluth Company','123456'); +INSERT INTO "accounts" VALUES("Account-2",'Sampson PLC','567890'); + CREATE TABLE "contacts" ( id VARCHAR(255) NOT NULL, "FirstName" VARCHAR(255), - "LastName" VARCHAR(255), + "LastName" VARCHAR(255), + "AccountId" VARCHAR(255), PRIMARY KEY (id) ); -INSERT INTO "contacts" VALUES("Contact-1",'Alpha','gamma'); -INSERT INTO "contacts" VALUES("Contact-2",'Temp','Bluth'); +INSERT INTO "contacts" VALUES("Contact-1",'Alpha','gamma', 'Account-2'); +INSERT INTO "contacts" VALUES("Contact-2",'Temp','Bluth', 'Account-1'); CREATE TABLE "events" ( id VARCHAR(255) NOT NULL, diff --git a/cumulusci/tasks/bulkdata/tests/test_query_db_joins_lookups_select.yml b/cumulusci/tasks/bulkdata/tests/test_query_db_joins_lookups_select.yml new file mode 100644 index 0000000000..4b37f491eb --- /dev/null +++ b/cumulusci/tasks/bulkdata/tests/test_query_db_joins_lookups_select.yml @@ -0,0 +1,48 @@ +Insert Account: + sf_object: Account + table: accounts + api: rest + fields: + - Name + - AccountNumber + +Insert Lead: + sf_object: Lead + table: leads + api: bulk + fields: + - LastName + +Select Contact: + sf_object: Contact + table: contacts + api: bulk + action: select + select_options: + strategy: similarity + priority_fields: + - FirstName + - AccountId + fields: + - FirstName + - LastName + lookups: + AccountId: + table: accounts + +Select Event: + sf_object: Event + table: events + api: rest + action: select + select_options: + strategy: similarity + priority_fields: + - WhoId + fields: + - Subject + lookups: + WhoId: + table: + - contacts + - leads diff --git a/cumulusci/tasks/bulkdata/tests/test_select_utils.py b/cumulusci/tasks/bulkdata/tests/test_select_utils.py index 26768d4ea1..4969722c6e 100644 --- a/cumulusci/tasks/bulkdata/tests/test_select_utils.py +++ b/cumulusci/tasks/bulkdata/tests/test_select_utils.py @@ -4,6 +4,7 @@ from cumulusci.tasks.bulkdata.select_utils import ( SelectOperationExecutor, SelectStrategy, + add_limit_offset_to_user_filter, annoy_post_process, calculate_levenshtein_distance, determine_field_types, @@ -21,7 +22,7 @@ def test_standard_generate_query_with_default_record_declaration(): limit = 5 offset = 2 query, fields = select_operator.select_generate_query( - sobject=sobject, fields=[], limit=limit, offset=offset + sobject=sobject, fields=[], user_filter="", limit=limit, offset=offset ) assert "WHERE" in query # Ensure WHERE clause is included @@ -36,7 +37,7 @@ def test_standard_generate_query_without_default_record_declaration(): limit = 3 offset = None query, fields = select_operator.select_generate_query( - sobject=sobject, fields=[], limit=limit, offset=offset + sobject=sobject, fields=[], user_filter="", limit=limit, offset=offset ) assert "WHERE" not in query # No WHERE clause should be present @@ -45,6 +46,23 @@ def test_standard_generate_query_without_default_record_declaration(): assert fields == ["Id"] +def test_standard_generate_query_with_user_filter(): + select_operator = SelectOperationExecutor(SelectStrategy.STANDARD) + sobject = "Contact" # Assuming no declaration for this object + limit = 3 + offset = None + user_filter = "WHERE Name IN ('Sample Contact')" + query, fields = select_operator.select_generate_query( + sobject=sobject, fields=[], user_filter=user_filter, limit=limit, offset=offset + ) + + assert "WHERE" in query + assert "Sample Contact" in query + assert "LIMIT" in query + assert "OFFSET" not in query + assert fields == ["Id"] + + # Test Cases for random generate query def test_random_generate_query_with_default_record_declaration(): select_operator = SelectOperationExecutor(SelectStrategy.RANDOM) @@ -52,7 +70,7 @@ def test_random_generate_query_with_default_record_declaration(): limit = 5 offset = 2 query, fields = select_operator.select_generate_query( - sobject=sobject, fields=[], limit=limit, offset=offset + sobject=sobject, fields=[], user_filter="", limit=limit, offset=offset ) assert "WHERE" in query # Ensure WHERE clause is included @@ -67,7 +85,7 @@ def test_random_generate_query_without_default_record_declaration(): limit = 3 offset = None query, fields = select_operator.select_generate_query( - sobject=sobject, fields=[], limit=limit, offset=offset + sobject=sobject, fields=[], user_filter="", limit=limit, offset=offset ) assert "WHERE" not in query # No WHERE clause should be present @@ -83,7 +101,7 @@ def test_standard_post_process_with_records(): num_records = 3 sobject = "Contact" selected_records, error_message = select_operator.select_post_process( - None, records, num_records, sobject + None, records, num_records, sobject, weights=[] ) assert error_message is None @@ -99,7 +117,7 @@ def test_standard_post_process_with_fewer_records(): num_records = 3 sobject = "Opportunity" selected_records, error_message = select_operator.select_post_process( - None, records, num_records, sobject + None, records, num_records, sobject, weights=[] ) assert error_message is None @@ -116,7 +134,7 @@ def test_standard_post_process_with_no_records(): num_records = 2 sobject = "Lead" selected_records, error_message = select_operator.select_post_process( - None, records, num_records, sobject + None, records, num_records, sobject, weights=[] ) assert selected_records == [] @@ -130,7 +148,7 @@ def test_random_post_process_with_records(): num_records = 3 sobject = "Contact" selected_records, error_message = select_operator.select_post_process( - None, records, num_records, sobject + None, records, num_records, sobject, weights=[] ) assert error_message is None @@ -145,7 +163,7 @@ def test_random_post_process_with_no_records(): num_records = 2 sobject = "Lead" selected_records, error_message = select_operator.select_post_process( - None, records, num_records, sobject + None, records, num_records, sobject, weights=[] ) assert selected_records == [] @@ -159,7 +177,7 @@ def test_similarity_generate_query_with_default_record_declaration(): limit = 5 offset = 2 query, fields = select_operator.select_generate_query( - sobject, ["Name"], limit, offset + sobject, ["Name"], [], limit, offset ) assert "WHERE" in query # Ensure WHERE clause is included @@ -174,7 +192,7 @@ def test_similarity_generate_query_without_default_record_declaration(): limit = 3 offset = None query, fields = select_operator.select_generate_query( - sobject, ["Name"], limit, offset + sobject, ["Name"], [], limit, offset ) assert "WHERE" not in query # No WHERE clause should be present @@ -183,6 +201,59 @@ def test_similarity_generate_query_without_default_record_declaration(): assert "OFFSET" not in query +def test_similarity_generate_query_with_nested_fields(): + select_operator = SelectOperationExecutor(SelectStrategy.SIMILARITY) + sobject = "Event" # Assuming no declaration for this object + limit = 3 + offset = None + fields = [ + "Subject", + "Who.Contact.Name", + "Who.Contact.Email", + "Who.Lead.Name", + "Who.Lead.Company", + ] + query, query_fields = select_operator.select_generate_query( + sobject, fields, [], limit, offset + ) + + assert "WHERE" not in query # No WHERE clause should be present + assert query_fields == [ + "Id", + "Subject", + "Who.Contact.Name", + "Who.Contact.Email", + "Who.Lead.Name", + "Who.Lead.Company", + ] + assert f"LIMIT {limit}" in query + assert "TYPEOF Who" in query + assert "WHEN Contact" in query + assert "WHEN Lead" in query + assert "OFFSET" not in query + + +def test_random_generate_query_with_user_filter(): + select_operator = SelectOperationExecutor(SelectStrategy.SIMILARITY) + sobject = "Contact" # Assuming no declaration for this object + limit = 3 + offset = None + user_filter = "WHERE Name IN ('Sample Contact')" + query, fields = select_operator.select_generate_query( + sobject=sobject, + fields=["Name"], + user_filter=user_filter, + limit=limit, + offset=offset, + ) + + assert "WHERE" in query + assert "Sample Contact" in query + assert "LIMIT" in query + assert "OFFSET" not in query + assert fields == ["Id", "Name"] + + def test_levenshtein_distance(): assert levenshtein_distance("kitten", "kitten") == 0 # Identical strings assert levenshtein_distance("kitten", "sitten") == 1 # One substitution @@ -284,7 +355,7 @@ def test_similarity_post_process_with_no_records(): num_records = 2 sobject = "Lead" selected_records, error_message = select_operator.select_post_process( - None, records, num_records, sobject + None, records, num_records, sobject, weights=[1, 1, 1] ) assert selected_records == [] @@ -305,6 +376,34 @@ def test_calculate_levenshtein_distance_basic(): expected_distance ), "Basic distance calculation failed." + # Empty fields + record1 = ["hello", ""] + record2 = ["hullo", ""] + weights = [1.0, 1.0] + + # Expected distance based on simple Levenshtein distances + # Levenshtein("hello", "hullo") = 1, Levenshtein("", "") = 0 + expected_distance = (1 * 1.0 + 0 * 1.0) / 2 # Averaged over two fields + + result = calculate_levenshtein_distance(record1, record2, weights) + assert result == pytest.approx( + expected_distance + ), "Basic distance calculation with empty fields failed." + + # Partial empty fields + record1 = ["hello", "world"] + record2 = ["hullo", ""] + weights = [1.0, 1.0] + + # Expected distance based on simple Levenshtein distances + # Levenshtein("hello", "hullo") = 1, Levenshtein("world", "") = 5 + expected_distance = (1 * 1.0 + 5 * 0.05 * 1.0) / 2 # Averaged over two fields + + result = calculate_levenshtein_distance(record1, record2, weights) + assert result == pytest.approx( + expected_distance + ), "Basic distance calculation with partial empty fields failed." + def test_calculate_levenshtein_distance_weighted(): record1 = ["cat", "dog"] @@ -320,6 +419,26 @@ def test_calculate_levenshtein_distance_weighted(): ), "Weighted distance calculation failed." +def test_calculate_levenshtein_distance_records_length_doesnt_match(): + record1 = ["cat", "dog", "cow"] + record2 = ["bat", "fog"] + weights = [2.0, 0.5] + + with pytest.raises(ValueError) as e: + calculate_levenshtein_distance(record1, record2, weights) + assert "Records must have the same number of fields." in str(e.value) + + +def test_calculate_levenshtein_distance_weights_length_doesnt_match(): + record1 = ["cat", "dog"] + record2 = ["bat", "fog"] + weights = [2.0, 0.5, 3.0] + + with pytest.raises(ValueError) as e: + calculate_levenshtein_distance(record1, record2, weights) + assert "Records must be same size as fields (weights)." in str(e.value) + + def test_replace_empty_strings_with_missing(): # Case 1: Normal case with some empty strings records = [ @@ -419,11 +538,11 @@ def test_mixed_types(): assert determine_field_types(df, weights) == expected_output -def test_vectorize_records_mixed_numerical_categorical(): +def test_vectorize_records_mixed_numerical_boolean_categorical(): # Test data with mixed types: numerical and categorical only - db_records = [["1.0", "apple"], ["2.0", "banana"]] - query_records = [["1.5", "apple"], ["2.5", "cherry"]] - weights = [1.0, 1.0] # Equal weights for numerical and categorical columns + db_records = [["1.0", "true", "apple"], ["2.0", "false", "banana"]] + query_records = [["1.5", "true", "apple"], ["2.5", "false", "cherry"]] + weights = [1.0, 1.0, 1.0] # Equal weights for numerical and categorical columns hash_features = 4 # Number of hashing vectorizer features for categorical columns final_db_vectors, final_query_vectors = vectorize_records( @@ -437,7 +556,7 @@ def test_vectorize_records_mixed_numerical_categorical(): ), "Query vectors row count mismatch" # Expected dimensions: numerical (1) + categorical hashed features (4) - expected_feature_count = 1 + hash_features + expected_feature_count = 2 + hash_features assert ( final_db_vectors.shape[1] == expected_feature_count ), "DB vectors column count mismatch" @@ -478,3 +597,59 @@ def test_single_record_match_annoy_post_process(): assert len(closest_records) == 2 assert closest_records[0]["id"] == "q1" assert error is None + + +@pytest.mark.parametrize( + "filter_clause, limit_clause, offset_clause, expected", + [ + # Test: No existing LIMIT/OFFSET and no new clauses + ("SELECT * FROM users", None, None, " SELECT * FROM users"), + # Test: Existing LIMIT and no new limit provided + ("SELECT * FROM users LIMIT 100", None, None, "SELECT * FROM users LIMIT 100"), + # Test: Existing OFFSET and no new offset provided + ("SELECT * FROM users OFFSET 20", None, None, "SELECT * FROM users OFFSET 20"), + # Test: Existing LIMIT/OFFSET and new clauses provided + ( + "SELECT * FROM users LIMIT 100 OFFSET 20", + 50, + 10, + "SELECT * FROM users LIMIT 50 OFFSET 30", + ), + # Test: Existing LIMIT, new limit larger than existing (should keep the smaller one) + ("SELECT * FROM users LIMIT 100", 150, None, "SELECT * FROM users LIMIT 100"), + # Test: New limit smaller than existing (should use the new one) + ("SELECT * FROM users LIMIT 100", 50, None, "SELECT * FROM users LIMIT 50"), + # Test: Existing OFFSET, adding a new offset (should sum the offsets) + ("SELECT * FROM users OFFSET 20", None, 30, "SELECT * FROM users OFFSET 50"), + # Test: Existing LIMIT/OFFSET and new values set to None + ( + "SELECT * FROM users LIMIT 100 OFFSET 20", + None, + None, + "SELECT * FROM users LIMIT 100 OFFSET 20", + ), + # Test: Removing existing LIMIT and adding a new one + ("SELECT * FROM users LIMIT 200", 50, None, "SELECT * FROM users LIMIT 50"), + # Test: Removing existing OFFSET and adding a new one + ("SELECT * FROM users OFFSET 40", None, 20, "SELECT * FROM users OFFSET 60"), + # Edge case: Filter clause with mixed cases + ( + "SELECT * FROM users LiMiT 100 oFfSeT 20", + 50, + 10, + "SELECT * FROM users LIMIT 50 OFFSET 30", + ), + # Test: Filter clause with trailing/leading spaces + ( + " SELECT * FROM users LIMIT 100 OFFSET 20 ", + 50, + 10, + "SELECT * FROM users LIMIT 50 OFFSET 30", + ), + ], +) +def test_add_limit_offset_to_user_filter( + filter_clause, limit_clause, offset_clause, expected +): + result = add_limit_offset_to_user_filter(filter_clause, limit_clause, offset_clause) + assert result.strip() == expected.strip() diff --git a/cumulusci/tasks/bulkdata/tests/test_step.py b/cumulusci/tasks/bulkdata/tests/test_step.py index da13a9a8eb..bd059b9bbf 100644 --- a/cumulusci/tasks/bulkdata/tests/test_step.py +++ b/cumulusci/tasks/bulkdata/tests/test_step.py @@ -1,14 +1,17 @@ import io import json +from itertools import tee from unittest import mock import pytest import responses -from cumulusci.core.exceptions import BulkDataException, SOQLQueryException +from cumulusci.core.exceptions import BulkDataException from cumulusci.tasks.bulkdata.load import LoadData from cumulusci.tasks.bulkdata.select_utils import SelectStrategy from cumulusci.tasks.bulkdata.step import ( + HIGH_PRIORITY_VALUE, + LOW_PRIORITY_VALUE, BulkApiDmlOperation, BulkApiQueryOperation, BulkJobMixin, @@ -19,7 +22,10 @@ DataOperationType, RestApiDmlOperation, RestApiQueryOperation, + assign_weights, download_file, + extract_flattened_headers, + flatten_record, get_dml_operation, get_query_operation, ) @@ -546,6 +552,7 @@ def test_select_records_standard_strategy_success(self, download_mock): context=context, fields=["LastName"], selection_strategy=SelectStrategy.STANDARD, + content_type="JSON", ) # Mock Bulk API responses @@ -555,10 +562,7 @@ def test_select_records_standard_strategy_success(self, download_mock): step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] # Mock the downloaded CSV content with a single record - download_mock.return_value = io.StringIO( - """Id -003000000000001""" - ) + download_mock.return_value = io.StringIO('[{"Id":"003000000000001"}]') # Mock the _wait_for_job method to simulate a successful job step._wait_for_job = mock.Mock() @@ -607,7 +611,7 @@ def test_select_records_standard_strategy_failure__no_records(self, download_moc step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] # Mock the downloaded CSV content indicating no records found - download_mock.return_value = io.StringIO("""Records not found for this query""") + download_mock.return_value = io.StringIO("[]") # Mock the _wait_for_job method to simulate a successful job step._wait_for_job = mock.Mock() @@ -654,51 +658,34 @@ def test_select_records_user_selection_filter_success(self, download_mock): step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] # Mock the downloaded CSV content with a single record - download_mock.return_value = io.StringIO( - """Id -003000000000001 -003000000000002 -003000000000003""" + download_mock.return_value = io.StringIO('[{"Id":"003000000000001"}]') + + # Mock the _wait_for_job method to simulate a successful job + step._wait_for_job = mock.Mock() + step._wait_for_job.return_value = DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 0, 0 ) - # Mock the query operation - with mock.patch( - "cumulusci.tasks.bulkdata.step.get_query_operation" - ) as query_operation_mock: - query_operation_mock.return_value = mock.Mock() - query_operation_mock.return_value.query = mock.Mock() - query_operation_mock.return_value.get_results = mock.Mock() - query_operation_mock.return_value.get_results.return_value = [ - ["003000000000001"] - ] - # Mock the _wait_for_job method to simulate a successful job - step._wait_for_job = mock.Mock() - step._wait_for_job.return_value = DataOperationJobResult( - DataOperationStatus.SUCCESS, [], 0, 0 - ) + # Prepare input records + records = iter([["Test1"], ["Test2"], ["Test3"]]) - # Prepare input records - records = iter([["Test1"], ["Test2"], ["Test3"]]) + # Execute the select_records operation + step.start() + step.select_records(records) + step.end() - # Execute the select_records operation - step.start() - step.select_records(records) - step.end() - - # Get the results and assert their properties - results = list(step.get_results()) - assert ( - len(results) == 3 - ) # Expect 3 results (matching the input records count) - # Assert that all results have the expected ID, success, and created values - assert ( - results.count( - DataOperationResult( - id="003000000000001", success=True, error="", created=False - ) + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False ) - == 3 ) + == 3 + ) @mock.patch("cumulusci.tasks.bulkdata.step.download_file") def test_select_records_user_selection_filter_order_success(self, download_mock): @@ -722,47 +709,29 @@ def test_select_records_user_selection_filter_order_success(self, download_mock) # Mock the downloaded CSV content with a single record download_mock.return_value = io.StringIO( - """Id -003000000000001 -003000000000002 -003000000000003""" + '[{"Id":"003000000000003"}, {"Id":"003000000000001"}, {"Id":"003000000000002"}]' + ) + # Mock the _wait_for_job method to simulate a successful job + step._wait_for_job = mock.Mock() + step._wait_for_job.return_value = DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 0, 0 ) - # Mock the query operation - with mock.patch( - "cumulusci.tasks.bulkdata.step.get_query_operation" - ) as query_operation_mock: - query_operation_mock.return_value = mock.Mock() - query_operation_mock.return_value.query = mock.Mock() - query_operation_mock.return_value.get_results = mock.Mock() - query_operation_mock.return_value.get_results.return_value = [ - ["003000000000003"], - ["003000000000001"], - ["003000000000002"], - ] - # Mock the _wait_for_job method to simulate a successful job - step._wait_for_job = mock.Mock() - step._wait_for_job.return_value = DataOperationJobResult( - DataOperationStatus.SUCCESS, [], 0, 0 - ) + # Prepare input records + records = iter([["Test1"], ["Test2"], ["Test3"]]) - # Prepare input records - records = iter([["Test1"], ["Test2"], ["Test3"]]) + # Execute the select_records operation + step.start() + step.select_records(records) + step.end() - # Execute the select_records operation - step.start() - step.select_records(records) - step.end() - - # Get the results and assert their properties - results = list(step.get_results()) - assert ( - len(results) == 3 - ) # Expect 3 results (matching the input records count) - # Assert that all results are in the order given by user query - assert results[0].id == "003000000000003" - assert results[1].id == "003000000000001" - assert results[2].id == "003000000000002" + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results are in the order given by user query + assert results[0].id == "003000000000003" + assert results[1].id == "003000000000001" + assert results[2].id == "003000000000002" @mock.patch("cumulusci.tasks.bulkdata.step.download_file") def test_select_records_user_selection_filter_failure(self, download_mock): @@ -785,29 +754,14 @@ def test_select_records_user_selection_filter_failure(self, download_mock): step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] # Mock the downloaded CSV content with a single record - download_mock.return_value = io.StringIO( - """Id -003000000000001 -003000000000002 -003000000000003""" - ) - # Mock the query operation - with mock.patch( - "cumulusci.tasks.bulkdata.step.get_query_operation" - ) as query_operation_mock: - query_operation_mock.return_value = mock.Mock() - query_operation_mock.return_value.query = mock.Mock() - query_operation_mock.return_value.query.side_effect = BulkDataException( - "MALFORMED QUERY" - ) - - # Prepare input records - records = iter([["Test1"], ["Test2"], ["Test3"]]) + download_mock.side_effect = BulkDataException("MALFORMED QUERY") + # Prepare input records + records = iter([["Test1"], ["Test2"], ["Test3"]]) - # Execute the select_records operation - step.start() - with pytest.raises(BulkDataException): - step.select_records(records) + # Execute the select_records operation + step.start() + with pytest.raises(BulkDataException): + step.select_records(records) @mock.patch("cumulusci.tasks.bulkdata.step.download_file") def test_select_records_similarity_strategy_success(self, download_mock): @@ -818,7 +772,7 @@ def test_select_records_similarity_strategy_success(self, download_mock): operation=DataOperationType.QUERY, api_options={"batch_size": 10, "update_key": "LastName"}, context=context, - fields=["Id", "Name", "Email"], + fields=["Name", "Email"], selection_strategy=SelectStrategy.SIMILARITY, ) @@ -830,10 +784,7 @@ def test_select_records_similarity_strategy_success(self, download_mock): # Mock the downloaded CSV content with a single record download_mock.return_value = io.StringIO( - """Id,Name,Email -003000000000001,Jawad,mjawadtp@example.com -003000000000002,Aditya,aditya@example.com -003000000000003,Tom,tom@example.com""" + """[{"Id":"003000000000001", "Name":"Jawad", "Email":"mjawadtp@example.com"}, {"Id":"003000000000002", "Name":"Aditya", "Email":"aditya@example.com"}, {"Id":"003000000000003", "Name":"Tom", "Email":"tom@example.com"}]""" ) # Mock the _wait_for_job method to simulate a successful job @@ -908,7 +859,7 @@ def test_select_records_similarity_strategy_failure__no_records( step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] # Mock the downloaded CSV content indicating no records found - download_mock.return_value = io.StringIO("""Records not found for this query""") + download_mock.return_value = io.StringIO("[]") # Mock the _wait_for_job method to simulate a successful job step._wait_for_job = mock.Mock() @@ -940,6 +891,214 @@ def test_select_records_similarity_strategy_failure__no_records( assert job_result.records_processed == 0 assert job_result.total_row_errors == 0 + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_similarity_strategy_parent_level_records__polymorphic( + self, download_mock + ): + mock_describe_calls() + # Set up mock context and BulkApiDmlOperation + context = mock.Mock() + step = BulkApiDmlOperation( + sobject="Event", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10}, + context=context, + fields=[ + "Subject", + "Who.Contact.Name", + "Who.Contact.Email", + "Who.Lead.Name", + "Who.Lead.Company", + ], + selection_strategy=SelectStrategy.SIMILARITY, + ) + + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + download_mock.return_value = io.StringIO( + """[ + {"Id": "003000000000001", "Subject": "Sample Event 1", "Who":{ "attributes": {"type": "Contact"}, "Name": "Sample Contact", "Email": "contact@example.com"}}, + { "Id": "003000000000002", "Subject": "Sample Event 2", "Who":{ "attributes": {"type": "Lead"}, "Name": "Sample Lead", "Company": "Salesforce"}} + ]""" + ) + + records = iter( + [ + ["Sample Event 1", "Sample Contact", "contact@example.com", "", ""], + ["Sample Event 2", "", "", "Sample Lead", "Salesforce"], + ] + ) + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 2 # Expect 2 results (matching the input records count) + + # Assert that all results have the expected ID, success, and created values + assert results[0] == DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + assert results[1] == DataOperationResult( + id="003000000000002", success=True, error="", created=False + ) + + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_similarity_strategy_parent_level_records__non_polymorphic( + self, download_mock + ): + mock_describe_calls() + # Set up mock context and BulkApiDmlOperation + context = mock.Mock() + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10}, + context=context, + fields=["Name", "Account.Name", "Account.AccountNumber"], + selection_strategy=SelectStrategy.SIMILARITY, + ) + + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + download_mock.return_value = io.StringIO( + """[ + {"Id": "003000000000001", "Name": "Sample Contact 1", "Account":{ "attributes": {"type": "Account"}, "Name": "Sample Account", "AccountNumber": 123456}}, + { "Id": "003000000000002", "Subject": "Sample Contact 2", "Account": null} + ]""" + ) + + records = iter( + [ + ["Sample Contact 3", "Sample Account", "123456"], + ["Sample Contact 4", "", ""], + ] + ) + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 2 # Expect 2 results (matching the input records count) + + # Assert that all results have the expected ID, success, and created values + assert results[0] == DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + assert results[1] == DataOperationResult( + id="003000000000002", success=True, error="", created=False + ) + + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_similarity_strategy_priority_fields(self, download_mock): + mock_describe_calls() + # Set up mock context and BulkApiDmlOperation + context = mock.Mock() + step_1 = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10}, + context=context, + fields=["Name", "Email", "Account.Name", "Account.AccountNumber"], + selection_strategy=SelectStrategy.SIMILARITY, + selection_priority_fields={"Name": "Name", "Email": "Email"}, + ) + + step_2 = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10}, + context=context, + fields=["Name", "Email", "Account.Name", "Account.AccountNumber"], + selection_strategy=SelectStrategy.SIMILARITY, + selection_priority_fields={ + "Account.Name": "Account.Name", + "Account.AccountNumber": "Account.AccountNumber", + }, + ) + + # Mock Bulk API responses + step_1.bulk.endpoint = "https://test" + step_1.bulk.create_query_job.return_value = "JOB" + step_1.bulk.query.return_value = "BATCH" + step_1.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + step_2.bulk.endpoint = "https://test" + step_2.bulk.create_query_job.return_value = "JOB" + step_2.bulk.query.return_value = "BATCH" + step_2.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + sample_response = [ + { + "Id": "003000000000001", + "Name": "Bob The Builder", + "Email": "bob@yahoo.org", + "Account": { + "attributes": {"type": "Account"}, + "Name": "Jawad TP", + "AccountNumber": 567890, + }, + }, + { + "Id": "003000000000002", + "Name": "Tom Cruise", + "Email": "tom@exmaple.com", + "Account": { + "attributes": {"type": "Account"}, + "Name": "Aditya B", + "AccountNumber": 123456, + }, + }, + ] + + download_mock.side_effect = [ + io.StringIO(f"""{json.dumps(sample_response)}"""), + io.StringIO(f"""{json.dumps(sample_response)}"""), + ] + + records = iter( + [ + ["Bob The Builder", "bob@yahoo.org", "Aditya B", "123456"], + ] + ) + records_1, records_2 = tee(records) + step_1.start() + step_1.select_records(records_1) + step_1.end() + + step_2.start() + step_2.select_records(records_2) + step_2.end() + + # Get the results and assert their properties + results_1 = list(step_1.get_results()) + results_2 = list(step_2.get_results()) + assert ( + len(results_1) == 1 + ) # Expect 1 results (matching the input records count) + assert ( + len(results_2) == 1 + ) # Expect 1 results (matching the input records count) + + # Assert that all results have the expected ID, success, and created values + # Prioritizes Name and Email + assert results_1[0] == DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + # Prioritizes Account.Name and Account.AccountNumber + assert results_2[0] == DataOperationResult( + id="003000000000002", success=True, error="", created=False + ) + def test_batch(self): context = mock.Mock() @@ -1344,10 +1503,101 @@ def test_select_records_standard_strategy_success(self): assert ( results.count( DataOperationResult( - id="003000000000001", success=True, error="", created=False + id="003000000000001", success=True, error="", created=False + ) + ) + == 3 + ) + + @responses.activate + def test_select_records_standard_strategy_success_pagination(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.UPSERT, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=task, + fields=["LastName"], + selection_strategy=SelectStrategy.STANDARD, + ) + + # Set up pagination: First call returns done=False, second call returns done=True + step.sf.restful = mock.Mock( + side_effect=[ + { + "records": [{"Id": "003000000000001"}, {"Id": "003000000000002"}], + "done": False, # Pagination in progress + "nextRecordsUrl": "/services/data/vXX.X/query/next-records", + }, + ] + ) + + step.sf.query_more = mock.Mock( + side_effect=[ + {"records": [{"Id": "003000000000003"}], "done": True} # Final page + ] + ) + + records = iter([["Test1"], ["Test2"], ["Test3"]]) + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000002", success=True, error="", created=False + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000003", success=True, error="", created=False ) ) - == 3 + == 1 ) @responses.activate @@ -1448,28 +1698,10 @@ def test_select_records_user_selection_filter_success(self): ) results = { - "compositeResponse": [ - { - "body": { - "records": [ - {"Id": "003000000000001"}, - {"Id": "003000000000002"}, - {"Id": "003000000000003"}, - ] - }, - "referenceId": "select_query", - "httpStatusCode": 200, - }, - { - "body": { - "records": [ - {"Id": "003000000000001"}, - ] - }, - "referenceId": "user_query", - "httpStatusCode": 200, - }, - ] + "records": [ + {"Id": "003000000000001"}, + ], + "done": True, } step.sf.restful = mock.Mock() step.sf.restful.return_value = results @@ -1532,30 +1764,12 @@ def test_select_records_user_selection_filter_order_success(self): ) results = { - "compositeResponse": [ - { - "body": { - "records": [ - {"Id": "003000000000001"}, - {"Id": "003000000000002"}, - {"Id": "003000000000003"}, - ] - }, - "referenceId": "select_query", - "httpStatusCode": 200, - }, - { - "body": { - "records": [ - {"Id": "003000000000003"}, - {"Id": "003000000000001"}, - {"Id": "003000000000002"}, - ] - }, - "referenceId": "user_query", - "httpStatusCode": 200, - }, - ] + "records": [ + {"Id": "003000000000003"}, + {"Id": "003000000000001"}, + {"Id": "003000000000002"}, + ], + "done": True, } step.sf.restful = mock.Mock() step.sf.restful.return_value = results @@ -1612,38 +1826,12 @@ def test_select_records_user_selection_filter_failure(self): selection_filter="MALFORMED FILTER", # Applying malformed filter ) - results = { - "compositeResponse": [ - { - "body": { - "records": [ - {"Id": "003000000000001"}, - {"Id": "003000000000002"}, - {"Id": "003000000000003"}, - ] - }, - "referenceId": "select_query", - "httpStatusCode": 200, - }, - { - "body": [ - { - "message": "Error in MALFORMED FILTER", - "errorCode": "MALFORMED QUERY", - } - ], - "referenceId": "user_query", - "httpStatusCode": 400, - }, - ] - } step.sf.restful = mock.Mock() - step.sf.restful.return_value = results + step.sf.restful.side_effect = Exception("MALFORMED QUERY") records = iter([["Test1"], ["Test2"], ["Test3"]]) step.start() - with pytest.raises(SOQLQueryException) as e: + with pytest.raises(Exception): step.select_records(records) - assert "MALFORMED QUERY" in str(e.value) @responses.activate def test_select_records_similarity_strategy_success(self): @@ -1680,7 +1868,7 @@ def test_select_records_similarity_strategy_success(self): operation=DataOperationType.UPSERT, api_options={"batch_size": 10, "update_key": "LastName"}, context=task, - fields=["Id", "Name", "Email"], + fields=["Name", "Email"], selection_strategy=SelectStrategy.SIMILARITY, ) @@ -1812,6 +2000,318 @@ def test_select_records_similarity_strategy_failure__no_records(self): assert job_result.records_processed == 0 assert job_result.total_row_errors == 0 + @responses.activate + def test_select_records_similarity_strategy_parent_level_records__polymorphic(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + step = RestApiDmlOperation( + sobject="Event", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10}, + context=task, + fields=[ + "Subject", + "Who.Contact.Name", + "Who.Contact.Email", + "Who.Lead.Name", + "Who.Lead.Company", + ], + selection_strategy=SelectStrategy.SIMILARITY, + ) + + step.sf.restful = mock.Mock( + side_effect=[ + { + "records": [ + { + "Id": "003000000000001", + "Subject": "Sample Event 1", + "Who": { + "attributes": {"type": "Contact"}, + "Name": "Sample Contact", + "Email": "contact@example.com", + }, + }, + { + "Id": "003000000000002", + "Subject": "Sample Event 2", + "Who": { + "attributes": {"type": "Lead"}, + "Name": "Sample Lead", + "Company": "Salesforce", + }, + }, + ], + "done": True, + }, + ] + ) + + records = iter( + [ + ["Sample Event 1", "Sample Contact", "contact@example.com", "", ""], + ["Sample Event 2", "", "", "Sample Lead", "Salesforce"], + ] + ) + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 2 # Expect 2 results (matching the input records count) + + # Assert that all results have the expected ID, success, and created values + assert results[0] == DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + assert results[1] == DataOperationResult( + id="003000000000002", success=True, error="", created=False + ) + + @responses.activate + def test_select_records_similarity_strategy_parent_level_records__non_polymorphic( + self, + ): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10}, + context=task, + fields=["Name", "Account.Name", "Account.AccountNumber"], + selection_strategy=SelectStrategy.SIMILARITY, + ) + + step.sf.restful = mock.Mock( + side_effect=[ + { + "records": [ + { + "Id": "003000000000001", + "Name": "Sample Contact 1", + "Account": { + "attributes": {"type": "Account"}, + "Name": "Sample Account", + "AccountNumber": 123456, + }, + }, + { + "Id": "003000000000002", + "Name": "Sample Contact 2", + "Account": None, + }, + ], + "done": True, + }, + ] + ) + + records = iter( + [ + ["Sample Contact 3", "Sample Account", "123456"], + ["Sample Contact 4", "", ""], + ] + ) + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 2 # Expect 2 results (matching the input records count) + + # Assert that all results have the expected ID, success, and created values + assert results[0] == DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + assert results[1] == DataOperationResult( + id="003000000000002", success=True, error="", created=False + ) + + @responses.activate + def test_select_records_similarity_strategy_priority_fields(self): + mock_describe_calls() + task_1 = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task_1.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task_1._init_task() + + task_2 = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task_2.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task_2._init_task() + + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + ], + status=200, + ) + responses.add( + responses.POST, + url=f"https://example.com/services/data/v{CURRENT_SF_API_VERSION}/composite/sobjects", + json=[{"id": "003000000000003", "success": True}], + status=200, + ) + step_1 = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10}, + context=task_1, + fields=["Name", "Email", "Account.Name", "Account.AccountNumber"], + selection_strategy=SelectStrategy.SIMILARITY, + selection_priority_fields={"Name": "Name", "Email": "Email"}, + ) + + step_2 = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10}, + context=task_2, + fields=["Name", "Email", "Account.Name", "Account.AccountNumber"], + selection_strategy=SelectStrategy.SIMILARITY, + selection_priority_fields={ + "Account.Name": "Account.Name", + "Account.AccountNumber": "Account.AccountNumber", + }, + ) + + sample_response = [ + { + "records": [ + { + "Id": "003000000000001", + "Name": "Bob The Builder", + "Email": "bob@yahoo.org", + "Account": { + "attributes": {"type": "Account"}, + "Name": "Jawad TP", + "AccountNumber": 567890, + }, + }, + { + "Id": "003000000000002", + "Name": "Tom Cruise", + "Email": "tom@exmaple.com", + "Account": { + "attributes": {"type": "Account"}, + "Name": "Aditya B", + "AccountNumber": 123456, + }, + }, + ], + "done": True, + }, + ] + + step_1.sf.restful = mock.Mock(side_effect=sample_response) + step_2.sf.restful = mock.Mock(side_effect=sample_response) + + records = iter( + [ + ["Bob The Builder", "bob@yahoo.org", "Aditya B", "123456"], + ] + ) + records_1, records_2 = tee(records) + step_1.start() + step_1.select_records(records_1) + step_1.end() + + step_2.start() + step_2.select_records(records_2) + step_2.end() + + # Get the results and assert their properties + results_1 = list(step_1.get_results()) + results_2 = list(step_2.get_results()) + assert ( + len(results_1) == 1 + ) # Expect 1 results (matching the input records count) + assert ( + len(results_2) == 1 + ) # Expect 1 results (matching the input records count) + + # Assert that all results have the expected ID, success, and created values + # Prioritizes Name and Email + assert results_1[0] == DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + # Prioritizes Account.Name and Account.AccountNumber + assert results_2[0] == DataOperationResult( + id="003000000000002", success=True, error="", created=False + ) + @responses.activate def test_insert_dml_operation__boolean_conversion(self): mock_describe_calls() @@ -2301,6 +2801,8 @@ def test_get_dml_operation(self, rest_dml, bulk_dml): context=context, selection_strategy=SelectStrategy.SIMILARITY, selection_filter=None, + selection_priority_fields=None, + content_type=None, ) op = get_dml_operation( @@ -2324,6 +2826,8 @@ def test_get_dml_operation(self, rest_dml, bulk_dml): context=context, selection_strategy=SelectStrategy.SIMILARITY, selection_filter=None, + selection_priority_fields=None, + content_type=None, ) @mock.patch("cumulusci.tasks.bulkdata.step.BulkApiDmlOperation") @@ -2488,99 +2992,120 @@ def test_cleanup_date_strings__upsert_update(self, operation): }, json_out -import pytest - -# def test_generate_user_filter_query_basic(): -# """Tests basic query generation without existing LIMIT or OFFSET.""" -# filter_clause = "WHERE Name = 'John'" -# sobject = "Account" -# fields = ["Id", "Name"] -# limit_clause = 10 -# offset_clause = 5 - -# expected_query = ( -# "SELECT Id, Name FROM Account WHERE Name = 'John' LIMIT 10 OFFSET 5" -# ) -# assert ( -# generate_user_filter_query( -# filter_clause, sobject, fields, limit_clause, offset_clause -# ) -# == expected_query -# ) - - -# def test_generate_user_filter_query_existing_limit(): -# """Tests handling of existing LIMIT in the filter clause.""" -# filter_clause = "WHERE Name = 'John' LIMIT 20" -# sobject = "Contact" -# fields = ["Id", "FirstName"] -# limit_clause = 5 # Should override the existing LIMIT -# offset_clause = None - -# expected_query = "SELECT Id, FirstName FROM Contact WHERE Name = 'John' LIMIT 5" -# assert ( -# generate_user_filter_query( -# filter_clause, sobject, fields, limit_clause, offset_clause -# ) -# == expected_query -# ) - - -# def test_generate_user_filter_query_existing_offset(): -# """Tests handling of existing OFFSET in the filter clause.""" -# filter_clause = "WHERE Name = 'John' OFFSET 15" -# sobject = "Opportunity" -# fields = ["Id", "Name"] -# limit_clause = None -# offset_clause = 10 # Should add to the existing OFFSET - -# expected_query = "SELECT Id, Name FROM Opportunity WHERE Name = 'John' OFFSET 25" -# assert ( -# generate_user_filter_query( -# filter_clause, sobject, fields, limit_clause, offset_clause -# ) -# == expected_query -# ) - - -# def test_generate_user_filter_query_no_limit_or_offset(): -# """Tests when no limit or offset is provided or present in the filter.""" -# filter_clause = "WHERE Name = 'John' LIMIT 5 OFFSET 20" -# sobject = "Lead" -# fields = ["Id", "Name", "Email"] -# limit_clause = None -# offset_clause = None - -# expected_query = ( -# "SELECT Id, Name, Email FROM Lead WHERE Name = 'John' LIMIT 5 OFFSET 20" -# ) -# print( -# generate_user_filter_query( -# filter_clause, sobject, fields, limit_clause, offset_clause -# ) -# ) -# assert ( -# generate_user_filter_query( -# filter_clause, sobject, fields, limit_clause, offset_clause -# ) -# == expected_query -# ) - - -# def test_generate_user_filter_query_case_insensitivity(): -# """Tests case-insensitivity for LIMIT and OFFSET.""" -# filter_clause = "where name = 'John' offset 5 limit 20" -# sobject = "Task" -# fields = ["Id", "Subject"] -# limit_clause = 15 -# offset_clause = 20 - -# expected_query = ( -# "SELECT Id, Subject FROM Task where name = 'John' LIMIT 15 OFFSET 25" -# ) -# assert ( -# generate_user_filter_query( -# filter_clause, sobject, fields, limit_clause, offset_clause -# ) -# == expected_query -# ) +@pytest.mark.parametrize( + "query_fields, expected", + [ + # Test with simple field names + (["Id", "Name", "Email"], ["Id", "Name", "Email"]), + # Test with TYPEOF fields (polymorphic fields) + ( + [ + "Subject", + { + "Who": [ + {"Contact": ["Name", "Email"]}, + {"Lead": ["Name", "Company"]}, + ] + }, + ], + [ + "Subject", + "Who.Contact.Name", + "Who.Contact.Email", + "Who.Lead.Name", + "Who.Lead.Company", + ], + ), + # Test with mixed simple and TYPEOF fields + ( + ["Subject", {"Who": [{"Contact": ["Email"]}]}, "Account.Name"], + ["Subject", "Who.Contact.Email", "Account.Name"], + ), + # Test with an empty list + ([], []), + ], +) +def test_extract_flattened_headers(query_fields, expected): + result = extract_flattened_headers(query_fields) + assert result == expected + + +@pytest.mark.parametrize( + "record, headers, expected", + [ + # Test with simple field matching + ( + {"Id": "001", "Name": "John Doe", "Email": "john@example.com"}, + ["Id", "Name", "Email"], + ["001", "John Doe", "john@example.com"], + ), + # Test with lookup fields and missing values + ( + { + "Who": { + "attributes": {"type": "Contact"}, + "Name": "Jane Doe", + "Email": "johndoe@org.com", + "Number": 10, + } + }, + ["Who.Contact.Name", "Who.Contact.Email", "Who.Contact.Number"], + ["Jane Doe", "johndoe@org.com", "10"], + ), + # Test with non-matching ref_obj type + ( + {"Who": {"attributes": {"type": "Contact"}, "Email": "jane@contact.com"}}, + ["Who.Lead.Email"], + [""], + ), + # Test with mixed fields and nested lookups + ( + { + "Who": {"attributes": {"type": "Lead"}, "Name": "John Doe"}, + "Email": "john@example.com", + }, + ["Who.Lead.Name", "Who.Lead.Company", "Email"], + ["John Doe", "", "john@example.com"], + ), + # Test with mixed fields and nested lookups + ( + { + "Who": {"attributes": {"type": "Lead"}, "Name": "John Doe"}, + "Email": "john@example.com", + }, + ["What.Account.Name"], + [""], + ), + # Test with empty record + ({}, ["Id", "Name"], ["", ""]), + ], +) +def test_flatten_record(record, headers, expected): + result = flatten_record(record, headers) + assert result == expected + + +@pytest.mark.parametrize( + "priority_fields, fields, expected", + [ + # Test with priority fields matching + ( + {"Id": "Id", "Name": "Name"}, + ["Id", "Name", "Email"], + [HIGH_PRIORITY_VALUE, HIGH_PRIORITY_VALUE, LOW_PRIORITY_VALUE], + ), + # Test with no priority fields provided + (None, ["Id", "Name", "Email"], [1, 1, 1]), + # Test with empty priority fields dictionary + ({}, ["Id", "Name", "Email"], [1, 1, 1]), + # Test with some fields not in priority_fields + ( + {"Id": "Id"}, + ["Id", "Name", "Email"], + [HIGH_PRIORITY_VALUE, LOW_PRIORITY_VALUE, LOW_PRIORITY_VALUE], + ), + ], +) +def test_assign_weights(priority_fields, fields, expected): + result = assign_weights(priority_fields, fields) + assert result == expected From 138241e38037720eee2c954bf3f3f4d7187f8ff0 Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Tue, 19 Nov 2024 02:47:04 +0530 Subject: [PATCH 28/34] Functionality for Select+Insert --- cumulusci/tasks/bulkdata/load.py | 24 +- cumulusci/tasks/bulkdata/mapping_parser.py | 1 + .../tasks/bulkdata/query_transformers.py | 20 +- cumulusci/tasks/bulkdata/select_utils.py | 251 +++++++++-- cumulusci/tasks/bulkdata/step.py | 234 +++++++--- ...lect_invalid_threshold__invalid_number.yml | 21 + ...ct_invalid_threshold__invalid_strategy.yml | 21 + ...ng_select_invalid_threshold__non_float.yml | 21 + cumulusci/tasks/bulkdata/tests/test_load.py | 6 +- .../bulkdata/tests/test_mapping_parser.py | 29 ++ .../tasks/bulkdata/tests/test_select_utils.py | 411 ++++++++++++++++-- cumulusci/tasks/bulkdata/tests/test_step.py | 87 +++- 12 files changed, 960 insertions(+), 166 deletions(-) create mode 100644 cumulusci/tasks/bulkdata/tests/mapping_select_invalid_threshold__invalid_number.yml create mode 100644 cumulusci/tasks/bulkdata/tests/mapping_select_invalid_threshold__invalid_strategy.yml create mode 100644 cumulusci/tasks/bulkdata/tests/mapping_select_invalid_threshold__non_float.yml diff --git a/cumulusci/tasks/bulkdata/load.py b/cumulusci/tasks/bulkdata/load.py index 9a2f08ee90..0732d57777 100644 --- a/cumulusci/tasks/bulkdata/load.py +++ b/cumulusci/tasks/bulkdata/load.py @@ -312,10 +312,14 @@ def _execute_step( def process_lookup_fields(self, mapping, fields, polymorphic_fields): """Modify fields and priority fields based on lookup and polymorphic checks.""" + # Store the lookups and their original order for re-insertion at the end + original_lookups = [name for name in fields if name in mapping.lookups] + max_insert_index = -1 for name, lookup in mapping.lookups.items(): if name in fields: # Get the index of the lookup field before removing it insert_index = fields.index(name) + max_insert_index = max(max_insert_index, insert_index) # Remove the lookup field from fields fields.remove(name) @@ -351,7 +355,7 @@ def process_lookup_fields(self, mapping, fields, polymorphic_fields): None, ) if lookup_mapping_step: - lookup_fields = lookup_mapping_step.get_load_field_list() + lookup_fields = lookup_mapping_step.fields.keys() # Insert fields in the format {relationship_name}.{ref_type}.{lookup_field} for field in lookup_fields: fields.insert( @@ -359,6 +363,7 @@ def process_lookup_fields(self, mapping, fields, polymorphic_fields): f"{relationship_name}.{lookup_mapping_step.sf_object}.{field}", ) insert_index += 1 + max_insert_index = max(max_insert_index, insert_index) if lookup_in_priority_fields: mapping.select_options.priority_fields[ f"{relationship_name}.{lookup_mapping_step.sf_object}.{field}" @@ -383,17 +388,24 @@ def process_lookup_fields(self, mapping, fields, polymorphic_fields): if lookup_mapping_step: relationship_name = polymorphic_fields[name]["relationshipName"] - lookup_fields = lookup_mapping_step.get_load_field_list() + lookup_fields = lookup_mapping_step.fields.keys() # Insert the new fields at the same position as the removed lookup field for field in lookup_fields: fields.insert(insert_index, f"{relationship_name}.{field}") insert_index += 1 + max_insert_index = max(max_insert_index, insert_index) if lookup_in_priority_fields: mapping.select_options.priority_fields[ f"{relationship_name}.{field}" ] = f"{relationship_name}.{field}" + # Append the original lookups at the end in the same order + for name in original_lookups: + if name not in fields: + fields.insert(max_insert_index, name) + max_insert_index += 1 + def configure_step(self, mapping): """Create a step appropriate to the action""" bulk_mode = mapping.bulk_mode or self.bulk_mode or "Parallel" @@ -479,6 +491,7 @@ def configure_step(self, mapping): selection_filter=mapping.select_options.filter, selection_priority_fields=mapping.select_options.priority_fields, content_type=content_type, + threshold=mapping.select_options.threshold, ) return step, query @@ -588,10 +601,9 @@ def _query_db(self, mapping): mapping, self.mapping, self.metadata, model, self._old_format ) ) - else: - transformers.append( - AddLookupsToQuery(mapping, self.metadata, model, self._old_format) - ) + transformers.append( + AddLookupsToQuery(mapping, self.metadata, model, self._old_format) + ) transformers.extend([cls(mapping, self.metadata, model) for cls in classes]) diff --git a/cumulusci/tasks/bulkdata/mapping_parser.py b/cumulusci/tasks/bulkdata/mapping_parser.py index 1593dc97a1..e630d564c6 100644 --- a/cumulusci/tasks/bulkdata/mapping_parser.py +++ b/cumulusci/tasks/bulkdata/mapping_parser.py @@ -31,6 +31,7 @@ class MappingLookup(CCIDictModel): join_field: Optional[str] = None after: Optional[str] = None aliased_table: Optional[Any] = None + parent_tables: Optional[Any] = None name: Optional[str] = None # populated by parent def get_lookup_key_field(self, model=None): diff --git a/cumulusci/tasks/bulkdata/query_transformers.py b/cumulusci/tasks/bulkdata/query_transformers.py index f99689618e..181736a4bc 100644 --- a/cumulusci/tasks/bulkdata/query_transformers.py +++ b/cumulusci/tasks/bulkdata/query_transformers.py @@ -106,14 +106,14 @@ def columns_to_add(self): columns = [] for lookup in self.lookups: tables = lookup.table if isinstance(lookup.table, list) else [lookup.table] - lookup.aliased_table = [ + lookup.parent_tables = [ aliased( self.metadata.tables[table], name=f"{lookup.name}_{table}_alias" ) for table in tables ] - for aliased_table, table_name in zip(lookup.aliased_table, tables): + for parent_table, table_name in zip(lookup.parent_tables, tables): # Find the mapping step for this polymorphic type lookup_mapping_step = next( ( @@ -124,24 +124,24 @@ def columns_to_add(self): None, ) if lookup_mapping_step: - load_fields = lookup_mapping_step.get_load_field_list() + load_fields = lookup_mapping_step.fields.keys() for field in load_fields: if field in lookup_mapping_step.fields: matching_column = next( ( col - for col in aliased_table.columns + for col in parent_table.columns if col.name == lookup_mapping_step.fields[field] ) ) columns.append( - matching_column.label(f"{aliased_table.name}_{field}") + matching_column.label(f"{parent_table.name}_{field}") ) else: # Append an empty string if the field is not present columns.append( literal_column("''").label( - f"{aliased_table.name}_{field}" + f"{parent_table.name}_{field}" ) ) return columns @@ -150,15 +150,15 @@ def columns_to_add(self): def outerjoins_to_add(self): """Add outer joins for each lookup table directly, including handling for polymorphic lookups.""" - def join_for_lookup(lookup, aliased_table): + def join_for_lookup(lookup, parent_table): key_field = lookup.get_lookup_key_field(self.model) value_column = getattr(self.model, key_field) - return (aliased_table, aliased_table.columns.id == value_column) + return (parent_table, parent_table.columns.id == value_column) joins = [] for lookup in self.lookups: - for aliased_table in lookup.aliased_table: - joins.append(join_for_lookup(lookup, aliased_table)) + for parent_table in lookup.parent_tables: + joins.append(join_for_lookup(lookup, parent_table)) return joins diff --git a/cumulusci/tasks/bulkdata/select_utils.py b/cumulusci/tasks/bulkdata/select_utils.py index f5800f9b38..7412a38ae4 100644 --- a/cumulusci/tasks/bulkdata/select_utils.py +++ b/cumulusci/tasks/bulkdata/select_utils.py @@ -6,7 +6,7 @@ import numpy as np import pandas as pd from annoy import AnnoyIndex -from pydantic import Field, validator +from pydantic import Field, root_validator, validator from sklearn.feature_extraction.text import HashingVectorizer from sklearn.preprocessing import StandardScaler @@ -45,6 +45,7 @@ class SelectOptions(CCIDictModel): filter: T.Optional[str] = None # Optional filter for selection strategy: SelectStrategy = SelectStrategy.STANDARD # Strategy for selection priority_fields: T.Dict[str, str] = Field({}) + threshold: T.Optional[float] = None @validator("strategy", pre=True) def validate_strategy(cls, value): @@ -66,6 +67,26 @@ def standardize_fields_to_dict(cls, values): values = {elem: elem for elem in values} return CaseInsensitiveDict(values) + @root_validator + def validate_threshold_and_strategy(cls, values): + threshold = values.get("threshold") + strategy = values.get("strategy") + + if threshold is not None: + values["threshold"] = float(threshold) # Convert to float + + if not (0 <= values["threshold"] <= 1): + raise ValueError( + f"Threshold must be between 0 and 1, got {values['threshold']}." + ) + + if strategy != SelectStrategy.SIMILARITY: + raise ValueError( + "If a threshold is specified, the strategy must be set to 'similarity'." + ) + + return values + class SelectOperationExecutor: def __init__(self, strategy: SelectStrategy): @@ -84,6 +105,7 @@ def select_generate_query( limit: T.Union[int, None], offset: T.Union[int, None], ): + _, select_fields = split_and_filter_fields(fields=fields) # For STANDARD strategy if self.strategy == SelectStrategy.STANDARD: return standard_generate_query( @@ -93,7 +115,7 @@ def select_generate_query( elif self.strategy == SelectStrategy.SIMILARITY: return similarity_generate_query( sobject=sobject, - fields=fields, + fields=select_fields, user_filter=user_filter, limit=limit, offset=offset, @@ -108,9 +130,11 @@ def select_post_process( self, load_records, query_records: list, + fields: list, num_records: int, sobject: str, weights: list, + threshold: T.Union[float, None], ): # For STANDARD strategy if self.strategy == SelectStrategy.STANDARD: @@ -122,8 +146,10 @@ def select_post_process( return similarity_post_process( load_records=load_records, query_records=query_records, + fields=fields, sobject=sobject, weights=weights, + threshold=threshold, ) # For RANDOM strategy elif self.strategy == SelectStrategy.RANDOM: @@ -158,12 +184,12 @@ def standard_generate_query( def standard_post_process( query_records: list, num_records: int, sobject: str -) -> T.Tuple[T.List[dict], T.Union[str, None]]: +) -> T.Tuple[T.List[dict], None, T.Union[str, None]]: """Processes the query results for the standard selection strategy""" # Handle case where query returns 0 records if not query_records: error_message = f"No records found for {sobject} in the target org." - return [], error_message + return [], None, error_message # Add 'success: True' to each record to emulate records have been inserted selected_records = [ @@ -177,7 +203,7 @@ def standard_post_process( selected_records.extend(original_records) selected_records = selected_records[:num_records] - return selected_records, None # Return selected records and None for error + return selected_records, None, None # Return selected records and None for error def similarity_generate_query( @@ -255,13 +281,20 @@ def similarity_generate_query( def similarity_post_process( - load_records, query_records: list, sobject: str, weights: list -) -> T.Tuple[T.List[dict], T.Union[str, None]]: + load_records, + query_records: list, + fields: list, + sobject: str, + weights: list, + threshold: T.Union[float, None], +) -> T.Tuple[ + T.List[T.Union[dict, None]], T.List[T.Union[list, None]], T.Union[str, None] +]: """Processes the query results for the similarity selection strategy""" # Handle case where query returns 0 records - if not query_records: + if not query_records and not threshold: error_message = f"No records found for {sobject} in the target org." - return [], error_message + return [], [], error_message load_records = list(load_records) # Replace None values in each row with empty strings @@ -272,23 +305,55 @@ def similarity_post_process( complexity_constant = load_record_count * query_record_count - closest_records = [] + select_records = [] + insert_records = [] if complexity_constant < 1000: - closest_records = levenshtein_post_process(load_records, query_records, weights) + select_records, insert_records = levenshtein_post_process( + load_records, query_records, fields, weights, threshold + ) else: - closest_records = annoy_post_process(load_records, query_records, weights) + select_records, insert_records = annoy_post_process( + load_records, query_records, fields, weights, threshold + ) - return closest_records + return select_records, insert_records, None def annoy_post_process( - load_records: list, query_records: list, weights: list -) -> T.Tuple[T.List[dict], T.Union[str, None]]: + load_records: list, + query_records: list, + all_fields: list, + similarity_weights: list, + threshold: T.Union[float, None], +) -> T.Tuple[T.List[dict], list]: """Processes the query results for the similarity selection strategy using Annoy algorithm for large number of records""" + selected_records = [] + insertion_candidates = [] + + # Split fields into load and select categories + load_field_list, select_field_list = split_and_filter_fields(fields=all_fields) + # Only select those weights for select field list + similarity_weights = [ + similarity_weights[idx] + for idx, field in enumerate(all_fields) + if field in select_field_list + ] + load_shaped_records = reorder_records( + records=load_records, original_fields=all_fields, new_fields=load_field_list + ) + select_shaped_records = reorder_records( + records=load_records, original_fields=all_fields, new_fields=select_field_list + ) + + if not query_records: + # Directly append to load record for insertion if target_records is empty + selected_records = [None for _ in load_records] + insertion_candidates = load_shaped_records + return selected_records, insertion_candidates query_records = replace_empty_strings_with_missing(query_records) - load_records = replace_empty_strings_with_missing(load_records) + select_shaped_records = replace_empty_strings_with_missing(select_shaped_records) hash_features = 100 num_trees = 10 @@ -302,7 +367,10 @@ def annoy_post_process( } final_load_vectors, final_query_vectors = vectorize_records( - load_records, query_record_data, hash_features=hash_features, weights=weights + select_shaped_records, + query_record_data, + hash_features=hash_features, + weights=similarity_weights, ) # Create Annoy index for nearest neighbor search @@ -318,49 +386,89 @@ def annoy_post_process( # Find nearest neighbors for each query vector n_neighbors = 1 - closest_records = [] - for i, load_vector in enumerate(final_load_vectors): # Get nearest neighbors' indices and distances nearest_neighbors = annoy_index.get_nns_by_vector( load_vector, n_neighbors, include_distances=True ) neighbor_indices = nearest_neighbors[0] # Indices of nearest neighbors + neighbor_distances = [ + distance / 2 for distance in nearest_neighbors[1] + ] # Distances sqrt(2(1-cos(u,v)))/2 lies between [0,1] - for neighbor_index in neighbor_indices: + for idx, neighbor_index in enumerate(neighbor_indices): # Retrieve the corresponding record from the database record = query_record_data[neighbor_index] closest_record_id = record_to_id_map[tuple(record)] - closest_records.append( - {"id": closest_record_id, "success": True, "created": False} - ) + if threshold and (neighbor_distances[idx] >= threshold): + selected_records.append(None) + insertion_candidates.append(load_shaped_records[i]) + else: + selected_records.append( + {"id": closest_record_id, "success": True, "created": False} + ) - return closest_records, None + return selected_records, insertion_candidates def levenshtein_post_process( - load_records: list, query_records: list, weights: list -) -> T.Tuple[T.List[dict], T.Union[str, None]]: - """Processes the query results for the similarity selection strategy using Levenshtein algorithm for small number of records""" - closest_records = [] - - for record in load_records: - closest_record = find_closest_record(record, query_records, weights) - closest_records.append( - {"id": closest_record[0], "success": True, "created": False} + source_records: list, + target_records: list, + all_fields: list, + similarity_weights: list, + distance_threshold: T.Union[float, None], +) -> T.Tuple[T.List[T.Optional[dict]], T.List[T.Optional[list]]]: + """Processes query results using Levenshtein algorithm for similarity selection with a small number of records.""" + selected_records = [] + insertion_candidates = [] + + # Split fields into load and select categories + load_field_list, select_field_list = split_and_filter_fields(fields=all_fields) + # Only select those weights for select field list + similarity_weights = [ + similarity_weights[idx] + for idx, field in enumerate(all_fields) + if field in select_field_list + ] + load_shaped_records = reorder_records( + records=source_records, original_fields=all_fields, new_fields=load_field_list + ) + select_shaped_records = reorder_records( + records=source_records, original_fields=all_fields, new_fields=select_field_list + ) + + if not target_records: + # Directly append to load record for insertion if target_records is empty + selected_records = [None for _ in source_records] + insertion_candidates = load_shaped_records + return selected_records, insertion_candidates + + for select_record, load_record in zip(select_shaped_records, load_shaped_records): + closest_match, match_distance = find_closest_record( + select_record, target_records, similarity_weights ) - return closest_records, None + if distance_threshold and match_distance > distance_threshold: + # Append load record for insertion if distance exceeds threshold + insertion_candidates.append(load_record) + selected_records.append(None) + elif closest_match: + # Append match details if distance is within threshold + selected_records.append( + {"id": closest_match[0], "success": True, "created": False} + ) + + return selected_records, insertion_candidates def random_post_process( query_records: list, num_records: int, sobject: str -) -> T.Tuple[T.List[dict], T.Union[str, None]]: +) -> T.Tuple[T.List[dict], None, T.Union[str, None]]: """Processes the query results for the random selection strategy""" if not query_records: error_message = f"No records found for {sobject} in the target org." - return [], error_message + return [], None, error_message selected_records = [] for _ in range(num_records): # Loop 'num_records' times @@ -370,7 +478,7 @@ def random_post_process( {"id": random_record[0], "success": True, "created": False} ) - return selected_records, None + return selected_records, None, None def find_closest_record(load_record: list, query_records: list, weights: list): @@ -383,7 +491,7 @@ def find_closest_record(load_record: list, query_records: list, weights: list): closest_distance = distance closest_record = record - return closest_record + return closest_record, closest_distance def levenshtein_distance(str1: str, str2: str): @@ -417,7 +525,6 @@ def calculate_levenshtein_distance(record1: list, record2: list, weights: list): raise ValueError("Records must be same size as fields (weights).") total_distance = 0 - total_fields = 0 for field1, field2, weight in zip(record1, record2, weights): field1 = field1.lower() @@ -427,16 +534,19 @@ def calculate_levenshtein_distance(record1: list, record2: list, weights: list): # If both fields are blank, distance is 0 distance = 0 else: - distance = levenshtein_distance(field1, field2) + # Average distance per character + distance = levenshtein_distance(field1, field2) / max( + len(field1), len(field2) + ) if len(field1) == 0 or len(field2) == 0: # If one field is blank, reduce the impact of the distance distance = distance * 0.05 # Fixed value for blank vs non-blank # Multiply the distance by the corresponding weight total_distance += distance * weight - total_fields += 1 - return total_distance / total_fields if total_fields > 0 else 0 + # Average distance per character with weights + return total_distance / sum(weights) if len(weights) else 0 def add_limit_offset_to_user_filter( @@ -600,3 +710,60 @@ def replace_empty_strings_with_missing(records): [(field if field != "" else "missing") for field in record] for record in records ] + + +def split_and_filter_fields(fields: T.List[str]) -> T.Tuple[T.List[str], T.List[str]]: + # List to store non-lookup fields (load fields) + load_fields = [] + + # Set to store unique first components of select fields + unique_components = set() + # Keep track of last flattened lookup index + last_flat_lookup_index = -1 + + # Iterate through the fields + for idx, field in enumerate(fields): + if "." in field: + # Split the field by '.' and add the first component to the set + first_component = field.split(".")[0] + unique_components.add(first_component) + last_flat_lookup_index = max(last_flat_lookup_index, idx) + else: + # Add the field to the load_fields list + load_fields.append(field) + + # Number of unique components + num_unique_components = len(unique_components) + + # Adjust select_fields by removing only the field at last_flat_lookup_index + 1 + if last_flat_lookup_index + 1 < len( + fields + ) and last_flat_lookup_index + num_unique_components < len(fields): + select_fields = ( + fields[: last_flat_lookup_index + 1] + + fields[last_flat_lookup_index + num_unique_components + 1 :] + ) + else: + select_fields = fields + + return load_fields, select_fields + + +# Function to reorder records based on the new field list +def reorder_records(records, original_fields, new_fields): + if not original_fields: + raise KeyError("original_fields should not be empty") + # Map the original field indices + field_index_map = {field: i for i, field in enumerate(original_fields)} + reordered_records = [] + + for record in records: + reordered_records.append( + [ + record[field_index_map[field]] + for field in new_fields + if field in field_index_map + ] + ) + + return reordered_records diff --git a/cumulusci/tasks/bulkdata/step.py b/cumulusci/tasks/bulkdata/step.py index 3e60ef91c0..b88fa8b100 100644 --- a/cumulusci/tasks/bulkdata/step.py +++ b/cumulusci/tasks/bulkdata/step.py @@ -20,6 +20,7 @@ SelectOperationExecutor, SelectRecordRetrievalMode, SelectStrategy, + split_and_filter_fields, ) from cumulusci.tasks.bulkdata.utils import DataApi, iterate_in_chunks from cumulusci.utils.classutils import namedtuple_as_simple_dict @@ -356,6 +357,7 @@ def __init__( selection_filter=None, selection_priority_fields=None, content_type=None, + threshold=None, ): super().__init__( sobject=sobject, @@ -377,6 +379,7 @@ def __init__( priority_fields=selection_priority_fields, fields=fields ) self.content_type = content_type if content_type else "CSV" + self.threshold = threshold def start(self): self.job_id = self.bulk.create_job( @@ -459,18 +462,7 @@ def select_records(self, records): records, records_copy = tee(records) # Count total number of records to fetch using the copy total_num_records = sum(1 for _ in records_copy) - - # Set LIMIT condition - if ( - self.select_operation_executor.retrieval_mode - == SelectRecordRetrievalMode.ALL - ): - limit_clause = None - elif ( - self.select_operation_executor.retrieval_mode - == SelectRecordRetrievalMode.MATCH - ): - limit_clause = total_num_records + limit_clause = self._determine_limit_clause(total_num_records=total_num_records) # Generate and execute SOQL query # (not passing offset as it is not supported in Bulk) @@ -494,14 +486,34 @@ def select_records(self, records): # Post-process the query results ( selected_records, + insert_records, error_message, ) = self.select_operation_executor.select_post_process( load_records=records, query_records=query_records, + fields=self.fields, num_records=total_num_records, sobject=self.sobject, weights=self.weights, + threshold=self.threshold, ) + + # Log the number of selected and prepared for insertion records + num_selected = sum(1 for record in selected_records if record) + num_prepared = len(insert_records) if insert_records else 0 + + self.logger.info( + f"{num_selected} records selected." + + ( + f" {num_prepared} records prepared for insertion." + if num_prepared > 0 + else "" + ) + ) + + if insert_records: + self._process_insert_records(insert_records, selected_records) + if not error_message: self.select_results.extend(selected_records) @@ -517,6 +529,60 @@ def select_records(self, records): total_row_errors=0, ) + def _process_insert_records(self, insert_records, selected_records): + """Processes and inserts records if necessary.""" + insert_fields, _ = split_and_filter_fields(fields=self.fields) + insert_step = BulkApiDmlOperation( + sobject=self.sobject, + operation=DataOperationType.INSERT, + api_options=self.api_options, + context=self.context, + fields=insert_fields, + ) + insert_step.start() + insert_step.load_records(insert_records) + insert_step.end() + # Retrieve insert results + insert_results = [] + for batch_id in insert_step.batch_ids: + try: + results_url = f"{insert_step.bulk.endpoint}/job/{insert_step.job_id}/batch/{batch_id}/result" + # Download entire result file to a temporary file first + # to avoid the server dropping connections + with download_file(results_url, insert_step.bulk) as f: + self.logger.info(f"Downloaded results for batch {batch_id}") + reader = csv.reader(f) + next(reader) # Skip header row + for row in reader: + success = process_bool_arg(row[1]) + created = process_bool_arg(row[2]) + insert_results.append( + {"id": row[0], "success": success, "created": created} + ) + except Exception as e: + raise BulkDataException( + f"Failed to download results for batch {batch_id} ({str(e)})" + ) + + insert_index = 0 + for idx, record in enumerate(selected_records): + if record is None: + selected_records[idx] = insert_results[insert_index] + insert_index += 1 + + def _determine_limit_clause(self, total_num_records): + """Determines the LIMIT clause based on the retrieval mode.""" + if ( + self.select_operation_executor.retrieval_mode + == SelectRecordRetrievalMode.ALL + ): + return None + elif ( + self.select_operation_executor.retrieval_mode + == SelectRecordRetrievalMode.MATCH + ): + return total_num_records + def _execute_select_query(self, select_query: str, query_fields: List[str]): """Executes the select Bulk API query, retrieves results in JSON, and converts to CSV format if needed.""" self.batch_id = self.bulk.query(self.job_id, select_query) @@ -660,6 +726,7 @@ def __init__( selection_filter=None, selection_priority_fields=None, content_type=None, + threshold=None, ): super().__init__( sobject=sobject, @@ -691,6 +758,7 @@ def __init__( priority_fields=selection_priority_fields, fields=fields ) self.content_type = content_type + self.threshold = threshold def _record_to_json(self, rec): result = dict(zip(self.fields, rec)) @@ -804,74 +872,126 @@ def select_records(self, records): self.results = [] query_records = [] + # Create a copy of the generator using tee records, records_copy = tee(records) + # Count total number of records to fetch using the copy total_num_records = sum(1 for _ in records_copy) # Set LIMIT condition + limit_clause = self._determine_limit_clause(total_num_records) + + # Generate the SOQL query based on the selection strategy + select_query, query_fields = ( + self.select_operation_executor.select_generate_query( + sobject=self.sobject, + fields=self.fields, + user_filter=self.selection_filter or None, + limit=limit_clause, + offset=None, + ) + ) + + # Execute the query and gather the records + query_records = self._execute_soql_query(select_query, query_fields) + + # Post-process the query results for this batch + selected_records, insert_records, error_message = ( + self.select_operation_executor.select_post_process( + load_records=records, + query_records=query_records, + fields=self.fields, + num_records=total_num_records, + sobject=self.sobject, + weights=self.weights, + threshold=self.threshold, + ) + ) + + # Log the number of selected and prepared for insertion records + num_selected = sum(1 for record in selected_records if record) + num_prepared = len(insert_records) if insert_records else 0 + + self.logger.info( + f"{num_selected} records selected." + + ( + f" {num_prepared} records prepared for insertion." + if num_prepared > 0 + else "" + ) + ) + + if insert_records: + self._process_insert_records(insert_records, selected_records) + + if not error_message: + # Add selected records from this batch to the overall results + self.results.extend(selected_records) + + # Update the job result based on the overall selection outcome + self._update_job_result(error_message) + + def _determine_limit_clause(self, total_num_records): + """Determines the LIMIT clause based on the retrieval mode.""" if ( self.select_operation_executor.retrieval_mode == SelectRecordRetrievalMode.ALL ): - limit_clause = None + return None elif ( self.select_operation_executor.retrieval_mode == SelectRecordRetrievalMode.MATCH ): - limit_clause = total_num_records + return total_num_records - # Generate the SOQL query based on the selection strategy - ( - select_query, - query_fields, - ) = self.select_operation_executor.select_generate_query( - sobject=self.sobject, - fields=self.fields, - user_filter=self.selection_filter if self.selection_filter else None, - limit=limit_clause, - offset=None, - ) - - # Handle the case where self.selection_query is None (and hence user_query is also None) + def _execute_soql_query(self, select_query, query_fields): + """Executes the SOQL query and returns the flattened records.""" + query_records = [] response = self.sf.restful( requests.utils.requote_uri(f"query/?q={select_query}"), method="GET" ) - # Convert each record to a flat row - for record in response["records"]: - flat_record = flatten_record(record, query_fields) - query_records.append(flat_record) - while True: - if not response["done"]: - response = self.sf.query_more( - response["nextRecordsUrl"], identifier_is_url=True - ) - for record in response["records"]: - flat_record = flatten_record(record, query_fields) - query_records.append(flat_record) - else: - break + query_records.extend(self._flatten_response_records(response, query_fields)) - # Post-process the query results for this batch - ( - selected_records, - error_message, - ) = self.select_operation_executor.select_post_process( - load_records=records, - query_records=query_records, - num_records=total_num_records, + while not response["done"]: + response = self.sf.query_more( + response["nextRecordsUrl"], identifier_is_url=True + ) + query_records.extend(self._flatten_response_records(response, query_fields)) + + return query_records + + def _flatten_response_records(self, response, query_fields): + """Flattens the response records and returns them as a list.""" + return [flatten_record(record, query_fields) for record in response["records"]] + + def _process_insert_records(self, insert_records, selected_records): + """Processes and inserts records if necessary.""" + insert_fields, _ = split_and_filter_fields(fields=self.fields) + insert_step = RestApiDmlOperation( sobject=self.sobject, - weights=self.weights, + operation=DataOperationType.INSERT, + api_options=self.api_options, + context=self.context, + fields=insert_fields, ) - if not error_message: - # Add selected records from this batch to the overall results - self.results.extend(selected_records) - - # Update the job result based on the overall selection outcome + insert_step.start() + insert_step.load_records(insert_records) + insert_step.end() + insert_results = insert_step.results + + insert_index = 0 + for idx, record in enumerate(selected_records): + if record is None: + selected_records[idx] = insert_results[insert_index] + insert_index += 1 + + def _update_job_result(self, error_message): + """Updates the job result based on the selection outcome.""" self.job_result = DataOperationJobResult( status=( DataOperationStatus.SUCCESS - if len(self.results) # Check the overall results length + if len(self.results) else DataOperationStatus.JOB_FAILURE ), job_errors=[error_message] if error_message else [], @@ -964,6 +1084,7 @@ def get_dml_operation( selection_filter: Union[str, None] = None, selection_priority_fields: Union[dict, None] = None, content_type: Union[str, None] = None, + threshold: Union[float, None] = None, ) -> BaseDmlOperation: """Create an appropriate DmlOperation instance for the given parameters, selecting between REST and Bulk APIs based upon volume (Bulk used at volumes over 2000 records, @@ -1001,6 +1122,7 @@ def get_dml_operation( selection_filter=selection_filter, selection_priority_fields=selection_priority_fields, content_type=content_type, + threshold=threshold, ) diff --git a/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_threshold__invalid_number.yml b/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_threshold__invalid_number.yml new file mode 100644 index 0000000000..1bad614b1d --- /dev/null +++ b/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_threshold__invalid_number.yml @@ -0,0 +1,21 @@ +# Select Mapping File for load +Select Accounts: + api: bulk + action: select + sf_object: Account + table: accounts + select_options: + strategy: similarity + filter: WHEN Name in ('Sample Account') + priority_fields: + Name: name + AccountNumber: account_number + threshold: 1.5 + fields: + Name: name + AccountNumber: account_number + Description: description + lookups: + ParentId: + key_field: parent_id + table: accounts diff --git a/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_threshold__invalid_strategy.yml b/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_threshold__invalid_strategy.yml new file mode 100644 index 0000000000..71958848c5 --- /dev/null +++ b/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_threshold__invalid_strategy.yml @@ -0,0 +1,21 @@ +# Select Mapping File for load +Select Accounts: + api: bulk + action: select + sf_object: Account + table: accounts + select_options: + strategy: standard + filter: WHEN Name in ('Sample Account') + priority_fields: + Name: name + AccountNumber: account_number + threshold: 0.5 + fields: + Name: name + AccountNumber: account_number + Description: description + lookups: + ParentId: + key_field: parent_id + table: accounts diff --git a/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_threshold__non_float.yml b/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_threshold__non_float.yml new file mode 100644 index 0000000000..2ff1482f3d --- /dev/null +++ b/cumulusci/tasks/bulkdata/tests/mapping_select_invalid_threshold__non_float.yml @@ -0,0 +1,21 @@ +# Select Mapping File for load +Select Accounts: + api: bulk + action: select + sf_object: Account + table: accounts + select_options: + strategy: similarity + filter: WHEN Name in ('Sample Account') + priority_fields: + Name: name + AccountNumber: account_number + threshold: invalid threshold + fields: + Name: name + AccountNumber: account_number + Description: description + lookups: + ParentId: + key_field: parent_id + table: accounts diff --git a/cumulusci/tasks/bulkdata/tests/test_load.py b/cumulusci/tasks/bulkdata/tests/test_load.py index 9fb6ea1d87..8fb8ee0756 100644 --- a/cumulusci/tasks/bulkdata/tests/test_load.py +++ b/cumulusci/tasks/bulkdata/tests/test_load.py @@ -835,13 +835,12 @@ def test_process_lookup_fields_polymorphic(self): "Subject", "Who.Contact.FirstName", "Who.Contact.LastName", - "Who.Contact.AccountId", "Who.Lead.LastName", + "WhoId", ] expected_priority_fields_keys = { "Who.Contact.FirstName", "Who.Contact.LastName", - "Who.Contact.AccountId", "Who.Lead.LastName", } with mock.patch( @@ -886,6 +885,7 @@ def test_process_lookup_fields_non_polymorphic(self): "LastName", "Account.Name", "Account.AccountNumber", + "AccountId", ] expected_priority_fields_keys = { "FirstName", @@ -989,7 +989,7 @@ def test_query_db__joins_select_lookups(self): sql_path=Path(__file__).parent / "test_query_db_joins_lookups.sql", mapping=Path(__file__).parent / "test_query_db_joins_lookups_select.yml", mapping_step_name="Select Event", - expected='''SELECT events.id AS events_id, events."subject" AS "events_subject", "whoid_contacts_alias"."firstname" AS "whoid_contacts_alias_firstname", "whoid_contacts_alias"."lastname" AS "whoid_contacts_alias_lastname", '' AS "whoid_contacts_alias_accountid", "whoid_leads_alias"."lastname" AS "whoid_leads_alias_lastname" from events LEFT OUTER JOIN contacts AS "whoid_contacts_alias" ON "whoid_contacts_alias".id=events."whoid" LEFT OUTER JOIN leads AS "whoid_leads_alias" ON "whoid_leads_alias".id=events."whoid" ORDER BY events."whoid"''', + expected='''SELECT events.id AS events_id, events."subject" AS "events_subject", "whoid_contacts_alias"."firstname" AS "whoid_contacts_alias_firstname", "whoid_contacts_alias"."lastname" AS "whoid_contacts_alias_lastname", "whoid_leads_alias"."lastname" AS "whoid_leads_alias_lastname", cumulusci_id_table_1.sf_id AS cumulusci_id_table_1_sf_id FROM events LEFT OUTER JOIN contacts AS "whoid_contacts_alias" ON "whoid_contacts_alias".id=events."whoid" LEFT OUTER JOIN leads AS "whoid_leads_alias" ON "whoid_leads_alias".id=events."whoid" LEFT OUTER JOIN cumulusci_id_table AS cumulusci_id_table_1 ON cumulusci_id_table_1.id=? || cast(events."whoid" as varchar) ORDER BY events."whoid"''', ) def test_query_db__joins_polymorphic_lookups(self): diff --git a/cumulusci/tasks/bulkdata/tests/test_mapping_parser.py b/cumulusci/tasks/bulkdata/tests/test_mapping_parser.py index ae9fe91686..8ce38ff5a8 100644 --- a/cumulusci/tasks/bulkdata/tests/test_mapping_parser.py +++ b/cumulusci/tasks/bulkdata/tests/test_mapping_parser.py @@ -231,6 +231,35 @@ def test_select_options__invalid_strategy(self): parse_from_yaml(base_path) assert "Invalid strategy value: invalid_strategy" in str(e.value) + def test_select_options__invalid_threshold__non_float(self): + base_path = ( + Path(__file__).parent / "mapping_select_invalid_threshold__non_float.yml" + ) + with pytest.raises(ValueError) as e: + parse_from_yaml(base_path) + assert "value is not a valid float" in str(e.value) + + def test_select_options__invalid_threshold__invalid_strategy(self): + base_path = ( + Path(__file__).parent + / "mapping_select_invalid_threshold__invalid_strategy.yml" + ) + with pytest.raises(ValueError) as e: + parse_from_yaml(base_path) + assert ( + "If a threshold is specified, the strategy must be set to 'similarity'." + in str(e.value) + ) + + def test_select_options__invalid_threshold__invalid_number(self): + base_path = ( + Path(__file__).parent + / "mapping_select_invalid_threshold__invalid_number.yml" + ) + with pytest.raises(ValueError) as e: + parse_from_yaml(base_path) + assert "Threshold must be between 0 and 1, got 1.5" in str(e.value) + def test_select_options__missing_priority_fields(self): base_path = Path(__file__).parent / "mapping_select_missing_priority_fields.yml" with pytest.raises(ValueError) as e: diff --git a/cumulusci/tasks/bulkdata/tests/test_select_utils.py b/cumulusci/tasks/bulkdata/tests/test_select_utils.py index 4969722c6e..a0b5a3fcad 100644 --- a/cumulusci/tasks/bulkdata/tests/test_select_utils.py +++ b/cumulusci/tasks/bulkdata/tests/test_select_utils.py @@ -10,7 +10,9 @@ determine_field_types, find_closest_record, levenshtein_distance, + reorder_records, replace_empty_strings_with_missing, + split_and_filter_fields, vectorize_records, ) @@ -100,8 +102,14 @@ def test_standard_post_process_with_records(): records = [["001"], ["002"], ["003"]] num_records = 3 sobject = "Contact" - selected_records, error_message = select_operator.select_post_process( - None, records, num_records, sobject, weights=[] + selected_records, _, error_message = select_operator.select_post_process( + load_records=None, + query_records=records, + num_records=num_records, + sobject=sobject, + weights=[], + fields=[], + threshold=None, ) assert error_message is None @@ -116,8 +124,14 @@ def test_standard_post_process_with_fewer_records(): records = [["001"]] num_records = 3 sobject = "Opportunity" - selected_records, error_message = select_operator.select_post_process( - None, records, num_records, sobject, weights=[] + selected_records, _, error_message = select_operator.select_post_process( + load_records=None, + query_records=records, + num_records=num_records, + sobject=sobject, + weights=[], + fields=[], + threshold=None, ) assert error_message is None @@ -133,8 +147,14 @@ def test_standard_post_process_with_no_records(): records = [] num_records = 2 sobject = "Lead" - selected_records, error_message = select_operator.select_post_process( - None, records, num_records, sobject, weights=[] + selected_records, _, error_message = select_operator.select_post_process( + load_records=None, + query_records=records, + num_records=num_records, + sobject=sobject, + weights=[], + fields=[], + threshold=None, ) assert selected_records == [] @@ -147,8 +167,14 @@ def test_random_post_process_with_records(): records = [["001"], ["002"], ["003"]] num_records = 3 sobject = "Contact" - selected_records, error_message = select_operator.select_post_process( - None, records, num_records, sobject, weights=[] + selected_records, _, error_message = select_operator.select_post_process( + load_records=None, + query_records=records, + num_records=num_records, + sobject=sobject, + weights=[], + fields=[], + threshold=None, ) assert error_message is None @@ -162,8 +188,14 @@ def test_random_post_process_with_no_records(): records = [] num_records = 2 sobject = "Lead" - selected_records, error_message = select_operator.select_post_process( - None, records, num_records, sobject, weights=[] + selected_records, _, error_message = select_operator.select_post_process( + load_records=None, + query_records=records, + num_records=num_records, + sobject=sobject, + weights=[], + fields=[], + threshold=None, ) assert selected_records == [] @@ -279,7 +311,7 @@ def test_find_closest_record_different_weights(): weights = [2.0, 0.5] # With different weights, the first field will have more impact - closest_record = find_closest_record(load_record, query_records, weights) + closest_record, _ = find_closest_record(load_record, query_records, weights) assert closest_record == [ "record1", "hello", @@ -296,7 +328,7 @@ def test_find_closest_record_basic(): ] weights = [1.0, 1.0] - closest_record = find_closest_record(load_record, query_records, weights) + closest_record, _ = find_closest_record(load_record, query_records, weights) assert closest_record == [ "record1", "hello", @@ -313,7 +345,7 @@ def test_find_closest_record_multiple_matches(): ] weights = [1.0, 1.0] - closest_record = find_closest_record(load_record, query_records, weights) + closest_record, _ = find_closest_record(load_record, query_records, weights) assert closest_record == [ "record2", "cat", @@ -327,25 +359,29 @@ def test_similarity_post_process_with_records(): sobject = "Contact" load_records = [["Tom Cruise", "62", "Actor"]] query_records = [ - ["001", "Tom Hanks", "62", "Actor"], + ["001", "Bob Hanks", "62", "Actor"], ["002", "Tom Cruise", "63", "Actor"], # Slight difference ["003", "Jennifer Aniston", "30", "Actress"], ] weights = [1.0, 1.0, 1.0] # Adjust weights to match your data structure - selected_records, error_message = select_operator.select_post_process( - load_records, query_records, num_records, sobject, weights + selected_records, _, error_message = select_operator.select_post_process( + load_records=load_records, + query_records=query_records, + num_records=num_records, + sobject=sobject, + weights=weights, + fields=["Name", "Age", "Occupation"], + threshold=None, ) - # selected_records, error_message = select_operator.select_post_process( - # load_records, query_records, num_records, sobject - # ) - assert error_message is None assert len(selected_records) == num_records assert all(record["success"] for record in selected_records) assert all(record["created"] is False for record in selected_records) + x = [record["id"] for record in selected_records] + print(x) assert all(record["id"] in ["002"] for record in selected_records) @@ -354,8 +390,14 @@ def test_similarity_post_process_with_no_records(): records = [] num_records = 2 sobject = "Lead" - selected_records, error_message = select_operator.select_post_process( - None, records, num_records, sobject, weights=[1, 1, 1] + selected_records, _, error_message = select_operator.select_post_process( + load_records=None, + query_records=records, + num_records=num_records, + sobject=sobject, + weights=[1, 1, 1], + fields=[], + threshold=None, ) assert selected_records == [] @@ -369,7 +411,7 @@ def test_calculate_levenshtein_distance_basic(): # Expected distance based on simple Levenshtein distances # Levenshtein("hello", "hullo") = 1, Levenshtein("world", "word") = 1 - expected_distance = (1 * 1.0 + 1 * 1.0) / 2 # Averaged over two fields + expected_distance = (1 / 5 * 1.0 + 1 / 5 * 1.0) / 2 # Averaged over two fields result = calculate_levenshtein_distance(record1, record2, weights) assert result == pytest.approx( @@ -383,7 +425,7 @@ def test_calculate_levenshtein_distance_basic(): # Expected distance based on simple Levenshtein distances # Levenshtein("hello", "hullo") = 1, Levenshtein("", "") = 0 - expected_distance = (1 * 1.0 + 0 * 1.0) / 2 # Averaged over two fields + expected_distance = (1 / 5 * 1.0 + 0 * 1.0) / 2 # Averaged over two fields result = calculate_levenshtein_distance(record1, record2, weights) assert result == pytest.approx( @@ -397,7 +439,9 @@ def test_calculate_levenshtein_distance_basic(): # Expected distance based on simple Levenshtein distances # Levenshtein("hello", "hullo") = 1, Levenshtein("world", "") = 5 - expected_distance = (1 * 1.0 + 5 * 0.05 * 1.0) / 2 # Averaged over two fields + expected_distance = ( + 1 / 5 * 1.0 + 5 / 5 * 0.05 * 1.0 + ) / 2 # Averaged over two fields result = calculate_levenshtein_distance(record1, record2, weights) assert result == pytest.approx( @@ -411,7 +455,9 @@ def test_calculate_levenshtein_distance_weighted(): weights = [2.0, 0.5] # Levenshtein("cat", "bat") = 1, Levenshtein("dog", "fog") = 1 - expected_distance = (1 * 2.0 + 1 * 0.5) / 2 # Weighted average over two fields + expected_distance = ( + 1 / 3 * 2.0 + 1 / 3 * 0.5 + ) / 2.5 # Weighted average over two fields result = calculate_levenshtein_distance(record1, record2, weights) assert result == pytest.approx( @@ -571,7 +617,13 @@ def test_annoy_post_process(): query_records = [["q1", "Alice", "Engineer"], ["q2", "Charlie", "Artist"]] weights = [1.0, 1.0, 1.0] # Example weights - closest_records, error = annoy_post_process(load_records, query_records, weights) + closest_records, insert_records = annoy_post_process( + load_records=load_records, + query_records=query_records, + similarity_weights=weights, + all_fields=["Name", "Occupation"], + threshold=None, + ) # Assert the closest records assert ( @@ -582,7 +634,97 @@ def test_annoy_post_process(): ) # The first query record should match the first load record # No errors expected - assert error is None + assert not insert_records + + +def test_annoy_post_process__insert_records(): + # Test data + load_records = [["Alice", "Engineer"], ["Bob", "Doctor"]] + query_records = [["q1", "Alice", "Engineer"], ["q2", "Charlie", "Artist"]] + weights = [1.0, 1.0, 1.0] # Example weights + threshold = 0.3 + + closest_records, insert_records = annoy_post_process( + load_records=load_records, + query_records=query_records, + similarity_weights=weights, + all_fields=["Name", "Occupation"], + threshold=threshold, + ) + + # Assert the closest records + assert len(closest_records) == 2 # We expect two results (one record and one None) + assert ( + closest_records[0]["id"] == "q1" + ) # The first query record should match the first load record + assert closest_records[1] is None # The second query record should be None + assert insert_records[0] == [ + "Bob", + "Doctor", + ] # The first insert record should match the second load record + + +def test_annoy_post_process__no_query_records(): + # Test data + load_records = [["Alice", "Engineer"], ["Bob", "Doctor"]] + query_records = [] + weights = [1.0, 1.0, 1.0] # Example weights + threshold = 0.3 + + closest_records, insert_records = annoy_post_process( + load_records=load_records, + query_records=query_records, + similarity_weights=weights, + all_fields=["Name", "Occupation"], + threshold=threshold, + ) + + # Assert the closest records + assert len(closest_records) == 2 # We expect two results (both None) + assert all(rec is None for rec in closest_records) # Both should be None + assert insert_records[0] == [ + "Alice", + "Engineer", + ] # The first insert record should match the second load record + assert insert_records[1] == [ + "Bob", + "Doctor", + ] # The first insert record should match the second load record + + +def test_annoy_post_process__insert_records_with_polymorphic_fields(): + # Test data + load_records = [ + ["Alice", "Engineer", "Alice_Contact", "abcd1234"], + ["Bob", "Doctor", "Bob_Contact", "qwer1234"], + ] + query_records = [ + ["q1", "Alice", "Engineer", "Alice_Contact"], + ["q2", "Charlie", "Artist", "Charlie_Contact"], + ] + weights = [1.0, 1.0, 1.0, 1.0] # Example weights + threshold = 0.3 + all_fields = ["Name", "Occupation", "Contact.Name", "ContactId"] + + closest_records, insert_records = annoy_post_process( + load_records=load_records, + query_records=query_records, + similarity_weights=weights, + all_fields=all_fields, + threshold=threshold, + ) + + # Assert the closest records + assert len(closest_records) == 2 # We expect two results (one record and one None) + assert ( + closest_records[0]["id"] == "q1" + ) # The first query record should match the first load record + assert closest_records[1] is None # The second query record should be None + assert insert_records[0] == [ + "Bob", + "Doctor", + "qwer1234", + ] # The first insert record should match the second load record def test_single_record_match_annoy_post_process(): @@ -591,12 +733,18 @@ def test_single_record_match_annoy_post_process(): query_records = [["q1", "Alice", "Engineer"]] weights = [1.0, 1.0, 1.0] - closest_records, error = annoy_post_process(load_records, query_records, weights) + closest_records, insert_records = annoy_post_process( + load_records=load_records, + query_records=query_records, + similarity_weights=weights, + all_fields=["Name", "Occupation"], + threshold=None, + ) # Both the load records should be matched with the only query record we have assert len(closest_records) == 2 assert closest_records[0]["id"] == "q1" - assert error is None + assert not insert_records @pytest.mark.parametrize( @@ -653,3 +801,206 @@ def test_add_limit_offset_to_user_filter( ): result = add_limit_offset_to_user_filter(filter_clause, limit_clause, offset_clause) assert result.strip() == expected.strip() + + +def test_reorder_records_basic_reordering(): + records = [ + ["Alice", 30, "Engineer"], + ["Bob", 25, "Designer"], + ] + original_fields = ["name", "age", "job"] + new_fields = ["job", "name"] + + expected = [ + ["Engineer", "Alice"], + ["Designer", "Bob"], + ] + result = reorder_records(records, original_fields, new_fields) + assert result == expected + + +def test_reorder_records_partial_fields(): + records = [ + ["Alice", 30, "Engineer"], + ["Bob", 25, "Designer"], + ] + original_fields = ["name", "age", "job"] + new_fields = ["age"] + + expected = [ + [30], + [25], + ] + result = reorder_records(records, original_fields, new_fields) + assert result == expected + + +def test_reorder_records_missing_fields_in_new_fields(): + records = [ + ["Alice", 30, "Engineer"], + ["Bob", 25, "Designer"], + ] + original_fields = ["name", "age", "job"] + new_fields = ["nonexistent", "job"] + + expected = [ + ["Engineer"], + ["Designer"], + ] + result = reorder_records(records, original_fields, new_fields) + assert result == expected + + +def test_reorder_records_empty_records(): + records = [] + original_fields = ["name", "age", "job"] + new_fields = ["job", "name"] + + expected = [] + result = reorder_records(records, original_fields, new_fields) + assert result == expected + + +def test_reorder_records_empty_new_fields(): + records = [ + ["Alice", 30, "Engineer"], + ["Bob", 25, "Designer"], + ] + original_fields = ["name", "age", "job"] + new_fields = [] + + expected = [ + [], + [], + ] + result = reorder_records(records, original_fields, new_fields) + assert result == expected + + +def test_reorder_records_empty_original_fields(): + records = [ + ["Alice", 30, "Engineer"], + ["Bob", 25, "Designer"], + ] + original_fields = [] + new_fields = ["job", "name"] + + with pytest.raises(KeyError): + reorder_records(records, original_fields, new_fields) + + +def test_reorder_records_no_common_fields(): + records = [ + ["Alice", 30, "Engineer"], + ["Bob", 25, "Designer"], + ] + original_fields = ["name", "age", "job"] + new_fields = ["nonexistent_field"] + + expected = [ + [], + [], + ] + result = reorder_records(records, original_fields, new_fields) + assert result == expected + + +def test_reorder_records_duplicate_fields_in_new_fields(): + records = [ + ["Alice", 30, "Engineer"], + ["Bob", 25, "Designer"], + ] + original_fields = ["name", "age", "job"] + new_fields = ["job", "job", "name"] + + expected = [ + ["Engineer", "Engineer", "Alice"], + ["Designer", "Designer", "Bob"], + ] + result = reorder_records(records, original_fields, new_fields) + assert result == expected + + +def test_reorder_records_all_fields_in_order(): + records = [ + ["Alice", 30, "Engineer"], + ["Bob", 25, "Designer"], + ] + original_fields = ["name", "age", "job"] + new_fields = ["name", "age", "job"] + + expected = [ + ["Alice", 30, "Engineer"], + ["Bob", 25, "Designer"], + ] + result = reorder_records(records, original_fields, new_fields) + assert result == expected + + +def test_split_and_filter_fields_basic_case(): + fields = [ + "Account.Name", + "Account.Industry", + "Contact.Name", + "AccountId", + "ContactId", + "CreatedDate", + ] + load_fields, select_fields = split_and_filter_fields(fields) + assert load_fields == ["AccountId", "ContactId", "CreatedDate"] + assert select_fields == [ + "Account.Name", + "Account.Industry", + "Contact.Name", + "CreatedDate", + ] + + +def test_split_and_filter_fields_all_non_lookup_fields(): + fields = ["Name", "CreatedDate"] + load_fields, select_fields = split_and_filter_fields(fields) + assert load_fields == ["Name", "CreatedDate"] + assert select_fields == fields + + +def test_split_and_filter_fields_all_lookup_fields(): + fields = ["Account.Name", "Account.Industry", "Contact.Name"] + load_fields, select_fields = split_and_filter_fields(fields) + assert load_fields == [] + assert select_fields == fields + + +def test_split_and_filter_fields_empty_fields(): + fields = [] + load_fields, select_fields = split_and_filter_fields(fields) + assert load_fields == [] + assert select_fields == [] + + +def test_split_and_filter_fields_single_non_lookup_field(): + fields = ["Id"] + load_fields, select_fields = split_and_filter_fields(fields) + assert load_fields == ["Id"] + assert select_fields == ["Id"] + + +def test_split_and_filter_fields_single_lookup_field(): + fields = ["Account.Name"] + load_fields, select_fields = split_and_filter_fields(fields) + assert load_fields == [] + assert select_fields == ["Account.Name"] + + +def test_split_and_filter_fields_multiple_unique_lookups(): + fields = [ + "Account.Name", + "Account.Industry", + "Contact.Email", + "Contact.Phone", + "Id", + ] + load_fields, select_fields = split_and_filter_fields(fields) + assert load_fields == ["Id"] + assert ( + select_fields == fields + ) # No filtering applied since all components are unique diff --git a/cumulusci/tasks/bulkdata/tests/test_step.py b/cumulusci/tasks/bulkdata/tests/test_step.py index bd059b9bbf..3d797df8bf 100644 --- a/cumulusci/tasks/bulkdata/tests/test_step.py +++ b/cumulusci/tasks/bulkdata/tests/test_step.py @@ -909,6 +909,7 @@ def test_select_records_similarity_strategy_parent_level_records__polymorphic( "Who.Contact.Email", "Who.Lead.Name", "Who.Lead.Company", + "WhoId", ], selection_strategy=SelectStrategy.SIMILARITY, ) @@ -921,15 +922,22 @@ def test_select_records_similarity_strategy_parent_level_records__polymorphic( download_mock.return_value = io.StringIO( """[ - {"Id": "003000000000001", "Subject": "Sample Event 1", "Who":{ "attributes": {"type": "Contact"}, "Name": "Sample Contact", "Email": "contact@example.com"}}, - { "Id": "003000000000002", "Subject": "Sample Event 2", "Who":{ "attributes": {"type": "Lead"}, "Name": "Sample Lead", "Company": "Salesforce"}} + {"Id": "003000000000001", "Subject": "Sample Event 1", "Who":{ "attributes": {"type": "Contact"}, "Id": "abcd1234", "Name": "Sample Contact", "Email": "contact@example.com"}}, + { "Id": "003000000000002", "Subject": "Sample Event 2", "Who":{ "attributes": {"type": "Lead"}, "Id": "qwer1234", "Name": "Sample Lead", "Company": "Salesforce"}} ]""" ) records = iter( [ - ["Sample Event 1", "Sample Contact", "contact@example.com", "", ""], - ["Sample Event 2", "", "", "Sample Lead", "Salesforce"], + [ + "Sample Event 1", + "Sample Contact", + "contact@example.com", + "", + "", + "lkjh1234", + ], + ["Sample Event 2", "", "", "Sample Lead", "Salesforce", "poiu1234"], ] ) step.start() @@ -960,7 +968,7 @@ def test_select_records_similarity_strategy_parent_level_records__non_polymorphi operation=DataOperationType.QUERY, api_options={"batch_size": 10}, context=context, - fields=["Name", "Account.Name", "Account.AccountNumber"], + fields=["Name", "Account.Name", "Account.AccountNumber", "AccountId"], selection_strategy=SelectStrategy.SIMILARITY, ) @@ -972,15 +980,15 @@ def test_select_records_similarity_strategy_parent_level_records__non_polymorphi download_mock.return_value = io.StringIO( """[ - {"Id": "003000000000001", "Name": "Sample Contact 1", "Account":{ "attributes": {"type": "Account"}, "Name": "Sample Account", "AccountNumber": 123456}}, + {"Id": "003000000000001", "Name": "Sample Contact 1", "Account":{ "attributes": {"type": "Account"}, "Id": "abcd1234", "Name": "Sample Account", "AccountNumber": 123456}}, { "Id": "003000000000002", "Subject": "Sample Contact 2", "Account": null} ]""" ) records = iter( [ - ["Sample Contact 3", "Sample Account", "123456"], - ["Sample Contact 4", "", ""], + ["Sample Contact 3", "Sample Account", "123456", "poiu1234"], + ["Sample Contact 4", "", "", ""], ] ) step.start() @@ -1009,7 +1017,13 @@ def test_select_records_similarity_strategy_priority_fields(self, download_mock) operation=DataOperationType.QUERY, api_options={"batch_size": 10}, context=context, - fields=["Name", "Email", "Account.Name", "Account.AccountNumber"], + fields=[ + "Name", + "Email", + "Account.Name", + "Account.AccountNumber", + "AccountId", + ], selection_strategy=SelectStrategy.SIMILARITY, selection_priority_fields={"Name": "Name", "Email": "Email"}, ) @@ -1019,7 +1033,13 @@ def test_select_records_similarity_strategy_priority_fields(self, download_mock) operation=DataOperationType.QUERY, api_options={"batch_size": 10}, context=context, - fields=["Name", "Email", "Account.Name", "Account.AccountNumber"], + fields=[ + "Name", + "Email", + "Account.Name", + "Account.AccountNumber", + "AccountId", + ], selection_strategy=SelectStrategy.SIMILARITY, selection_priority_fields={ "Account.Name": "Account.Name", @@ -1044,6 +1064,7 @@ def test_select_records_similarity_strategy_priority_fields(self, download_mock) "Email": "bob@yahoo.org", "Account": { "attributes": {"type": "Account"}, + "Id": "abcd1234", "Name": "Jawad TP", "AccountNumber": 567890, }, @@ -1054,6 +1075,7 @@ def test_select_records_similarity_strategy_priority_fields(self, download_mock) "Email": "tom@exmaple.com", "Account": { "attributes": {"type": "Account"}, + "Id": "qwer1234", "Name": "Aditya B", "AccountNumber": 123456, }, @@ -1067,7 +1089,7 @@ def test_select_records_similarity_strategy_priority_fields(self, download_mock) records = iter( [ - ["Bob The Builder", "bob@yahoo.org", "Aditya B", "123456"], + ["Bob The Builder", "bob@yahoo.org", "Aditya B", "123456", "poiu1234"], ] ) records_1, records_2 = tee(records) @@ -2041,6 +2063,7 @@ def test_select_records_similarity_strategy_parent_level_records__polymorphic(se "Who.Contact.Email", "Who.Lead.Name", "Who.Lead.Company", + "WhoId", ], selection_strategy=SelectStrategy.SIMILARITY, ) @@ -2054,6 +2077,7 @@ def test_select_records_similarity_strategy_parent_level_records__polymorphic(se "Subject": "Sample Event 1", "Who": { "attributes": {"type": "Contact"}, + "Id": "abcd1234", "Name": "Sample Contact", "Email": "contact@example.com", }, @@ -2063,6 +2087,7 @@ def test_select_records_similarity_strategy_parent_level_records__polymorphic(se "Subject": "Sample Event 2", "Who": { "attributes": {"type": "Lead"}, + "Id": "qwer1234", "Name": "Sample Lead", "Company": "Salesforce", }, @@ -2075,8 +2100,15 @@ def test_select_records_similarity_strategy_parent_level_records__polymorphic(se records = iter( [ - ["Sample Event 1", "Sample Contact", "contact@example.com", "", ""], - ["Sample Event 2", "", "", "Sample Lead", "Salesforce"], + [ + "Sample Event 1", + "Sample Contact", + "contact@example.com", + "", + "", + "poiu1234", + ], + ["Sample Event 2", "", "", "Sample Lead", "Salesforce", "lkjh1234"], ] ) step.start() @@ -2132,7 +2164,7 @@ def test_select_records_similarity_strategy_parent_level_records__non_polymorphi operation=DataOperationType.QUERY, api_options={"batch_size": 10}, context=task, - fields=["Name", "Account.Name", "Account.AccountNumber"], + fields=["Name", "Account.Name", "Account.AccountNumber", "AccountId"], selection_strategy=SelectStrategy.SIMILARITY, ) @@ -2145,6 +2177,7 @@ def test_select_records_similarity_strategy_parent_level_records__non_polymorphi "Name": "Sample Contact 1", "Account": { "attributes": {"type": "Account"}, + "Id": "abcd1234", "Name": "Sample Account", "AccountNumber": 123456, }, @@ -2162,8 +2195,8 @@ def test_select_records_similarity_strategy_parent_level_records__non_polymorphi records = iter( [ - ["Sample Contact 3", "Sample Account", "123456"], - ["Sample Contact 4", "", ""], + ["Sample Contact 3", "Sample Account", "123456", "poiu1234"], + ["Sample Contact 4", "", "", ""], ] ) step.start() @@ -2229,7 +2262,13 @@ def test_select_records_similarity_strategy_priority_fields(self): operation=DataOperationType.QUERY, api_options={"batch_size": 10}, context=task_1, - fields=["Name", "Email", "Account.Name", "Account.AccountNumber"], + fields=[ + "Name", + "Email", + "Account.Name", + "Account.AccountNumber", + "AccountId", + ], selection_strategy=SelectStrategy.SIMILARITY, selection_priority_fields={"Name": "Name", "Email": "Email"}, ) @@ -2239,7 +2278,13 @@ def test_select_records_similarity_strategy_priority_fields(self): operation=DataOperationType.QUERY, api_options={"batch_size": 10}, context=task_2, - fields=["Name", "Email", "Account.Name", "Account.AccountNumber"], + fields=[ + "Name", + "Email", + "Account.Name", + "Account.AccountNumber", + "AccountId", + ], selection_strategy=SelectStrategy.SIMILARITY, selection_priority_fields={ "Account.Name": "Account.Name", @@ -2256,6 +2301,7 @@ def test_select_records_similarity_strategy_priority_fields(self): "Email": "bob@yahoo.org", "Account": { "attributes": {"type": "Account"}, + "Id": "abcd1234", "Name": "Jawad TP", "AccountNumber": 567890, }, @@ -2266,6 +2312,7 @@ def test_select_records_similarity_strategy_priority_fields(self): "Email": "tom@exmaple.com", "Account": { "attributes": {"type": "Account"}, + "Id": "qwer1234", "Name": "Aditya B", "AccountNumber": 123456, }, @@ -2280,7 +2327,7 @@ def test_select_records_similarity_strategy_priority_fields(self): records = iter( [ - ["Bob The Builder", "bob@yahoo.org", "Aditya B", "123456"], + ["Bob The Builder", "bob@yahoo.org", "Aditya B", "123456", "poiu1234"], ] ) records_1, records_2 = tee(records) @@ -2803,6 +2850,7 @@ def test_get_dml_operation(self, rest_dml, bulk_dml): selection_filter=None, selection_priority_fields=None, content_type=None, + threshold=None, ) op = get_dml_operation( @@ -2828,6 +2876,7 @@ def test_get_dml_operation(self, rest_dml, bulk_dml): selection_filter=None, selection_priority_fields=None, content_type=None, + threshold=None, ) @mock.patch("cumulusci.tasks.bulkdata.step.BulkApiDmlOperation") From f4fb69603da4285fe3d9c6be6a991b2bf664d7f7 Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Tue, 19 Nov 2024 19:01:08 +0530 Subject: [PATCH 29/34] Add new imports to requirements.txt --- .pre-commit-config.yaml | 6 +- cumulusci/tasks/bulkdata/tests/test_step.py | 532 ++++++++++++++++++++ pyproject.toml | 4 + requirements/dev.txt | 36 +- requirements/prod.txt | 34 +- 5 files changed, 590 insertions(+), 22 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b1a928eafd..62af507949 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ default_language_version: python: python3 repos: - repo: https://github.com/ambv/black - rev: 24.10.0 + rev: 22.3.0 hooks: - id: black - repo: https://github.com/pre-commit/pre-commit-hooks @@ -18,12 +18,12 @@ repos: - id: rst-linter exclude: "docs" - repo: https://github.com/pycqa/isort - rev: 5.13.2 + rev: 5.12.0 hooks: - id: isort args: ["--profile", "black", "--filter-files"] - repo: https://github.com/pre-commit/mirrors-prettier - rev: v4.0.0-alpha.8 + rev: v2.5.1 hooks: - id: prettier - repo: local diff --git a/cumulusci/tasks/bulkdata/tests/test_step.py b/cumulusci/tasks/bulkdata/tests/test_step.py index 3d797df8bf..e94e91f226 100644 --- a/cumulusci/tasks/bulkdata/tests/test_step.py +++ b/cumulusci/tasks/bulkdata/tests/test_step.py @@ -1121,6 +1121,304 @@ def test_select_records_similarity_strategy_priority_fields(self, download_mock) id="003000000000002", success=True, error="", created=False ) + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_process_insert_records_success(self, download_mock): + # Mock context and insert records + context = mock.Mock() + insert_records = iter([["John", "Doe"], ["Jane", "Smith"]]) + selected_records = [None, None] + + # Mock insert fields splitting + insert_fields = ["FirstName", "LastName"] + with mock.patch( + "cumulusci.tasks.bulkdata.step.split_and_filter_fields", + return_value=(insert_fields, None), + ) as split_mock: + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10}, + context=context, + fields=["FirstName", "LastName"], + ) + + # Mock Bulk API + step.bulk.endpoint = "https://test" + step.bulk.create_insert_job.return_value = "JOB" + step.bulk.get_insert_batch_result_ids.return_value = ["RESULT"] + + # Mock the downloaded CSV content with successful results + download_mock.return_value = io.StringIO( + "Id,Success,Created\n0011k00003E8xAaAAI,true,true\n0011k00003E8xAbAAJ,true,true\n" + ) + + # Mock sub-operation for BulkApiDmlOperation + insert_step = mock.Mock(spec=BulkApiDmlOperation) + insert_step.start = mock.Mock() + insert_step.load_records = mock.Mock() + insert_step.end = mock.Mock() + insert_step.batch_ids = ["BATCH1"] + insert_step.bulk = mock.Mock() + insert_step.bulk.endpoint = "https://test" + insert_step.job_id = "JOB" + + with mock.patch( + "cumulusci.tasks.bulkdata.step.BulkApiDmlOperation", + return_value=insert_step, + ): + step._process_insert_records(insert_records, selected_records) + + # Assertions for split fields and sub-operation + split_mock.assert_called_once_with(fields=["FirstName", "LastName"]) + insert_step.start.assert_called_once() + insert_step.load_records.assert_called_once_with(insert_records) + insert_step.end.assert_called_once() + + # Validate the download file interactions + download_mock.assert_called_once_with( + "https://test/job/JOB/batch/BATCH1/result", insert_step.bulk + ) + + # Validate that selected_records is updated with insert results + assert selected_records == [ + {"id": "0011k00003E8xAaAAI", "success": True, "created": True}, + {"id": "0011k00003E8xAbAAJ", "success": True, "created": True}, + ] + + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_process_insert_records_failure(self, download_mock): + # Mock context and insert records + context = mock.Mock() + insert_records = iter([["John", "Doe"], ["Jane", "Smith"]]) + selected_records = [None, None] + + # Mock insert fields splitting + insert_fields = ["FirstName", "LastName"] + with mock.patch( + "cumulusci.tasks.bulkdata.step.split_and_filter_fields", + return_value=(insert_fields, None), + ): + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10}, + context=context, + fields=["FirstName", "LastName"], + ) + + # Mock failure during results download + download_mock.side_effect = Exception("Download failed") + + # Mock sub-operation for BulkApiDmlOperation + insert_step = mock.Mock(spec=BulkApiDmlOperation) + insert_step.start = mock.Mock() + insert_step.load_records = mock.Mock() + insert_step.end = mock.Mock() + insert_step.batch_ids = ["BATCH1"] + insert_step.bulk = mock.Mock() + insert_step.bulk.endpoint = "https://test" + insert_step.job_id = "JOB" + + with mock.patch( + "cumulusci.tasks.bulkdata.step.BulkApiDmlOperation", + return_value=insert_step, + ): + with pytest.raises(BulkDataException) as excinfo: + step._process_insert_records(insert_records, selected_records) + + # Validate that the exception is raised with the correct message + assert "Failed to download results for batch BATCH1" in str( + excinfo.value + ) + + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_similarity_strategy__insert_records(self, download_mock): + # Set up mock context and BulkApiDmlOperation + context = mock.Mock() + # Add step with threshold + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=context, + fields=["Name", "Email"], + selection_strategy=SelectStrategy.SIMILARITY, + threshold=0.3, + ) + + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + # Mock the downloaded CSV content with a single record + select_results = io.StringIO( + """[{"Id":"003000000000001", "Name":"Jawad", "Email":"mjawadtp@example.com"}]""" + ) + insert_results = io.StringIO( + "Id,Success,Created\n003000000000002,true,true\n003000000000003,true,true\n" + ) + download_mock.side_effect = [select_results, insert_results] + + # Mock the _wait_for_job method to simulate a successful job + step._wait_for_job = mock.Mock() + step._wait_for_job.return_value = DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 0, 0 + ) + + # Prepare input records + records = iter( + [ + ["Jawad", "mjawadtp@example.com"], + ["Aditya", "aditya@example.com"], + ["Tom", "cruise@example.com"], + ] + ) + + # Mock sub-operation for BulkApiDmlOperation + insert_step = mock.Mock(spec=BulkApiDmlOperation) + insert_step.start = mock.Mock() + insert_step.load_records = mock.Mock() + insert_step.end = mock.Mock() + insert_step.batch_ids = ["BATCH1"] + insert_step.bulk = mock.Mock() + insert_step.bulk.endpoint = "https://test" + insert_step.job_id = "JOB" + + with mock.patch( + "cumulusci.tasks.bulkdata.step.BulkApiDmlOperation", + return_value=insert_step, + ): + # Execute the select_records operation + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000002", success=True, error="", created=True + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000003", success=True, error="", created=True + ) + ) + == 1 + ) + + @mock.patch("cumulusci.tasks.bulkdata.step.download_file") + def test_select_records_similarity_strategy__insert_records__no_select_records( + self, download_mock + ): + # Set up mock context and BulkApiDmlOperation + context = mock.Mock() + # Add step with threshold + step = BulkApiDmlOperation( + sobject="Contact", + operation=DataOperationType.QUERY, + api_options={"batch_size": 10, "update_key": "LastName"}, + context=context, + fields=["Name", "Email"], + selection_strategy=SelectStrategy.SIMILARITY, + threshold=0.3, + ) + + # Mock Bulk API responses + step.bulk.endpoint = "https://test" + step.bulk.create_query_job.return_value = "JOB" + step.bulk.query.return_value = "BATCH" + step.bulk.get_query_batch_result_ids.return_value = ["RESULT"] + + # Mock the downloaded CSV content with a single record + select_results = io.StringIO("""[]""") + insert_results = io.StringIO( + "Id,Success,Created\n003000000000001,true,true\n003000000000002,true,true\n003000000000003,true,true\n" + ) + download_mock.side_effect = [select_results, insert_results] + + # Mock the _wait_for_job method to simulate a successful job + step._wait_for_job = mock.Mock() + step._wait_for_job.return_value = DataOperationJobResult( + DataOperationStatus.SUCCESS, [], 0, 0 + ) + + # Prepare input records + records = iter( + [ + ["Jawad", "mjawadtp@example.com"], + ["Aditya", "aditya@example.com"], + ["Tom", "cruise@example.com"], + ] + ) + + # Mock sub-operation for BulkApiDmlOperation + insert_step = mock.Mock(spec=BulkApiDmlOperation) + insert_step.start = mock.Mock() + insert_step.load_records = mock.Mock() + insert_step.end = mock.Mock() + insert_step.batch_ids = ["BATCH1"] + insert_step.bulk = mock.Mock() + insert_step.bulk.endpoint = "https://test" + insert_step.job_id = "JOB" + + with mock.patch( + "cumulusci.tasks.bulkdata.step.BulkApiDmlOperation", + return_value=insert_step, + ): + # Execute the select_records operation + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=True + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000002", success=True, error="", created=True + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000003", success=True, error="", created=True + ) + ) + == 1 + ) + def test_batch(self): context = mock.Mock() @@ -2359,6 +2657,240 @@ def test_select_records_similarity_strategy_priority_fields(self): id="003000000000002", success=True, error="", created=False ) + @responses.activate + def test_process_insert_records_success(self): + # Mock describe calls + mock_describe_calls() + + # Create a task and mock project config + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + # Prepare inputs + insert_records = iter( + [ + ["Jawad", "mjawadtp@example.com"], + ["Aditya", "aditya@example.com"], + ["Tom Cruise", "tomcruise@example.com"], + ] + ) + selected_records = [None, None, None] + + # Mock fields splitting + insert_fields = ["Name", "Email"] + with mock.patch( + "cumulusci.tasks.bulkdata.step.split_and_filter_fields", + return_value=(insert_fields, None), + ) as split_mock: + # Mock the instance of RestApiDmlOperation + mock_rest_api_dml_operation = mock.create_autospec( + RestApiDmlOperation, instance=True + ) + mock_rest_api_dml_operation.results = [ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + {"id": "003000000000003", "success": True}, + ] + + with mock.patch( + "cumulusci.tasks.bulkdata.step.RestApiDmlOperation", + return_value=mock_rest_api_dml_operation, + ): + # Call the function + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.INSERT, + api_options={"batch_size": 10}, + context=task, + fields=["Name", "Email"], + ) + step._process_insert_records(insert_records, selected_records) + + # Assert the mocked splitting is called + split_mock.assert_called_once_with(fields=["Name", "Email"]) + + # Validate that `selected_records` is updated correctly + assert selected_records == [ + {"id": "003000000000001", "success": True}, + {"id": "003000000000002", "success": True}, + {"id": "003000000000003", "success": True}, + ] + + # Validate the operation sequence + mock_rest_api_dml_operation.start.assert_called_once() + mock_rest_api_dml_operation.load_records.assert_called_once_with( + insert_records + ) + mock_rest_api_dml_operation.end.assert_called_once() + + @responses.activate + def test_process_insert_records_failure(self): + # Mock describe calls + mock_describe_calls() + + # Create a task and mock project config + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + # Prepare inputs + insert_records = iter( + [ + ["Jawad", "mjawadtp@example.com"], + ["Aditya", "aditya@example.com"], + ] + ) + selected_records = [None, None] + + # Mock fields splitting + insert_fields = ["Name", "Email"] + with mock.patch( + "cumulusci.tasks.bulkdata.step.split_and_filter_fields", + return_value=(insert_fields, None), + ) as split_mock: + # Mock the instance of RestApiDmlOperation + mock_rest_api_dml_operation = mock.create_autospec( + RestApiDmlOperation, instance=True + ) + mock_rest_api_dml_operation.results = ( + None # Simulate no results due to an exception + ) + + # Simulate an exception during processing results + mock_rest_api_dml_operation.load_records.side_effect = BulkDataException( + "Simulated failure" + ) + + with mock.patch( + "cumulusci.tasks.bulkdata.step.RestApiDmlOperation", + return_value=mock_rest_api_dml_operation, + ): + # Call the function and verify that it raises the expected exception + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.INSERT, + api_options={"batch_size": 10}, + context=task, + fields=["Name", "Email"], + ) + with pytest.raises(BulkDataException): + step._process_insert_records(insert_records, selected_records) + + # Assert the mocked splitting is called + split_mock.assert_called_once_with(fields=["Name", "Email"]) + + # Validate that `selected_records` remains unchanged + assert selected_records == [None, None] + + # Validate the operation sequence + mock_rest_api_dml_operation.start.assert_called_once() + mock_rest_api_dml_operation.load_records.assert_called_once_with( + insert_records + ) + mock_rest_api_dml_operation.end.assert_not_called() + + @responses.activate + def test_select_records_similarity_strategy__insert_records(self): + mock_describe_calls() + task = _make_task( + LoadData, + { + "options": { + "database_url": "sqlite:///test.db", + "mapping": "mapping.yml", + } + }, + ) + task.project_config.project__package__api_version = CURRENT_SF_API_VERSION + task._init_task() + + # Create step with threshold + step = RestApiDmlOperation( + sobject="Contact", + operation=DataOperationType.UPSERT, + api_options={"batch_size": 10}, + context=task, + fields=["Name", "Email"], + selection_strategy=SelectStrategy.SIMILARITY, + threshold=0.3, + ) + + results_select_call = { + "records": [ + { + "Id": "003000000000001", + "Name": "Jawad", + "Email": "mjawadtp@example.com", + }, + ], + "done": True, + } + + results_insert_call = [ + {"id": "003000000000002", "success": True, "created": True}, + {"id": "003000000000003", "success": True, "created": True}, + ] + + step.sf.restful = mock.Mock( + side_effect=[results_select_call, results_insert_call] + ) + records = iter( + [ + ["Jawad", "mjawadtp@example.com"], + ["Aditya", "aditya@example.com"], + ["Tom Cruise", "tom@example.com"], + ] + ) + step.start() + step.select_records(records) + step.end() + + # Get the results and assert their properties + results = list(step.get_results()) + assert len(results) == 3 # Expect 3 results (matching the input records count) + # Assert that all results have the expected ID, success, and created values + assert ( + results.count( + DataOperationResult( + id="003000000000001", success=True, error="", created=False + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000002", success=True, error="", created=True + ) + ) + == 1 + ) + assert ( + results.count( + DataOperationResult( + id="003000000000003", success=True, error="", created=True + ) + ) + == 1 + ) + @responses.activate def test_insert_dml_operation__boolean_conversion(self): mock_describe_calls() diff --git a/pyproject.toml b/pyproject.toml index b04f0b66c6..ae56e71afe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] dependencies = [ + "annoy", "click", "cryptography", "python-dateutil", @@ -38,6 +39,8 @@ dependencies = [ "lxml", "markdown-it-py==2.2.0", # resolve dependency conflict between prod/dev "MarkupSafe", + "numpy", + "pandas", "psutil", "pydantic<2", "PyJWT", @@ -55,6 +58,7 @@ dependencies = [ "sarge", "selenium<4", "simple-salesforce==1.11.4", + "scikit-learn", "snowfakery", "SQLAlchemy<2", "xmltodict", diff --git a/requirements/dev.txt b/requirements/dev.txt index 9a7d8b1fac..a90bf89cac 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -1,11 +1,13 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.9 # by the following command: # # pip-compile --all-extras --output-file=requirements/dev.txt pyproject.toml # alabaster==0.7.13 # via sphinx +annoy==1.17.3 + # via cumulusci (pyproject.toml) appdirs==1.4.4 # via fs attrs==24.2.0 @@ -53,7 +55,6 @@ cryptography==43.0.1 # authlib # cumulusci (pyproject.toml) # pyjwt - # secretstorage defusedxml==0.7.1 # via cumulusci (pyproject.toml) distlib==0.3.8 @@ -89,9 +90,7 @@ furo==2023.3.27 github3-py==4.0.1 # via cumulusci (pyproject.toml) greenlet==3.0.3 - # via - # snowfakery - # sqlalchemy + # via snowfakery gvgen==1.0 # via snowfakery identify==2.6.0 @@ -125,6 +124,8 @@ jinja2==3.1.3 # myst-parser # snowfakery # sphinx +joblib==1.4.2 + # via scikit-learn jsonschema==4.23.0 # via cumulusci (pyproject.toml) jsonschema-specifications==2023.12.1 @@ -160,6 +161,12 @@ natsort==8.4.0 # via robotframework-pabot nodeenv==1.9.1 # via pre-commit +numpy==2.0.2 + # via + # cumulusci (pyproject.toml) + # pandas + # scikit-learn + # scipy packaging==24.1 # via # black @@ -167,6 +174,8 @@ packaging==24.1 # pytest # sphinx # tox +pandas==2.2.3 + # via cumulusci (pyproject.toml) pathspec==0.12.1 # via black pkgutil-resolve-name==1.3.10 @@ -226,11 +235,12 @@ python-dateutil==2.9.0.post0 # cumulusci (pyproject.toml) # faker # github3-py + # pandas # snowfakery pytz==2024.1 # via - # babel # cumulusci (pyproject.toml) + # pandas pyyaml==6.0.1 # via # cumulusci (pyproject.toml) @@ -289,6 +299,10 @@ sarge==0.1.7.post1 # via cumulusci (pyproject.toml) secretstorage==3.3.3 # via keyring +scikit-learn==1.5.2 + # via cumulusci (pyproject.toml) +scipy==1.13.1 + # via scikit-learn selenium==3.141.0 # via # cumulusci (pyproject.toml) @@ -335,6 +349,8 @@ sqlalchemy==1.4.52 # snowfakery testfixtures==8.3.0 # via cumulusci (pyproject.toml) +threadpoolctl==3.5.0 + # via scikit-learn tomli==2.0.1 # via # black @@ -351,10 +367,10 @@ types-pyyaml==6.0.12.20240808 typing-extensions==4.10.0 # via # black - # faker # pydantic - # rich # snowfakery +tzdata==2024.2 + # via pandas unicodecsv==0.14.1 # via salesforce-bulk uritemplate==4.1.1 @@ -381,9 +397,7 @@ xmltodict==0.13.0 yarl==1.9.11 # via vcrpy zipp==3.20.1 - # via - # importlib-metadata - # importlib-resources + # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/requirements/prod.txt b/requirements/prod.txt index 40ae1621a3..ab8c75581a 100644 --- a/requirements/prod.txt +++ b/requirements/prod.txt @@ -1,9 +1,11 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.9 # by the following command: # # pip-compile --output-file=requirements/prod.txt pyproject.toml # +annoy==1.17.3 + # via cumulusci (pyproject.toml) appdirs==1.4.4 # via fs authlib==1.3.2 @@ -27,7 +29,6 @@ cryptography==43.0.1 # authlib # cumulusci (pyproject.toml) # pyjwt - # secretstorage defusedxml==0.7.1 # via cumulusci (pyproject.toml) docutils==0.16 @@ -47,9 +48,7 @@ fs==2.4.16 github3-py==4.0.1 # via cumulusci (pyproject.toml) greenlet==3.0.3 - # via - # snowfakery - # sqlalchemy + # via snowfakery gvgen==1.0 # via snowfakery idna==3.6 @@ -66,6 +65,8 @@ jinja2==3.1.3 # via # cumulusci (pyproject.toml) # snowfakery +joblib==1.4.2 + # via scikit-learn keyring==23.0.1 # via cumulusci (pyproject.toml) lxml==5.3.0 @@ -83,6 +84,14 @@ mdurl==0.1.2 # via markdown-it-py natsort==8.4.0 # via robotframework-pabot +numpy==2.0.2 + # via + # cumulusci (pyproject.toml) + # pandas + # scikit-learn + # scipy +pandas==2.2.3 + # via cumulusci (pyproject.toml) psutil==6.0.0 # via cumulusci (pyproject.toml) pycparser==2.22 @@ -104,9 +113,12 @@ python-dateutil==2.9.0.post0 # cumulusci (pyproject.toml) # faker # github3-py + # pandas # snowfakery pytz==2024.1 - # via cumulusci (pyproject.toml) + # via + # cumulusci (pyproject.toml) + # pandas pyyaml==6.0.1 # via # cumulusci (pyproject.toml) @@ -149,6 +161,10 @@ sarge==0.1.7.post1 # via cumulusci (pyproject.toml) secretstorage==3.3.3 # via keyring +scikit-learn==1.5.2 + # via cumulusci (pyproject.toml) +scipy==1.13.1 + # via scikit-learn selenium==3.141.0 # via # cumulusci (pyproject.toml) @@ -169,12 +185,14 @@ sqlalchemy==1.4.52 # via # cumulusci (pyproject.toml) # snowfakery +threadpoolctl==3.5.0 + # via scikit-learn typing-extensions==4.10.0 # via - # faker # pydantic - # rich # snowfakery +tzdata==2024.2 + # via pandas unicodecsv==0.14.1 # via salesforce-bulk uritemplate==4.1.1 From 44fff95d9abcb7de9d9fadf097e3aff396b9591d Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Tue, 19 Nov 2024 19:06:50 +0530 Subject: [PATCH 30/34] Add scikit-learn to pyproject.toml --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 6063c73cab..7dec9eedab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ dependencies = [ "rst2ansi>=0.1.5", "salesforce-bulk", "sarge", + "scikit-learn", "selenium<4", "simple-salesforce==1.11.4", "snowfakery>=4.0.0", From db230ceb671500ec7956da4fbcc37ebd72c919cf Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Tue, 19 Nov 2024 19:34:39 +0530 Subject: [PATCH 31/34] Re-lint files --- cumulusci/tasks/bulkdata/mapping_parser.py | 6 ++-- cumulusci/tasks/bulkdata/step.py | 39 ++++++++++++---------- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/cumulusci/tasks/bulkdata/mapping_parser.py b/cumulusci/tasks/bulkdata/mapping_parser.py index e630d564c6..59c7d630a2 100644 --- a/cumulusci/tasks/bulkdata/mapping_parser.py +++ b/cumulusci/tasks/bulkdata/mapping_parser.py @@ -91,9 +91,9 @@ class MappingStep(CCIDictModel): batch_size: int = None oid_as_pk: bool = False # this one should be discussed and probably deprecated record_type: Optional[str] = None # should be discussed and probably deprecated - bulk_mode: Optional[Literal["Serial", "Parallel"]] = ( - None # default should come from task options - ) + bulk_mode: Optional[ + Literal["Serial", "Parallel"] + ] = None # default should come from task options anchor_date: Optional[Union[str, date]] = None soql_filter: Optional[str] = None # soql_filter property select_options: Optional[SelectOptions] = Field( diff --git a/cumulusci/tasks/bulkdata/step.py b/cumulusci/tasks/bulkdata/step.py index b88fa8b100..b2a13bf966 100644 --- a/cumulusci/tasks/bulkdata/step.py +++ b/cumulusci/tasks/bulkdata/step.py @@ -883,30 +883,33 @@ def select_records(self, records): limit_clause = self._determine_limit_clause(total_num_records) # Generate the SOQL query based on the selection strategy - select_query, query_fields = ( - self.select_operation_executor.select_generate_query( - sobject=self.sobject, - fields=self.fields, - user_filter=self.selection_filter or None, - limit=limit_clause, - offset=None, - ) + ( + select_query, + query_fields, + ) = self.select_operation_executor.select_generate_query( + sobject=self.sobject, + fields=self.fields, + user_filter=self.selection_filter or None, + limit=limit_clause, + offset=None, ) # Execute the query and gather the records query_records = self._execute_soql_query(select_query, query_fields) # Post-process the query results for this batch - selected_records, insert_records, error_message = ( - self.select_operation_executor.select_post_process( - load_records=records, - query_records=query_records, - fields=self.fields, - num_records=total_num_records, - sobject=self.sobject, - weights=self.weights, - threshold=self.threshold, - ) + ( + selected_records, + insert_records, + error_message, + ) = self.select_operation_executor.select_post_process( + load_records=records, + query_records=query_records, + fields=self.fields, + num_records=total_num_records, + sobject=self.sobject, + weights=self.weights, + threshold=self.threshold, ) # Log the number of selected and prepared for insertion records From 6889461e236602f1d4ad72d61f40b7e2b81ffd50 Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Tue, 19 Nov 2024 20:29:35 +0530 Subject: [PATCH 32/34] Update document with new enhancements --- docs/data.md | 87 ++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 71 insertions(+), 16 deletions(-) diff --git a/docs/data.md b/docs/data.md index c81ba44c90..a3e96275c1 100644 --- a/docs/data.md +++ b/docs/data.md @@ -252,40 +252,95 @@ versa. ### Selects -The "select" functionality enhances the mapping process by enabling direct record selection from the target Salesforce org for lookups. This is achieved by specifying the `select` action in the mapping file, particularly useful when dealing with objects dependent on non-insertable Salesforce objects. +The `select` functionality is designed to streamline the mapping process by enabling the selection of specific records directly from Salesforce for lookups. This feature is particularly useful when dealing with non-insertable Salesforce objects and ensures that pre-existing records are used rather than inserting new ones. The selection process is highly customizable with various strategies, filters, and additional capabilities that provide flexibility and precision in data mapping. ```yaml -Select Accounts: +Account: sf_object: Account - action: select - selection_strategy: standard - selection_filter: WHERE Name IN ('Bluth Company', 'Camacho PLC') fields: - Name - - AccountNumber -Insert Contacts: + - Description + +Contact: sf_object: Contact - action: insert fields: - LastName + - Email lookups: AccountId: table: Account + +Lead: + sf_object: Lead + fields: + - LastName + - Company + +Event: + sf_object: Event + action: select + select_options: + strategy: similarity + filter: WHERE Subject IN ('Sample Event') + priority_fields: + - Subject + - WhoId + threshold: 0.3 + fields: + - Subject + - DurationInMinutes + - ActivityDateTime + lookups: + WhoId: + table: + - Contact + - Lead + WhatId: + table: Account ``` -The `Select Accounts` section in this YAML demonstrates how to fetch specific records from your Salesforce org. These selected Account records will then be referenced by the subsequent `Insert Contacts` section via lookups, ensuring that new Contacts are linked to the pre-existing Accounts chosen in the `select` step rather than relying on any newly inserted Account records. +--- + +#### Selection Strategies + +- **`standard` Strategy:** + The `standard` selection strategy retrieves records from Salesforce in the same order as they appear, applying any specified filters and sorting criteria. This method ensures that records are selected without any prioritization based on similarity or randomness, offering a straightforward way to pull the desired data. + +- **`similarity` Strategy:** + The `similarity` strategy is used when you need to find records in Salesforce that closely resemble the records defined in your SQL file. This strategy performs a similarity match between the records in the SQL file and those in Salesforce. In addition to comparing the fields of the record itself, this strategy includes the fields of parent records (up to one level) for a more granular and accurate match. + +- **`random` Strategy:** + The `random` selection strategy randomly assigns records picked from the target org. This method is useful when the selection order does not matter, and you simply need to fetch records in a randomized manner. + +--- + +#### Selection Filters + +The selection `filter` provides a flexible way to refine the records selected by using any functionality supported by SOQL. This includes filtering, sorting, and limiting records based on specific conditions, such as using the `WHERE` clause to filter records by field values, the `ORDER BY` clause to sort records in ascending or descending order, and the `LIMIT` clause to restrict the number of records returned. Essentially, any feature available in SOQL for record selection is supported here, allowing you to tailor the selection process to your precise needs and ensuring only the relevant records are included in the mapping process. + +--- + +#### Priority Fields + +The `priority_fields` feature enables you to specify a subset of fields in your mapping step that will have more weight during the similarity matching process. When similarity matching is performed, these priority fields will be given greater importance compared to other fields, allowing for a more refined match. + +This feature is particularly useful when certain fields are more critical in defining the identity or relevance of a record, ensuring that these fields have a stronger influence in the selection process. + +--- + +#### Select + Insert -#### Selection Strategy +This feature allows you to either select or insert records based on a similarity threshold. When using the `select` action with the `similarity` strategy, you can specify a `threshold` value between `0` and `1`, where `0` represents a perfect match and `1` signifies no similarity. -The `selection_strategy` dictates how these records are chosen: +- **Select Records:** + If a record from your SQL file has a similarity score below the threshold, it will be selected from the target org. -- `standard`: This strategy fetches records from the org in the same order as they appear, respecting any filtering applied via `selection_filter`. -- `similarity`: This strategy is employed when you want to find records in the org that closely resemble those defined in your SQL file. -- `random`: As the name suggests, this strategy randomly selects records from the org. +- **Insert Records:** + If the similarity score exceeds the threshold, the record will be inserted into the target org instead of being selected. -#### Selection Filter +This feature is particularly useful during version upgrades, where records that closely match can be selected, while those that do not match sufficiently can be inserted into the target org. -The `selection_filter` acts as a versatile SOQL clause, providing fine-grained control over record selection. It allows filtering with `WHERE`, sorting with `ORDER BY`, limiting with `LIMIT`, and potentially utilizing other SOQL capabilities, ensuring you select the precise records needed for your chosen `selection_strategy`. +--- ### Database Mapping From d2dbddd3042b0742bc7c393e2c8ee87cf0d888ba Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Tue, 19 Nov 2024 20:31:12 +0530 Subject: [PATCH 33/34] Add divider in doc --- docs/data.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/data.md b/docs/data.md index a3e96275c1..c449783af7 100644 --- a/docs/data.md +++ b/docs/data.md @@ -250,6 +250,8 @@ Insert Accounts: Whenever `update_key` is supplied, the action must be `upsert` and vice versa. +--- + ### Selects The `select` functionality is designed to streamline the mapping process by enabling the selection of specific records directly from Salesforce for lookups. This feature is particularly useful when dealing with non-insertable Salesforce objects and ensures that pre-existing records are used rather than inserting new ones. The selection process is highly customizable with various strategies, filters, and additional capabilities that provide flexibility and precision in data mapping. From 8a984fd14cae8f5cb836b9182a9f813e6ae5dce2 Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Fri, 22 Nov 2024 17:17:22 +0530 Subject: [PATCH 34/34] Update documentation --- docs/data.md | 39 +++++++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/docs/data.md b/docs/data.md index c449783af7..9badb404e8 100644 --- a/docs/data.md +++ b/docs/data.md @@ -283,7 +283,7 @@ Event: action: select select_options: strategy: similarity - filter: WHERE Subject IN ('Sample Event') + filter: WHERE Subject LIKE 'Meeting%' priority_fields: - Subject - WhoId @@ -305,11 +305,13 @@ Event: #### Selection Strategies +The `strategy` parameter determines how records are selected from the target org. It is **optional**; if no strategy is specified, the `standard` strategy will be applied by default. + - **`standard` Strategy:** - The `standard` selection strategy retrieves records from Salesforce in the same order as they appear, applying any specified filters and sorting criteria. This method ensures that records are selected without any prioritization based on similarity or randomness, offering a straightforward way to pull the desired data. + The `standard` selection strategy retrieves records from target org in the same order as they appear, applying any specified filters and sorting criteria. This method ensures that records are selected without any prioritization based on similarity or randomness, offering a straightforward way to pull the desired data. - **`similarity` Strategy:** - The `similarity` strategy is used when you need to find records in Salesforce that closely resemble the records defined in your SQL file. This strategy performs a similarity match between the records in the SQL file and those in Salesforce. In addition to comparing the fields of the record itself, this strategy includes the fields of parent records (up to one level) for a more granular and accurate match. + The `similarity` strategy is used when you need to find records in the target org that closely resemble the records defined in your SQL file. This strategy performs a similarity match between the records in the SQL file and those in the target org. In addition to comparing the fields of the record itself, this strategy includes the fields of parent records (up to one level) for a more granular and accurate match. - **`random` Strategy:** The `random` selection strategy randomly assigns records picked from the target org. This method is useful when the selection order does not matter, and you simply need to fetch records in a randomized manner. @@ -320,17 +322,21 @@ Event: The selection `filter` provides a flexible way to refine the records selected by using any functionality supported by SOQL. This includes filtering, sorting, and limiting records based on specific conditions, such as using the `WHERE` clause to filter records by field values, the `ORDER BY` clause to sort records in ascending or descending order, and the `LIMIT` clause to restrict the number of records returned. Essentially, any feature available in SOQL for record selection is supported here, allowing you to tailor the selection process to your precise needs and ensuring only the relevant records are included in the mapping process. +This parameter is **optional**; and if not specified, no filter will apply. + --- #### Priority Fields The `priority_fields` feature enables you to specify a subset of fields in your mapping step that will have more weight during the similarity matching process. When similarity matching is performed, these priority fields will be given greater importance compared to other fields, allowing for a more refined match. +This parameter is **optional**; and if not specified, all fields will be considered with same priority. + This feature is particularly useful when certain fields are more critical in defining the identity or relevance of a record, ensuring that these fields have a stronger influence in the selection process. --- -#### Select + Insert +#### Threshold This feature allows you to either select or insert records based on a similarity threshold. When using the `select` action with the `similarity` strategy, you can specify a `threshold` value between `0` and `1`, where `0` represents a perfect match and `1` signifies no similarity. @@ -340,10 +346,35 @@ This feature allows you to either select or insert records based on a similarity - **Insert Records:** If the similarity score exceeds the threshold, the record will be inserted into the target org instead of being selected. +This parameter is **optional**; if not specified, no threshold will be applied and all records will default to be selected. + This feature is particularly useful during version upgrades, where records that closely match can be selected, while those that do not match sufficiently can be inserted into the target org. --- +#### Example + +To demonstrate the `select` functionality, consider the example of the `Event` entity, which utilizes the `similarity` strategy, a filter condition, and other advanced options to select matching records effectively as given in the yaml above. + +1. **Basic Object Configuration**: + + - The `Account`, `Contact`, and `Lead` objects are configured for straightforward field mapping. + - A `lookup` is defined on the `Contact` object to map `AccountId` to the `Account` table. + +2. **Advanced `Event` Object Mapping**: + - **Action**: The `Event` object uses the `select` action, meaning records are selected rather than inserted. + - **Strategy**: The `similarity` strategy matches `Event` records in target org that are similar to those defined in the SQL file. + - **Filter**: Only `Event` records with a `Subject` field starting with "Meeting" are considered. + - **Priority Fields**: The `Subject` and `WhoId` fields are given more weight during similarity matching. + - **Threshold**: A similarity score of 0.3 is used to determine whether records are selected or inserted. + - **Lookups**: + - The `WhoId` field looks up records from either the `Contact` or `Lead` objects. + - The `WhatId` field looks up records from the `Account` object. + +This example highlights how the `select` functionality can be applied in real-world scenarios, such as selecting `Event` records that meet specific criteria while considering similarity, filters, and priority fields. + +--- + ### Database Mapping CumulusCI's definition format includes considerable flexibility for use