diff --git a/splink/linker.py b/splink/linker.py index eb6f6f77ca..ebe71b017d 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -535,6 +535,17 @@ def _initialise_df_concat_with_tf(self, materialise=True): return nodes_with_tf + def _enqueue_df_concat_with_tf(self, pipeline: SQLPipeline): + + sql = vertically_concatenate_sql(self) + pipeline.enqueue_sql(sql, "__splink__df_concat") + + sqls = compute_all_term_frequencies_sqls(self) + for sql in sqls: + pipeline.enqueue_sql(sql["sql"], sql["output_table_name"]) + + return pipeline + def _table_to_splink_dataframe( self, templated_name, physical_name ) -> SplinkDataFrame: @@ -943,19 +954,24 @@ def deterministic_link(self) -> SplinkDataFrame: SplinkDataFrame allow you to access the underlying data. """ + pipeline = SQLPipeline() + # Allows clustering during a deterministic linkage. # This is used in `cluster_pairwise_predictions_at_threshold` # to set the cluster threshold to 1 self._deterministic_link_mode = True - concat_with_tf = self._initialise_df_concat_with_tf() + self._enqueue_df_concat_with_tf(pipeline) + concat_with_tf = self.db_api.sql_pipeline_to_splink_dataframe(pipeline) + exploding_br_with_id_tables = materialise_exploded_id_tables(self) sqls = block_using_rules_sqls(self) - for sql in sqls: - self._enqueue_sql(sql["sql"], sql["output_table_name"]) + pipeline.enqueue_list_of_sqls(sqls) - deterministic_link_df = self._execute_sql_pipeline([concat_with_tf]) + deterministic_link_df = self.db_api.sql_pipeline_to_splink_dataframe( + pipeline, [concat_with_tf] + ) [b.drop_materialised_id_pairs_dataframe() for b in exploding_br_with_id_tables] return deterministic_link_df @@ -1283,39 +1299,45 @@ def predict( """ + pipeline = SQLPipeline() + # If materialise_after_computing_term_frequencies=False and the user only # calls predict, it runs as a single pipeline with no materialisation # of anything. - # _initialise_df_concat_with_tf returns None if the table doesn't exist - # and only SQL is queued in this step. - nodes_with_tf = self._initialise_df_concat_with_tf( - materialise=materialise_after_computing_term_frequencies - ) + self._enqueue_df_concat_with_tf(pipeline) + + # In duckdb, calls to random() in a CTE pipeline cause problems: + # https://gist.github.com/RobinL/d329e7004998503ce91b68479aa41139 + if self._settings_obj.salting_required: + materialise_after_computing_term_frequencies = True input_dataframes = [] - if nodes_with_tf: + if materialise_after_computing_term_frequencies: + nodes_with_tf = self.db_api.sql_pipeline_to_splink_dataframe(pipeline) input_dataframes.append(nodes_with_tf) + pipeline = SQLPipeline() # If exploded blocking rules exist, we need to materialise # the tables of ID pairs exploding_br_with_id_tables = materialise_exploded_id_tables(self) sqls = block_using_rules_sqls(self) - for sql in sqls: - self._enqueue_sql(sql["sql"], sql["output_table_name"]) + pipeline.enqueue_list_of_sqls(sqls) repartition_after_blocking = getattr(self, "repartition_after_blocking", False) # repartition after blocking only exists on the SparkLinker if repartition_after_blocking: - df_blocked = self._execute_sql_pipeline(input_dataframes) + df_blocked = self.db_api.sql_pipeline_to_splink_dataframe( + pipeline, input_dataframes + ) input_dataframes.append(df_blocked) sql = compute_comparison_vector_values_sql( self._settings_obj._columns_to_select_for_comparison_vector_values ) - self._enqueue_sql(sql, "__splink__df_comparison_vectors") + pipeline.enqueue_sql(sql, "__splink__df_comparison_vectors") sqls = predict_from_comparison_vectors_sqls_using_settings( self._settings_obj, @@ -1323,10 +1345,11 @@ def predict( threshold_match_weight, sql_infinity_expression=self._infinity_expression, ) - for sql in sqls: - self._enqueue_sql(sql["sql"], sql["output_table_name"]) + pipeline.enqueue_list_of_sqls(sqls) - predictions = self._execute_sql_pipeline(input_dataframes) + predictions = self.db_api.sql_pipeline_to_splink_dataframe( + pipeline, input_dataframes + ) self._predict_warning() [b.drop_materialised_id_pairs_dataframe() for b in exploding_br_with_id_tables] diff --git a/splink/pipeline.py b/splink/pipeline.py index 3a5d418ed9..40cd70bc89 100644 --- a/splink/pipeline.py +++ b/splink/pipeline.py @@ -1,5 +1,6 @@ import logging from copy import deepcopy +from typing import List import sqlglot from sqlglot.errors import ParseError @@ -45,6 +46,10 @@ def enqueue_sql(self, sql, output_table_name): sql_task = SQLTask(sql, output_table_name) self.queue.append(sql_task) + def enqueue_list_of_sqls(self, sql_list: List[dict]): + for sql_dict in sql_list: + self.enqueue_sql(sql_dict["sql"], sql_dict["output_table_name"]) + def generate_pipeline_parts(self, input_dataframes): parts = deepcopy(self.queue) for df in input_dataframes: