Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed May 16, 2024
1 parent 01b8d6e commit ef7fdd8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
15 changes: 13 additions & 2 deletions scripts/reduce_notebook_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,31 @@ def modify_notebook(file_path):
data["cells"] = data["cells"][:19]
changed = True

if "sqlite" in file_path:
max_pairs = 3e5
head_num = 800
else:
max_pairs = 1e5
head_num = 400

for cell in data["cells"]:
if cell["cell_type"] == "code":
source = cell["source"]
new_source = []
for line in source:
if "splink_datasets" in line and "=" in line:
parts = line.split("=")
parts[1] = parts[1].strip() + ".head(400)"
parts[1] = parts[1].strip() + f".head({head_num})"
new_line = " = ".join(parts) + "\n"
new_source.append(new_line)
changed = True
elif "estimate_u_using_random_sampling(" in line:
new_line = (
re.sub(r"max_pairs=\d+(\.\d+)?[eE]\d+", "max_pairs=1e5", line)
re.sub(
r"max_pairs=\d+(\.\d+)?[eE]\d+",
f"max_pairs={max_pairs}",
line,
)
+ "\n"
)
new_source.append(new_line)
Expand Down
4 changes: 2 additions & 2 deletions splink/internals/blocking_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def count_comparisons_from_blocking_rule(
blocking_rule_creator: Union[BlockingRuleCreator, str, Dict[str, Any]],
link_type: user_input_link_type_options,
db_api: DatabaseAPISubClass,
unique_id_column_name: str,
unique_id_column_name: str = "unqiue_id",
source_dataset_column_name: Optional[str] = None,
compute_post_filter_count: bool = True,
max_rows_limit: int = int(1e9),
Expand Down Expand Up @@ -574,7 +574,7 @@ def cumulative_comparisons_to_be_scored_from_blocking_rules_data(
blocking_rule_creators: Iterable[Union[BlockingRuleCreator, str, Dict[str, Any]]],
link_type: user_input_link_type_options,
db_api: DatabaseAPISubClass,
unique_id_column_name: str,
unique_id_column_name: str = "unique_id",
max_rows_limit: int = int(1e9),
source_dataset_column_name: Optional[str] = None,
) -> pd.DataFrame:
Expand Down

0 comments on commit ef7fdd8

Please sign in to comment.