Skip to content

Commit

Permalink
Merge pull request #2086 from moj-analytical-services/enqueue_and_com…
Browse files Browse the repository at this point in the history
…pute_methods

Enqueue and compute methods
  • Loading branch information
RobinL authored Mar 21, 2024
2 parents cfc2ad3 + e70d66d commit 7878be8
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 45 deletions.
8 changes: 6 additions & 2 deletions splink/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .pipeline import CTEPipeline
from .predict import predict_from_comparison_vectors_sqls_using_settings
from .sql_transform import move_l_r_table_prefix_to_column_suffix
from .vertically_concatenate import compute_df_concat_with_tf

if TYPE_CHECKING:
from .linker import Linker
Expand Down Expand Up @@ -168,7 +169,9 @@ def truth_space_table_from_labels_table(
linker, labels_tablename, threshold_actual=0.5, match_weight_round_to_nearest=None
):
pipeline = CTEPipeline(reusable=False)
pipeline = linker._enqueue_df_concat_with_tf(pipeline)

nodes_with_tf = compute_df_concat_with_tf(linker, pipeline)
pipeline = CTEPipeline([nodes_with_tf], reusable=False)

sqls = predictions_from_sample_of_pairwise_labels_sql(linker, labels_tablename)
pipeline.enqueue_list_of_sqls(sqls)
Expand Down Expand Up @@ -269,7 +272,8 @@ def prediction_errors_from_labels_table(
threshold=0.5,
):
pipeline = CTEPipeline(reusable=False)
pipeline = linker._enqueue_df_concat_with_tf(pipeline)
nodes_with_tf = compute_df_concat_with_tf(linker, pipeline)
pipeline = CTEPipeline([nodes_with_tf], reusable=False)

sqls = predictions_from_sample_of_pairwise_labels_sql(linker, labels_tablename)

Expand Down
9 changes: 5 additions & 4 deletions splink/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .pipeline import CTEPipeline
from .splink_dataframe import SplinkDataFrame
from .unique_id_concat import _composite_unique_id_from_nodes_sql
from .vertically_concatenate import compute_df_concat_with_tf

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -397,13 +398,12 @@ def materialise_exploded_id_tables(linker: Linker):
exploding_blocking_rules = [
br for br in blocking_rules if isinstance(br, ExplodingBlockingRule)
]
if len(exploding_blocking_rules) == 0:
return []
exploded_tables = []

pipeline = CTEPipeline(reusable=False)
linker._enqueue_df_concat_with_tf(pipeline)
nodes_with_tf = linker._intermediate_table_cache.get_with_logging(
"__splink__df_concat_with_tf"
)
nodes_with_tf = compute_df_concat_with_tf(linker, pipeline)

input_colnames = {col.name for col in nodes_with_tf.columns}

Expand Down Expand Up @@ -434,6 +434,7 @@ def materialise_exploded_id_tables(linker: Linker):
marginal_ids_table = linker.db_api.sql_pipeline_to_splink_dataframe(pipeline)
br.exploded_id_pair_table = marginal_ids_table
exploded_tables.append(marginal_ids_table)

return exploding_blocking_rules


Expand Down
4 changes: 3 additions & 1 deletion splink/em_training_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Settings,
TrainingSettings,
)
from .vertically_concatenate import compute_df_concat_with_tf

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -178,7 +179,8 @@ def _comparison_vectors(self):
self._training_log_message()

pipeline = CTEPipeline()
pipeline = self._original_linker._enqueue_df_concat_with_tf(pipeline)
nodes_with_tf = compute_df_concat_with_tf(self._original_linker, pipeline)
pipeline = CTEPipeline([nodes_with_tf], reusable=False)

sqls = block_using_rules_sqls(
self._original_linker, [self._blocking_rule_for_training]
Expand Down
4 changes: 3 additions & 1 deletion splink/estimate_u.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
m_u_records_to_lookup_dict,
)
from .pipeline import CTEPipeline
from .vertically_concatenate import compute_df_concat_with_tf

# https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports
if TYPE_CHECKING:
Expand Down Expand Up @@ -57,7 +58,8 @@ def estimate_u_values(linker: Linker, max_pairs, seed=None):
logger.info("----- Estimating u probabilities using random sampling -----")
pipeline = CTEPipeline(reusable=False)

pipeline = linker._enqueue_df_concat_with_tf(pipeline)
nodes_with_tf = compute_df_concat_with_tf(linker, pipeline)
pipeline = CTEPipeline([nodes_with_tf], reusable=False)

original_settings_obj = linker._settings_obj

Expand Down
10 changes: 7 additions & 3 deletions splink/labelling_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from jinja2 import Template

from .misc import EverythingEncoder, read_resource
from .pipeline import CTEPipeline
from .splink_dataframe import SplinkDataFrame
from .vertically_concatenate import compute_df_concat_with_tf

# https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports
if TYPE_CHECKING:
Expand All @@ -21,8 +23,10 @@ def generate_labelling_tool_comparisons(
linker: "Linker", unique_id, source_dataset, match_weight_threshold=-4
):
# ensure the tf table exists
concat_with_tf = linker._initialise_df_concat_with_tf()
pipeline = CTEPipeline(reusable=False)
nodes_with_tf = compute_df_concat_with_tf(linker, pipeline)

pipeline = CTEPipeline([nodes_with_tf], reusable=False)
settings = linker._settings_obj

source_dataset_condition = ""
Expand All @@ -40,8 +44,8 @@ def generate_labelling_tool_comparisons(
{source_dataset_condition}
"""

linker._enqueue_sql(sql, "__splink__df_labelling_tool_record")
splink_df = linker._execute_sql_pipeline([concat_with_tf])
pipeline.enqueue_sql(sql, "__splink__df_labelling_tool_record")
splink_df = linker.db_api.sql_pipeline_to_splink_dataframe(pipeline)

matches = linker.find_matches_to_new_records(
splink_df.physical_name, match_weight_threshold=match_weight_threshold
Expand Down
47 changes: 13 additions & 34 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from statistics import median
from typing import Dict, Optional, Union


from .vertically_concatenate import enqueue_df_concat_with_tf, compute_df_concat_with_tf
from splink.input_column import InputColumn
from splink.settings_validation.log_invalid_columns import (
InvalidColumnsLogger,
Expand Down Expand Up @@ -536,35 +536,6 @@ def _initialise_df_concat_with_tf(self, materialise=True):

return nodes_with_tf

def _enqueue_df_concat_with_tf(self, pipeline: CTEPipeline, materialise=True):

cache = self._intermediate_table_cache

if "__splink__df_concat_with_tf" in cache:
nodes_with_tf = cache.get_with_logging("__splink__df_concat_with_tf")
pipeline.append_input_dataframe(nodes_with_tf)
return pipeline

# In duckdb, calls to random() in a CTE pipeline cause problems:
# https://gist.github.com/RobinL/d329e7004998503ce91b68479aa41139
if self._settings_obj.salting_required:
materialise = True

sql = vertically_concatenate_sql(self)
pipeline.enqueue_sql(sql, "__splink__df_concat")

sqls = compute_all_term_frequencies_sqls(self)
pipeline.enqueue_list_of_sqls(sqls)

if materialise:
# Can't use break lineage here because we need nodes_with_tf
# so it can be explicitly set to the named cache
nodes_with_tf = self.db_api.sql_pipeline_to_splink_dataframe(pipeline)
cache["__splink__df_concat_with_tf"] = nodes_with_tf
pipeline = CTEPipeline(input_dataframes=[nodes_with_tf])

return pipeline

def _table_to_splink_dataframe(
self, templated_name, physical_name
) -> SplinkDataFrame:
Expand Down Expand Up @@ -978,7 +949,8 @@ def deterministic_link(self) -> SplinkDataFrame:
# to set the cluster threshold to 1
self._deterministic_link_mode = True

pipeline = self._enqueue_df_concat_with_tf(pipeline)
df_concat_with_tf = compute_df_concat_with_tf(self, pipeline)
pipeline = CTEPipeline([df_concat_with_tf], reusable=False)

exploding_br_with_id_tables = materialise_exploded_id_tables(self)

Expand Down Expand Up @@ -1321,9 +1293,16 @@ def predict(
# calls predict, it runs as a single pipeline with no materialisation
# of anything.

pipeline = self._enqueue_df_concat_with_tf(
pipeline, materialise=materialise_after_computing_term_frequencies
)
# In duckdb, calls to random() in a CTE pipeline cause problems:
# https://gist.github.com/RobinL/d329e7004998503ce91b68479aa41139
if (
materialise_after_computing_term_frequencies
or self._sql_dialect == "duckdb"
):
df_concat_with_tf = compute_df_concat_with_tf(self, pipeline)
pipeline = CTEPipeline([df_concat_with_tf], reusable=False)
else:
pipeline = enqueue_df_concat_with_tf(self, pipeline)

# If exploded blocking rules exist, we need to materialise
# the tables of ID pairs
Expand Down
39 changes: 39 additions & 0 deletions splink/vertically_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import logging
from typing import TYPE_CHECKING

from .pipeline import CTEPipeline
from .splink_dataframe import SplinkDataFrame
from .term_frequencies import compute_all_term_frequencies_sqls

logger = logging.getLogger(__name__)

# https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports
Expand Down Expand Up @@ -74,3 +78,38 @@ def vertically_concatenate_sql(linker: Linker) -> str:
"""

return sql


def enqueue_df_concat_with_tf(linker: Linker, pipeline: CTEPipeline) -> CTEPipeline:

cache = linker._intermediate_table_cache
if "__splink__df_concat_with_tf" in cache:
nodes_with_tf = cache.get_with_logging("__splink__df_concat_with_tf")
pipeline.append_input_dataframe(nodes_with_tf)
return pipeline

sql = vertically_concatenate_sql(linker)
pipeline.enqueue_sql(sql, "__splink__df_concat")

sqls = compute_all_term_frequencies_sqls(linker)
pipeline.enqueue_list_of_sqls(sqls)

return pipeline


def compute_df_concat_with_tf(linker: Linker, pipeline) -> SplinkDataFrame:
cache = linker._intermediate_table_cache
db_api = linker.db_api

if "__splink__df_concat_with_tf" in cache:
return cache.get_with_logging("__splink__df_concat_with_tf")

sql = vertically_concatenate_sql(linker)
pipeline.enqueue_sql(sql, "__splink__df_concat")

sqls = compute_all_term_frequencies_sqls(linker)
pipeline.enqueue_list_of_sqls(sqls)

nodes_with_tf = db_api.sql_pipeline_to_splink_dataframe(pipeline)
cache["__splink__df_concat_with_tf"] = nodes_with_tf
return nodes_with_tf

0 comments on commit 7878be8

Please sign in to comment.