From 59bdc386acf840a38682dea11f13213f0450530a Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Wed, 15 May 2024 14:09:01 +0100 Subject: [PATCH] rename arg --- splink/analyse_blocking.py | 8 ++++---- tests/test_analyse_blocking.py | 24 +++++++++++++----------- tests/test_full_example_duckdb.py | 2 +- tests/test_full_example_postgres.py | 2 +- tests/test_total_comparison_count.py | 2 +- 5 files changed, 20 insertions(+), 18 deletions(-) diff --git a/splink/analyse_blocking.py b/splink/analyse_blocking.py index 6478aea969..78849589ec 100644 --- a/splink/analyse_blocking.py +++ b/splink/analyse_blocking.py @@ -479,15 +479,15 @@ def count_comparisons_from_blocking_rule( compute_post_filter_count: bool = True, max_rows_limit: int = int(1e9), ) -> dict[str, Union[int, str]]: - blocking_rule = to_blocking_rule_creator(blocking_rule_creator).get_blocking_rule( - db_api.sql_dialect.name - ) + blocking_rule_creator = to_blocking_rule_creator( + blocking_rule_creator + ).get_blocking_rule(db_api.sql_dialect.name) splink_df_dict = db_api.register_multiple_tables(table_or_tables) return _count_comparisons_generated_from_blocking_rule( splink_df_dict=splink_df_dict, - blocking_rule=blocking_rule, + blocking_rule=blocking_rule_creator, link_type=link_type, db_api=db_api, compute_post_filter_count=compute_post_filter_count, diff --git a/tests/test_analyse_blocking.py b/tests/test_analyse_blocking.py index 76e4407544..ed84224429 100644 --- a/tests/test_analyse_blocking.py +++ b/tests/test_analyse_blocking.py @@ -49,13 +49,13 @@ def test_analyse_blocking_slow_methodology(test_helpers, dialect): } res_dict = count_comparisons_from_blocking_rule( - table_or_tables=df_1, blocking_rule="1=1", **args + table_or_tables=df_1, blocking_rule_creator="1=1", **args ) res = res_dict["number_of_comparisons_to_be_scored_post_filter_conditions"] assert res == 4 * 3 / 2 res_dict = count_comparisons_from_blocking_rule( - table_or_tables=df_1, blocking_rule=block_on("first_name"), **args + table_or_tables=df_1, blocking_rule_creator=block_on("first_name"), **args ) res = res_dict["number_of_comparisons_to_be_scored_post_filter_conditions"] @@ -63,33 +63,35 @@ def test_analyse_blocking_slow_methodology(test_helpers, dialect): args["link_type"] = "link_only" res_dict = count_comparisons_from_blocking_rule( - table_or_tables=[df_1, df_2], blocking_rule="1=1", **args + table_or_tables=[df_1, df_2], blocking_rule_creator="1=1", **args ) res = res_dict["number_of_comparisons_to_be_scored_post_filter_conditions"] assert res == 4 * 3 res_dict = count_comparisons_from_blocking_rule( - table_or_tables=[df_1, df_2], blocking_rule=block_on("surname"), **args + table_or_tables=[df_1, df_2], blocking_rule_creator=block_on("surname"), **args ) res = res_dict["number_of_comparisons_to_be_scored_post_filter_conditions"] assert res == 1 res_dict = count_comparisons_from_blocking_rule( - table_or_tables=[df_1, df_2], blocking_rule=block_on("first_name"), **args + table_or_tables=[df_1, df_2], + blocking_rule_creator=block_on("first_name"), + **args, ) res = res_dict["number_of_comparisons_to_be_scored_post_filter_conditions"] assert res == 3 res_dict = count_comparisons_from_blocking_rule( - table_or_tables=[df_1, df_2, df_3], blocking_rule="1=1", **args + table_or_tables=[df_1, df_2, df_3], blocking_rule_creator="1=1", **args ) res = res_dict["number_of_comparisons_to_be_scored_post_filter_conditions"] assert res == 4 * 3 + 4 * 2 + 2 * 3 args["link_type"] = "link_and_dedupe" res_dict = count_comparisons_from_blocking_rule( - table_or_tables=[df_1, df_2], blocking_rule="1=1", **args + table_or_tables=[df_1, df_2], blocking_rule_creator="1=1", **args ) res = res_dict["number_of_comparisons_to_be_scored_post_filter_conditions"] expected = 4 * 3 + (4 * 3 / 2) + (3 * 2 / 2) @@ -97,14 +99,14 @@ def test_analyse_blocking_slow_methodology(test_helpers, dialect): rule = "l.first_name = r.first_name and l.surname = r.surname" res_dict = count_comparisons_from_blocking_rule( - table_or_tables=[df_1, df_2], blocking_rule=rule, **args + table_or_tables=[df_1, df_2], blocking_rule_creator=rule, **args ) res = res_dict["number_of_comparisons_to_be_scored_post_filter_conditions"] assert res == 1 rule = block_on("first_name", "surname") res_dict = count_comparisons_from_blocking_rule( - table_or_tables=[df_1, df_2], blocking_rule=rule, **args + table_or_tables=[df_1, df_2], blocking_rule_creator=rule, **args ) res = res_dict["number_of_comparisons_to_be_scored_post_filter_conditions"] assert res == 1 @@ -416,7 +418,7 @@ def test_analyse_blocking_fast_methodology_edge_cases(): for br in blocking_rules: res_dict = count_comparisons_from_blocking_rule( table_or_tables=df, - blocking_rule=br, + blocking_rule_creator=br, link_type="dedupe_only", db_api=db_api, unique_id_column_name="unique_id", @@ -453,7 +455,7 @@ def test_analyse_blocking_fast_methodology_edge_cases(): for br in blocking_rules: res_dict = count_comparisons_from_blocking_rule( table_or_tables=[df_l, df_r], - blocking_rule=br, + blocking_rule_creator=br, link_type="link_only", db_api=db_api, unique_id_column_name="unique_id", diff --git a/tests/test_full_example_duckdb.py b/tests/test_full_example_duckdb.py index f47394d253..4cfd7b2fe8 100644 --- a/tests/test_full_example_duckdb.py +++ b/tests/test_full_example_duckdb.py @@ -44,7 +44,7 @@ def test_full_example_duckdb(tmp_path): count_comparisons_from_blocking_rule( table_or_tables=df, - blocking_rule='l.first_name = r.first_name and l."SUR name" = r."SUR name"', + blocking_rule_creator='l.first_name = r.first_name and l."SUR name" = r."SUR name"', # noqa: E501 link_type="dedupe_only", db_api=db_api, unique_id_column_name="unique_id", diff --git a/tests/test_full_example_postgres.py b/tests/test_full_example_postgres.py index b4d67a45a9..2eb927cdf4 100644 --- a/tests/test_full_example_postgres.py +++ b/tests/test_full_example_postgres.py @@ -29,7 +29,7 @@ def test_full_example_postgres(tmp_path, pg_engine): count_comparisons_from_blocking_rule( table_or_tables=df, - blocking_rule='l.first_name = r.first_name and l."surname" = r."surname"', + blocking_rule_creator='l.first_name = r.first_name and l."surname" = r."surname"', # noqa: E501 link_type="dedupe_only", db_api=db_api, unique_id_column_name="unique_id", diff --git a/tests/test_total_comparison_count.py b/tests/test_total_comparison_count.py index 83381e5fe5..4cf637450d 100644 --- a/tests/test_total_comparison_count.py +++ b/tests/test_total_comparison_count.py @@ -87,7 +87,7 @@ def make_dummy_frame(row_count): res_dict = count_comparisons_from_blocking_rule( table_or_tables=dfs, - blocking_rule="1=1", + blocking_rule_creator="1=1", link_type=link_type, db_api=db_api, unique_id_column_name="unique_id",