Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type hinting improvements - mypy conformance #2051

Merged
merged 49 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
01d9876
fix type-checking imports
ADBond Mar 12, 2024
281fb13
bump mypy version
ADBond Mar 12, 2024
678cd19
fix SettingsCreator path logic
ADBond Mar 12, 2024
30d7a96
add tf_adjustment_chart test
ADBond Mar 12, 2024
ab6f1e0
wrap possible issue in `read_resource`
ADBond Mar 12, 2024
37e278f
typing dialect base class
ADBond Mar 12, 2024
4a2856e
more specific type-hinting
ADBond Mar 12, 2024
f2ec72c
expand on variables and types
ADBond Mar 12, 2024
cfa75c2
fix types + var names
ADBond Mar 12, 2024
5a7e85d
guard for missing `_dialect_name_for_factory` in lookup
ADBond Mar 12, 2024
a72a0af
SplinkDataFrame -> ABC
ADBond Mar 12, 2024
e4d9930
clarify vars
ADBond Mar 12, 2024
020d428
lint with black
ADBond Mar 12, 2024
689bf1f
allow ColumnExpression args
ADBond Mar 12, 2024
ad7f311
+ base var
ADBond Mar 12, 2024
8f0cdc5
improve type-hinting settings
ADBond Mar 12, 2024
ca0e587
permissively typed dict, rather than inferred Noneish
ADBond Mar 12, 2024
abc5fa1
colexp type hint
ADBond Mar 12, 2024
e684d39
update temp-skipped typing modules
ADBond Mar 12, 2024
d83e950
clarify var names
ADBond Mar 12, 2024
c15add2
fix return type
ADBond Mar 12, 2024
20fcfa4
permissive dict
ADBond Mar 12, 2024
cce16e4
type hinting and variable clarifications
ADBond Mar 12, 2024
66c5195
update lockfile
ADBond Mar 12, 2024
de6c36c
mypy github action
ADBond Mar 12, 2024
f1de170
mypy action - install correct group for library
ADBond Mar 12, 2024
86a2746
specificy cache to this workflow
ADBond Mar 12, 2024
deb8617
optional atts and ignoring downstream/compat issues
ADBond Mar 12, 2024
dccf266
type hints and variable renamings
ADBond Mar 12, 2024
4111e30
lint with black
ADBond Mar 12, 2024
7f46419
br typing
ADBond Mar 13, 2024
b05c34d
optional attrs marked as such
ADBond Mar 13, 2024
366ad51
only final on setters
ADBond Mar 13, 2024
6bbcb67
more temporary mypy exclusions
ADBond Mar 13, 2024
2f0ee81
generalise type
ADBond Mar 13, 2024
4adce08
use alias to remove cyclicity
ADBond Mar 13, 2024
7f34cc4
type hint
ADBond Mar 13, 2024
bc6155b
guard against sqlglot optionals
ADBond Mar 13, 2024
57b376a
handle some blocking rule edge cases
ADBond Mar 13, 2024
c358ead
correct arg type
ADBond Mar 13, 2024
1ef4e09
default tf off - clarify
ADBond Mar 13, 2024
72f7f7e
concrete type for merging dicts
ADBond Mar 13, 2024
7ec9765
lint with black
ADBond Mar 13, 2024
95f0ebe
annotate init to check body
ADBond Mar 13, 2024
ccc358a
Settings always dialected
ADBond Mar 13, 2024
399dc0c
annotate default path, remove needless code
ADBond Mar 13, 2024
96bd6cd
lint with black
ADBond Mar 13, 2024
bccf9f3
clearer type - Sequence for covariance
ADBond Mar 14, 2024
e146b48
split out logic for getting cluster ids when not supplied
ADBond Mar 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
name: Type hinting with mypy
on:
pull_request:
branches:
- master
- '**dev'
paths:
- splink/**
- tests/**
- pyproject.toml

jobs:
mypy:
runs-on: ubuntu-20.04
name: Check type hinting with mypy
steps:
#----------------------------------------------
# check-out repo and set-up python
#----------------------------------------------
- name: Check out repository
uses: actions/checkout@v3
- name: Set up python
id: setup-python
uses: actions/setup-python@v4
with:
python-version: 3.9.10
#----------------------------------------------
# set up environment
#----------------------------------------------
- name: Load cached Poetry installation
uses: actions/cache@v2
with:
path: ~/.local
key: poetry-0
- name: Install Poetry
uses: snok/install-poetry@v1
with:
version: 1.8.2
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v2
with:
path: .venv
key: venv-typehint-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}-00
- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction --no-root --with=typechecking
- name: Install library
run: poetry install --no-interaction
#----------------------------------------------
# run mypy
#----------------------------------------------
- name: Run mypy
run: |
source .venv/bin/activate
mypy splink

600 changes: 243 additions & 357 deletions poetry.lock

Large diffs are not rendered by default.

13 changes: 7 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ rapidfuzz = ">=2.0.3"
[tool.poetry.group.typechecking]
optional = true
[tool.poetry.group.typechecking.dependencies]
mypy = "1.7.0"
mypy = "1.9.0"

[tool.poetry.extras]
pyspark = ["pyspark"]
Expand Down Expand Up @@ -112,16 +112,17 @@ markers = [
packages = "splink"
# temporary exclusions
exclude = [
# modules getting substantial rewrites:
'.*comparison_imports\.py$',
'.*comparison.*library\.py',
'comparison_level_composition',
# modules with large number of errors
'.*comparison.*library\.py',
'.*linker\.py',
'/settings_validation/'
]
# for now at least allow implicit optionals
# to cut down on noise. Easy to fix.
implicit_optional = true
# for now, ignore missing imports
# can remove later and install stubs, where existent
ignore_missing_imports = true
ignore_missing_imports = true
# don't follow imports to modules we don't want to typecheck yet
# eventually restore this back to the default "normal"
follow_imports = "silent"
10 changes: 6 additions & 4 deletions splink/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,12 @@ def _select_found_by_blocking_rules(linker: "Linker"):
brs = linker._settings_obj._blocking_rules_to_generate_predictions

if brs:
brs = [move_l_r_table_prefix_to_column_suffix(b.blocking_rule_sql) for b in brs]
brs = [f"(coalesce({b}, false))" for b in brs]
brs = " OR ".join(brs)
br_col = f" ({brs}) "
br_strings = [
move_l_r_table_prefix_to_column_suffix(b.blocking_rule_sql) for b in brs
]
wrapped_br_strings = [f"(coalesce({b}, false))" for b in br_strings]
full_br_string = " OR ".join(wrapped_br_strings)
br_col = f" ({full_br_string}) "
else:
br_col = " 1=1 "

Expand Down
22 changes: 17 additions & 5 deletions splink/blocking.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, List, Optional

from sqlglot import parse_one
from sqlglot.expressions import Column, Join
from sqlglot.optimizer.eliminate_joins import join_condition

from .exceptions import SplinkException
from .input_column import InputColumn
from .misc import ensure_is_list
from .splink_dataframe import SplinkDataFrame
Expand All @@ -19,7 +20,7 @@
from .linker import Linker


def blocking_rule_to_obj(br) -> BlockingRule:
def blocking_rule_to_obj(br: BlockingRule | dict | str) -> BlockingRule:
if isinstance(br, BlockingRule):
return br
elif isinstance(br, dict):
Expand Down Expand Up @@ -262,9 +263,15 @@ def __init__(
sqlglot_dialect: str = None,
array_columns_to_explode: list = [],
):
super().__init__(blocking_rule, sqlglot_dialect)
if isinstance(blocking_rule, BlockingRule):
blocking_rule_sql = blocking_rule.blocking_rule_sql
elif isinstance(blocking_rule, dict):
blocking_rule_sql = blocking_rule["blocking_rule_sql"]
else:
blocking_rule_sql = blocking_rule
super().__init__(blocking_rule_sql, sqlglot_dialect)
self.array_columns_to_explode: List[str] = array_columns_to_explode
self.exploded_id_pair_table: SplinkDataFrame = None
self.exploded_id_pair_table: Optional[SplinkDataFrame] = None

def marginal_exploded_id_pairs_table_sql(self, linker: Linker, br: BlockingRule):
"""generates a table of the marginal id pairs from the exploded blocking rule
Expand Down Expand Up @@ -325,7 +332,12 @@ def exclude_pairs_generated_by_this_rule_sql(self, linker: Linker):
unique_id_input_columns = (
linker._settings_obj.column_info_settings.unique_id_input_columns
)
splink_df = self.exploded_id_pair_table
if (splink_df := self.exploded_id_pair_table) is None:
raise SplinkException(
"Must use `materialise_exploded_id_table(linker)` "
"to set `exploded_id_pair_table` before calling "
"exclude_pairs_generated_by_this_rule_sql()."
)
ids_to_compare_sql = f"select * from {splink_df.physical_name}"

id_expr_l = _composite_unique_id_from_nodes_sql(unique_id_input_columns, "l")
Expand Down
9 changes: 7 additions & 2 deletions splink/blocking_rule_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def create_sql(self, sql_dialect: SplinkDialect) -> str:


class _Merge(BlockingRuleCreator):
_clause = ""

@final
def __init__(
self,
Expand All @@ -90,7 +92,10 @@ def __init__(
raise ValueError(
f"Must provide at least one blocking rule to {type(self)}()"
)
self.blocking_rules = blocking_rules
blocking_rule_creators = [
CustomRule(**br) if isinstance(br, dict) else br for br in blocking_rules
]
self.blocking_rules = blocking_rule_creators

@property
def salting_partitions(self):
Expand Down Expand Up @@ -165,7 +170,7 @@ def block_on(
)

if len(col_names_or_exprs) == 1:
br = ExactMatchRule(col_names_or_exprs[0])
br: BlockingRuleCreator = ExactMatchRule(col_names_or_exprs[0])
else:
br = And(*[ExactMatchRule(c) for c in col_names_or_exprs])

Expand Down
108 changes: 64 additions & 44 deletions splink/cluster_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import os
import random
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional

from jinja2 import Template

Expand Down Expand Up @@ -42,7 +42,7 @@ def _clusters_sql(df_clustered_nodes, cluster_ids: list) -> str:

def df_clusters_as_records(
linker: "Linker", df_clustered_nodes: SplinkDataFrame, cluster_ids: list
):
) -> list[dict]:
"""Retrieves distinct clusters which exist in df_clustered_nodes based on
list of cluster IDs provided and converts them to a record dictionary.

Expand Down Expand Up @@ -86,7 +86,7 @@ def _nodes_sql(df_clustered_nodes, cluster_ids) -> str:

def create_df_nodes(
linker: "Linker", df_clustered_nodes: SplinkDataFrame, cluster_ids: list
):
) -> SplinkDataFrame:
"""Retrieves nodes from df_clustered_nodes for list of cluster IDs provided.

Args:
Expand Down Expand Up @@ -150,7 +150,7 @@ def df_edges_as_records(

def _get_random_cluster_ids(
linker: "Linker", connected_components: SplinkDataFrame, sample_size: int, seed=None
):
) -> list[str]:
sql = f"""
select count(distinct cluster_id) as count
from {connected_components.physical_name}
Expand Down Expand Up @@ -189,7 +189,7 @@ def _get_random_cluster_ids(

def _get_cluster_id_of_each_size(
linker: "Linker", connected_components: SplinkDataFrame, rows_per_partition: int
):
) -> list[dict]:
unique_id_col_name = linker._settings_obj.column_info_settings.unique_id_column_name
sql = f"""
select
Expand Down Expand Up @@ -233,7 +233,7 @@ def _get_lowest_density_clusters(
df_cluster_metrics: SplinkDataFrame,
rows_per_partition: int,
min_nodes: int,
):
) -> list[dict]:
"""Returns lowest density clusters of different sizes by
performing stratified sampling.

Expand Down Expand Up @@ -277,6 +277,55 @@ def _get_lowest_density_clusters(
return df_lowest_density_clusters.as_record_dict()


def _get_cluster_ids(
linker: "Linker",
df_clustered_nodes: SplinkDataFrame,
sampling_method,
sample_size,
sample_seed,
_df_cluster_metrics: Optional[SplinkDataFrame] = None,
) -> tuple[list, list]:
if sampling_method == "random":
cluster_ids = _get_random_cluster_ids(
linker, df_clustered_nodes, sample_size, sample_seed
)
cluster_names = []
elif sampling_method == "by_cluster_size":
cluster_id_infos = _get_cluster_id_of_each_size(
linker, df_clustered_nodes, rows_per_partition=1
)
if len(cluster_id_infos) > sample_size:
cluster_id_infos = random.sample(cluster_id_infos, k=sample_size)
cluster_names = [
f"Cluster ID: {c['cluster_id']}, size: {c['cluster_size']}"
for c in cluster_id_infos
]
cluster_ids = [c["cluster_id"] for c in cluster_id_infos]
elif sampling_method == "lowest_density_clusters_by_size":
if _df_cluster_metrics is None:
raise SplinkException(
"""To sample by density, you must provide a cluster metrics table
containing density. This can be generated by calling the
_compute_graph_metrics method on the linker."""
)
# Using sensible default for min_nodes. Might become option
# for users in future
cluster_id_infos = _get_lowest_density_clusters(
linker, _df_cluster_metrics, rows_per_partition=1, min_nodes=3
)
if len(cluster_id_infos) > sample_size:
cluster_id_infos = random.sample(cluster_id_infos, k=sample_size)
cluster_names = [
f"""Cluster ID: {c['cluster_id']}, density (4dp): {c['density_4dp']},
size: {c['cluster_size']}"""
for c in cluster_id_infos
]
cluster_ids = [c["cluster_id"] for c in cluster_id_infos]
else:
raise ValueError(f"Unknown sampling method {sampling_method}")
return cluster_ids, cluster_names


def render_splink_cluster_studio_html(
linker: "Linker",
df_predicted_edges: SplinkDataFrame,
Expand All @@ -285,7 +334,7 @@ def render_splink_cluster_studio_html(
sampling_method="random",
sample_size=10,
sample_seed=None,
cluster_ids: list = None,
cluster_ids: list[str] = None,
cluster_names: list = None,
overwrite: bool = False,
_df_cluster_metrics: SplinkDataFrame = None,
Expand All @@ -296,43 +345,15 @@ def render_splink_cluster_studio_html(
"cluster_colname": "cluster_id",
"prob_colname": "match_probability",
}
named_clusters_dict = None
if cluster_ids is None:
if sampling_method == "random":
cluster_ids = _get_random_cluster_ids(
linker, df_clustered_nodes, sample_size, sample_seed
)
if sampling_method == "by_cluster_size":
cluster_ids = _get_cluster_id_of_each_size(linker, df_clustered_nodes, 1)
if len(cluster_ids) > sample_size:
cluster_ids = random.sample(cluster_ids, k=sample_size)
cluster_names = [
f"Cluster ID: {c['cluster_id']}, size: {c['cluster_size']}"
for c in cluster_ids
]
cluster_ids = [c["cluster_id"] for c in cluster_ids]
named_clusters_dict = dict(zip(cluster_ids, cluster_names))
if sampling_method == "lowest_density_clusters_by_size":
if _df_cluster_metrics is None:
raise SplinkException(
"""To sample by density, you must provide a cluster metrics table
containing density. This can be generated by calling the
_compute_graph_metrics method on the linker."""
)
# Using sensible default for min_nodes. Might become option
# for users in future
cluster_ids = _get_lowest_density_clusters(
linker, _df_cluster_metrics, rows_per_partition=1, min_nodes=3
)
if len(cluster_ids) > sample_size:
cluster_ids = random.sample(cluster_ids, k=sample_size)
cluster_names = [
f"""Cluster ID: {c['cluster_id']}, density (4dp): {c['density_4dp']},
size: {c['cluster_size']}"""
for c in cluster_ids
]
cluster_ids = [c["cluster_id"] for c in cluster_ids]
named_clusters_dict = dict(zip(cluster_ids, cluster_names))
cluster_ids, cluster_names = _get_cluster_ids(
linker,
df_clustered_nodes,
sampling_method,
sample_size,
sample_seed,
_df_cluster_metrics,
)

cluster_recs = df_clusters_as_records(linker, df_clustered_nodes, cluster_ids)
df_nodes = create_df_nodes(linker, df_clustered_nodes, cluster_ids)
Expand All @@ -356,7 +377,6 @@ def render_splink_cluster_studio_html(
if cluster_names:
named_clusters_dict = dict(zip(cluster_ids, cluster_names))

if named_clusters_dict:
template_data["named_clusters"] = json.dumps(
named_clusters_dict, cls=EverythingEncoder
)
Expand Down
6 changes: 4 additions & 2 deletions splink/column_expression.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import re
import string
from copy import copy
from functools import partial
from typing import Union
from typing import Callable, Union

import sqlglot

Expand Down Expand Up @@ -38,7 +40,7 @@ class ColumnExpression:

def __init__(self, sql_expression: str, sql_dialect: SplinkDialect = None):
self.raw_sql_expression = sql_expression
self.operations = []
self.operations: list[Callable] = []
if sql_dialect is not None:
self.sql_dialect: SplinkDialect = sql_dialect

Expand Down
Loading
Loading