diff --git a/src/pseudopeople/interface.py b/src/pseudopeople/interface.py index bdd63dda..4b771977 100644 --- a/src/pseudopeople/interface.py +++ b/src/pseudopeople/interface.py @@ -68,46 +68,39 @@ def _generate_dataset( "Please provide the path to the unmodified root data directory." ) validate_data_path_suffix(data_paths) - noised_dataset = [] - iterator = ( - tqdm(data_paths, desc="Noising data", leave=False) - if len(data_paths) > 1 - else data_paths - ) + all_data = [] + iterator = tqdm(data_paths, desc="Loading data") if len(data_paths) > 1 else data_paths - for data_path_index, data_path in enumerate(iterator): + for data_path in iterator: logger.debug(f"Loading data from {data_path}.") data = _load_data_from_path(data_path, user_filters) if data.empty: continue - data = _reformat_dates_for_noising(data, dataset) - data = _coerce_dtypes(data, dataset) - # Use a different seed for each data file/shard, otherwise the randomness will duplicate - # and the Nth row in each shard will get the same noise - data_path_seed = f"{seed}_{data_path_index}" - noised_data = noise_dataset(dataset, data, configuration_tree, data_path_seed) - noised_data = _extract_columns(dataset.columns, noised_data) - noised_dataset.append(noised_data) + # FIXME: Right now, Categorical columns in the Rhode Island data + # contain a very large number of unnecessary categories. We want + # to get rid of these during this loop so that they are never all + # in memory at the same time. + # TODO: Remove this when we stop Categorical encoding. + data = _remove_unused_categories(data, dataset) + all_data.append(data) # Check if all shards for the dataset are empty - if len(noised_dataset) == 0: + if len(all_data) == 0: raise ValueError( "Invalid value provided for 'state' or 'year'. No data found with " - f"the user provided 'state' or 'year' filters at {data_path}." + f"the user provided 'state' or 'year' filters at {source}." ) - noised_dataset = pd.concat(noised_dataset, ignore_index=True) - - # Known pandas bug: pd.concat does not preserve category dtypes so we coerce - # again after concat (https://github.com/pandas-dev/pandas/issues/51362) - noised_dataset = _coerce_dtypes( - noised_dataset, - dataset, - cleanse_int_cols=True, - ) + + all_data = pd.concat(all_data, ignore_index=True) + _reformat_dates_for_noising(all_data, dataset) + all_data = _coerce_dtypes(all_data, dataset) + all_data = noise_dataset(dataset, all_data, configuration_tree, seed) + all_data = _extract_columns(dataset.columns, all_data) + all_data = _coerce_dtypes(all_data, dataset, cleanse_int_cols=True) logger.debug("*** Finished ***") - return noised_dataset + return all_data def validate_source_compatibility(source: Path, dataset: Dataset): @@ -168,15 +161,27 @@ def _coerce_dtypes( return data +def _remove_unused_categories(data: pd.DataFrame, dataset: Dataset) -> pd.DataFrame: + for col in data.columns: + if data[col].dtype.name == "category" and ( + # NOTE: We want to avoid dropping categories that just happen not to be used + # in columns that are returned as Categorical to the user such as event_type + col not in dataset.columns + or dataset.columns[col].dtype_name != "category" + ): + data[col] = data[col].cat.remove_unused_categories() + + return data + + def _load_data_from_path(data_path: Path, user_filters: List[Tuple]) -> pd.DataFrame: """Load data from a data file given a data_path and a year_filter.""" data = load_standard_dataset_file(data_path, user_filters) return data -def _reformat_dates_for_noising(data: pd.DataFrame, dataset: Dataset): +def _reformat_dates_for_noising(data: pd.DataFrame, dataset: Dataset) -> None: """Formats date columns so they can be noised as strings.""" - data = data.copy() for date_column in [COLUMNS.dob.name, COLUMNS.ssa_event_date.name]: # Format both the actual column, and the shadow version that will be used @@ -204,8 +209,6 @@ def _reformat_dates_for_noising(data: pd.DataFrame, dataset: Dataset): data[column] = pd.Series(np.nan, dtype=str) data.loc[~is_na, column] = result - return data - def _zfill_fast(col: pd.Series, desired_length: int) -> pd.Series: """Performs the same operation as col.str.zfill(desired_length), but vectorized.""" diff --git a/src/pseudopeople/noise.py b/src/pseudopeople/noise.py index 09bedb09..112100f9 100644 --- a/src/pseudopeople/noise.py +++ b/src/pseudopeople/noise.py @@ -56,7 +56,7 @@ def noise_dataset( # except for the leave_blank kind which is special-cased below missingness = (dataset_data == "") | (dataset_data.isna()) - for noise_type in tqdm(NOISE_TYPES, desc="Applying noise", unit="type", leave=False): + for noise_type in tqdm(NOISE_TYPES, desc="Applying noise", unit="type"): if isinstance(noise_type, RowNoiseType): if ( Keys.ROW_NOISE in noise_configuration