Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor block_using_rules_sql to follow normal pattern and avoid confusion #1695

Merged
merged 3 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions splink/analyse_blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pandas as pd

from .blocking import BlockingRule, _sql_gen_where_condition, block_using_rules_sql
from .blocking import BlockingRule, _sql_gen_where_condition, block_using_rules_sqls
from .misc import calculate_cartesian, calculate_reduction_ratio

# https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports
Expand Down Expand Up @@ -83,16 +83,17 @@ def cumulative_comparisons_generated_by_blocking_rules(
cartesian = calculate_cartesian(row_count_df, settings_obj._link_type)

# Calculate the total number of rows generated by each blocking rule
sql = block_using_rules_sql(linker)
linker._enqueue_sql(sql, "__splink__df_blocked_data")
sqls = block_using_rules_sqls(linker)
for sql in sqls:
linker._enqueue_sql(sql["sql"], sql["output_table_name"])

brs_as_objs = linker._settings_obj_._blocking_rules_to_generate_predictions

sql = """
select
count(*) as row_count,
match_key
from __splink__df_blocked_data
from __splink__df_blocked
group by match_key
order by cast(match_key as int) asc
"""
Expand Down
90 changes: 55 additions & 35 deletions splink/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _sql_gen_where_condition(link_type, unique_id_cols):


# flake8: noqa: C901
def block_using_rules_sql(linker: Linker):
def block_using_rules_sqls(linker: Linker):
"""Use the blocking rules specified in the linker's settings object to
generate a SQL statement that will create pairwise record comparions
according to the blocking rule(s).
Expand All @@ -206,6 +206,54 @@ def block_using_rules_sql(linker: Linker):
so that duplicate comparisons are not generated.
"""

sqls = []

# For the two dataset link only, rather than a self join of
# __splink__df_concat_with_tf, it's much faster to split the input
# into two tables, and join (because then Splink doesn't have to evaluate)
# intra-dataset comparisons.
# see https://github.com/moj-analytical-services/splink/pull/1359
if (
linker._two_dataset_link_only
and not linker._find_new_matches_mode
and not linker._compare_two_records_mode
):
source_dataset_col = linker._source_dataset_column_name
# Need df_l to be the one with the lowest id to preeserve the property
# that the left dataset is the one with the lowest concatenated id
keys = linker._input_tables_dict.keys()
keys = list(sorted(keys))
df_l = linker._input_tables_dict[keys[0]]
df_r = linker._input_tables_dict[keys[1]]

# This also needs to work for training u
if linker._train_u_using_random_sample_mode:
spl_switch = "_sample"
else:
spl_switch = ""

sql = f"""
select * from __splink__df_concat_with_tf{spl_switch}
where {source_dataset_col} = '{df_l.templated_name}'
"""
sqls.append(
{
"sql": sql,
"output_table_name": f"__splink__df_concat_with_tf{spl_switch}_left",
}
)

sql = f"""
select * from __splink__df_concat_with_tf{spl_switch}
where {source_dataset_col} = '{df_r.templated_name}'
"""
sqls.append(
{
"sql": sql,
"output_table_name": f"__splink__df_concat_with_tf{spl_switch}_right",
}
)

if type(linker).__name__ in ["SparkLinker"]:
apply_salt = True
else:
Expand Down Expand Up @@ -243,36 +291,6 @@ def block_using_rules_sql(linker: Linker):
" will not be implemented for this run."
)

if (
linker._two_dataset_link_only
and not linker._find_new_matches_mode
and not linker._compare_two_records_mode
):
source_dataset_col = linker._source_dataset_column_name
# Need df_l to be the one with the lowest id to preeserve the property
# that the left dataset is the one with the lowest concatenated id
keys = linker._input_tables_dict.keys()
keys = list(sorted(keys))
df_l = linker._input_tables_dict[keys[0]]
df_r = linker._input_tables_dict[keys[1]]

if linker._train_u_using_random_sample_mode:
sample_switch = "_sample"
else:
sample_switch = ""

sql = f"""
select * from __splink__df_concat_with_tf{sample_switch}
where {source_dataset_col} = '{df_l.templated_name}'
"""
linker._enqueue_sql(sql, f"__splink__df_concat_with_tf{sample_switch}_left")

sql = f"""
select * from __splink__df_concat_with_tf{sample_switch}
where {source_dataset_col} = '{df_r.templated_name}'
"""
linker._enqueue_sql(sql, f"__splink__df_concat_with_tf{sample_switch}_right")

# Cover the case where there are no blocking rules
# This is a bit of a hack where if you do a self-join on 'true'
# you create a cartesian product, rather than having separate code
Expand All @@ -287,7 +305,7 @@ def block_using_rules_sql(linker: Linker):
else:
probability = ""

sqls = []
br_sqls = []
for br in blocking_rules:
# Apply our salted rules to resolve skew issues. If no salt was
# selected to be added, then apply the initial blocking rule.
Expand All @@ -310,8 +328,10 @@ def block_using_rules_sql(linker: Linker):
{where_condition}
"""

sqls.append(sql)
br_sqls.append(sql)

sql = "union all".join(br_sqls)

sql = "union all".join(sqls)
sqls.append({"sql": sql, "output_table_name": "__splink__df_blocked"})

return sql
return sqls
7 changes: 4 additions & 3 deletions splink/em_training_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from copy import deepcopy
from typing import TYPE_CHECKING

from .blocking import BlockingRule, block_using_rules_sql
from .blocking import BlockingRule, block_using_rules_sqls
from .charts import (
m_u_parameters_interactive_history_chart,
match_weights_interactive_history_chart,
Expand Down Expand Up @@ -151,8 +151,9 @@ def _comparison_vectors(self):

nodes_with_tf = self._original_linker._initialise_df_concat_with_tf()

sql = block_using_rules_sql(self._training_linker)
self._training_linker._enqueue_sql(sql, "__splink__df_blocked")
sqls = block_using_rules_sqls(self._training_linker)
for sql in sqls:
self._training_linker._enqueue_sql(sql["sql"], sql["output_table_name"])

# repartition after blocking only exists on the SparkLinker
repartition_after_blocking = getattr(
Expand Down
7 changes: 4 additions & 3 deletions splink/estimate_u.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from copy import deepcopy
from typing import TYPE_CHECKING

from .blocking import block_using_rules_sql
from .blocking import block_using_rules_sqls
from .comparison_vector_values import compute_comparison_vector_values_sql
from .expectation_maximisation import (
compute_new_parameters_sql,
Expand Down Expand Up @@ -106,8 +106,9 @@ def estimate_u_values(linker: Linker, max_pairs, seed=None):

settings_obj._blocking_rules_to_generate_predictions = []

sql = block_using_rules_sql(training_linker)
training_linker._enqueue_sql(sql, "__splink__df_blocked")
sqls = block_using_rules_sqls(training_linker)
for sql in sqls:
training_linker._enqueue_sql(sql["sql"], sql["output_table_name"])

# repartition after blocking only exists on the SparkLinker
repartition_after_blocking = getattr(
Expand Down
28 changes: 16 additions & 12 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from .blocking import (
BlockingRule,
block_using_rules_sql,
block_using_rules_sqls,
blocking_rule_to_obj,
)
from .cache_dict_with_logging import CacheDictWithLogging
Expand Down Expand Up @@ -1406,8 +1406,9 @@ def deterministic_link(self) -> SplinkDataFrame:
self._deterministic_link_mode = True

concat_with_tf = self._initialise_df_concat_with_tf()
sql = block_using_rules_sql(self)
self._enqueue_sql(sql, "__splink__df_blocked")
sqls = block_using_rules_sqls(self)
for sql in sqls:
self._enqueue_sql(sql["sql"], sql["output_table_name"])
return self._execute_sql_pipeline([concat_with_tf])

def estimate_u_using_random_sampling(
Expand Down Expand Up @@ -1728,8 +1729,9 @@ def predict(
if nodes_with_tf:
input_dataframes.append(nodes_with_tf)

sql = block_using_rules_sql(self)
self._enqueue_sql(sql, "__splink__df_blocked")
sqls = block_using_rules_sqls(self)
for sql in sqls:
self._enqueue_sql(sql["sql"], sql["output_table_name"])

repartition_after_blocking = getattr(self, "repartition_after_blocking", False)

Expand Down Expand Up @@ -1853,8 +1855,9 @@ def find_matches_to_new_records(

add_unique_id_and_source_dataset_cols_if_needed(self, new_records_df)

sql = block_using_rules_sql(self)
self._enqueue_sql(sql, "__splink__df_blocked")
sqls = block_using_rules_sqls(self)
for sql in sqls:
self._enqueue_sql(sql["sql"], sql["output_table_name"])

sql = compute_comparison_vector_values_sql(self._settings_obj)
self._enqueue_sql(sql, "__splink__df_comparison_vectors")
Expand Down Expand Up @@ -1937,8 +1940,9 @@ def compare_two_records(self, record_1: dict, record_2: dict):

self._enqueue_sql(sql_join_tf, "__splink__compare_two_records_right_with_tf")

sql = block_using_rules_sql(self)
self._enqueue_sql(sql, "__splink__df_blocked")
sqls = block_using_rules_sqls(self)
for sql in sqls:
self._enqueue_sql(sql["sql"], sql["output_table_name"])

sql = compute_comparison_vector_values_sql(self._settings_obj)
self._enqueue_sql(sql, "__splink__df_comparison_vectors")
Expand Down Expand Up @@ -1993,9 +1997,9 @@ def _self_link(self) -> SplinkDataFrame:

nodes_with_tf = self._initialise_df_concat_with_tf()

sql = block_using_rules_sql(self)

self._enqueue_sql(sql, "__splink__df_blocked")
sqls = block_using_rules_sqls(self)
for sql in sqls:
self._enqueue_sql(sql["sql"], sql["output_table_name"])

sql = compute_comparison_vector_values_sql(self._settings_obj)

Expand Down
8 changes: 5 additions & 3 deletions splink/m_training.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from copy import deepcopy

from .blocking import BlockingRule, block_using_rules_sql
from .blocking import BlockingRule, block_using_rules_sqls
from .comparison_vector_values import compute_comparison_vector_values_sql
from .expectation_maximisation import (
compute_new_parameters_sql,
Expand Down Expand Up @@ -34,8 +34,10 @@ def estimate_m_values_from_label_column(linker, df_dict, label_colname):

concat_with_tf = linker._initialise_df_concat_with_tf()

sql = block_using_rules_sql(training_linker)
training_linker._enqueue_sql(sql, "__splink__df_blocked")
sqls = block_using_rules_sqls(training_linker)

for sql in sqls:
training_linker._enqueue_sql(sql["sql"], sql["output_table_name"])

sql = compute_comparison_vector_values_sql(settings_obj)

Expand Down
Loading