diff --git a/splink/linker.py b/splink/linker.py index ab3f7fbae2..36770e4d26 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -1305,11 +1305,12 @@ def predict( # In duckdb, calls to random() in a CTE pipeline cause problems: # https://gist.github.com/RobinL/d329e7004998503ce91b68479aa41139 if self._settings_obj.salting_required: - materialise = True + materialise_after_computing_term_frequencies = True input_dataframes = [] if materialise_after_computing_term_frequencies: nodes_with_tf = self.db_api.sql_pipeline_to_splink_dataframe(pipeline) + input_dataframes.append(nodes_with_tf) pipeline = SQLPipeline() # If exploded blocking rules exist, we need to materialise @@ -1317,20 +1318,21 @@ def predict( exploding_br_with_id_tables = materialise_exploded_id_tables(self) sqls = block_using_rules_sqls(self) - for sql in sqls: - self._enqueue_sql(sql["sql"], sql["output_table_name"]) + pipeline.enqueue_list_of_sqls(sqls) repartition_after_blocking = getattr(self, "repartition_after_blocking", False) # repartition after blocking only exists on the SparkLinker if repartition_after_blocking: - df_blocked = self._execute_sql_pipeline(input_dataframes) + df_blocked = self.db_api.sql_pipeline_to_splink_dataframe( + pipeline, input_dataframes + ) input_dataframes.append(df_blocked) sql = compute_comparison_vector_values_sql( self._settings_obj._columns_to_select_for_comparison_vector_values ) - self._enqueue_sql(sql, "__splink__df_comparison_vectors") + pipeline.enqueue_sql(sql, "__splink__df_comparison_vectors") sqls = predict_from_comparison_vectors_sqls_using_settings( self._settings_obj, @@ -1338,10 +1340,11 @@ def predict( threshold_match_weight, sql_infinity_expression=self._infinity_expression, ) - for sql in sqls: - self._enqueue_sql(sql["sql"], sql["output_table_name"]) + pipeline.enqueue_list_of_sqls(sqls) - predictions = self._execute_sql_pipeline(input_dataframes) + predictions = self.db_api.sql_pipeline_to_splink_dataframe( + pipeline, input_dataframes + ) self._predict_warning() [b.drop_materialised_id_pairs_dataframe() for b in exploding_br_with_id_tables] diff --git a/splink/pipeline.py b/splink/pipeline.py index 3a5d418ed9..40cd70bc89 100644 --- a/splink/pipeline.py +++ b/splink/pipeline.py @@ -1,5 +1,6 @@ import logging from copy import deepcopy +from typing import List import sqlglot from sqlglot.errors import ParseError @@ -45,6 +46,10 @@ def enqueue_sql(self, sql, output_table_name): sql_task = SQLTask(sql, output_table_name) self.queue.append(sql_task) + def enqueue_list_of_sqls(self, sql_list: List[dict]): + for sql_dict in sql_list: + self.enqueue_sql(sql_dict["sql"], sql_dict["output_table_name"]) + def generate_pipeline_parts(self, input_dataframes): parts = deepcopy(self.queue) for df in input_dataframes: