Skip to content

Commit

Permalink
Fixed ordering of records
Browse files Browse the repository at this point in the history
  • Loading branch information
jcadam14 committed Oct 5, 2024
1 parent f2c9412 commit 566e63c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
6 changes: 3 additions & 3 deletions src/regtech_data_validator/data_formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,12 @@ def df_to_table(df: pl.DataFrame) -> str:
return tabulate(df, headers='keys', showindex=True, tablefmt='rounded_outline') # type: ignore


def df_to_json(df: pl.DataFrame, max_records: int = 10000, max_group_size: int = 100) -> str:
def df_to_json(df: pl.DataFrame, max_records: int = 10000, max_group_size: int = 200) -> str:
results = df_to_dicts(df, max_records, max_group_size)
return ujson.dumps(results, indent=4, escape_forward_slashes=False)


def df_to_dicts(df: pl.DataFrame, max_records: int = 10000, max_group_size: int = 100) -> list[dict]:
def df_to_dicts(df: pl.DataFrame, max_records: int = 10000, max_group_size: int = 200) -> list[dict]:
json_results = []
if not df.is_empty():
# polars str columns sort by entry, not lexigraphical sorting like we'd expect, so cast the column to use
Expand All @@ -264,7 +264,7 @@ def df_to_dicts(df: pl.DataFrame, max_records: int = 10000, max_group_size: int
# So this function uses the group error counts to truncate on record numbers
def truncate_validation_group_records(group, group_size):
need_to_truncate = group.select(pl.col('row').n_unique()).item() > group_size
unique_record_nos = group.select('row').unique().limit(group_size)
unique_record_nos = group.select('row').unique(maintain_order=True).limit(group_size)
truncated_group = group.filter(pl.col('row').is_in(unique_record_nos['row']))
return truncated_group, need_to_truncate

Expand Down
11 changes: 6 additions & 5 deletions src/regtech_data_validator/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def _get_check_fields(check: Check, primary_column: str) -> list[str]:
# Retrieves the row data from the original dataframe that threw errors/warnings, and pulls out the fields/values
# from the original row data that caused the error/warning
def _filter_valid_records(df: pl.DataFrame, check_output: pl.Series, fields: list[str]) -> pl.DataFrame:

sorted_check_output = check_output["index"]
fields = ["index"] + fields
filtered_df = df.filter(pl.col('index').is_in(sorted_check_output))
Expand All @@ -69,7 +68,7 @@ def _add_validation_metadata(failed_check_fields_df: pl.DataFrame, check: SBLChe
return validation_fields_df


def validate(schema: pa.DataFrameSchema, submission_df: pl.LazyFrame, process_errors: bool) -> pl.DataFrame:
def validate(schema: pa.DataFrameSchema, submission_df: pl.LazyFrame, row_start: int, process_errors: bool) -> pl.DataFrame:
"""
validate received dataframe with schema and return list of
schema errors
Expand All @@ -85,7 +84,7 @@ def validate(schema: pa.DataFrameSchema, submission_df: pl.LazyFrame, process_er
try:
# since polars dataframes don't normally have an index column, add it, so that we can match
# up original submission rows with rows found with errors/warnings
submission_df = submission_df.with_row_index()
submission_df = submission_df.with_row_index(offset=row_start)
schema(submission_df, lazy=True)
except SchemaErrors as err:
check_findings = []
Expand Down Expand Up @@ -187,7 +186,7 @@ def validate_batch_csv(
# than reading in the whole csv and just selecting on the UID column (currently our only register level check data)
uids = pl.scan_csv(real_path, infer_schema_length=0, missing_utf8_is_empty_string=True).select("uid").collect()
register_schema = get_register_schema(context)
validation_results = validate(register_schema, uids, True)
validation_results = validate(register_schema, uids, 0, True)
if not validation_results.findings.is_empty():
validation_results.findings = format_findings(
validation_results.findings,
Expand Down Expand Up @@ -215,9 +214,10 @@ def validate_chunks(schema, path, batch_size, batch_count, max_errors, checks):
batches = reader.next_batches(batch_count)
process_errors = True
total_count = 0
row_start = 0
while batches:
df = pl.concat(batches)
validation_results = validate(schema, df, process_errors)
validation_results = validate(schema, df, row_start, process_errors)
if not validation_results.findings.is_empty():
validation_results.findings = format_findings(
validation_results.findings, validation_results.phase.value, checks
Expand All @@ -230,6 +230,7 @@ def validate_chunks(schema, path, batch_size, batch_count, max_errors, checks):
head_count = validation_results.findings.height - (total_count - max_errors)
validation_results.findings = validation_results.findings.head(head_count)

row_start += df.height
batches = reader.next_batches(batch_count)
yield validation_results

Expand Down

0 comments on commit 566e63c

Please sign in to comment.