diff --git a/src/vivarium_inputs/core.py b/src/vivarium_inputs/core.py index d49e6d9b..16ff4868 100644 --- a/src/vivarium_inputs/core.py +++ b/src/vivarium_inputs/core.py @@ -160,7 +160,9 @@ def get_disability_weight( entity: Union[Cause, Sequela], location_id: int, get_all_years: bool = False ) -> pd.DataFrame: if entity.kind == "cause": - data = utility_data.get_demographic_dimensions(location_id, get_all_years, draws=True, value=0.0) + data = utility_data.get_demographic_dimensions( + location_id, get_all_years, draws=True, value=0.0 + ) data = data.set_index( utilities.get_ordered_index_cols(data.columns.difference(DRAW_COLUMNS)) ) @@ -204,7 +206,9 @@ def get_disability_weight( logger.warning( f"{entity.name.capitalize()} has no disability weight data. All values will be 0." ) - data = utility_data.get_demographic_dimensions(location_id, get_all_years, draws=True, value=0.0) + data = utility_data.get_demographic_dimensions( + location_id, get_all_years, draws=True, value=0.0 + ) return data @@ -517,7 +521,11 @@ def get_age_bins(entity: Population, location_id: int) -> pd.DataFrame: return age_bins -def get_demographic_dimensions(entity: Population, location_id: int, get_all_years: bool = False) -> pd.DataFrame: - demographic_dimensions = utility_data.get_demographic_dimensions(location_id, get_all_years) +def get_demographic_dimensions( + entity: Population, location_id: int, get_all_years: bool = False +) -> pd.DataFrame: + demographic_dimensions = utility_data.get_demographic_dimensions( + location_id, get_all_years + ) demographic_dimensions = utilities.normalize(demographic_dimensions) return demographic_dimensions diff --git a/src/vivarium_inputs/utility_data.py b/src/vivarium_inputs/utility_data.py index 54966c94..5389c001 100644 --- a/src/vivarium_inputs/utility_data.py +++ b/src/vivarium_inputs/utility_data.py @@ -4,7 +4,6 @@ from gbd_mapping import RiskFactor from vivarium_inputs.globals import NON_MAX_TMREL, NUM_DRAWS, SEXES, gbd -from vivarium_inputs.utility_data import get_most_recent_year def get_estimation_years(*_, **__) -> pd.Series: @@ -60,7 +59,7 @@ def get_demographic_dimensions( estimation_years = get_estimation_years() years = range(min(estimation_years), max(estimation_years) + 1) else: - years = get_most_recent_year() + years = [get_most_recent_year()] sexes = [SEXES["Male"], SEXES["Female"]] location = [location_id] values = [location, sexes, ages, years]