From b380674210060ad6a4baee25b980fccd0e3b0c15 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 18 Mar 2024 10:43:51 +0000 Subject: [PATCH] forgot to remove input_dataframes --- splink/linker.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/splink/linker.py b/splink/linker.py index f11688fd53..b9494c8614 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -1311,39 +1311,34 @@ def predict( """ + pipeline = CTEPipeline() + # If materialise_after_computing_term_frequencies=False and the user only # calls predict, it runs as a single pipeline with no materialisation # of anything. - # _initialise_df_concat_with_tf returns None if the table doesn't exist - # and only SQL is queued in this step. - nodes_with_tf = self._initialise_df_concat_with_tf( - materialise=materialise_after_computing_term_frequencies + pipeline = self._enqueue_df_concat_with_tf( + pipeline, materialise=materialise_after_computing_term_frequencies ) - input_dataframes = [] - if nodes_with_tf: - input_dataframes.append(nodes_with_tf) - # If exploded blocking rules exist, we need to materialise # the tables of ID pairs exploding_br_with_id_tables = materialise_exploded_id_tables(self) sqls = block_using_rules_sqls(self) - for sql in sqls: - self._enqueue_sql(sql["sql"], sql["output_table_name"]) + pipeline.enqueue_list_of_sqls(sqls) repartition_after_blocking = getattr(self, "repartition_after_blocking", False) # repartition after blocking only exists on the SparkLinker if repartition_after_blocking: - df_blocked = self._execute_sql_pipeline(input_dataframes) - input_dataframes.append(df_blocked) + df_blocked = self.db_api.sql_pipeline_to_splink_dataframe(pipeline) + pipeline.append_input_dataframe(df_blocked) sql = compute_comparison_vector_values_sql( self._settings_obj._columns_to_select_for_comparison_vector_values ) - self._enqueue_sql(sql, "__splink__df_comparison_vectors") + pipeline.enqueue_sql(sql, "__splink__df_comparison_vectors") sqls = predict_from_comparison_vectors_sqls_using_settings( self._settings_obj, @@ -1351,10 +1346,9 @@ def predict( threshold_match_weight, sql_infinity_expression=self._infinity_expression, ) - for sql in sqls: - self._enqueue_sql(sql["sql"], sql["output_table_name"]) + pipeline.enqueue_list_of_sqls(sqls) - predictions = self._execute_sql_pipeline(input_dataframes) + predictions = self.db_api.sql_pipeline_to_splink_dataframe(pipeline) self._predict_warning() [b.drop_materialised_id_pairs_dataframe() for b in exploding_br_with_id_tables]