diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ca29928c..4ac805d3 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,7 @@ +**5.2.3 - 12/26/24** + + - Bugfix: better handle 'all' years when requesting mean data + **5.2.2 - 11/13/24** - Bugfix to implement data type requests for remaining interface tests diff --git a/src/vivarium_inputs/core.py b/src/vivarium_inputs/core.py index c89454ad..ceb39bd0 100644 --- a/src/vivarium_inputs/core.py +++ b/src/vivarium_inputs/core.py @@ -176,8 +176,8 @@ def get_raw_incidence_rate( data = utilities.filter_data_by_restrictions( data, restrictions_entity, "yld", utility_data.get_age_group_ids() ) - data = utilities.normalize(data, data_type.value_columns, fill_value=0) data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns) + data = utilities.normalize(data, data_type.value_columns, fill_value=0) return data @@ -197,8 +197,8 @@ def get_prevalence( data = utilities.filter_data_by_restrictions( data, restrictions_entity, "yld", utility_data.get_age_group_ids() ) - data = utilities.normalize(data, data_type.value_columns, fill_value=0) data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns) + data = utilities.normalize(data, data_type.value_columns, fill_value=0) return data @@ -267,13 +267,13 @@ def get_disability_weight( data = extract.extract_data( entity, "disability_weight", location_id, years, data_type ) + data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns) data = utilities.normalize(data, data_type.value_columns) cause = [c for c in causes if c.sequelae and entity in c.sequelae][0] data = utilities.clear_disability_weight_outside_restrictions( data, cause, 0.0, utility_data.get_age_group_ids() ) - data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns) except (IndexError, DataDoesNotExistError): logger.warning( f"{entity.name.capitalize()} has no disability weight data. All values will be 0." @@ -300,8 +300,8 @@ def get_remission_rate( data = utilities.filter_data_by_restrictions( data, entity, "yld", utility_data.get_age_group_ids() ) - data = utilities.normalize(data, data_type.value_columns, fill_value=0) data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns) + data = utilities.normalize(data, data_type.value_columns, fill_value=0) return data @@ -381,8 +381,8 @@ def get_deaths( data = utilities.filter_data_by_restrictions( data, entity, "yll", utility_data.get_age_group_ids() ) - data = utilities.normalize(data, data_type.value_columns, fill_value=0) data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns) + data = utilities.normalize(data, data_type.value_columns, fill_value=0) return data @@ -413,6 +413,8 @@ def get_exposure( data, entity, "outer", utility_data.get_age_group_ids() ) + data = data.filter(DEMOGRAPHIC_COLUMNS + value_columns + ["parameter"]) + if entity.distribution in ["dichotomous", "ordered_polytomous", "unordered_polytomous"]: tmrel_cat = utility_data.get_tmrel_category(entity) exposed = data[data.parameter != tmrel_cat] @@ -437,7 +439,6 @@ def get_exposure( ) else: data = utilities.normalize(data, value_columns, fill_value=0) - data = data.filter(DEMOGRAPHIC_COLUMNS + value_columns + ["parameter"]) return data @@ -462,8 +463,8 @@ def get_exposure_standard_deviation( valid_age_groups = utilities.get_exposure_and_restriction_ages(exposure, entity) data = data[data.age_group_id.isin(valid_age_groups)] - data = utilities.normalize(data, data_type.value_columns, fill_value=0) data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns) + data = utilities.normalize(data, data_type.value_columns, fill_value=0) return data @@ -493,6 +494,7 @@ def get_exposure_distribution_weights( copied["age_group_id"] = age_id df.append(copied) data = pd.concat(df) + data = data.filter(DEMOGRAPHIC_COLUMNS + DISTRIBUTION_COLUMNS) data = utilities.normalize(data, DISTRIBUTION_COLUMNS, fill_value=0) if years != "all": if years: @@ -501,7 +503,6 @@ def get_exposure_distribution_weights( else: most_recent_year = utility_data.get_most_recent_year() data = data.query(f"year_id=={most_recent_year}") - data = data.filter(DEMOGRAPHIC_COLUMNS + DISTRIBUTION_COLUMNS) data = utilities.wide_to_long(data, DISTRIBUTION_COLUMNS, var_name="parameter") return data @@ -512,7 +513,6 @@ def get_relative_risk( years: int | str | list[int] | None, data_type: utilities.DataType, ) -> pd.DataFrame: - if data_type.type != "draws": raise utilities.DataTypeNotImplementedError( f"Data type(s) {data_type.type} are not supported for this function." @@ -525,7 +525,6 @@ def get_relative_risk( ) data = extract.extract_data(entity, "relative_risk", location_id, years, data_type) - # FIXME: we don't currently support yll-only causes so I'm dropping them because the data in some cases is # very messed up, with mort = morb = 1 (e.g., aortic aneurysm in the RR data for high systolic bp) - # 2/8/19 K.W. @@ -629,14 +628,14 @@ def get_population_attributable_fraction( data["measure_id"] == MEASURES["YLLs"], "affected_measure" ] = "excess_mortality_rate" data.loc[data["measure_id"] == MEASURES["YLDs"], "affected_measure"] = "incidence_rate" + data = data.filter( + DEMOGRAPHIC_COLUMNS + ["affected_entity", "affected_measure"] + value_columns + ) data = ( data.groupby(["affected_entity", "affected_measure"]) .apply(utilities.normalize, cols_to_fill=value_columns, fill_value=0) .reset_index(drop=True) ) - data = data.filter( - DEMOGRAPHIC_COLUMNS + ["affected_entity", "affected_measure"] + value_columns - ) return data @@ -674,8 +673,8 @@ def get_utilization_rate( data_type: utilities.DataType, ) -> pd.DataFrame: data = extract.extract_data(entity, "utilization_rate", location_id, years, data_type) - data = utilities.normalize(data, data_type.value_columns, fill_value=0) data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns) + data = utilities.normalize(data, data_type.value_columns, fill_value=0) return data diff --git a/src/vivarium_inputs/extract.py b/src/vivarium_inputs/extract.py index a24935a5..43692bf6 100644 --- a/src/vivarium_inputs/extract.py +++ b/src/vivarium_inputs/extract.py @@ -112,7 +112,7 @@ def extract_data( validation.check_metadata(entity, measure) - year_id = _get_year_id(years) + year_id = _get_year_id(years, data_type) try: main_extractor, additional_extractors = extractors[measure] @@ -146,6 +146,10 @@ def extract_data( extra_draw_cols = [col for col in existing_draw_cols if col not in DRAW_COLUMNS] data = data.drop(columns=extra_draw_cols, errors="ignore") + # drop get_outputs data earlier than the estimation years + if data_type.type == "means": + data = data.loc[data["year_id"] >= min(utility_data.get_estimation_years())] + if validate: additional_data = { name: extractor(entity, location_id, year_id, data_type) @@ -167,11 +171,14 @@ def extract_data( #################### -def _get_year_id(years): +def _get_year_id(years, data_type): if years is None: # default to most recent year year_id = utility_data.get_most_recent_year() elif years == "all": - year_id = None + if data_type.type == "draws": + year_id = None + else: # means + year_id = "full" else: year_id = years if isinstance(years, list) else [years] estimation_years = utility_data.get_estimation_years() diff --git a/src/vivarium_inputs/validation/raw.py b/src/vivarium_inputs/validation/raw.py index ef7a583c..f81f959d 100644 --- a/src/vivarium_inputs/validation/raw.py +++ b/src/vivarium_inputs/validation/raw.py @@ -59,8 +59,9 @@ class RawValidationContext: def __init__(self, location_id, **additional_data): self.context_data = {"location_id": location_id} self.context_data.update(additional_data) - - if "estimation_years" not in self.context_data: + if "estimation_years" not in self.context_data or self.context_data[ + "estimation_years" + ] == ["full"]: self.context_data["estimation_years"] = utility_data.get_estimation_years() if "age_group_ids" not in self.context_data: self.context_data["age_group_ids"] = utility_data.get_age_group_ids() diff --git a/tests/e2e/test_get_measure.py b/tests/e2e/test_get_measure.py index 739802e4..aed7b3cd 100644 --- a/tests/e2e/test_get_measure.py +++ b/tests/e2e/test_get_measure.py @@ -275,7 +275,7 @@ def test_get_measure_risklike( and not mock_gbd and data_type == "draws" ): - pytest.skip("FIXME: [mic-5543] continuous rrs cannot validate") + pytest.skip("FIXME: [mic-5542] continuous rrs cannot validate") # Handle not implemented is_unimplemented = isinstance(data_type, list) or data_type == "means" diff --git a/tests/unit/test_core.py b/tests/unit/test_core.py new file mode 100644 index 00000000..afc83be1 --- /dev/null +++ b/tests/unit/test_core.py @@ -0,0 +1,252 @@ +import pytest +from gbd_mapping import ModelableEntity, causes, covariates, risk_factors + +from tests.conftest import NO_GBD_ACCESS +from vivarium_inputs import core, utility_data +from vivarium_inputs.utilities import DataType, DataTypeNotImplementedError + +pytestmark = pytest.mark.skipif( + NO_GBD_ACCESS, reason="Cannot run these tests without vivarium_gbd_access" +) + + +def check_year_in_data(entity, measure, location, years, data_type): + if _is_not_implemented(data_type, measure): + with pytest.raises(DataTypeNotImplementedError): + data_type = DataType(measure, data_type) + core.get_data(entity, measure, location, years, data_type) + else: + data_type = DataType(measure, data_type) + if isinstance(years, list): + df = core.get_data(entity, measure, location, years, data_type) + assert set(df.index.get_level_values("year_id")) == set(years) + # years expected to be 1900, 2019, None, or "all" + elif years != 1900: + df = core.get_data(entity, measure, location, years, data_type) + if years == None: + assert set(df.index.get_level_values("year_id")) == set([2021]) + elif years == "all": + assert set(df.index.get_level_values("year_id")) == set(range(1990, 2023)) + else: # a single (non-1900) year + assert set(df.index.get_level_values("year_id")) == set([years]) + else: + with pytest.raises(ValueError, match="years must be in"): + core.get_data(entity, measure, location, years, data_type) + + +ENTITIES_C = [ + ( + causes.measles, + [ + "incidence_rate", + "raw_incidence_rate", + "prevalence", + "disability_weight", + "cause_specific_mortality_rate", + "excess_mortality_rate", + "deaths", + ], + ), + ( + causes.diarrheal_diseases, + [ + "incidence_rate", + "raw_incidence_rate", + "prevalence", + "disability_weight", + "remission_rate", + "cause_specific_mortality_rate", + "excess_mortality_rate", + "deaths", + ], + ), + ( + causes.diabetes_mellitus_type_2, + [ + "incidence_rate", + "raw_incidence_rate", + "prevalence", + "disability_weight", + "cause_specific_mortality_rate", + "excess_mortality_rate", + "deaths", + ], + ), +] +MEASURES_C = [ + "incidence_rate", + "raw_incidence_rate", + "prevalence", + "birth_prevalence", + "disability_weight", + "remission_rate", + "cause_specific_mortality_rate", + "excess_mortality_rate", + "deaths", +] +LOCATIONS_C = ["India"] + + +@pytest.mark.parametrize("entity_details", ENTITIES_C, ids=lambda x: x[0].name) +@pytest.mark.parametrize("measure", MEASURES_C, ids=lambda x: x) +@pytest.mark.parametrize("location", LOCATIONS_C, ids=lambda x: x) +@pytest.mark.parametrize( + "years", [None, 2019, 1900, [2019], [2019, 2020], "all"], ids=lambda x: str(x) +) +@pytest.mark.parametrize( + "data_type", ["draws", "means", ["draws", "means"]], ids=lambda x: str(x) +) +def test_year_id_causelike(entity_details, measure, location, years, data_type): + entity, entity_expected_measures = entity_details + if measure in entity_expected_measures: + check_year_in_data(entity, measure, location, years, data_type) + + +ENTITIES_R = [ + ( + risk_factors.high_systolic_blood_pressure, + [ + "exposure", + "exposure_standard_deviation", + "exposure_distribution_weights", + "relative_risk", + "population_attributable_fraction", + ], + ), + ( + risk_factors.low_birth_weight_and_short_gestation, + [ + "exposure", + "relative_risk", + "population_attributable_fraction", + ], + ), +] +MEASURES_R = [ + "exposure", + "exposure_standard_deviation", + "exposure_distribution_weights", + "relative_risk", + "population_attributable_fraction", +] +LOCATIONS_R = ["India"] + + +@pytest.mark.parametrize("entity_details", ENTITIES_R, ids=lambda x: x[0].name) +@pytest.mark.parametrize("measure", MEASURES_R, ids=lambda x: x) +@pytest.mark.parametrize("location", LOCATIONS_R, ids=lambda x: x) +@pytest.mark.parametrize( + "years", [None, 2019, 1900, [2019], [2019, 2020], "all"], ids=lambda x: str(x) +) +@pytest.mark.parametrize( + "data_type", ["draws", "means", ["draws", "means"]], ids=lambda x: str(x) +) +def test_year_id_risklike(entity_details, measure, location, years, data_type): + if measure in ["relative_risk", "population_attributable_fraction"]: + pytest.skip("TODO: mic-5245 punting until later b/c these are soooo slow") + entity, entity_expected_measures = entity_details + if measure in entity_expected_measures: + check_year_in_data(entity, measure, location, years, data_type) + + +ENTITIES_COV = [ + covariates.systolic_blood_pressure_mmhg, +] +MEASURES_COV = ["estimate"] +LOCATIONS_COV = ["India"] + + +@pytest.mark.parametrize("entity", ENTITIES_COV, ids=lambda x: x.name) +@pytest.mark.parametrize("measure", MEASURES_COV, ids=lambda x: x) +@pytest.mark.parametrize("location", LOCATIONS_COV, ids=lambda x: x) +@pytest.mark.parametrize( + "years", [None, 2019, 1900, [2019], [2019, 2020], "all"], ids=lambda x: str(x) +) +@pytest.mark.parametrize( + "data_type", ["draws", "means", ["draws", "means"]], ids=lambda x: str(x) +) +def test_year_id_covariatelike(entity, measure, location, years, data_type): + check_year_in_data(entity, measure, location, years, data_type) + + +@pytest.mark.parametrize("measure", ["structure", "demographic_dimensions"], ids=lambda x: x) +@pytest.mark.parametrize( + "years", [None, 2019, 1900, [2019], [2019, 2020], "all"], ids=lambda x: str(x) +) +@pytest.mark.parametrize( + "data_type", ["draws", "means", ["draws", "means"]], ids=lambda x: str(x) +) +def test_year_id_population(measure, years, data_type): + pop = ModelableEntity("ignored", "population", None) + location = utility_data.get_location_id("India") + check_year_in_data(pop, measure, location, years, data_type) + + +@pytest.mark.parametrize("entity_details", ENTITIES_C, ids=lambda x: x[0].name) +@pytest.mark.parametrize("measure", MEASURES_C, ids=lambda x: x) +@pytest.mark.parametrize( + "locations", + [ + "Ethiopia", # 179 + 179, + ["Ethiopia", "Nigeria"], # [179, 214] + [179, 214], + [179, "Nigeria"], + ], + ids=lambda x: str(x), +) +@pytest.mark.parametrize( + "data_type", ["draws", "means", ["draws", "means"]], ids=lambda x: str(x) +) +def test_multiple_locations_causelike(entity_details, measure, locations, data_type): + year = 2021 + location_id_mapping = { + "Ethiopia": 179, + "Nigeria": 214, + } + entity, entity_expected_measures = entity_details + if _is_not_implemented(data_type, measure): + with pytest.raises(DataTypeNotImplementedError): + data_type = DataType(measure, data_type) + core.get_data(entity, measure, locations, year, data_type) + else: + data_type = DataType(measure, data_type) + if measure not in entity_expected_measures: + with pytest.raises(Exception): + core.get_data(entity, measure, locations, year, data_type) + else: + df = core.get_data(entity, measure, locations, year, data_type) + if not isinstance(locations, list): + locations = [locations] + location_ids = { + ( + location_id_mapping[item] + if isinstance(item, str) and item in location_id_mapping + else item + ) + for item in locations + } + assert set(df.index.get_level_values("location_id")) == set(location_ids) + + +# TODO: Should we add the location tests for other entity types? + + +def _is_not_implemented(data_type: str | list[str], measure: str) -> bool: + return isinstance(data_type, list) or ( + data_type == "means" + and measure + in [ + "disability_weight", + "remission_rate", + "cause_specific_mortality_rate", + "excess_mortality_rate", + "deaths", + "exposure", + "low_birth_weight_and_short_gestation", + "exposure_standard_deviation", + "exposure_distribution_weights", + "estimate", + "structure", + ] + )