Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Use fresh pipelines rather than linker pipeline #2062

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,17 @@ def _initialise_df_concat_with_tf(self, materialise=True):

return nodes_with_tf

def _enqueue_df_concat_with_tf(self, pipeline: SQLPipeline):
Copy link
Member Author

@RobinL RobinL Mar 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is no good because it doesn't deal with the cache

This is really fiddly because either:

  • You want to execute the pipeline, add it to the cache and return the dataframe
  • Retrieve from the cache if exists
  • Return a pipeilne with the right sqls enqueued

So there's not an easy version of this that uses a fresh sql pipeline. you can pass one in, but then uncear what to return

Another option would be to allow the pipeline to queue input dataframes, tha seems promising ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in that case, a fresh sql pipeline is passed in. If it's in the cache, then the pipeline is returned with it as an input dataframe queued up

if it's not in the cache either

  • materialisation case: a pipeline is returned with it in the cache
  • non-materialisation case: a pipeline is returned with sql enqueued


sql = vertically_concatenate_sql(self)
pipeline.enqueue_sql(sql, "__splink__df_concat")

sqls = compute_all_term_frequencies_sqls(self)
for sql in sqls:
pipeline.enqueue_sql(sql["sql"], sql["output_table_name"])

return pipeline

def _table_to_splink_dataframe(
self, templated_name, physical_name
) -> SplinkDataFrame:
Expand Down Expand Up @@ -943,19 +954,24 @@ def deterministic_link(self) -> SplinkDataFrame:
SplinkDataFrame allow you to access the underlying data.
"""

pipeline = SQLPipeline()

# Allows clustering during a deterministic linkage.
# This is used in `cluster_pairwise_predictions_at_threshold`
# to set the cluster threshold to 1
self._deterministic_link_mode = True

concat_with_tf = self._initialise_df_concat_with_tf()
self._enqueue_df_concat_with_tf(pipeline)
concat_with_tf = self.db_api.sql_pipeline_to_splink_dataframe(pipeline)

exploding_br_with_id_tables = materialise_exploded_id_tables(self)

sqls = block_using_rules_sqls(self)
for sql in sqls:
self._enqueue_sql(sql["sql"], sql["output_table_name"])
pipeline.enqueue_list_of_sqls(sqls)

deterministic_link_df = self._execute_sql_pipeline([concat_with_tf])
deterministic_link_df = self.db_api.sql_pipeline_to_splink_dataframe(
pipeline, [concat_with_tf]
)
[b.drop_materialised_id_pairs_dataframe() for b in exploding_br_with_id_tables]
return deterministic_link_df

Expand Down Expand Up @@ -1283,50 +1299,57 @@ def predict(

"""

pipeline = SQLPipeline()

# If materialise_after_computing_term_frequencies=False and the user only
# calls predict, it runs as a single pipeline with no materialisation
# of anything.

# _initialise_df_concat_with_tf returns None if the table doesn't exist
# and only SQL is queued in this step.
nodes_with_tf = self._initialise_df_concat_with_tf(
materialise=materialise_after_computing_term_frequencies
)
self._enqueue_df_concat_with_tf(pipeline)

# In duckdb, calls to random() in a CTE pipeline cause problems:
# https://gist.github.com/RobinL/d329e7004998503ce91b68479aa41139
if self._settings_obj.salting_required:
materialise_after_computing_term_frequencies = True

input_dataframes = []
if nodes_with_tf:
if materialise_after_computing_term_frequencies:
nodes_with_tf = self.db_api.sql_pipeline_to_splink_dataframe(pipeline)
input_dataframes.append(nodes_with_tf)
pipeline = SQLPipeline()

# If exploded blocking rules exist, we need to materialise
# the tables of ID pairs
exploding_br_with_id_tables = materialise_exploded_id_tables(self)

sqls = block_using_rules_sqls(self)
for sql in sqls:
self._enqueue_sql(sql["sql"], sql["output_table_name"])
pipeline.enqueue_list_of_sqls(sqls)

repartition_after_blocking = getattr(self, "repartition_after_blocking", False)

# repartition after blocking only exists on the SparkLinker
if repartition_after_blocking:
df_blocked = self._execute_sql_pipeline(input_dataframes)
df_blocked = self.db_api.sql_pipeline_to_splink_dataframe(
pipeline, input_dataframes
)
input_dataframes.append(df_blocked)

sql = compute_comparison_vector_values_sql(
self._settings_obj._columns_to_select_for_comparison_vector_values
)
self._enqueue_sql(sql, "__splink__df_comparison_vectors")
pipeline.enqueue_sql(sql, "__splink__df_comparison_vectors")

sqls = predict_from_comparison_vectors_sqls_using_settings(
self._settings_obj,
threshold_match_probability,
threshold_match_weight,
sql_infinity_expression=self._infinity_expression,
)
for sql in sqls:
self._enqueue_sql(sql["sql"], sql["output_table_name"])
pipeline.enqueue_list_of_sqls(sqls)

predictions = self._execute_sql_pipeline(input_dataframes)
predictions = self.db_api.sql_pipeline_to_splink_dataframe(
pipeline, input_dataframes
)
self._predict_warning()

[b.drop_materialised_id_pairs_dataframe() for b in exploding_br_with_id_tables]
Expand Down
5 changes: 5 additions & 0 deletions splink/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from copy import deepcopy
from typing import List

import sqlglot
from sqlglot.errors import ParseError
Expand Down Expand Up @@ -45,6 +46,10 @@ def enqueue_sql(self, sql, output_table_name):
sql_task = SQLTask(sql, output_table_name)
self.queue.append(sql_task)

def enqueue_list_of_sqls(self, sql_list: List[dict]):
for sql_dict in sql_list:
self.enqueue_sql(sql_dict["sql"], sql_dict["output_table_name"])

def generate_pipeline_parts(self, input_dataframes):
parts = deepcopy(self.queue)
for df in input_dataframes:
Expand Down