From 539ad02c4eb0700d3e35d01920e28dcb7755c50b Mon Sep 17 00:00:00 2001 From: Steve Bachmeier Date: Thu, 19 Dec 2024 08:35:37 -0800 Subject: [PATCH 1/5] copy test_core.py to tests/unit/ folder --- tests/unit/test_core.py | 256 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 256 insertions(+) create mode 100644 tests/unit/test_core.py diff --git a/tests/unit/test_core.py b/tests/unit/test_core.py new file mode 100644 index 00000000..f637b511 --- /dev/null +++ b/tests/unit/test_core.py @@ -0,0 +1,256 @@ +import pytest +from gbd_mapping import ModelableEntity, causes, covariates, risk_factors + +from tests.conftest import NO_GBD_ACCESS +from vivarium_inputs import core, utility_data +from vivarium_inputs.mapping_extension import healthcare_entities + +pytestmark = pytest.mark.skipif( + NO_GBD_ACCESS, reason="Cannot run these tests without vivarium_gbd_access" +) + + +def success_expected(entity_name, measure_name, location): + df = core.get_data(entity_name, measure_name, location) + return df + + +def fail_expected(entity_name, measure_name, location): + with pytest.raises(Exception): + _df = core.get_data(entity_name, measure_name, location) + + +def check_year_in_data(entity, measure, location, years): + if isinstance(years, list): + df = core.get_data(entity, measure, location, years=years) + assert set(df.reset_index()["year_id"]) == set(years) + # years expected to be 1900, 2019, None, or "all" + elif years != 1900: + df = core.get_data(entity, measure, location, years=years) + if years == None: + assert set(df.reset_index()["year_id"]) == set([2021]) + elif years == 2019: + assert set(df.reset_index()["year_id"]) == set([2019]) + elif years == "all": + assert set(df.reset_index()["year_id"]) == set(range(1990, 2023)) + else: + with pytest.raises(ValueError, match="years must be in"): + df = core.get_data(entity, measure, location, years=years) + + +ENTITIES_C = [ + ( + causes.measles, + [ + "incidence_rate", + "raw_incidence_rate", + "prevalence", + "disability_weight", + "cause_specific_mortality_rate", + "excess_mortality_rate", + "deaths", + ], + ), + ( + causes.diarrheal_diseases, + [ + "incidence_rate", + "raw_incidence_rate", + "prevalence", + "disability_weight", + "remission_rate", + "cause_specific_mortality_rate", + "excess_mortality_rate", + "deaths", + ], + ), + ( + causes.diabetes_mellitus_type_2, + [ + "incidence_rate", + "raw_incidence_rate", + "prevalence", + "disability_weight", + "cause_specific_mortality_rate", + "excess_mortality_rate", + "deaths", + ], + ), +] +MEASURES_C = [ + "incidence_rate", + "raw_incidence_rate", + "prevalence", + "birth_prevalence", + "disability_weight", + "remission_rate", + "cause_specific_mortality_rate", + "excess_mortality_rate", + "deaths", +] +LOCATIONS_C = ["India"] + + +@pytest.mark.parametrize("entity_details", ENTITIES_C, ids=lambda x: x[0].name) +@pytest.mark.parametrize("measure", MEASURES_C, ids=lambda x: x[0]) +@pytest.mark.parametrize("location", LOCATIONS_C) +def test_core_causelike(entity_details, measure, location): + entity, entity_expected_measures = entity_details + tester = success_expected if measure in entity_expected_measures else fail_expected + _df = tester(entity, measure, utility_data.get_location_id(location)) + + +@pytest.mark.parametrize("entity_details", ENTITIES_C, ids=lambda x: x[0].name) +@pytest.mark.parametrize("measure", MEASURES_C, ids=lambda x: x[0]) +@pytest.mark.parametrize("location", LOCATIONS_C) +@pytest.mark.parametrize("years", [None, 2019, 1900, [2019], [2019, 2020, 2021], "all"]) +def test_year_id_causelike(entity_details, measure, location, years): + entity, entity_expected_measures = entity_details + if measure in entity_expected_measures: + check_year_in_data(entity, measure, location, years=years) + + +ENTITIES_R = [ + ( + risk_factors.high_systolic_blood_pressure, + [ + "exposure", + "exposure_standard_deviation", + "exposure_distribution_weights", + "relative_risk", + "population_attributable_fraction", + ], + ), + ( + risk_factors.low_birth_weight_and_short_gestation, + [ + "exposure", + "relative_risk", + "population_attributable_fraction", + ], + ), +] +MEASURES_R = [ + "exposure", + "exposure_standard_deviation", + "exposure_distribution_weights", + "relative_risk", + "population_attributable_fraction", +] +LOCATIONS_R = ["India"] + + +@pytest.mark.parametrize("entity_details", ENTITIES_R, ids=lambda x: x[0].name) +@pytest.mark.parametrize("measure", MEASURES_R, ids=lambda x: x[0]) +@pytest.mark.parametrize("location", LOCATIONS_R) +def test_core_risklike(entity_details, measure, location): + entity, entity_expected_measures = entity_details + if ( + entity.name == risk_factors.high_systolic_blood_pressure.name + and measure == "population_attributable_fraction" + ): + pytest.skip("MIC-4891") + tester = success_expected if measure in entity_expected_measures else fail_expected + _df = tester(entity, measure, utility_data.get_location_id(location)) + + +@pytest.mark.parametrize("entity_details", ENTITIES_R, ids=lambda x: x[0].name) +@pytest.mark.parametrize("measure", MEASURES_R, ids=lambda x: x[0]) +@pytest.mark.parametrize("location", LOCATIONS_R) +@pytest.mark.parametrize("years", [None, 2019, 1900, [2019], [2019, 2020, 2021], "all"]) +def test_year_id_risklike(entity_details, measure, location, years): + entity, entity_expected_measures = entity_details + # exposure-parametrized RRs for all years requires a lot of time and memory to process + if ( + entity == risk_factors.high_systolic_blood_pressure + and measure == "relative_risk" + and years == "all" + ): + pytest.skip(reason="need --runslow option to run") + if measure in entity_expected_measures: + check_year_in_data(entity, measure, location, years=years) + + +@pytest.mark.slow # this test requires a lot of time and memory to run +@pytest.mark.parametrize("location", LOCATIONS_R) +def test_slow_year_id_risklike(location): + check_year_in_data( + risk_factors.high_systolic_blood_pressure, "relative_risk", location, years="all" + ) + + +ENTITIES_COV = [ + covariates.systolic_blood_pressure_mmhg, +] +MEASURES_COV = ["estimate"] +LOCATIONS_COV = ["India"] + + +@pytest.mark.parametrize("entity", ENTITIES_COV, ids=lambda x: x.name) +@pytest.mark.parametrize("measure", MEASURES_COV, ids=lambda x: x) +@pytest.mark.parametrize("location", LOCATIONS_COV) +def test_core_covariatelike(entity, measure, location): + _df = core.get_data(entity, measure, utility_data.get_location_id(location)) + + +@pytest.mark.parametrize("entity", ENTITIES_COV, ids=lambda x: x.name) +@pytest.mark.parametrize("measure", MEASURES_COV, ids=lambda x: x) +@pytest.mark.parametrize("location", LOCATIONS_COV) +@pytest.mark.parametrize("years", [None, 2019, 1900, [2019], [2019, 2020, 2021], "all"]) +def test_year_id_covariatelike(entity, measure, location, years): + check_year_in_data(entity, measure, location, years=years) + + +@pytest.mark.parametrize( + "measures", + [ + "structure", + "age_bins", + "demographic_dimensions", + "theoretical_minimum_risk_life_expectancy", + ], +) +def test_core_population(measures): + pop = ModelableEntity("ignored", "population", None) + _df = core.get_data(pop, measures, utility_data.get_location_id("India")) + + +@pytest.mark.parametrize("measure", ["structure", "demographic_dimensions"]) +@pytest.mark.parametrize("years", [None, 2019, 1900, [2019], [2019, 2020, 2021], "all"]) +def test_year_id_population(measure, years): + pop = ModelableEntity("ignored", "population", None) + location = utility_data.get_location_id("India") + check_year_in_data(pop, measure, location, years=years) + + +# TODO - Underlying problem with gbd access. Remove when corrected. +ENTITIES_HEALTH_SYSTEM = [ + healthcare_entities.outpatient_visits, +] +MEASURES_HEALTH_SYSTEM = ["utilization_rate"] +LOCATIONS_HEALTH_SYSTEM = ["India"] + + +@pytest.mark.skip(reason="Underlying problem with gbd access. Remove when corrected.") +@pytest.mark.parametrize("entity", ENTITIES_HEALTH_SYSTEM, ids=lambda x: x.name) +@pytest.mark.parametrize("measure", MEASURES_HEALTH_SYSTEM, ids=lambda x: x) +@pytest.mark.parametrize("location", LOCATIONS_HEALTH_SYSTEM) +def test_core_healthsystem(entity, measure, location): + _df = core.get_data(entity, measure, utility_data.get_location_id(location)) + + +@pytest.mark.parametrize("entity_details", ENTITIES_C, ids=lambda x: x[0].name) +@pytest.mark.parametrize("measure", MEASURES_C, ids=lambda x: x[0]) +@pytest.mark.parametrize( + "locations", + [ + [164, 165, 175], + ["Ethiopia", "Nigeria"], + [164, "Nigeria"], + ], +) +def test_pulling_multiple_locations(entity_details, measure, locations): + entity, entity_expected_measures = entity_details + measure_name, measure_id = measure + tester = success_expected if (entity_expected_measures & measure_id) else fail_expected + _df = tester(entity, measure_name, locations) From 115aef892998febaa5ba3e9799a0b176eb381aeb Mon Sep 17 00:00:00 2001 From: Steve Bachmeier Date: Thu, 19 Dec 2024 10:52:07 -0800 Subject: [PATCH 2/5] remove unwanted tests; fix broken ones --- src/vivarium_inputs/core.py | 36 ++-- src/vivarium_inputs/extract.py | 13 +- src/vivarium_inputs/utilities.py | 6 +- src/vivarium_inputs/validation/raw.py | 5 +- tests/e2e/test_get_measure.py | 2 +- tests/unit/test_core.py | 246 +++++++++++++------------- 6 files changed, 157 insertions(+), 151 deletions(-) diff --git a/src/vivarium_inputs/core.py b/src/vivarium_inputs/core.py index c89454ad..0598c5fa 100644 --- a/src/vivarium_inputs/core.py +++ b/src/vivarium_inputs/core.py @@ -176,8 +176,8 @@ def get_raw_incidence_rate( data = utilities.filter_data_by_restrictions( data, restrictions_entity, "yld", utility_data.get_age_group_ids() ) - data = utilities.normalize(data, data_type.value_columns, fill_value=0) data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns) + data = utilities.normalize(data, data_type.value_columns, fill_value=0) return data @@ -197,8 +197,8 @@ def get_prevalence( data = utilities.filter_data_by_restrictions( data, restrictions_entity, "yld", utility_data.get_age_group_ids() ) - data = utilities.normalize(data, data_type.value_columns, fill_value=0) data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns) + data = utilities.normalize(data, data_type.value_columns, fill_value=0) return data @@ -267,13 +267,13 @@ def get_disability_weight( data = extract.extract_data( entity, "disability_weight", location_id, years, data_type ) + data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns) data = utilities.normalize(data, data_type.value_columns) cause = [c for c in causes if c.sequelae and entity in c.sequelae][0] data = utilities.clear_disability_weight_outside_restrictions( data, cause, 0.0, utility_data.get_age_group_ids() ) - data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns) except (IndexError, DataDoesNotExistError): logger.warning( f"{entity.name.capitalize()} has no disability weight data. All values will be 0." @@ -300,8 +300,8 @@ def get_remission_rate( data = utilities.filter_data_by_restrictions( data, entity, "yld", utility_data.get_age_group_ids() ) - data = utilities.normalize(data, data_type.value_columns, fill_value=0) data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns) + data = utilities.normalize(data, data_type.value_columns, fill_value=0) return data @@ -381,8 +381,8 @@ def get_deaths( data = utilities.filter_data_by_restrictions( data, entity, "yll", utility_data.get_age_group_ids() ) - data = utilities.normalize(data, data_type.value_columns, fill_value=0) data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns) + data = utilities.normalize(data, data_type.value_columns, fill_value=0) return data @@ -413,6 +413,8 @@ def get_exposure( data, entity, "outer", utility_data.get_age_group_ids() ) + data = data.filter(DEMOGRAPHIC_COLUMNS + value_columns + ["parameter"]) + if entity.distribution in ["dichotomous", "ordered_polytomous", "unordered_polytomous"]: tmrel_cat = utility_data.get_tmrel_category(entity) exposed = data[data.parameter != tmrel_cat] @@ -437,7 +439,6 @@ def get_exposure( ) else: data = utilities.normalize(data, value_columns, fill_value=0) - data = data.filter(DEMOGRAPHIC_COLUMNS + value_columns + ["parameter"]) return data @@ -462,8 +463,8 @@ def get_exposure_standard_deviation( valid_age_groups = utilities.get_exposure_and_restriction_ages(exposure, entity) data = data[data.age_group_id.isin(valid_age_groups)] - data = utilities.normalize(data, data_type.value_columns, fill_value=0) data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns) + data = utilities.normalize(data, data_type.value_columns, fill_value=0) return data @@ -493,6 +494,7 @@ def get_exposure_distribution_weights( copied["age_group_id"] = age_id df.append(copied) data = pd.concat(df) + data = data.filter(DEMOGRAPHIC_COLUMNS + DISTRIBUTION_COLUMNS) data = utilities.normalize(data, DISTRIBUTION_COLUMNS, fill_value=0) if years != "all": if years: @@ -501,7 +503,6 @@ def get_exposure_distribution_weights( else: most_recent_year = utility_data.get_most_recent_year() data = data.query(f"year_id=={most_recent_year}") - data = data.filter(DEMOGRAPHIC_COLUMNS + DISTRIBUTION_COLUMNS) data = utilities.wide_to_long(data, DISTRIBUTION_COLUMNS, var_name="parameter") return data @@ -512,7 +513,6 @@ def get_relative_risk( years: int | str | list[int] | None, data_type: utilities.DataType, ) -> pd.DataFrame: - if data_type.type != "draws": raise utilities.DataTypeNotImplementedError( f"Data type(s) {data_type.type} are not supported for this function." @@ -524,8 +524,9 @@ def get_relative_risk( f"{location_id}." ) + breakpoint() data = extract.extract_data(entity, "relative_risk", location_id, years, data_type) - + breakpoint() # FIXME: we don't currently support yll-only causes so I'm dropping them because the data in some cases is # very messed up, with mort = morb = 1 (e.g., aortic aneurysm in the RR data for high systolic bp) - # 2/8/19 K.W. @@ -545,6 +546,7 @@ def get_relative_risk( + ["affected_entity", "affected_measure", "parameter"] + value_columns ) + breakpoint() data = ( data.groupby(["affected_entity", "parameter"]) .apply(utilities.normalize, cols_to_fill=value_columns, fill_value=1) @@ -625,18 +627,18 @@ def get_population_attributable_fraction( data = data.where(data[value_columns] > 0, 0).reset_index() data = utilities.convert_affected_entity(data, "cause_id") - data.loc[ - data["measure_id"] == MEASURES["YLLs"], "affected_measure" - ] = "excess_mortality_rate" + data.loc[data["measure_id"] == MEASURES["YLLs"], "affected_measure"] = ( + "excess_mortality_rate" + ) data.loc[data["measure_id"] == MEASURES["YLDs"], "affected_measure"] = "incidence_rate" + data = data.filter( + DEMOGRAPHIC_COLUMNS + ["affected_entity", "affected_measure"] + value_columns + ) data = ( data.groupby(["affected_entity", "affected_measure"]) .apply(utilities.normalize, cols_to_fill=value_columns, fill_value=0) .reset_index(drop=True) ) - data = data.filter( - DEMOGRAPHIC_COLUMNS + ["affected_entity", "affected_measure"] + value_columns - ) return data @@ -674,8 +676,8 @@ def get_utilization_rate( data_type: utilities.DataType, ) -> pd.DataFrame: data = extract.extract_data(entity, "utilization_rate", location_id, years, data_type) - data = utilities.normalize(data, data_type.value_columns, fill_value=0) data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns) + data = utilities.normalize(data, data_type.value_columns, fill_value=0) return data diff --git a/src/vivarium_inputs/extract.py b/src/vivarium_inputs/extract.py index a24935a5..43692bf6 100644 --- a/src/vivarium_inputs/extract.py +++ b/src/vivarium_inputs/extract.py @@ -112,7 +112,7 @@ def extract_data( validation.check_metadata(entity, measure) - year_id = _get_year_id(years) + year_id = _get_year_id(years, data_type) try: main_extractor, additional_extractors = extractors[measure] @@ -146,6 +146,10 @@ def extract_data( extra_draw_cols = [col for col in existing_draw_cols if col not in DRAW_COLUMNS] data = data.drop(columns=extra_draw_cols, errors="ignore") + # drop get_outputs data earlier than the estimation years + if data_type.type == "means": + data = data.loc[data["year_id"] >= min(utility_data.get_estimation_years())] + if validate: additional_data = { name: extractor(entity, location_id, year_id, data_type) @@ -167,11 +171,14 @@ def extract_data( #################### -def _get_year_id(years): +def _get_year_id(years, data_type): if years is None: # default to most recent year year_id = utility_data.get_most_recent_year() elif years == "all": - year_id = None + if data_type.type == "draws": + year_id = None + else: # means + year_id = "full" else: year_id = years if isinstance(years, list) else [years] estimation_years = utility_data.get_estimation_years() diff --git a/src/vivarium_inputs/utilities.py b/src/vivarium_inputs/utilities.py index f51bf40f..1da8a114 100644 --- a/src/vivarium_inputs/utilities.py +++ b/src/vivarium_inputs/utilities.py @@ -393,9 +393,9 @@ def clear_disability_weight_outside_restrictions( start, end = get_age_group_ids_by_restriction(cause, "yld") ages = get_restriction_age_ids(start, end, age_group_ids) - data.loc[ - (~data.sex_id.isin(sexes)) | (~data.age_group_id.isin(ages)), DRAW_COLUMNS - ] = fill_value + data.loc[(~data.sex_id.isin(sexes)) | (~data.age_group_id.isin(ages)), DRAW_COLUMNS] = ( + fill_value + ) return data diff --git a/src/vivarium_inputs/validation/raw.py b/src/vivarium_inputs/validation/raw.py index ef7a583c..f81f959d 100644 --- a/src/vivarium_inputs/validation/raw.py +++ b/src/vivarium_inputs/validation/raw.py @@ -59,8 +59,9 @@ class RawValidationContext: def __init__(self, location_id, **additional_data): self.context_data = {"location_id": location_id} self.context_data.update(additional_data) - - if "estimation_years" not in self.context_data: + if "estimation_years" not in self.context_data or self.context_data[ + "estimation_years" + ] == ["full"]: self.context_data["estimation_years"] = utility_data.get_estimation_years() if "age_group_ids" not in self.context_data: self.context_data["age_group_ids"] = utility_data.get_age_group_ids() diff --git a/tests/e2e/test_get_measure.py b/tests/e2e/test_get_measure.py index 739802e4..aed7b3cd 100644 --- a/tests/e2e/test_get_measure.py +++ b/tests/e2e/test_get_measure.py @@ -275,7 +275,7 @@ def test_get_measure_risklike( and not mock_gbd and data_type == "draws" ): - pytest.skip("FIXME: [mic-5543] continuous rrs cannot validate") + pytest.skip("FIXME: [mic-5542] continuous rrs cannot validate") # Handle not implemented is_unimplemented = isinstance(data_type, list) or data_type == "means" diff --git a/tests/unit/test_core.py b/tests/unit/test_core.py index f637b511..afc83be1 100644 --- a/tests/unit/test_core.py +++ b/tests/unit/test_core.py @@ -3,39 +3,35 @@ from tests.conftest import NO_GBD_ACCESS from vivarium_inputs import core, utility_data -from vivarium_inputs.mapping_extension import healthcare_entities +from vivarium_inputs.utilities import DataType, DataTypeNotImplementedError pytestmark = pytest.mark.skipif( NO_GBD_ACCESS, reason="Cannot run these tests without vivarium_gbd_access" ) -def success_expected(entity_name, measure_name, location): - df = core.get_data(entity_name, measure_name, location) - return df - - -def fail_expected(entity_name, measure_name, location): - with pytest.raises(Exception): - _df = core.get_data(entity_name, measure_name, location) - - -def check_year_in_data(entity, measure, location, years): - if isinstance(years, list): - df = core.get_data(entity, measure, location, years=years) - assert set(df.reset_index()["year_id"]) == set(years) - # years expected to be 1900, 2019, None, or "all" - elif years != 1900: - df = core.get_data(entity, measure, location, years=years) - if years == None: - assert set(df.reset_index()["year_id"]) == set([2021]) - elif years == 2019: - assert set(df.reset_index()["year_id"]) == set([2019]) - elif years == "all": - assert set(df.reset_index()["year_id"]) == set(range(1990, 2023)) +def check_year_in_data(entity, measure, location, years, data_type): + if _is_not_implemented(data_type, measure): + with pytest.raises(DataTypeNotImplementedError): + data_type = DataType(measure, data_type) + core.get_data(entity, measure, location, years, data_type) else: - with pytest.raises(ValueError, match="years must be in"): - df = core.get_data(entity, measure, location, years=years) + data_type = DataType(measure, data_type) + if isinstance(years, list): + df = core.get_data(entity, measure, location, years, data_type) + assert set(df.index.get_level_values("year_id")) == set(years) + # years expected to be 1900, 2019, None, or "all" + elif years != 1900: + df = core.get_data(entity, measure, location, years, data_type) + if years == None: + assert set(df.index.get_level_values("year_id")) == set([2021]) + elif years == "all": + assert set(df.index.get_level_values("year_id")) == set(range(1990, 2023)) + else: # a single (non-1900) year + assert set(df.index.get_level_values("year_id")) == set([years]) + else: + with pytest.raises(ValueError, match="years must be in"): + core.get_data(entity, measure, location, years, data_type) ENTITIES_C = [ @@ -92,22 +88,18 @@ def check_year_in_data(entity, measure, location, years): @pytest.mark.parametrize("entity_details", ENTITIES_C, ids=lambda x: x[0].name) -@pytest.mark.parametrize("measure", MEASURES_C, ids=lambda x: x[0]) -@pytest.mark.parametrize("location", LOCATIONS_C) -def test_core_causelike(entity_details, measure, location): - entity, entity_expected_measures = entity_details - tester = success_expected if measure in entity_expected_measures else fail_expected - _df = tester(entity, measure, utility_data.get_location_id(location)) - - -@pytest.mark.parametrize("entity_details", ENTITIES_C, ids=lambda x: x[0].name) -@pytest.mark.parametrize("measure", MEASURES_C, ids=lambda x: x[0]) -@pytest.mark.parametrize("location", LOCATIONS_C) -@pytest.mark.parametrize("years", [None, 2019, 1900, [2019], [2019, 2020, 2021], "all"]) -def test_year_id_causelike(entity_details, measure, location, years): +@pytest.mark.parametrize("measure", MEASURES_C, ids=lambda x: x) +@pytest.mark.parametrize("location", LOCATIONS_C, ids=lambda x: x) +@pytest.mark.parametrize( + "years", [None, 2019, 1900, [2019], [2019, 2020], "all"], ids=lambda x: str(x) +) +@pytest.mark.parametrize( + "data_type", ["draws", "means", ["draws", "means"]], ids=lambda x: str(x) +) +def test_year_id_causelike(entity_details, measure, location, years, data_type): entity, entity_expected_measures = entity_details if measure in entity_expected_measures: - check_year_in_data(entity, measure, location, years=years) + check_year_in_data(entity, measure, location, years, data_type) ENTITIES_R = [ @@ -141,42 +133,20 @@ def test_year_id_causelike(entity_details, measure, location, years): @pytest.mark.parametrize("entity_details", ENTITIES_R, ids=lambda x: x[0].name) -@pytest.mark.parametrize("measure", MEASURES_R, ids=lambda x: x[0]) -@pytest.mark.parametrize("location", LOCATIONS_R) -def test_core_risklike(entity_details, measure, location): - entity, entity_expected_measures = entity_details - if ( - entity.name == risk_factors.high_systolic_blood_pressure.name - and measure == "population_attributable_fraction" - ): - pytest.skip("MIC-4891") - tester = success_expected if measure in entity_expected_measures else fail_expected - _df = tester(entity, measure, utility_data.get_location_id(location)) - - -@pytest.mark.parametrize("entity_details", ENTITIES_R, ids=lambda x: x[0].name) -@pytest.mark.parametrize("measure", MEASURES_R, ids=lambda x: x[0]) -@pytest.mark.parametrize("location", LOCATIONS_R) -@pytest.mark.parametrize("years", [None, 2019, 1900, [2019], [2019, 2020, 2021], "all"]) -def test_year_id_risklike(entity_details, measure, location, years): +@pytest.mark.parametrize("measure", MEASURES_R, ids=lambda x: x) +@pytest.mark.parametrize("location", LOCATIONS_R, ids=lambda x: x) +@pytest.mark.parametrize( + "years", [None, 2019, 1900, [2019], [2019, 2020], "all"], ids=lambda x: str(x) +) +@pytest.mark.parametrize( + "data_type", ["draws", "means", ["draws", "means"]], ids=lambda x: str(x) +) +def test_year_id_risklike(entity_details, measure, location, years, data_type): + if measure in ["relative_risk", "population_attributable_fraction"]: + pytest.skip("TODO: mic-5245 punting until later b/c these are soooo slow") entity, entity_expected_measures = entity_details - # exposure-parametrized RRs for all years requires a lot of time and memory to process - if ( - entity == risk_factors.high_systolic_blood_pressure - and measure == "relative_risk" - and years == "all" - ): - pytest.skip(reason="need --runslow option to run") if measure in entity_expected_measures: - check_year_in_data(entity, measure, location, years=years) - - -@pytest.mark.slow # this test requires a lot of time and memory to run -@pytest.mark.parametrize("location", LOCATIONS_R) -def test_slow_year_id_risklike(location): - check_year_in_data( - risk_factors.high_systolic_blood_pressure, "relative_risk", location, years="all" - ) + check_year_in_data(entity, measure, location, years, data_type) ENTITIES_COV = [ @@ -188,69 +158,95 @@ def test_slow_year_id_risklike(location): @pytest.mark.parametrize("entity", ENTITIES_COV, ids=lambda x: x.name) @pytest.mark.parametrize("measure", MEASURES_COV, ids=lambda x: x) -@pytest.mark.parametrize("location", LOCATIONS_COV) -def test_core_covariatelike(entity, measure, location): - _df = core.get_data(entity, measure, utility_data.get_location_id(location)) - - -@pytest.mark.parametrize("entity", ENTITIES_COV, ids=lambda x: x.name) -@pytest.mark.parametrize("measure", MEASURES_COV, ids=lambda x: x) -@pytest.mark.parametrize("location", LOCATIONS_COV) -@pytest.mark.parametrize("years", [None, 2019, 1900, [2019], [2019, 2020, 2021], "all"]) -def test_year_id_covariatelike(entity, measure, location, years): - check_year_in_data(entity, measure, location, years=years) - - +@pytest.mark.parametrize("location", LOCATIONS_COV, ids=lambda x: x) @pytest.mark.parametrize( - "measures", - [ - "structure", - "age_bins", - "demographic_dimensions", - "theoretical_minimum_risk_life_expectancy", - ], + "years", [None, 2019, 1900, [2019], [2019, 2020], "all"], ids=lambda x: str(x) ) -def test_core_population(measures): - pop = ModelableEntity("ignored", "population", None) - _df = core.get_data(pop, measures, utility_data.get_location_id("India")) +@pytest.mark.parametrize( + "data_type", ["draws", "means", ["draws", "means"]], ids=lambda x: str(x) +) +def test_year_id_covariatelike(entity, measure, location, years, data_type): + check_year_in_data(entity, measure, location, years, data_type) -@pytest.mark.parametrize("measure", ["structure", "demographic_dimensions"]) -@pytest.mark.parametrize("years", [None, 2019, 1900, [2019], [2019, 2020, 2021], "all"]) -def test_year_id_population(measure, years): +@pytest.mark.parametrize("measure", ["structure", "demographic_dimensions"], ids=lambda x: x) +@pytest.mark.parametrize( + "years", [None, 2019, 1900, [2019], [2019, 2020], "all"], ids=lambda x: str(x) +) +@pytest.mark.parametrize( + "data_type", ["draws", "means", ["draws", "means"]], ids=lambda x: str(x) +) +def test_year_id_population(measure, years, data_type): pop = ModelableEntity("ignored", "population", None) location = utility_data.get_location_id("India") - check_year_in_data(pop, measure, location, years=years) - - -# TODO - Underlying problem with gbd access. Remove when corrected. -ENTITIES_HEALTH_SYSTEM = [ - healthcare_entities.outpatient_visits, -] -MEASURES_HEALTH_SYSTEM = ["utilization_rate"] -LOCATIONS_HEALTH_SYSTEM = ["India"] - - -@pytest.mark.skip(reason="Underlying problem with gbd access. Remove when corrected.") -@pytest.mark.parametrize("entity", ENTITIES_HEALTH_SYSTEM, ids=lambda x: x.name) -@pytest.mark.parametrize("measure", MEASURES_HEALTH_SYSTEM, ids=lambda x: x) -@pytest.mark.parametrize("location", LOCATIONS_HEALTH_SYSTEM) -def test_core_healthsystem(entity, measure, location): - _df = core.get_data(entity, measure, utility_data.get_location_id(location)) + check_year_in_data(pop, measure, location, years, data_type) @pytest.mark.parametrize("entity_details", ENTITIES_C, ids=lambda x: x[0].name) -@pytest.mark.parametrize("measure", MEASURES_C, ids=lambda x: x[0]) +@pytest.mark.parametrize("measure", MEASURES_C, ids=lambda x: x) @pytest.mark.parametrize( "locations", [ - [164, 165, 175], - ["Ethiopia", "Nigeria"], - [164, "Nigeria"], + "Ethiopia", # 179 + 179, + ["Ethiopia", "Nigeria"], # [179, 214] + [179, 214], + [179, "Nigeria"], ], + ids=lambda x: str(x), +) +@pytest.mark.parametrize( + "data_type", ["draws", "means", ["draws", "means"]], ids=lambda x: str(x) ) -def test_pulling_multiple_locations(entity_details, measure, locations): +def test_multiple_locations_causelike(entity_details, measure, locations, data_type): + year = 2021 + location_id_mapping = { + "Ethiopia": 179, + "Nigeria": 214, + } entity, entity_expected_measures = entity_details - measure_name, measure_id = measure - tester = success_expected if (entity_expected_measures & measure_id) else fail_expected - _df = tester(entity, measure_name, locations) + if _is_not_implemented(data_type, measure): + with pytest.raises(DataTypeNotImplementedError): + data_type = DataType(measure, data_type) + core.get_data(entity, measure, locations, year, data_type) + else: + data_type = DataType(measure, data_type) + if measure not in entity_expected_measures: + with pytest.raises(Exception): + core.get_data(entity, measure, locations, year, data_type) + else: + df = core.get_data(entity, measure, locations, year, data_type) + if not isinstance(locations, list): + locations = [locations] + location_ids = { + ( + location_id_mapping[item] + if isinstance(item, str) and item in location_id_mapping + else item + ) + for item in locations + } + assert set(df.index.get_level_values("location_id")) == set(location_ids) + + +# TODO: Should we add the location tests for other entity types? + + +def _is_not_implemented(data_type: str | list[str], measure: str) -> bool: + return isinstance(data_type, list) or ( + data_type == "means" + and measure + in [ + "disability_weight", + "remission_rate", + "cause_specific_mortality_rate", + "excess_mortality_rate", + "deaths", + "exposure", + "low_birth_weight_and_short_gestation", + "exposure_standard_deviation", + "exposure_distribution_weights", + "estimate", + "structure", + ] + ) From 9176fa295d98019f21e34d9ad0d51bf61016f44f Mon Sep 17 00:00:00 2001 From: Steve Bachmeier Date: Thu, 19 Dec 2024 11:36:15 -0800 Subject: [PATCH 3/5] black --- src/vivarium_inputs/core.py | 6 +++--- src/vivarium_inputs/utilities.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/vivarium_inputs/core.py b/src/vivarium_inputs/core.py index 0598c5fa..e942e05b 100644 --- a/src/vivarium_inputs/core.py +++ b/src/vivarium_inputs/core.py @@ -627,9 +627,9 @@ def get_population_attributable_fraction( data = data.where(data[value_columns] > 0, 0).reset_index() data = utilities.convert_affected_entity(data, "cause_id") - data.loc[data["measure_id"] == MEASURES["YLLs"], "affected_measure"] = ( - "excess_mortality_rate" - ) + data.loc[ + data["measure_id"] == MEASURES["YLLs"], "affected_measure" + ] = "excess_mortality_rate" data.loc[data["measure_id"] == MEASURES["YLDs"], "affected_measure"] = "incidence_rate" data = data.filter( DEMOGRAPHIC_COLUMNS + ["affected_entity", "affected_measure"] + value_columns diff --git a/src/vivarium_inputs/utilities.py b/src/vivarium_inputs/utilities.py index 1da8a114..f51bf40f 100644 --- a/src/vivarium_inputs/utilities.py +++ b/src/vivarium_inputs/utilities.py @@ -393,9 +393,9 @@ def clear_disability_weight_outside_restrictions( start, end = get_age_group_ids_by_restriction(cause, "yld") ages = get_restriction_age_ids(start, end, age_group_ids) - data.loc[(~data.sex_id.isin(sexes)) | (~data.age_group_id.isin(ages)), DRAW_COLUMNS] = ( - fill_value - ) + data.loc[ + (~data.sex_id.isin(sexes)) | (~data.age_group_id.isin(ages)), DRAW_COLUMNS + ] = fill_value return data From acde40915f9ee9659b55137ca7f5c11cca27f346 Mon Sep 17 00:00:00 2001 From: Steve Bachmeier Date: Thu, 19 Dec 2024 11:38:29 -0800 Subject: [PATCH 4/5] remove breakpoints --- src/vivarium_inputs/core.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/vivarium_inputs/core.py b/src/vivarium_inputs/core.py index e942e05b..ceb39bd0 100644 --- a/src/vivarium_inputs/core.py +++ b/src/vivarium_inputs/core.py @@ -524,9 +524,7 @@ def get_relative_risk( f"{location_id}." ) - breakpoint() data = extract.extract_data(entity, "relative_risk", location_id, years, data_type) - breakpoint() # FIXME: we don't currently support yll-only causes so I'm dropping them because the data in some cases is # very messed up, with mort = morb = 1 (e.g., aortic aneurysm in the RR data for high systolic bp) - # 2/8/19 K.W. @@ -546,7 +544,6 @@ def get_relative_risk( + ["affected_entity", "affected_measure", "parameter"] + value_columns ) - breakpoint() data = ( data.groupby(["affected_entity", "parameter"]) .apply(utilities.normalize, cols_to_fill=value_columns, fill_value=1) From 5a9257661dbcdd42c3c6d4176ccbd6822a0dadca Mon Sep 17 00:00:00 2001 From: Steve Bachmeier Date: Thu, 26 Dec 2024 11:20:42 -0800 Subject: [PATCH 5/5] changelog --- CHANGELOG.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ca29928c..4ac805d3 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,7 @@ +**5.2.3 - 12/26/24** + + - Bugfix: better handle 'all' years when requesting mean data + **5.2.2 - 11/13/24** - Bugfix to implement data type requests for remaining interface tests