Skip to content

Commit

Permalink
forgot to remove input_dataframes
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Mar 18, 2024
1 parent f09b264 commit b380674
Showing 1 changed file with 10 additions and 16 deletions.
26 changes: 10 additions & 16 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,50 +1311,44 @@ def predict(
"""

pipeline = CTEPipeline()

# If materialise_after_computing_term_frequencies=False and the user only
# calls predict, it runs as a single pipeline with no materialisation
# of anything.

# _initialise_df_concat_with_tf returns None if the table doesn't exist
# and only SQL is queued in this step.
nodes_with_tf = self._initialise_df_concat_with_tf(
materialise=materialise_after_computing_term_frequencies
pipeline = self._enqueue_df_concat_with_tf(
pipeline, materialise=materialise_after_computing_term_frequencies
)

input_dataframes = []
if nodes_with_tf:
input_dataframes.append(nodes_with_tf)

# If exploded blocking rules exist, we need to materialise
# the tables of ID pairs
exploding_br_with_id_tables = materialise_exploded_id_tables(self)

sqls = block_using_rules_sqls(self)
for sql in sqls:
self._enqueue_sql(sql["sql"], sql["output_table_name"])
pipeline.enqueue_list_of_sqls(sqls)

repartition_after_blocking = getattr(self, "repartition_after_blocking", False)

# repartition after blocking only exists on the SparkLinker
if repartition_after_blocking:
df_blocked = self._execute_sql_pipeline(input_dataframes)
input_dataframes.append(df_blocked)
df_blocked = self.db_api.sql_pipeline_to_splink_dataframe(pipeline)
pipeline.append_input_dataframe(df_blocked)

sql = compute_comparison_vector_values_sql(
self._settings_obj._columns_to_select_for_comparison_vector_values
)
self._enqueue_sql(sql, "__splink__df_comparison_vectors")
pipeline.enqueue_sql(sql, "__splink__df_comparison_vectors")

sqls = predict_from_comparison_vectors_sqls_using_settings(
self._settings_obj,
threshold_match_probability,
threshold_match_weight,
sql_infinity_expression=self._infinity_expression,
)
for sql in sqls:
self._enqueue_sql(sql["sql"], sql["output_table_name"])
pipeline.enqueue_list_of_sqls(sqls)

predictions = self._execute_sql_pipeline(input_dataframes)
predictions = self.db_api.sql_pipeline_to_splink_dataframe(pipeline)
self._predict_warning()

[b.drop_materialised_id_pairs_dataframe() for b in exploding_br_with_id_tables]
Expand Down

0 comments on commit b380674

Please sign in to comment.