diff --git a/src/regtech_data_validator/data_formatters.py b/src/regtech_data_validator/data_formatters.py index cdb8b7e..62bd790 100644 --- a/src/regtech_data_validator/data_formatters.py +++ b/src/regtech_data_validator/data_formatters.py @@ -234,12 +234,12 @@ def df_to_table(df: pl.DataFrame) -> str: return tabulate(df, headers='keys', showindex=True, tablefmt='rounded_outline') # type: ignore -def df_to_json(df: pl.DataFrame, max_records: int = 10000, max_group_size: int = 100) -> str: +def df_to_json(df: pl.DataFrame, max_records: int = 10000, max_group_size: int = 200) -> str: results = df_to_dicts(df, max_records, max_group_size) return ujson.dumps(results, indent=4, escape_forward_slashes=False) -def df_to_dicts(df: pl.DataFrame, max_records: int = 10000, max_group_size: int = 100) -> list[dict]: +def df_to_dicts(df: pl.DataFrame, max_records: int = 10000, max_group_size: int = 200) -> list[dict]: json_results = [] if not df.is_empty(): # polars str columns sort by entry, not lexigraphical sorting like we'd expect, so cast the column to use @@ -264,7 +264,7 @@ def df_to_dicts(df: pl.DataFrame, max_records: int = 10000, max_group_size: int # So this function uses the group error counts to truncate on record numbers def truncate_validation_group_records(group, group_size): need_to_truncate = group.select(pl.col('row').n_unique()).item() > group_size - unique_record_nos = group.select('row').unique().limit(group_size) + unique_record_nos = group.select('row').unique(maintain_order=True).limit(group_size) truncated_group = group.filter(pl.col('row').is_in(unique_record_nos['row'])) return truncated_group, need_to_truncate diff --git a/src/regtech_data_validator/validator.py b/src/regtech_data_validator/validator.py index 508d60c..6663fbc 100644 --- a/src/regtech_data_validator/validator.py +++ b/src/regtech_data_validator/validator.py @@ -44,7 +44,6 @@ def _get_check_fields(check: Check, primary_column: str) -> list[str]: # Retrieves the row data from the original dataframe that threw errors/warnings, and pulls out the fields/values # from the original row data that caused the error/warning def _filter_valid_records(df: pl.DataFrame, check_output: pl.Series, fields: list[str]) -> pl.DataFrame: - sorted_check_output = check_output["index"] fields = ["index"] + fields filtered_df = df.filter(pl.col('index').is_in(sorted_check_output)) @@ -69,7 +68,7 @@ def _add_validation_metadata(failed_check_fields_df: pl.DataFrame, check: SBLChe return validation_fields_df -def validate(schema: pa.DataFrameSchema, submission_df: pl.LazyFrame, process_errors: bool) -> pl.DataFrame: +def validate(schema: pa.DataFrameSchema, submission_df: pl.LazyFrame, row_start: int, process_errors: bool) -> pl.DataFrame: """ validate received dataframe with schema and return list of schema errors @@ -85,7 +84,7 @@ def validate(schema: pa.DataFrameSchema, submission_df: pl.LazyFrame, process_er try: # since polars dataframes don't normally have an index column, add it, so that we can match # up original submission rows with rows found with errors/warnings - submission_df = submission_df.with_row_index() + submission_df = submission_df.with_row_index(offset=row_start) schema(submission_df, lazy=True) except SchemaErrors as err: check_findings = [] @@ -187,7 +186,7 @@ def validate_batch_csv( # than reading in the whole csv and just selecting on the UID column (currently our only register level check data) uids = pl.scan_csv(real_path, infer_schema_length=0, missing_utf8_is_empty_string=True).select("uid").collect() register_schema = get_register_schema(context) - validation_results = validate(register_schema, uids, True) + validation_results = validate(register_schema, uids, 0, True) if not validation_results.findings.is_empty(): validation_results.findings = format_findings( validation_results.findings, @@ -215,9 +214,10 @@ def validate_chunks(schema, path, batch_size, batch_count, max_errors, checks): batches = reader.next_batches(batch_count) process_errors = True total_count = 0 + row_start = 0 while batches: df = pl.concat(batches) - validation_results = validate(schema, df, process_errors) + validation_results = validate(schema, df, row_start, process_errors) if not validation_results.findings.is_empty(): validation_results.findings = format_findings( validation_results.findings, validation_results.phase.value, checks @@ -230,6 +230,7 @@ def validate_chunks(schema, path, batch_size, batch_count, max_errors, checks): head_count = validation_results.findings.height - (total_count - max_errors) validation_results.findings = validation_results.findings.head(head_count) + row_start += df.height batches = reader.next_batches(batch_count) yield validation_results