Skip to content

Commit

Permalink
use fresh pipeline for predict
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Mar 15, 2024
1 parent b734b09 commit fa7d560
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
19 changes: 11 additions & 8 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,43 +1305,46 @@ def predict(
# In duckdb, calls to random() in a CTE pipeline cause problems:
# https://gist.github.com/RobinL/d329e7004998503ce91b68479aa41139
if self._settings_obj.salting_required:
materialise = True
materialise_after_computing_term_frequencies = True

input_dataframes = []
if materialise_after_computing_term_frequencies:
nodes_with_tf = self.db_api.sql_pipeline_to_splink_dataframe(pipeline)
input_dataframes.append(nodes_with_tf)
pipeline = SQLPipeline()

# If exploded blocking rules exist, we need to materialise
# the tables of ID pairs
exploding_br_with_id_tables = materialise_exploded_id_tables(self)

sqls = block_using_rules_sqls(self)
for sql in sqls:
self._enqueue_sql(sql["sql"], sql["output_table_name"])
pipeline.enqueue_list_of_sqls(sqls)

repartition_after_blocking = getattr(self, "repartition_after_blocking", False)

# repartition after blocking only exists on the SparkLinker
if repartition_after_blocking:
df_blocked = self._execute_sql_pipeline(input_dataframes)
df_blocked = self.db_api.sql_pipeline_to_splink_dataframe(
pipeline, input_dataframes
)
input_dataframes.append(df_blocked)

sql = compute_comparison_vector_values_sql(
self._settings_obj._columns_to_select_for_comparison_vector_values
)
self._enqueue_sql(sql, "__splink__df_comparison_vectors")
pipeline.enqueue_sql(sql, "__splink__df_comparison_vectors")

sqls = predict_from_comparison_vectors_sqls_using_settings(
self._settings_obj,
threshold_match_probability,
threshold_match_weight,
sql_infinity_expression=self._infinity_expression,
)
for sql in sqls:
self._enqueue_sql(sql["sql"], sql["output_table_name"])
pipeline.enqueue_list_of_sqls(sqls)

predictions = self._execute_sql_pipeline(input_dataframes)
predictions = self.db_api.sql_pipeline_to_splink_dataframe(
pipeline, input_dataframes
)
self._predict_warning()

[b.drop_materialised_id_pairs_dataframe() for b in exploding_br_with_id_tables]
Expand Down
5 changes: 5 additions & 0 deletions splink/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from copy import deepcopy
from typing import List

import sqlglot
from sqlglot.errors import ParseError
Expand Down Expand Up @@ -45,6 +46,10 @@ def enqueue_sql(self, sql, output_table_name):
sql_task = SQLTask(sql, output_table_name)
self.queue.append(sql_task)

def enqueue_list_of_sqls(self, sql_list: List[dict]):
for sql_dict in sql_list:
self.enqueue_sql(sql_dict["sql"], sql_dict["output_table_name"])

def generate_pipeline_parts(self, input_dataframes):
parts = deepcopy(self.queue)
for df in input_dataframes:
Expand Down

0 comments on commit fa7d560

Please sign in to comment.