diff --git a/helpers/regression_imputation.py b/helpers/regression_imputation.py new file mode 100644 index 000000000..3eedbcabc --- /dev/null +++ b/helpers/regression_imputation.py @@ -0,0 +1,68 @@ +""" +Regression test to compare two versions of outputs +Reads two csv files, old and new +Selects the columns of interest +Joins old and new on key columns, outer +Checks which records are in old only (left), new only (right) or both +Compares if the old and new values are the same within tolerance +Saves the ouotputs +""" + +#%% Configuration settings +import pandas as pd + +# Input folder and file names +root_path = "R:/BERD Results System Development 2023/DAP_emulation/2023_surveys/BERD/06_imputation/imputation_qa/" +in_file_old = "2023_full_responses_imputed_24-09-10_v764.csv" +in_file_new = "tmp_qa_output2.csv" + +# Output folder and file +out_fol = root_path +out_file = "imputation_breakdown_check.csv" + +# Columns to select +key_cols = ["reference", "instance"] +value_col = "211" +other_cols = [ + "200", + "201", + "formtype", + "imp_class", + "imp_marker" + "status", +] +tolerance = 0.001 +#%% Read files +cols_read = key_cols + [value_col] + other_cols +df_old = pd.read_csv(root_path + in_file_old) +df_new = pd.read_csv(root_path + in_file_new) + +#%% join old and new +df_merge = df_old.merge(df_new, on=key_cols, how="inner", suffixes=("_old", "_new")) + +#%% +df_merge.to_csv(root_path + out_file, index=False) + + +#%% Filter good statuses only +imp_markers_to_keep = ["TMI", "CF", "MoR", "constructed"] +df_old_good = df_old[df_old["imp_marker"].isin(imp_markers_to_keep)] +df_new_good = df_new[df_new["imp_marker"].isin(imp_markers_to_keep)] + +#%% sizes +print(f"Old size: {df_old_good.shape}") +print(f"New size: {df_new_good.shape}") + +#%% Join +df_merge = df_old_good.merge( + df_new_good, on=key_cols, how="outer", suffixes=("_old", "_new"), indicator=True +) +#%% Compare the values +df_merge["value_different"] = ( + df_merge[value_col + "_old"] - df_merge[value_col + "_new"] +) ** 2 > tolerance**2 + +# %% Save output +write_csv(out_fol + out_file, df_merge) + +# %% diff --git a/src/_version.py b/src/_version.py index bf7882637..8c0d5d5bb 100644 --- a/src/_version.py +++ b/src/_version.py @@ -1 +1 @@ -__version__ = "1.1.7" +__version__ = "2.0.0" diff --git a/src/construction/all_data_construction.py b/src/construction/all_data_construction.py index ec0ff98df..2961a4dd1 100644 --- a/src/construction/all_data_construction.py +++ b/src/construction/all_data_construction.py @@ -24,6 +24,7 @@ def all_data_construction( construction_df: pd.DataFrame, snapshot_df: pd.DataFrame, construction_logger: logging.Logger, + config: dict, is_northern_ireland: bool = False, ) -> pd.DataFrame: """Run all data construction on the GB or NI data. @@ -122,7 +123,7 @@ def all_data_construction( # Check breakdowns if not is_northern_ireland: updated_snapshot_df = run_breakdown_validation( - updated_snapshot_df, check="constructed" + updated_snapshot_df, config, check="constructed" ) construction_logger.info(f"Construction edited {construction_df.shape[0]} rows.") diff --git a/src/construction/construction_main.py b/src/construction/construction_main.py index 583dc00d8..0847fc770 100644 --- a/src/construction/construction_main.py +++ b/src/construction/construction_main.py @@ -3,7 +3,6 @@ from typing import Callable import pandas as pd -import numpy as np from src.construction.construction_read_validate import ( read_validate_all_construction_files, @@ -63,7 +62,7 @@ def run_construction( # noqa: C901 is_northern_ireland=True, ) updated_snapshot_df = all_data_construction( - df, snapshot_df, construction_logger, is_northern_ireland=True + df, snapshot_df, construction_logger, config, is_northern_ireland=True ) elif is_run_all_data_construction: @@ -73,7 +72,7 @@ def run_construction( # noqa: C901 config, check_file_exists, read_csv, construction_logger ) updated_snapshot_df = all_data_construction( - df, snapshot_df, construction_logger + df, snapshot_df, construction_logger, config ) elif is_run_postcode_construction: diff --git a/src/dev_config.yaml b/src/dev_config.yaml index 789ab825a..2a7636035 100644 --- a/src/dev_config.yaml +++ b/src/dev_config.yaml @@ -199,6 +199,30 @@ breakdowns: - "headcount_tec_f" - "headcount_oth_m" - "headcount_oth_f" +consistency_checks: + 2xx_totals: + purchases_split: ["222", "223", "203"] + sal_oth_expend: ["202", "203", "204"] + research_expend: ["205", "206", "207", "204"] + capex: ["219", "220", "209", "210"] + intram: ["204", "210", "211"] + funding: ['212', '214', '216', '242', '250', '243', '244', '245', '246', '247', '248', '249', '218'] + ownership: ['225', '226', '227', '228', '229', '237', '218'] + equality: ['211', '218'] + 3xx_totals: + purchases: ['302', '303', '304', '305'] + 4xx_totals: + emp_civil: ['405', '407', '409', '411'] + emp_defence: ['406', '408', '410', '412'] + 5xx_totals: + hc_res_m: ['501', '503', '505', '507'] + hc_res_f: ['502', '504', '506', '508'] + apportioned_totals: + employment: ["emp_researcher", "emp_technician", "emp_other", "emp_total"] + hc_male: ["headcount_res_m", "headcount_tec_m", "headcount_oth_m", "headcount_tot_m"] + hc_female: ["headcount_res_f", "headcount_tec_f", "headcount_oth_f", "headcount_tot_f"] + hc_tot: ["headcount_tot_m", "headcount_tot_f", "headcount_total"] + s3: ssl_file: "/etc/pki/tls/certs/ca-bundle.crt" s3_bucket: "onscdp-dev-data01-5320d6ca" diff --git a/src/imputation/imputation_helpers.py b/src/imputation/imputation_helpers.py index 41be0a2ef..d3728a7bd 100644 --- a/src/imputation/imputation_helpers.py +++ b/src/imputation/imputation_helpers.py @@ -279,11 +279,11 @@ def split_df_on_imp_class(df: pd.DataFrame, exclusion_list: List = ["817", "nan" (Product Group)- these will generally be filtered out from the imputation classes. Where short forms are under consideration, "817" imputation classes will be excluded - + Args: df (pd.DataFrame): The dataframe to split - exclusion_list (List, optional): A list of imputation classes to exclude. - + exclusion_list (List, optional): A list of imputation classes to exclude. + Returns: pd.DataFrame: The filtered dataframe with the invalid imp classes removed pd.DataFrame: The excluded dataframe @@ -363,9 +363,22 @@ def calculate_totals(df): return df +def breakdown_checks_after_imputation(df: pd.DataFrame) -> None: + """After imputation check required columns still sum correctly. + + Args: + df (pd.DataFrame): The dataframe with imputed values. + + Returns: + None + """ + # create dictionary of checks: the last col in the list is the total col + # the sum of the other cols should equal the total + + def tidy_imputation_dataframe(df: pd.DataFrame, to_impute_cols: List) -> pd.DataFrame: """Update cols with imputed values and remove rows and columns no longer needed. - + Args: df (pd.DataFrame): The dataframe with imputed values. to_impute_cols (List): The columns that were imputed. diff --git a/src/imputation/imputation_main.py b/src/imputation/imputation_main.py index 88adcdcde..1678be599 100644 --- a/src/imputation/imputation_main.py +++ b/src/imputation/imputation_main.py @@ -13,9 +13,8 @@ from src.imputation.sf_expansion import run_sf_expansion from src.imputation import manual_imputation as mimp from src.imputation.MoR import run_mor -from src.construction.construction_main import run_construction -from src.mapping.itl_mapping import join_itl_regions from src.outputs.outputs_helpers import create_output_df +from src.utils.breakdown_validation import run_breakdown_validation ImputationMainLogger = logging.getLogger(__name__) @@ -144,7 +143,9 @@ def run_imputation( f"{survey_year}_full_responses_imputed_{tdate}_v{run_id}.csv" ) wrong_604_filename = f"{survey_year}_wrong_604_error_qa_{tdate}_v{run_id}.csv" - trimmed_counts_filename = f"{survey_year}_tmi_trim_count_qa_{tdate}_v{run_id}.csv" + trimmed_counts_filename = ( + f"{survey_year}_tmi_trim_count_qa_{tdate}_v{run_id}.csv" + ) # create trimming qa dataframe with required columns from schema schema_path = config["schema_paths"]["manual_trimming_schema"] @@ -160,6 +161,9 @@ def run_imputation( # remove rows and columns no longer needed from the imputed dataframe imputed_df = hlp.tidy_imputation_dataframe(imputed_df, to_impute_cols) + # Check the imputed values are consistent with breakdown cols summing to totals. + run_breakdown_validation(imputed_df, config, check="imputed") + # optionally output backdata for imputation if config["global"]["output_backdata"]: ImputationMainLogger.info("Outputting backdata for imputation.") diff --git a/src/staging/staging_main.py b/src/staging/staging_main.py index 13317894b..59af39d98 100644 --- a/src/staging/staging_main.py +++ b/src/staging/staging_main.py @@ -10,6 +10,8 @@ import src.staging.staging_helpers as helpers from src.staging import validation as val +# from src.utils.breakdown_validation import run_breakdown_validation + StagingMainLogger = logging.getLogger(__name__) @@ -160,6 +162,12 @@ def run_staging( # noqa: C901 rd_file_exists(postcode_mapper, raise_error=True) postcode_mapper = rd_read_csv(postcode_mapper) + # Staging of the main snapshot data is now complete + StagingMainLogger.info("Staging of main snapshot data complete.") + # run validation on the breakdowns + # run_breakdown_validation(full_responses, config, "staged") + + # Staging of the additional data if config["global"]["load_manual_outliers"]: # Stage the manual outliers file StagingMainLogger.info("Loading Manual Outlier File") @@ -180,10 +188,7 @@ def run_staging( # noqa: C901 # Get the latest manual trim file manual_trim_path = staging_dict["manual_imp_trim_path"] - if ( - config["global"]["load_manual_imputation"] and - rd_file_exists(manual_trim_path) - ): + if config["global"]["load_manual_imputation"] and rd_file_exists(manual_trim_path): StagingMainLogger.info("Loading Imputation Manual Trimming File") wanted_cols = ["reference", "instance", "manual_trim"] manual_trim_df = rd_read_csv(manual_trim_path, wanted_cols) diff --git a/src/user_config.yaml b/src/user_config.yaml index b8e71a409..cfe9b43ec 100644 --- a/src/user_config.yaml +++ b/src/user_config.yaml @@ -29,7 +29,7 @@ global: output_imputation_qa: False output_auto_outliers: False output_outlier_qa : False - output_estimation_qa: True + output_estimation_qa: False output_apportionment_qa: False # Final output settings output_long_form: False diff --git a/src/utils/breakdown_validation.py b/src/utils/breakdown_validation.py index a2b4c84af..c8ef55509 100644 --- a/src/utils/breakdown_validation.py +++ b/src/utils/breakdown_validation.py @@ -1,38 +1,59 @@ """Function to validate the breakdown totals.""" -import os import logging import pandas as pd -import numpy as np -from pandas import DataFrame as pandasDF + +from typing import Tuple BreakdownValidationLogger = logging.getLogger(__name__) -# checks that need to be done: the last col in the list is the total col -# the sum of the other cols should equal the total -equals_checks = { - 'check1': ['222', '223', '203'], - 'check2': ['202', '203', '204'], - 'check3': ['205', '206', '207', '204'], - 'check4': ['219', '220', '209', '210'], - 'check5': ['204', '210', '211'], - 'check6': ['212', '214', '216', '242', '250', '243', '244', '245', '246', '247', '248', '249', '218'], - 'check7': ['211', '218'], - 'check8': ['225', '226', '227', '228', '229', '237', '218'], - 'check9': ['302', '303', '304', '305'], - 'check10': ['501', '503', '505', '507'], - 'check11': ['502', '504', '506', '508'], - 'check12': ['405', '407', '409', '411'], - 'check13': ['406', '408', '410', '412'], -} - -# checks that need to be done: the second value should not be greater than the first -greater_than_checks = { - 'check14': ['209', '221'], - 'check15': ['211', '202'], -} - - -def replace_nulls_with_zero(df: pd.DataFrame) -> pd.DataFrame: + +def get_equality_dicts(config: dict, sublist: str = "default") -> dict: + """ + Get the equality checks for the construction data. + + Args: + config (dict): The config dictionary. + + Returns: + dict + """ + # use the config to get breakdown totals + all_checks_dict = config["consistency_checks"] + + # isolate the relationships suitlable for checking in the construction module + if sublist == "default": + wanted_dicts = [key for key in all_checks_dict.keys() if "xx_totals" in key] + elif sublist == "imputation": + wanted_dicts = ["2xx_totals", "3xx_totals", "apportioned_totals"] + else: + wanted_dicts = list(all_checks_dict.keys()) + + # create a dictionary of the relationships to check + construction_equality_checks = {} + for item in wanted_dicts: + construction_equality_checks.update(all_checks_dict[item]) + + return construction_equality_checks + + +def get_all_wanted_columns(config: dict) -> list: + """ + Get all the columns that we want to check. + + Args: + config (dict): The config dictionary. + + Returns: + list: A list of all the columns to check. + """ + equals_checks = get_equality_dicts(config, "default") + all_columns = [] + for list_item in equals_checks.values(): + all_columns += list_item + return all_columns + + +def replace_nulls_with_zero(df: pd.DataFrame, equals_checks) -> pd.DataFrame: """ Replace nulls with zeros where the total is zero. @@ -43,6 +64,7 @@ def replace_nulls_with_zero(df: pd.DataFrame) -> pd.DataFrame: pd.DataFrame """ BreakdownValidationLogger.info("Replacing nulls with zeros where total zero") + for columns in equals_checks.values(): total_column = columns[-1] breakdown_columns = columns[:-1] @@ -52,58 +74,69 @@ def replace_nulls_with_zero(df: pd.DataFrame) -> pd.DataFrame: return df -def remove_all_nulls_rows(df: pd.DataFrame) -> pd.DataFrame: +def remove_all_nulls_rows(df: pd.DataFrame, config: dict) -> pd.DataFrame: """ Remove rows where all breakdown/total cols are null from validation. Args: df (pd.DataFrame): The dataframe to check. + config (dict): The pipeline config dictionary. Returns: pd.DataFrame """ BreakdownValidationLogger.info("Removing rows with all null values from validation") + wanted_cols = get_all_wanted_columns(config) rows_to_validate = df.dropna( - subset=[ - '222', '223', '203', '202', '204', '205', '206', '207', '221', '209', '219', - '220', '210', '204', '211', '212', '214', '216', '242', '250', '243', '244', - '245', '246', '218', '225', '226', '227', '228', '229', '237', '302', '303', - '304', '305', '501', '503', '505', '507', '502', '504', '506', '508', '405', - '407', '409', '411', '406', '408', '410', '412' - ], how='all').reset_index(drop=True) + subset=wanted_cols, + how="all", + ).reset_index(drop=True) + return rows_to_validate -def equal_validation(rows_to_validate: pd.DataFrame) -> pd.DataFrame: +def equal_validation( + rows_to_validate: pd.DataFrame, equals_checks: dict +) -> pd.DataFrame: """ Check where the sum of some columns should equal another column. Args: rows_to_validate (pd.DataFrame): The dataframe to check. + equals_checks (dict): The dictionary of columns to check. Returns: tuple(str, int) """ BreakdownValidationLogger.info("Doing breakdown total checks...") + msg = "" count = 0 for index, row in rows_to_validate.iterrows(): for key, columns in equals_checks.items(): total_column = columns[-1] breakdown_columns = columns[:-1] - if rows_to_validate[columns].isnull().all(axis=1).iloc[index] or (rows_to_validate[columns] == 0).all(axis=1).iloc[index]: + if ( + rows_to_validate[columns].isnull().all(axis=1).iloc[index] + or (rows_to_validate[columns] == 0).all(axis=1).iloc[index] + ): continue - if not (rows_to_validate[breakdown_columns].sum(axis=1) == rows_to_validate[total_column]).iloc[index]: - msg += f"Columns {breakdown_columns} do not equal column {total_column} for reference: {row['reference']}, instance {row['instance']}.\n " + if not ( + rows_to_validate[breakdown_columns].sum(axis=1) + == rows_to_validate[total_column] + ).iloc[index]: + msg += ( + f"Columns {breakdown_columns} do not equal column" + f" {total_column} for reference: {row['reference']}, instance" + f" {row['instance']}.\n " + ) count += 1 return msg, count def greater_than_validation( - rows_to_validate: pd.DataFrame, - msg: str, - count: int - ) -> pd.DataFrame: + rows_to_validate: pd.DataFrame, msg: str, count: int +) -> pd.DataFrame: """ Check where one value should be greater than another. @@ -115,54 +148,227 @@ def greater_than_validation( Returns: pd.DataFrame """ - BreakdownValidationLogger.info("Doing checks for values that should be greater than...") + BreakdownValidationLogger.info( + "Doing checks for values that should be greater than..." + ) + greater_than_checks = { + "check14": ["209", "221"], + "check15": ["211", "202"], + } for index, row in rows_to_validate.iterrows(): for key, columns in greater_than_checks.items(): should_be_greater = columns[0] should_not_be_greater = columns[1] - if (rows_to_validate[should_not_be_greater] > rows_to_validate[should_be_greater]).all(): - msg += f"Column {should_not_be_greater} is greater than {should_be_greater} for reference: {row['reference']}, instance {row['instance']}.\n " + if ( + rows_to_validate[should_not_be_greater] + > rows_to_validate[should_be_greater] + ).all(): + msg += ( + f"Column {should_not_be_greater} is greater than" + f" {should_be_greater} for reference: {row['reference']}, instance" + f" {row['instance']}.\n " + ) count += 1 return msg, count -def run_breakdown_validation( - df: pd.DataFrame, - check: str = 'all' - ) -> pd.DataFrame: +def get_breakdown_errors(df: pd.DataFrame, to_check: dict) -> pd.DataFrame: + """Function to check total columns remain consistent after imputation. + + Args: + df (pd.DataFrame): The dataframe to check. + config (dict): The config dictionary. + + Returns: + dict: A dictionary with boolean values for each check. + pd.DataFrame: The dataframe with the breakdown errors """ - Function to validate the breakdown totals. + qa_df = df.copy() + wanted_refs = [] # a list of references that have errors + cols = [] # a list of columns that have errors + + check_results_dict = {} + for key, columns in to_check.items(): + total_column = columns[-1] + breakdown_columns = columns[:-1] + check_cond = abs(df[breakdown_columns].sum(axis=1) - df[total_column]) > 0.005 + # if there are any errors for particular check.. + if any(check_cond): + # ...create a mini dataframe for the relevant rows and columns + wanted_cols = ["reference", "instance", "imp_class", "imp_marker"] + list( + columns + ) + check_df = df.copy().loc[check_cond, wanted_cols] + check_df[f"{key}_diff"] = ( + df.loc[check_cond, breakdown_columns].sum(axis=1) + - df.loc[check_cond, total_column] + ) + + # add the diff column to the qa dataframe + qa_df.loc[check_cond, f"{key}_diff"] = check_df[f"{key}_diff"] + wanted_refs += [ + r for r in (check_df["reference"].tolist()) if r not in wanted_refs + ] + cols += [c for c in check_df.columns if c not in cols] + [f"{key}_diff"] + + # add the mini-dataframe to the dict for unit testing and logger messages + check_results_dict[key] = check_df + else: + check_results_dict[key] = pd.DataFrame() + + # Filter and select so the qa dataframe has only relevant columns and rows + qa_df = qa_df.loc[df["reference"].isin(wanted_refs)][cols] + + return check_results_dict, qa_df + + +def log_errors_to_screen(check_results_dict: dict, check_type: str) -> None: + """Function to log the errors to the screen. + + Args: + check_results_dict (dict): The dictionary of errors to log. + + Returns: + None + """ + for key, value in check_results_dict.items(): + if not value.empty: + BreakdownValidationLogger.error( + f"Breakdown validation failed for {key} columns" + ) + BreakdownValidationLogger.error(value) + else: + BreakdownValidationLogger.info( + f"All {check_type} breakdown vals are valid." + ) + + +def run_imputation_breakdown_validation(df: pd.DataFrame, config: dict) -> None: + """Function to run the breakdown validation for the imputed data. Args: df (pd.DataFrame): The dataframe to check. - check (str): The type of validation to run. Default is 'all'. + config (dict): The config dictionary. + + Raises: + ValueError: If any of the breakdown values do not sum to the total. Returns: - pd.DataFrame + None """ + to_check_dict = get_equality_dicts(config, "imputation") + check_results_dict, qa_df = get_breakdown_errors(df, to_check_dict) + # temp output of qa_df for debugging + # if not qa_df.empty: + # print(qa_df) - if check == 'constructed': - BreakdownValidationLogger.info("Running breakdown validation for constructed data") - validation_df = df[df['is_constructed'] == True].copy() - not_for_validating_df = df[df['is_constructed'] == False].copy() - else: - validation_df = df.copy() + log_errors_to_screen(check_results_dict, "imputation") - validation_df = replace_nulls_with_zero(validation_df) - rows_to_validate = remove_all_nulls_rows(validation_df) - msg, count = equal_validation(rows_to_validate) - msg, count = greater_than_validation(rows_to_validate, msg, count) - if check != 'all': - df = pd.concat([validation_df, not_for_validating_df], ignore_index=True) - else: - df = validation_df +def run_construction_breakdown_validation(df: pd.DataFrame, config: dict) -> None: + """Function to run the breakdown validation for the constructed data. + + Few errors are expected, so any that exist will be logged to the screen. + + Args: + df (pd.DataFrame): The dataframe to check. + config (dict): The config dictionary. + + Returns: + pd.DataFrame + """ + to_check_dict = get_equality_dicts(config) + df = replace_nulls_with_zero(df, to_check_dict) - BreakdownValidationLogger.info(f"There are {count} errors with the breakdown values") + rows_to_validate = remove_all_nulls_rows(df, config) + msg, count = equal_validation(rows_to_validate, to_check_dict) + msg, count = greater_than_validation(rows_to_validate, msg, count) + BreakdownValidationLogger.info( + f"There are {count} errors with the breakdown values" + ) if not msg: BreakdownValidationLogger.info("All breakdown values are valid.") + # TODO: we will probably want to raise an error here when in production + # else: + # raise ValueError(msg) + + return df + + +def run_staging_breakdown_validation(df: pd.DataFrame, config: dict) -> None: + """Function to run the breakdown validation for the staged data. + Args: + df (pd.DataFrame): The dataframe to check. + config (dict): The config dictionary. + + Returns: + None + """ + to_check_dict = get_equality_dicts(config) + df = replace_nulls_with_zero(df, to_check_dict) + + check_results_dict, qa_df = get_breakdown_errors(df, to_check_dict) + + # Note: this is a temporary implementation to show the QA output while we debug + # if not qa_df.empty: + # print(qa_df) + + log_errors_to_screen(check_results_dict, "staging") + return df + + +def filter_on_condition( + df: pd.DataFrame, condition: pd.Series +) -> Tuple[pd.DataFrame, pd.DataFrame]: + """Function to filter a dataframe based on a condition. + + Args: + df (pd.DataFrame): The dataframe to filter. + condition (pd.Series): The condition to filter on. + + Returns: + tuple(pd.DataFrame, pd.DataFrame): The filtered dataframe and the dataframe that + was removed. + """ + validation_df = df[condition].copy() + removed_df = df[~condition].copy() + return validation_df, removed_df + + +def run_breakdown_validation( + df: pd.DataFrame, config: dict, check: str +) -> pd.DataFrame: + """ + Function to validate the breakdown totals. + + Args: + df (pd.DataFrame): The dataframe to check. + check (str): The type of validation to run. The values could be "staged", + "imputed", or "constructed". + + Returns: + pd.DataFrame or None + """ + BreakdownValidationLogger.info(f"Running breakdown validation for {check} data") + + if check == "constructed": + cond = df.is_constructed.isin([True]) + validation_df, remaining_df = filter_on_condition(df, cond) + validation_df = run_construction_breakdown_validation(validation_df, config) + + elif check == "imputed": + validation_df = df[df["imp_marker"].isin(["CF", "MoR", "TMI"])].copy() + qa_df = run_imputation_breakdown_validation(validation_df, config) + return qa_df + + elif check == "staged": + cond = df.status.isin(["Clear", "Clear - overridden"]) + validation_df, remaining_df = filter_on_condition(df, cond) + validation_df = run_staging_breakdown_validation(validation_df, config) + else: - raise ValueError(msg) + raise ValueError("Check must be one of 'constructed', 'imputed', or 'staged'.") + df = pd.concat([validation_df, remaining_df], ignore_index=True) return df diff --git a/tests/test_utils/test_breakdown_validation.py b/tests/test_utils/test_breakdown_validation.py index 957454557..61daa4607 100644 --- a/tests/test_utils/test_breakdown_validation.py +++ b/tests/test_utils/test_breakdown_validation.py @@ -3,6 +3,8 @@ import logging from src.utils.breakdown_validation import ( + get_equality_dicts, + get_all_wanted_columns, run_breakdown_validation, replace_nulls_with_zero, remove_all_nulls_rows, @@ -10,8 +12,169 @@ greater_than_validation, ) +@pytest.fixture(scope="module") +def create_config(): + """Create a config dictionary for the tests.""" + test_config = {"consistency_checks": { + "2xx_totals": { + "purchases": ["222", "223", "203"], + "sal_oth_expend": ["202", "203", "204"], + "research_expend": ["205", "206", "207", "204"], + "capex": ["219", "220", "209", "210"], + "intram": ["204", "210", "211"], + "funding": ["212", "214", "216", "242", "250", "243", "244", "245", "246", "247", "248", "249", "218"], + "ownership": ["225", "226", "227", "228", "229", "237", "218"], + "equality": ["211", "218"] + }, + "3xx_totals": { + "purchases": ['302', '303', '304', '305'] + }, + "4xx_totals": { + "emp_civil": ["405", "407", "409", "411"], + "emp_defence": ["406", "408", "410", "412"] + }, + "5xx_totals": { + "hc_res_m": ['501', '503', '505', '507'], + "hc_res_f": ['502', '504', '506', '508'], + }, + "apportioned_totals": { + "employment": ["emp_researcher", "emp_technician", "emp_other", "emp_total"], + "hc_male": ["headcount_res_m", "headcount_tec_m", "headcount_oth_m", "headcount_tot_m"], + "hc_female": ["headcount_res_f", "headcount_tec_f", "headcount_oth_f", "headcount_tot_f"], + "hc_tot": ["heacount_tot_m", "headcount_tot_f", "headcount_tot"] + } + }} + return test_config + + +@pytest.fixture(scope="module") +def create_equality_dict(): + equality_dict = { + "purchases": ["222", "223", "203"], + "sal_oth_expend": ["202", "203", "204"], + "research_expend": ["205", "206", "207", "204"], + "capex": ["219", "220", "209", "210"], + "intram": ["204", "210", "211"], + "funding": ["212", "214", "216", "242", "250", "243", "244", "245", "246", "247", "248", "249", "218"], + "ownership": ["225", "226", "227", "228", "229", "237", "218"], + "equality": ["211", "218"], + "purchases": ['302', '303', '304', '305'], + "emp_civil": ["405", "407", "409", "411"], + "emp_defence": ["406", "408", "410", "412"], + "hc_res_m": ['501', '503', '505', '507'], + "hc_res_f": ['502', '504', '506', '508'], + } + + return equality_dict + + +@pytest.fixture(scope="module") +def create_equality_dict_imputation(): + equality_dict = { + "purchases": ["222", "223", "203"], + "sal_oth_expend": ["202", "203", "204"], + "research_expend": ["205", "206", "207", "204"], + "capex": ["219", "220", "209", "210"], + "intram": ["204", "210", "211"], + "funding": ["212", "214", "216", "242", "250", "243", "244", "245", "246", "247", "248", "249", "218"], + "ownership": ["225", "226", "227", "228", "229", "237", "218"], + "equality": ["211", "218"], + "purchases": ['302', '303', '304', '305'], + "employment": ["emp_researcher", "emp_technician", "emp_other", "emp_total"], + "hc_male": ["headcount_res_m", "headcount_tec_m", "headcount_oth_m", "headcount_tot_m"], + "hc_female": ["headcount_res_f", "headcount_tec_f", "headcount_oth_f", "headcount_tot_f"], + "hc_tot": ["heacount_tot_m", "headcount_tot_f", "headcount_tot"] + } + + return equality_dict + + +def test_get_equality_dicts_construction(create_config, create_equality_dict): + """Test for get_equality_dicts function in the construction case.""" + config = create_config + expected_output = create_equality_dict + result = get_equality_dicts(config, "default") + assert result == expected_output + + +def test_get_equality_dicts_imputation(create_config, create_equality_dict_imputation): + """Test for get__equality_dicts function in the imputation case.""" + config = create_config + expected_output = create_equality_dict_imputation + result = get_equality_dicts(config, "imputation") + assert result == expected_output + + +def test_get_all_wanted_columns(create_config): + """Test for get_all_wanted_columns function.""" + config = create_config + expected_output = [ + '202', '203', '204', '205', '206', '207', "209", + "210", "211", "212", "214", "216", "218", "219", "220", + "225", "226", "227", "228", "229", "237", + "242", "243", "244", "245", "246", "247", "248", "249", "250", + "302", "303", "304", "305", + "501", "503", "505", "507", "502", "504", "506", "508", "405", + "407", "409", "411", "406", "408", "410", "412" + ] + print(list(expected_output)) + + result = get_all_wanted_columns(config) + + print([c for c in result if c not in expected_output]) + + assert set(result) == set(expected_output) -class TestRunBreakdownValidation: + +class TestRemoveAllNullRows: + """Unit tests for replace_nulls_with_zero function.""" + + def create_input_df(self): + """Create an input dataframe for the test.""" + input_cols = ["reference", "instance", "202", "203", "222", "223", "204", "205", "206", "207", "219", "220", "209", + "221", "210", "211", '212', '214', '216', '242', '250', '243', '244', '245', '246', '247', '248', '249', '218', + '225', '226', '227', '228', '229', '237', '302', '303', '304', '305', '501', '503', '505', '507', '502', '504', '506', + '508', '405', '407', '409', '411', '406', '408', '410', '412', "999"] + + data = [ + ['A', 1, 10, 30, 15, 15, 40, 10, 10, 20, 10, 15, 20, 10, 45, 85, 10, 10, 10, 10, 10, 10, 10, 10, 1, 1, 2, 1, 85, 20, 20, 20, 20, 2, 3, 50, 20, 10, 80, 50, 20, 25, 95, 60, 20, 10, 90, 80, 20, 10, 110, 80, 90, 20, 190, None], + ['B', 1, 1, 30, 15, 15, None, 10, 10, 20, 10, 15, 20, 10, 45, None, 10, 10, None, 10, 10, 10, 10, 10, 1, 1, 2, 1, 85, 20, 20, 20, 20, 2, 3, 50, 20, 10, 80, 50, 20, 25, 95, 60, 20, 10, 90, 80, 20, 10, 110, 80, 90, 20, 190, None], + ['C', None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None], + ] + + input_df = pd.DataFrame(data=data, columns=input_cols) + + return input_df + + def create_expected_df(self): + """Create an expected dataframe for the test.""" + input_cols = ["reference", "instance", "202", "203", "222", "223", "204", "205", "206", "207", "219", "220", "209", + "221", "210", "211", '212', '214', '216', '242', '250', '243', '244', '245', '246', '247', '248', '249', '218', + '225', '226', '227', '228', '229', '237', '302', '303', '304', '305', '501', '503', '505', '507', '502', '504', '506', + '508', '405', '407', '409', '411', '406', '408', '410', '412', "999"] + + data = [ + ['A', 1.0, 10.0, 30.0, 15.0, 15.0, 40.0, 10.0, 10.0, 20.0, 10.0, 15.0, 20.0, 10.0, 45.0, 85.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 1.0, 1.0, 2.0, 1.0, 85.0, 20.0, 20.0, 20.0, 20.0, 2.0, 3.0, 50.0, 20.0, 10.0, 80.0, 50.0, 20.0, 25.0, 95.0, 60.0, 20.0, 10.0, 90.0, 80.0, 20.0, 10.0, 110.0, 80.0, 90.0, 20.0, 190.0, None], + ['B', 1.0, 1.0, 30.0, 15.0, 15.0, None, 10.0, 10.0, 20.0, 10.0, 15.0, 20.0, 10.0, 45.0, None, 10.0, 10.0, None, 10.0, 10.0, 10.0, 10.0, 10.0, 1.0, 1.0, 2.0, 1.0, 85.0, 20.0, 20.0, 20.0, 20.0, 2.0, 3.0, 50.0, 20.0, 10.0, 80.0, 50.0, 20.0, 25.0, 95.0, 60.0, 20.0, 10.0, 90.0, 80.0, 20.0, 10.0, 110.0, 80.0, 90.0, 20.0, 190.0, None], + ] + + expected_df = pd.DataFrame(data=data, columns=input_cols) + + return expected_df + + def test_remove_all_nulls_rows(self, caplog, create_equality_dict): + """Test for remove_all_nulls_rows function.""" + input_df = self.create_input_df() + expected_df = self.create_expected_df() + equality_dict = create_equality_dict + + with caplog.at_level(logging.INFO): + result_df = remove_all_nulls_rows(input_df, equality_dict) + assert "Removing rows with all null values from validation" in caplog.text + pd.testing.assert_frame_equal(result_df, expected_df) + + +class TestRunBreakdownValidation(): """Unit tests for run_breakdown_validation function.""" def create_input_df(self): @@ -29,44 +192,49 @@ def create_input_df(self): ] input_df = pd.DataFrame(data=data, columns=input_cols) + input_df["is_constructed"] = True return input_df - - def test_breakdown_validation_success(self, caplog): + def test_breakdown_validation_success(self, caplog, create_config): """Test for run_breakdown_validation function where the values match.""" input_df = self.create_input_df() input_df = input_df.loc[(input_df['reference'] == 'A')] + config = create_config msg = 'All breakdown values are valid.\n' with caplog.at_level(logging.INFO): - run_breakdown_validation(input_df) + run_breakdown_validation(input_df, config, "constructed") assert msg in caplog.text - def test_breakdown_validation_msg(self): - """Test for run_breakdown_validation function to check the returned message.""" - input_df = self.create_input_df() - input_df = input_df.loc[(input_df['reference'] == 'B')] - msg = "Columns ['202', '203'] do not equal column 204 for reference: B, instance 1.\n " - with pytest.raises(ValueError) as e: - run_breakdown_validation(input_df) - assert str(e.value) == msg - - def test_breakdown_validation_fail_all_null(self, caplog): + #TODO: we're currently not raising an error but will later put this back in + # def test_breakdown_validation_msg(self, create_config): + # """Test for run_breakdown_validation function to check the returned message.""" + # input_df = self.create_input_df() + # input_df = input_df.loc[(input_df['reference'] == 'B')] + # config = create_config + # msg = "Columns ['202', '203'] do not equal column 204 for reference: B, instance 1.\n " + # with pytest.raises(ValueError) as e: + # run_breakdown_validation(input_df, config, "constructed") + # assert str(e.value) == msg + + def test_breakdown_validation_fail_all_null(self, caplog, create_config): """Test for run_breakdown_validation function where there are no values.""" input_df = self.create_input_df() input_df = input_df.loc[(input_df['reference'] == 'C')] + config = create_config msg = 'All breakdown values are valid.\n' with caplog.at_level(logging.INFO): - run_breakdown_validation(input_df) + run_breakdown_validation(input_df, config, "constructed") assert msg in caplog.text - def test_breakdown_validation_fail_totals_zero(self, caplog): + def test_breakdown_validation_fail_totals_zero(self, caplog, create_config): """Test for run_breakdown_validation function where there are zeros.""" input_df = self.create_input_df() input_df = input_df.loc[(input_df['reference'] == 'D')] + config = create_config msg = 'All breakdown values are valid.\n' with caplog.at_level(logging.INFO): - run_breakdown_validation(input_df) + run_breakdown_validation(input_df, config, "constructed") assert msg in caplog.text class TestReplaceNullsWithZero: @@ -89,14 +257,15 @@ def create_input_df(self): return input_df - def test_replace_nulls_with_zero(self, caplog): + def test_replace_nulls_with_zero(self, caplog, create_equality_dict): """Test for replace_nulls_with_zero function where nulls are replaced with zeros.""" input_df = self.create_input_df() expected_df = input_df.copy() + equals_dict = create_equality_dict expected_df.loc[2, ["202", "205", "206", "207", "219", "220", "209", "210", "211", '212', '214', '216', '242', '250', '243', '244', '245', '246', '247', '248', '249', '218', '225', '226', '227', '228', '229', '237', '302', '303', '304', '305', '501', '503', '505', '507', '502', '504', '506', '508', '405', '407', '409', '411', '406', '408', '410', '412']] = 0 with caplog.at_level(logging.INFO): - result_df = replace_nulls_with_zero(input_df) + result_df = replace_nulls_with_zero(input_df, equals_dict) assert "Replacing nulls with zeros where total zero" in caplog.text pd.testing.assert_frame_equal(result_df, expected_df) @@ -137,13 +306,14 @@ def create_expected_df(self): return expected_df - def test_remove_all_nulls_rows(self, caplog): + def test_remove_all_nulls_rows(self, caplog, create_config): """Test for remove_all_nulls_rows function.""" input_df = self.create_input_df() expected_df = self.create_expected_df() + config = create_config with caplog.at_level(logging.INFO): - result_df = remove_all_nulls_rows(input_df) + result_df = remove_all_nulls_rows(input_df, config) assert "Removing rows with all null values from validation" in caplog.text pd.testing.assert_frame_equal(result_df, expected_df) @@ -168,41 +338,44 @@ def create_input_df(self): return input_df - def test_equal_validation_success(self, caplog): + + def test_equal_validation_success(self, caplog, create_equality_dict): """Test for equal_validation function where the values match.""" input_df = self.create_input_df() input_df = input_df.loc[(input_df['reference'] == 'A')] + equality_dict = create_equality_dict msg = "" count = 0 with caplog.at_level(logging.INFO): - result_msg, result_count = equal_validation(input_df) + result_msg, result_count = equal_validation(input_df, equality_dict) assert "Doing breakdown total checks..." in caplog.text assert result_msg == msg assert result_count == count - def test_equal_validation_fail(self): + def test_equal_validation_fail(self, create_equality_dict): """Test for equal_validation function where the values do not meet the criteria.""" input_df = self.create_input_df() input_df = input_df.loc[(input_df['reference'] == 'B')].reset_index(drop=True) + equality_dict = create_equality_dict msg = "Columns ['202', '203'] do not equal column 204 for reference: B, instance 1.\n " count = 1 - result_msg, result_count = equal_validation(input_df) + result_msg, result_count = equal_validation(input_df, equality_dict) assert result_msg == msg assert result_count == count - def test_equal_validation_all_null(self, caplog): + def test_equal_validation_all_null(self, caplog, create_equality_dict): """Test for equal_validation function where all values are zero or null.""" input_df = self.create_input_df() input_df = input_df.loc[(input_df['reference'] == 'C')].reset_index(drop=True) + equals_dict = create_equality_dict msg = "" count = 0 with caplog.at_level(logging.INFO): - result_msg, result_count = equal_validation(input_df) + result_msg, result_count = equal_validation(input_df, equals_dict) assert "Doing breakdown total checks..." in caplog.text assert result_msg == msg assert result_count == count - class TestGreaterThanValidation: """Unit tests for greater_than_validation function.""" @@ -256,5 +429,3 @@ def test_greater_than_validation_all_null(self, caplog): assert "Doing checks for values that should be greater than..." in caplog.text assert result_msg == msg assert result_count == count - -