From 4c10046865f04069d563a77ddfa725e9c5eacf71 Mon Sep 17 00:00:00 2001 From: Steve Bachmeier <23350991+stevebachmeier@users.noreply.github.com> Date: Fri, 27 Dec 2024 14:01:10 -0700 Subject: [PATCH] Sbachmei/mic 5245/handle unused tests (#388) * remove unwanted tests; fix broken ones * delete TEST_TO_IMPLEMENT/ dir * move mocked_gbd.py higher; other minor changes * refactor mocked_gbd to accept locations and years (partial implementation) * mock all the things --- TESTS_TO_IMPLEMENT/extract/test_core.py | 256 ---------------- TESTS_TO_IMPLEMENT/extract/test_extract.py | 134 --------- src/vivarium_inputs/core.py | 13 - tests/conftest.py | 21 ++ tests/e2e/test_get_measure.py | 50 ++-- tests/{e2e => }/mocked_gbd.py | 329 ++++++++++++++------- tests/unit/test_core.py | 250 ++++++++++------ 7 files changed, 421 insertions(+), 632 deletions(-) delete mode 100644 TESTS_TO_IMPLEMENT/extract/test_core.py delete mode 100644 TESTS_TO_IMPLEMENT/extract/test_extract.py rename tests/{e2e => }/mocked_gbd.py (63%) diff --git a/TESTS_TO_IMPLEMENT/extract/test_core.py b/TESTS_TO_IMPLEMENT/extract/test_core.py deleted file mode 100644 index f637b511..00000000 --- a/TESTS_TO_IMPLEMENT/extract/test_core.py +++ /dev/null @@ -1,256 +0,0 @@ -import pytest -from gbd_mapping import ModelableEntity, causes, covariates, risk_factors - -from tests.conftest import NO_GBD_ACCESS -from vivarium_inputs import core, utility_data -from vivarium_inputs.mapping_extension import healthcare_entities - -pytestmark = pytest.mark.skipif( - NO_GBD_ACCESS, reason="Cannot run these tests without vivarium_gbd_access" -) - - -def success_expected(entity_name, measure_name, location): - df = core.get_data(entity_name, measure_name, location) - return df - - -def fail_expected(entity_name, measure_name, location): - with pytest.raises(Exception): - _df = core.get_data(entity_name, measure_name, location) - - -def check_year_in_data(entity, measure, location, years): - if isinstance(years, list): - df = core.get_data(entity, measure, location, years=years) - assert set(df.reset_index()["year_id"]) == set(years) - # years expected to be 1900, 2019, None, or "all" - elif years != 1900: - df = core.get_data(entity, measure, location, years=years) - if years == None: - assert set(df.reset_index()["year_id"]) == set([2021]) - elif years == 2019: - assert set(df.reset_index()["year_id"]) == set([2019]) - elif years == "all": - assert set(df.reset_index()["year_id"]) == set(range(1990, 2023)) - else: - with pytest.raises(ValueError, match="years must be in"): - df = core.get_data(entity, measure, location, years=years) - - -ENTITIES_C = [ - ( - causes.measles, - [ - "incidence_rate", - "raw_incidence_rate", - "prevalence", - "disability_weight", - "cause_specific_mortality_rate", - "excess_mortality_rate", - "deaths", - ], - ), - ( - causes.diarrheal_diseases, - [ - "incidence_rate", - "raw_incidence_rate", - "prevalence", - "disability_weight", - "remission_rate", - "cause_specific_mortality_rate", - "excess_mortality_rate", - "deaths", - ], - ), - ( - causes.diabetes_mellitus_type_2, - [ - "incidence_rate", - "raw_incidence_rate", - "prevalence", - "disability_weight", - "cause_specific_mortality_rate", - "excess_mortality_rate", - "deaths", - ], - ), -] -MEASURES_C = [ - "incidence_rate", - "raw_incidence_rate", - "prevalence", - "birth_prevalence", - "disability_weight", - "remission_rate", - "cause_specific_mortality_rate", - "excess_mortality_rate", - "deaths", -] -LOCATIONS_C = ["India"] - - -@pytest.mark.parametrize("entity_details", ENTITIES_C, ids=lambda x: x[0].name) -@pytest.mark.parametrize("measure", MEASURES_C, ids=lambda x: x[0]) -@pytest.mark.parametrize("location", LOCATIONS_C) -def test_core_causelike(entity_details, measure, location): - entity, entity_expected_measures = entity_details - tester = success_expected if measure in entity_expected_measures else fail_expected - _df = tester(entity, measure, utility_data.get_location_id(location)) - - -@pytest.mark.parametrize("entity_details", ENTITIES_C, ids=lambda x: x[0].name) -@pytest.mark.parametrize("measure", MEASURES_C, ids=lambda x: x[0]) -@pytest.mark.parametrize("location", LOCATIONS_C) -@pytest.mark.parametrize("years", [None, 2019, 1900, [2019], [2019, 2020, 2021], "all"]) -def test_year_id_causelike(entity_details, measure, location, years): - entity, entity_expected_measures = entity_details - if measure in entity_expected_measures: - check_year_in_data(entity, measure, location, years=years) - - -ENTITIES_R = [ - ( - risk_factors.high_systolic_blood_pressure, - [ - "exposure", - "exposure_standard_deviation", - "exposure_distribution_weights", - "relative_risk", - "population_attributable_fraction", - ], - ), - ( - risk_factors.low_birth_weight_and_short_gestation, - [ - "exposure", - "relative_risk", - "population_attributable_fraction", - ], - ), -] -MEASURES_R = [ - "exposure", - "exposure_standard_deviation", - "exposure_distribution_weights", - "relative_risk", - "population_attributable_fraction", -] -LOCATIONS_R = ["India"] - - -@pytest.mark.parametrize("entity_details", ENTITIES_R, ids=lambda x: x[0].name) -@pytest.mark.parametrize("measure", MEASURES_R, ids=lambda x: x[0]) -@pytest.mark.parametrize("location", LOCATIONS_R) -def test_core_risklike(entity_details, measure, location): - entity, entity_expected_measures = entity_details - if ( - entity.name == risk_factors.high_systolic_blood_pressure.name - and measure == "population_attributable_fraction" - ): - pytest.skip("MIC-4891") - tester = success_expected if measure in entity_expected_measures else fail_expected - _df = tester(entity, measure, utility_data.get_location_id(location)) - - -@pytest.mark.parametrize("entity_details", ENTITIES_R, ids=lambda x: x[0].name) -@pytest.mark.parametrize("measure", MEASURES_R, ids=lambda x: x[0]) -@pytest.mark.parametrize("location", LOCATIONS_R) -@pytest.mark.parametrize("years", [None, 2019, 1900, [2019], [2019, 2020, 2021], "all"]) -def test_year_id_risklike(entity_details, measure, location, years): - entity, entity_expected_measures = entity_details - # exposure-parametrized RRs for all years requires a lot of time and memory to process - if ( - entity == risk_factors.high_systolic_blood_pressure - and measure == "relative_risk" - and years == "all" - ): - pytest.skip(reason="need --runslow option to run") - if measure in entity_expected_measures: - check_year_in_data(entity, measure, location, years=years) - - -@pytest.mark.slow # this test requires a lot of time and memory to run -@pytest.mark.parametrize("location", LOCATIONS_R) -def test_slow_year_id_risklike(location): - check_year_in_data( - risk_factors.high_systolic_blood_pressure, "relative_risk", location, years="all" - ) - - -ENTITIES_COV = [ - covariates.systolic_blood_pressure_mmhg, -] -MEASURES_COV = ["estimate"] -LOCATIONS_COV = ["India"] - - -@pytest.mark.parametrize("entity", ENTITIES_COV, ids=lambda x: x.name) -@pytest.mark.parametrize("measure", MEASURES_COV, ids=lambda x: x) -@pytest.mark.parametrize("location", LOCATIONS_COV) -def test_core_covariatelike(entity, measure, location): - _df = core.get_data(entity, measure, utility_data.get_location_id(location)) - - -@pytest.mark.parametrize("entity", ENTITIES_COV, ids=lambda x: x.name) -@pytest.mark.parametrize("measure", MEASURES_COV, ids=lambda x: x) -@pytest.mark.parametrize("location", LOCATIONS_COV) -@pytest.mark.parametrize("years", [None, 2019, 1900, [2019], [2019, 2020, 2021], "all"]) -def test_year_id_covariatelike(entity, measure, location, years): - check_year_in_data(entity, measure, location, years=years) - - -@pytest.mark.parametrize( - "measures", - [ - "structure", - "age_bins", - "demographic_dimensions", - "theoretical_minimum_risk_life_expectancy", - ], -) -def test_core_population(measures): - pop = ModelableEntity("ignored", "population", None) - _df = core.get_data(pop, measures, utility_data.get_location_id("India")) - - -@pytest.mark.parametrize("measure", ["structure", "demographic_dimensions"]) -@pytest.mark.parametrize("years", [None, 2019, 1900, [2019], [2019, 2020, 2021], "all"]) -def test_year_id_population(measure, years): - pop = ModelableEntity("ignored", "population", None) - location = utility_data.get_location_id("India") - check_year_in_data(pop, measure, location, years=years) - - -# TODO - Underlying problem with gbd access. Remove when corrected. -ENTITIES_HEALTH_SYSTEM = [ - healthcare_entities.outpatient_visits, -] -MEASURES_HEALTH_SYSTEM = ["utilization_rate"] -LOCATIONS_HEALTH_SYSTEM = ["India"] - - -@pytest.mark.skip(reason="Underlying problem with gbd access. Remove when corrected.") -@pytest.mark.parametrize("entity", ENTITIES_HEALTH_SYSTEM, ids=lambda x: x.name) -@pytest.mark.parametrize("measure", MEASURES_HEALTH_SYSTEM, ids=lambda x: x) -@pytest.mark.parametrize("location", LOCATIONS_HEALTH_SYSTEM) -def test_core_healthsystem(entity, measure, location): - _df = core.get_data(entity, measure, utility_data.get_location_id(location)) - - -@pytest.mark.parametrize("entity_details", ENTITIES_C, ids=lambda x: x[0].name) -@pytest.mark.parametrize("measure", MEASURES_C, ids=lambda x: x[0]) -@pytest.mark.parametrize( - "locations", - [ - [164, 165, 175], - ["Ethiopia", "Nigeria"], - [164, "Nigeria"], - ], -) -def test_pulling_multiple_locations(entity_details, measure, locations): - entity, entity_expected_measures = entity_details - measure_name, measure_id = measure - tester = success_expected if (entity_expected_measures & measure_id) else fail_expected - _df = tester(entity, measure_name, locations) diff --git a/TESTS_TO_IMPLEMENT/extract/test_extract.py b/TESTS_TO_IMPLEMENT/extract/test_extract.py deleted file mode 100644 index f6b9707d..00000000 --- a/TESTS_TO_IMPLEMENT/extract/test_extract.py +++ /dev/null @@ -1,134 +0,0 @@ -import pytest -from gbd_mapping import ModelableEntity, causes, covariates, risk_factors - -from tests.conftest import NO_GBD_ACCESS -from vivarium_inputs import extract, utility_data - -pytestmark = pytest.mark.skipif( - NO_GBD_ACCESS, reason="Cannot run these tests without vivarium_gbd_access" -) - - -VALIDATE_FLAG = False - - -def success_expected(entity_name, measure_name, location): - df = extract.extract_data(entity_name, measure_name, location, validate=VALIDATE_FLAG) - return df - - -def fail_expected(entity_name, measure_name, location): - with pytest.raises(Exception): - _df = extract.extract_data( - entity_name, measure_name, location, validate=VALIDATE_FLAG - ) - - -ENTITIES_C = [ - ( - causes.hiv_aids, - ["incidence_rate", "prevalence", "birth_prevalence", "remission_rate", "deaths"], - ), - ( - causes.neural_tube_defects, - ["incidence_rate", "prevalence", "birth_prevalence", "deaths"], - ), -] -MEASURES_C = [ - "incidence_rate", - "prevalence", - "birth_prevalence", - "disability_weight", - "remission_rate", - "deaths", -] -LOCATIONS_C = ["India"] - - -@pytest.mark.parametrize("entity", ENTITIES_C, ids=lambda x: x[0].name) -@pytest.mark.parametrize("measure", MEASURES_C, ids=lambda x: x[0]) -@pytest.mark.parametrize("location", LOCATIONS_C) -def test_extract_causelike(entity, measure, location): - entity_name, entity_expected_measures = entity - tester = success_expected if measure in entity_expected_measures else fail_expected - _df = tester(entity_name, measure, utility_data.get_location_id(location)) - - -ENTITIES_R = [ - ( - risk_factors.high_fasting_plasma_glucose, - [ - "exposure", - "exposure_standard_deviation", - "exposure_distribution_weights", - "relative_risk", - "population_attributable_fraction", - "etiology_population_attributable_fraction", - "mediation_factors", - ], - ), - ( - risk_factors.low_birth_weight_and_short_gestation, - [ - "exposure", - "relative_risk", - "population_attributable_fraction", - "etiology_population_attributable_fraction", - ], - ), -] -MEASURES_R = [ - "exposure", - "exposure_standard_deviation", - "exposure_distribution_weights", - # "relative_risk", # TODO: Add back in with Mic-4936 - "population_attributable_fraction", - "etiology_population_attributable_fraction", - "mediation_factors", -] -LOCATIONS_R = ["India"] - - -@pytest.mark.parametrize("entity", ENTITIES_R, ids=lambda x: x[0].name) -@pytest.mark.parametrize("measure", MEASURES_R, ids=lambda x: x[0]) -@pytest.mark.parametrize("location", LOCATIONS_R) -def test_extract_risklike(entity, measure, location): - entity_name, entity_expected_measures = entity - tester = success_expected if measure in entity_expected_measures else fail_expected - _df = tester(entity_name, measure, utility_data.get_location_id(location)) - - -entity_cov = [ - covariates.systolic_blood_pressure_mmhg, -] -measures_cov = ["estimate"] -locations_cov = ["India"] - - -@pytest.mark.parametrize("entity", entity_cov) -@pytest.mark.parametrize("measure", measures_cov) -@pytest.mark.parametrize("location", locations_cov) -def test_extract_covariatelike(entity, measure, location): - _df = extract.extract_data( - entity, measure, utility_data.get_location_id(location), validate=VALIDATE_FLAG - ) - - -@pytest.mark.parametrize( - "measures", ["structure", "theoretical_minimum_risk_life_expectancy"] -) -def test_extract_population(measures): - pop = ModelableEntity("ignored", "population", None) - _df = extract.extract_data( - pop, measures, utility_data.get_location_id("India"), validate=VALIDATE_FLAG - ) - - -# TODO: Remove with Mic-4936 -@pytest.mark.parametrize("entity", ENTITIES_R, ids=lambda x: x[0].name) -@pytest.mark.parametrize("location", LOCATIONS_R) -@pytest.mark.xfail(reason="New relative risk data is not set up for processing yet") -def test_extract_relative_risk(entity, location): - measure_name = "relative_risk" - entity_name, _entity_expected_measures = entity - _df = extract.extract_data(entity_name, measure_name, location) diff --git a/src/vivarium_inputs/core.py b/src/vivarium_inputs/core.py index ceb39bd0..d7eaf29d 100644 --- a/src/vivarium_inputs/core.py +++ b/src/vivarium_inputs/core.py @@ -665,19 +665,6 @@ def get_estimate( return data -# FIXME: can this be deleted? It's not in the get_data() mapping. -def get_utilization_rate( - entity: HealthcareEntity, - location_id: list[int], - years: int | str | list[int] | None, - data_type: utilities.DataType, -) -> pd.DataFrame: - data = extract.extract_data(entity, "utilization_rate", location_id, years, data_type) - data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns) - data = utilities.normalize(data, data_type.value_columns, fill_value=0) - return data - - def get_structure( entity: Population, location_id: list[int], diff --git a/tests/conftest.py b/tests/conftest.py index 9bc1fc1c..dbddc6d1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -64,3 +64,24 @@ def _no_gbd_access(): NO_GBD_ACCESS = _no_gbd_access() + + +def is_not_implemented(data_type: str | list[str], measure: str) -> bool: + return isinstance(data_type, list) or ( + data_type == "means" + and measure + in [ + "disability_weight", + "remission_rate", + "cause_specific_mortality_rate", + "excess_mortality_rate", + "deaths", + "exposure", + "exposure_standard_deviation", + "exposure_distribution_weights", + "relative_risk", + "population_attributable_fraction", + "estimate", + "structure", + ] + ) diff --git a/tests/e2e/test_get_measure.py b/tests/e2e/test_get_measure.py index aed7b3cd..a2950e29 100644 --- a/tests/e2e/test_get_measure.py +++ b/tests/e2e/test_get_measure.py @@ -20,10 +20,9 @@ from layered_config_tree import LayeredConfigTree from pytest_mock import MockerFixture -from tests.conftest import NO_GBD_ACCESS -from tests.e2e.mocked_gbd import ( - LOCATION, - YEAR, +from tests.conftest import NO_GBD_ACCESS, is_not_implemented +from tests.mocked_gbd import ( + MOST_RECENT_YEAR, get_mocked_age_bins, mock_vivarium_gbd_access, ) @@ -35,6 +34,8 @@ SLOW_TEST_DAY = "Sunday" # Day to run very slow tests, e.g. PAFs and RRs +LOCATION = "India" + # TODO [MIC-5448]: Move to vivarium_testing_utilties @pytest.fixture(autouse=True) def no_cache(mocker: MockerFixture) -> None: @@ -50,6 +51,7 @@ def no_cache(mocker: MockerFixture) -> None: CAUSES = [ # (entity, applicable_measures) # NOTE: 'raw_incidence_rate' and 'deaths' should not be called directly from `get_measure()` + # because there are no implemented validations for them. ( causes.measles, [ @@ -129,16 +131,7 @@ def test_get_measure_causelike( if measure == "birth_prevalence": pytest.skip("FIXME: need to find causes with birth prevalence") - # Handle not implemented - is_unimplemented_means = data_type == "means" and measure in [ - "disability_weight", - "remission_rate", - "cause_specific_mortality_rate", - "excess_mortality_rate", - ] - is_unimplemented = isinstance(data_type, list) or is_unimplemented_means - - run_test(entity_details, measure, data_type, mock_gbd, runslow, mocker, is_unimplemented) + run_test(entity_details, measure, data_type, mock_gbd, runslow, mocker) SEQUELAE = [ @@ -197,13 +190,7 @@ def test_get_measure_sequelalike( if measure == "birth_prevalence": pytest.skip("FIXME: need to find sequelae with birth prevalence") - # Handle not implemented - is_unimplemented_means = data_type == "means" and measure in [ - "disability_weight", - ] - is_unimplemented = isinstance(data_type, list) or is_unimplemented_means - - run_test(entity_details, measure, data_type, mock_gbd, runslow, mocker, is_unimplemented) + run_test(entity_details, measure, data_type, mock_gbd, runslow, mocker) RISK_FACTORS = [ @@ -277,10 +264,7 @@ def test_get_measure_risklike( ): pytest.skip("FIXME: [mic-5542] continuous rrs cannot validate") - # Handle not implemented - is_unimplemented = isinstance(data_type, list) or data_type == "means" - - run_test(entity_details, measure, data_type, mock_gbd, runslow, mocker, is_unimplemented) + run_test(entity_details, measure, data_type, mock_gbd, runslow, mocker) COVARIATES = [(covariates.systolic_blood_pressure_mmhg, ["estimate"])] @@ -320,10 +304,7 @@ def test_get_measure_covariatelike( mocked tests are not slow but unmocked tests are). """ - # Handle not implemented - is_unimplemented = isinstance(data_type, list) or data_type == "means" - - run_test(entity_details, measure, data_type, mock_gbd, runslow, mocker, is_unimplemented) + run_test(entity_details, measure, data_type, mock_gbd, runslow, mocker) # TODO [MIC-5550]: Add tests for etiologies and alternative risk factors @@ -369,7 +350,6 @@ def run_test( mock_gbd: bool, runslow: bool, mocker: MockerFixture, - is_unimplemented: bool, raise_type: Exception = None, ): entity, entity_expected_measures = entity_details @@ -388,6 +368,8 @@ def run_test( pytest.skip(f"Only run full PAF and RR tests on {SLOW_TEST_DAY}s") tester = success_expected if measure in entity_expected_measures else fail_expected + is_unimplemented = is_not_implemented(data_type, measure) + if is_unimplemented: # This should trigger first tester = partial(fail_expected, raise_type=DataTypeNotImplementedError) elif raise_type: @@ -402,7 +384,9 @@ def run_test( pytest.skip("Cannot mock data for unimplemented features.") if tester == fail_expected: pytest.skip("Do mock data for expected failed calls.") - mocked_funcs = mock_vivarium_gbd_access(entity, measure, data_type, mocker) + mocked_funcs = mock_vivarium_gbd_access( + entity, measure, LOCATION, MOST_RECENT_YEAR, data_type, mocker + ) tester(entity, measure, utility_data.get_location_id(LOCATION), data_type) if mock_gbd: @@ -440,8 +424,8 @@ def check_data( "sex": {"Male", "Female"}, "age_start": set(age_bins["age_group_years_start"]), "age_end": set(age_bins["age_group_years_end"]), - "year_start": {YEAR}, - "year_end": {YEAR + 1}, + "year_start": {MOST_RECENT_YEAR}, + "year_end": {MOST_RECENT_YEAR + 1}, } if not getattr(entity, "by_age", True): # Some entities do not have ages diff --git a/tests/e2e/mocked_gbd.py b/tests/mocked_gbd.py similarity index 63% rename from tests/e2e/mocked_gbd.py rename to tests/mocked_gbd.py index 711296ad..b5db3c86 100644 --- a/tests/e2e/mocked_gbd.py +++ b/tests/mocked_gbd.py @@ -9,13 +9,13 @@ import numpy as np import pandas as pd +from gbd_mapping import ModelableEntity from pytest_mock import MockerFixture from vivarium_inputs import utility_data from vivarium_inputs.globals import DRAW_COLUMNS, MEAN_COLUMNS, MEASURES -LOCATION = "India" -YEAR = 2021 # most recent year (used when default None is provided to get_measure()) +MOST_RECENT_YEAR = 2021 DUMMY_INT = 1234 DUMMY_STR = "foo" DUMMY_FLOAT = 123.4 @@ -24,6 +24,8 @@ def mock_vivarium_gbd_access( entity, measure: str, + locations: str | int | list[str | int], + years: int | list[int] | str | None, data_type: str | list[str], mocker: MockerFixture, ) -> list["mocker.Mock"]: @@ -50,7 +52,9 @@ def mock_vivarium_gbd_access( # TODO [MIC-5461]: Speed up mocked data generation # Generic/small mocks that may or may not actually be called for this specific test - mocker.patch("vivarium_inputs.utility_data.get_most_recent_year", return_value=YEAR) + mocker.patch( + "vivarium_inputs.utility_data.get_most_recent_year", return_value=MOST_RECENT_YEAR + ) mocker.patch( "vivarium_inputs.utility_data.get_estimation_years", return_value=get_mocked_estimation_years(), @@ -80,58 +84,91 @@ def mock_vivarium_gbd_access( "draws": mocked_get_draws, }[data_type] + gbd_id = int(entity.gbd_id) if entity.gbd_id else None entity_specific_metadata_mapper = { "Cause": { - "cause_id": int(entity.gbd_id), + "cause_id": gbd_id, "acause": DUMMY_STR, "cause_name": DUMMY_STR, }, "Sequela": { - "sequela_id": int(entity.gbd_id), + "sequela_id": gbd_id, "sequela_name": DUMMY_STR, }, "RiskFactor": { - "rei_id": int(entity.gbd_id), + "rei_id": gbd_id, }, "Covariate": { "covariate_id": DUMMY_INT, "covariate_name_short": DUMMY_STR, }, } - entity_specific_metadata = entity_specific_metadata_mapper[entity.__class__.__name__] - if measure == "incidence_rate": + entity_specific_metadata = entity_specific_metadata_mapper.get( + entity.__class__.__name__, {} + ) + + # Convert years and locations to lists of IDs + if not years: + year_ids = [MOST_RECENT_YEAR] + elif years == "all": + estimation_years = get_mocked_estimation_years() + year_ids = list(range(min(estimation_years), max(estimation_years) + 1)) + elif not isinstance(years, list): + year_ids = [int(years)] + else: + year_ids = years + + location_ids = locations if isinstance(locations, list) else [locations] + location_ids = [ + utility_data.get_location_id(loc) if isinstance(loc, str) else loc + for loc in location_ids + ] + + if measure in ["incidence_rate", "raw_incidence_rate"]: mocked_extract_incidence_rate = mocker.patch( "vivarium_inputs.extract.extract_incidence_rate", - return_value=mocked_data_func(measure, entity, **entity_specific_metadata), + return_value=mocked_data_func( + measure, entity, location_ids, year_ids, **entity_specific_metadata + ), ) mocked_extract_prevalence = mocker.patch( "vivarium_inputs.extract.extract_prevalence", - return_value=mocked_data_func("prevalence", entity, **entity_specific_metadata), + return_value=mocked_data_func( + "prevalence", entity, location_ids, year_ids, **entity_specific_metadata + ), ) mocked_funcs = [mocked_extract_incidence_rate, mocked_extract_prevalence] elif measure == "prevalence": mock = mocker.patch( "vivarium_inputs.extract.extract_prevalence", - return_value=mocked_data_func(measure, entity, **entity_specific_metadata), + return_value=mocked_data_func( + measure, entity, location_ids, year_ids, **entity_specific_metadata + ), ) mocked_funcs = [mock] elif measure == "disability_weight": mock = mocker.patch( "vivarium_inputs.extract.extract_disability_weight", - return_value=mocked_data_func(measure, entity, **entity_specific_metadata), + return_value=mocked_data_func( + measure, entity, location_ids, year_ids, **entity_specific_metadata + ), ) combined_metadata = entity_specific_metadata_mapper["Cause"] combined_metadata.update(entity_specific_metadata_mapper["Sequela"]) mocked_extract_prevalence = mocker.patch( "vivarium_inputs.extract.extract_prevalence", - return_value=mocked_data_func("prevalence", entity, **combined_metadata), + return_value=mocked_data_func( + "prevalence", entity, location_ids, year_ids, **combined_metadata + ), ) mocked_funcs = [mock] elif measure == "remission_rate": mock = mocker.patch( "vivarium_inputs.extract.extract_remission_rate", - return_value=mocked_data_func(measure, entity, **entity_specific_metadata), + return_value=mocked_data_func( + measure, entity, location_ids, year_ids, **entity_specific_metadata + ), ) mocked_funcs = [mock] elif measure == "cause_specific_mortality_rate": @@ -139,16 +176,18 @@ def mock_vivarium_gbd_access( del entity_specific_metadata["cause_name"] mocked_extract_deaths = mocker.patch( "vivarium_inputs.extract.extract_deaths", - return_value=mocked_data_func("deaths", entity, **entity_specific_metadata), + return_value=mocked_data_func( + "deaths", entity, location_ids, year_ids, **entity_specific_metadata + ), ) mocked_extract_structure = mocker.patch( "vivarium_inputs.extract.extract_structure", - return_value=mocked_data_func("structure", entity), + return_value=mocked_data_func("structure", entity, location_ids, year_ids), ) mocked_funcs = [mocked_extract_deaths, mocked_extract_structure] elif measure == "excess_mortality_rate": mocked_prevalence_data = mocked_data_func( - "prevalence", entity, **entity_specific_metadata + "prevalence", entity, location_ids, year_ids, **entity_specific_metadata ) mocked_extract_prevalence = mocker.patch( "vivarium_inputs.extract.extract_prevalence", @@ -161,11 +200,13 @@ def mock_vivarium_gbd_access( del death_kwargs["cause_name"] mocked_extract_deaths = mocker.patch( "vivarium_inputs.extract.extract_deaths", - return_value=mocked_data_func("deaths", entity, **entity_specific_metadata), + return_value=mocked_data_func( + "deaths", entity, location_ids, year_ids, **entity_specific_metadata + ), ) mocked_extract_structure = mocker.patch( "vivarium_inputs.extract.extract_structure", - return_value=mocked_data_func("structure", entity), + return_value=mocked_data_func("structure", entity, location_ids, year_ids), ) mocked_funcs = [ mocked_extract_prevalence, @@ -175,63 +216,103 @@ def mock_vivarium_gbd_access( elif measure == "exposure": mock = mocker.patch( "vivarium_inputs.extract.extract_exposure", - return_value=mocked_data_func(measure, entity, **entity_specific_metadata), + return_value=mocked_data_func( + measure, entity, location_ids, year_ids, **entity_specific_metadata + ), ) mocked_funcs = [mock] elif measure == "exposure_standard_deviation": mocked_exposure_sd = mocker.patch( "vivarium_inputs.extract.extract_exposure_standard_deviation", - return_value=mocked_data_func(measure, entity, **entity_specific_metadata), + return_value=mocked_data_func( + measure, entity, location_ids, year_ids, **entity_specific_metadata + ), ) mocked_exposure = mocker.patch( "vivarium_inputs.extract.extract_exposure", - return_value=mocked_data_func("exposure", entity, **entity_specific_metadata), + return_value=mocked_data_func( + "exposure", entity, location_ids, year_ids, **entity_specific_metadata + ), ) mocked_funcs = [mocked_exposure_sd, mocked_exposure] elif measure == "exposure_distribution_weights": mocked_exposure_distribution_weights = mocker.patch( "vivarium_inputs.extract.extract_exposure_distribution_weights", - return_value=mocked_data_func(measure, entity, **entity_specific_metadata), + return_value=mocked_data_func( + measure, entity, location_ids, year_ids, **entity_specific_metadata + ), ) mocked_exposure = mocker.patch( "vivarium_inputs.extract.extract_exposure", - return_value=mocked_data_func("exposure", entity, **entity_specific_metadata), + return_value=mocked_data_func( + "exposure", entity, location_ids, year_ids, **entity_specific_metadata + ), ) mocked_funcs = [mocked_exposure_distribution_weights, mocked_exposure] elif measure == "relative_risk": mocked_rr = mocker.patch( "vivarium_inputs.extract.extract_relative_risk", return_value=mocked_data_func( - "relative_risk", entity, **entity_specific_metadata + "relative_risk", entity, location_ids, year_ids, **entity_specific_metadata ), ) mocked_exposure = mocker.patch( "vivarium_inputs.extract.extract_exposure", - return_value=mocked_data_func("exposure", entity, **entity_specific_metadata), + return_value=mocked_data_func( + "exposure", entity, location_ids, year_ids, **entity_specific_metadata + ), ) mocked_funcs = [mocked_rr, mocked_exposure] elif measure == "population_attributable_fraction": mocked_pafs = mocker.patch( "vivarium_inputs.extract.extract_population_attributable_fraction", - return_value=mocked_data_func(measure, entity, **entity_specific_metadata), + return_value=mocked_data_func( + measure, entity, location_ids, year_ids, **entity_specific_metadata + ), ) mocked_exposure = mocker.patch( "vivarium_inputs.extract.extract_exposure", - return_value=mocked_data_func("exposure", entity, **entity_specific_metadata), + return_value=mocked_data_func( + "exposure", entity, location_ids, year_ids, **entity_specific_metadata + ), ) mocked_rr = mocker.patch( "vivarium_inputs.extract.extract_relative_risk", return_value=mocked_data_func( - "relative_risk", entity, **entity_specific_metadata + "relative_risk", entity, location_ids, year_ids, **entity_specific_metadata ), ) mocked_funcs = [mocked_pafs, mocked_exposure, mocked_rr] elif measure == "estimate": mock = mocker.patch( "vivarium_inputs.extract.extract_estimate", - return_value=mocked_data_func(measure, entity, **entity_specific_metadata), + return_value=mocked_data_func( + measure, entity, location_ids, year_ids, **entity_specific_metadata + ), ) mocked_funcs = [mock] + elif measure == "structure": + mock = mocker.patch( + "vivarium_inputs.extract.extract_structure", + return_value=mocked_data_func( + "structure", entity, location_ids, year_ids, **entity_specific_metadata + ), + ) + mocked_funcs = [mock] + elif measure == "demographic_dimensions": + mocked_funcs = [] + elif measure == "deaths": + mocked_extract_deaths = mocker.patch( + "vivarium_inputs.extract.extract_deaths", + return_value=mocked_data_func( + "deaths", entity, location_ids, year_ids, **entity_specific_metadata + ), + ) + mocked_extract_structure = mocker.patch( + "vivarium_inputs.extract.extract_structure", + return_value=mocked_data_func("structure", entity, location_ids, year_ids), + ) + mocked_funcs = [mocked_extract_deaths, mocked_extract_structure] else: raise NotImplementedError(f"Unexpected measure: {measure}") return mocked_funcs @@ -267,12 +348,19 @@ def get_mocked_location_ids() -> pd.DataFrame: ######################### -def mocked_get_draws(measure: str, entity, **entity_specific_metadata) -> pd.DataFrame: +def mocked_get_draws( + measure: str, + entity: ModelableEntity, + locations: list[int], + years: list[int], + **entity_specific_metadata, +) -> pd.DataFrame: """Mocked vivarium_gbd_access get_draws() data for testing.""" # Get the common data for the specific measure (regardless of entity type) df = { "incidence_rate": get_mocked_incidence_rate_get_draws, + "raw_incidence_rate": get_mocked_incidence_rate_get_draws, # mock same as not-raw "prevalence": get_mocked_prevalence_get_draws, "disability_weight": get_mocked_dw_get_draws, "remission_rate": get_mocked_remission_rate_get_draws, @@ -284,7 +372,7 @@ def mocked_get_draws(measure: str, entity, **entity_specific_metadata) -> pd.Dat "population_attributable_fraction": get_mocked_pafs_get_draws, "relative_risk": partial(get_mocked_rr_get_draws, entity), "estimate": get_mocked_estimate_get_draws, - }[measure]() + }[measure](locations, years) # Add on entity-specific metadata columns for key, value in entity_specific_metadata.items(): @@ -293,19 +381,19 @@ def mocked_get_draws(measure: str, entity, **entity_specific_metadata) -> pd.Dat return df -def get_mocked_incidence_rate_get_draws() -> pd.DataFrame: +def get_mocked_incidence_rate_get_draws( + locations: list[int], years: list[int] +) -> pd.DataFrame: age_group_ids = get_mocked_age_bins()["age_group_id"] sex_ids = [1, 2] # Initiate df with all possible combinations of variable metadata columns df = pd.DataFrame( - list(itertools.product(age_group_ids, sex_ids)), - columns=["age_group_id", "sex_id"], + list(itertools.product(age_group_ids, sex_ids, locations, years)), + columns=["age_group_id", "sex_id", "location_id", "year_id"], ) # Add on other metadata columns - df["location_id"] = utility_data.get_location_id(LOCATION) - df["year_id"] = YEAR df["measure_id"] = 6 # incidence df["metric_id"] = 3 # rate df["version_id"] = DUMMY_INT @@ -315,19 +403,17 @@ def get_mocked_incidence_rate_get_draws() -> pd.DataFrame: return df -def get_mocked_prevalence_get_draws() -> pd.DataFrame: +def get_mocked_prevalence_get_draws(locations: list[int], years: list[int]) -> pd.DataFrame: age_group_ids = get_mocked_age_bins()["age_group_id"] sex_ids = [1, 2] # Initiate df with all possible combinations of variable metadata columns df = pd.DataFrame( - list(itertools.product(age_group_ids, sex_ids)), - columns=["age_group_id", "sex_id"], + list(itertools.product(age_group_ids, sex_ids, locations, years)), + columns=["age_group_id", "sex_id", "location_id", "year_id"], ) # Add on other metadata columns - df["location_id"] = utility_data.get_location_id(LOCATION) - df["year_id"] = YEAR df["measure_id"] = 5 # prevalence df["metric_id"] = 3 # rate df["version_id"] = DUMMY_INT @@ -337,39 +423,46 @@ def get_mocked_prevalence_get_draws() -> pd.DataFrame: return df -def get_mocked_dw_get_draws() -> pd.DataFrame: +def get_mocked_dw_get_draws(locations: list[int], years: list[int]) -> pd.DataFrame: + # Initiate df with all possible combinations of variable metadata columns df = pd.DataFrame( - { - "location_id": utility_data.get_location_id(LOCATION), - "year_id": YEAR, - "age_group_id": 22, - "sex_id": 3, - "measure": "disability_weight", - "healthstate_id": DUMMY_FLOAT, - "healthstate": DUMMY_STR, - }, - index=[0], + list(itertools.product(locations, years)), + columns=["location_id", "year_id"], ) + + # Add on other metadata columns + df["age_group_id"] = 22 + df["sex_id"] = 3 + df["measure"] = "disability_weight" + df["healthstate_id"] = DUMMY_FLOAT + df["healthstate"] = DUMMY_STR + # We set the values here very low to avoid validation errors _add_value_columns(df, DRAW_COLUMNS, 0.0, 0.1) return df -def get_mocked_remission_rate_get_draws() -> pd.DataFrame: +def get_mocked_remission_rate_get_draws( + locations: list[int], years: list[int] +) -> pd.DataFrame: age_bins = get_mocked_age_bins() age_group_ids = list(age_bins["age_group_id"]) sex_ids = [1, 2] + # HACK: Remission rates are binned by year. If `years`` is a full continuous range of + # the estimation years, then likely the user requested data for "all" years. + # If this is the case, we set years to the estimation years. + estimation_years = get_mocked_estimation_years() + if years == list(range(min(estimation_years), max(estimation_years) + 1)): + years = estimation_years # Initiate df with all possible combinations of variable metadata columns df = pd.DataFrame( - list(itertools.product(age_group_ids, sex_ids)), - columns=["age_group_id", "sex_id"], + list(itertools.product(age_group_ids, sex_ids, locations, years)), + columns=["age_group_id", "sex_id", "location_id", "year_id"], ) # Add on other metadata columns - df["location_id"] = utility_data.get_location_id(LOCATION) - df["year_id"] = YEAR df["measure_id"] = 7 # remission df["metric_id"] = 3 # rate df["model_version_id"] = DUMMY_INT @@ -380,20 +473,18 @@ def get_mocked_remission_rate_get_draws() -> pd.DataFrame: return df -def get_mocked_deaths_get_draws() -> pd.DataFrame: +def get_mocked_deaths_get_draws(locations: list[int], years: list[int]) -> pd.DataFrame: age_bins = get_mocked_age_bins() age_group_ids = list(age_bins["age_group_id"]) sex_ids = [1, 2] # Initiate df with all possible combinations of variable metadata columns df = pd.DataFrame( - list(itertools.product(age_group_ids, sex_ids)), - columns=["age_group_id", "sex_id"], + list(itertools.product(age_group_ids, sex_ids, locations, years)), + columns=["age_group_id", "sex_id", "location_id", "year_id"], ) # Add on other metadata columns - df["location_id"] = utility_data.get_location_id(LOCATION) - df["year_id"] = YEAR df["measure_id"] = 1 # deaths df["metric_id"] = 1 # number df["version_id"] = DUMMY_INT @@ -405,22 +496,18 @@ def get_mocked_deaths_get_draws() -> pd.DataFrame: return df -def get_mocked_structure_get_draws() -> pd.DataFrame: - # Populations is difficult to mock at the age-group level so just load it - # return pd.read_csv(f"tests/fixture_data/population_{LOCATION.lower()}_{YEAR}.csv") +def get_mocked_structure_get_draws(locations: list[int], years: list[int]) -> pd.DataFrame: age_bins = get_mocked_age_bins() age_group_ids = list(age_bins["age_group_id"]) sex_ids = [1, 2, 3] # Initiate df with all possible combinations of variable metadata columns df = pd.DataFrame( - list(itertools.product(age_group_ids, sex_ids)), - columns=["age_group_id", "sex_id"], + list(itertools.product(age_group_ids, sex_ids, locations, years)), + columns=["age_group_id", "sex_id", "location_id", "year_id"], ) # Add on other metadata columns - df["location_id"] = utility_data.get_location_id(LOCATION) - df["year_id"] = YEAR df["run_id"] = DUMMY_INT _add_value_columns(df, ["population"], 1.0e6, 100.0e6) @@ -428,15 +515,17 @@ def get_mocked_structure_get_draws() -> pd.DataFrame: return df -def get_mocked_exposure_get_draws(entity) -> pd.DataFrame: +def get_mocked_exposure_get_draws( + entity, locations: list[int], years: list[int] +) -> pd.DataFrame: if entity.name == "low_birth_weight_and_short_gestation": age_group_ids = [2, 3] sex_ids = [1, 2] parameters = list(entity.categories.to_dict()) # Initiate df with all possible combinations of variable metadata columns df = pd.DataFrame( - list(itertools.product(age_group_ids, sex_ids, parameters)), - columns=["age_group_id", "sex_id", "parameter"], + list(itertools.product(age_group_ids, sex_ids, parameters, locations, years)), + columns=["age_group_id", "sex_id", "parameter", "location_id", "year_id"], ) # Add on other metadata columns df["modelable_entity_id"] = DUMMY_FLOAT # b/c nans come in @@ -447,8 +536,8 @@ def get_mocked_exposure_get_draws(entity) -> pd.DataFrame: sex_ids = [1, 2] # Initiate df with all possible combinations of variable metadata columns df = pd.DataFrame( - list(itertools.product(age_group_ids, sex_ids)), - columns=["age_group_id", "sex_id"], + list(itertools.product(age_group_ids, sex_ids, locations, years)), + columns=["age_group_id", "sex_id", "location_id", "year_id"], ) # Add on other metadata columns df["modelable_entity_id"] = DUMMY_INT @@ -457,28 +546,24 @@ def get_mocked_exposure_get_draws(entity) -> pd.DataFrame: else: raise NotImplementedError(f"{entity.name} not implemented in mocked_gbd.py") - df["location_id"] = utility_data.get_location_id(LOCATION) - df["year_id"] = YEAR df["measure_id"] = 19 # continuous df["metric_id"] = 3 # rate return df -def get_mocked_exposure_sd_get_draws() -> pd.DataFrame: +def get_mocked_exposure_sd_get_draws(locations: list[int], years: list[int]) -> pd.DataFrame: age_bins = get_mocked_age_bins() age_group_ids = list(age_bins["age_group_id"]) sex_ids = [1, 2] # Initiate df with all possible combinations of variable metadata columns df = pd.DataFrame( - list(itertools.product(age_group_ids, sex_ids)), - columns=["age_group_id", "sex_id"], + list(itertools.product(age_group_ids, sex_ids, locations, years)), + columns=["age_group_id", "sex_id", "location_id", "year_id"], ) # Add on other metadata columns - df["location_id"] = utility_data.get_location_id(LOCATION) - df["year_id"] = YEAR df["modelable_entity_id"] = DUMMY_INT df["measure_id"] = 19 # continuous df["metric_id"] = 3 # rate @@ -489,9 +574,11 @@ def get_mocked_exposure_sd_get_draws() -> pd.DataFrame: return df -def get_mocked_exposure_distribution_weights_get_draws() -> pd.DataFrame: +def get_mocked_exposure_distribution_weights_get_draws( + locations: list[int], years: list[int] +) -> pd.DataFrame: - # We simply copy/paste the data from the call here. + # We simply copy/paste the data from the call here (year 2021, location 163) return pd.DataFrame( { "exp": 0.0012511270939698, @@ -517,7 +604,7 @@ def get_mocked_exposure_distribution_weights_get_draws() -> pd.DataFrame: ) -def get_mocked_pafs_get_draws() -> pd.DataFrame: +def get_mocked_pafs_get_draws(locations: list[int], years: list[int]) -> pd.DataFrame: age_bins = get_mocked_age_bins() age_group_ids = list(age_bins["age_group_id"]) sex_ids = [1, 2] @@ -525,13 +612,11 @@ def get_mocked_pafs_get_draws() -> pd.DataFrame: # Initiate df with all possible combinations of variable metadata columns df = pd.DataFrame( - list(itertools.product(age_group_ids, sex_ids, measure_ids)), - columns=["age_group_id", "sex_id", "measure_id"], + list(itertools.product(age_group_ids, sex_ids, measure_ids, locations, years)), + columns=["age_group_id", "sex_id", "measure_id", "location_id", "year_id"], ) # Add on other metadata columns - df["location_id"] = utility_data.get_location_id(LOCATION) - df["year_id"] = YEAR df["cause_id"] = 495 # Needs to be a valid cause_id df["metric_id"] = 2 # percent df["version_id"] = DUMMY_INT @@ -541,23 +626,28 @@ def get_mocked_pafs_get_draws() -> pd.DataFrame: return df -def get_mocked_rr_get_draws(entity) -> pd.DataFrame: +def get_mocked_rr_get_draws(entity, locations: list[int], years: list[int]) -> pd.DataFrame: age_bins = get_mocked_age_bins() if entity.name == "high_systolic_blood_pressure": # high sbp is only for >=25 years age_bins = age_bins[age_bins["age_group_years_start"] >= 25] age_group_ids = list(age_bins["age_group_id"]) sex_ids = [1, 2] + # HACK: Relative risks are binned by year. If `years`` is a full continuous range of + # the estimation years, then likely the user requested data for "all" years. + # If this is the case, we set years to the estimation years. + estimation_years = get_mocked_estimation_years() + if years == list(range(min(estimation_years), max(estimation_years) + 1)): + years = estimation_years # Initiate df with all possible combinations of variable metadata columns df = pd.DataFrame( - list(itertools.product(age_group_ids, sex_ids)), - columns=["age_group_id", "sex_id"], + list(itertools.product(age_group_ids, sex_ids, years)), + columns=["age_group_id", "sex_id", "year_id"], ) # Add on other metadata columns df["location_id"] = 1 # Most relative risks are global - df["year_id"] = YEAR df["modelable_entity_id"] = DUMMY_INT df["cause_id"] = 495 # Needs to be a valid cause_id df["mortality"] = 1 @@ -571,27 +661,30 @@ def get_mocked_rr_get_draws(entity) -> pd.DataFrame: return df -def get_mocked_estimate_get_draws() -> pd.DataFrame: +def get_mocked_estimate_get_draws(locations: list[int], years: list[int]) -> pd.DataFrame: age_group_ids = [27] sex_ids = [1, 2] # Initiate df with all possible combinations of variable metadata columns df = pd.DataFrame( - list(itertools.product(age_group_ids, sex_ids)), - columns=["age_group_id", "sex_id"], + list(itertools.product(age_group_ids, sex_ids, locations, years)), + columns=["age_group_id", "sex_id", "location_id", "year_id"], ) # Add on other metadata columns - df["location_id"] = utility_data.get_location_id(LOCATION) - df["location_name"] = LOCATION - df["year_id"] = YEAR + location_ids = ( + get_mocked_location_ids()[["location_id", "location_name"]] + .set_index("location_id") + .squeeze() + ) + df["location_name"] = df["location_id"].map(location_ids) df["covariate_id"] = DUMMY_INT df["model_version_id"] = DUMMY_INT df["age_group_name"] = DUMMY_STR df["sex"] = DUMMY_STR # Estimates don't play by the rules - df["mean_value"] = [DUMMY_FLOAT, DUMMY_FLOAT] + df["mean_value"] = DUMMY_FLOAT df["lower_value"] = 0.9 * df["mean_value"] df["upper_value"] = 1.1 * df["mean_value"] @@ -603,22 +696,32 @@ def get_mocked_estimate_get_draws() -> pd.DataFrame: ########################### -def mocked_get_outputs(measure: str, entity, **entity_specific_metadata) -> pd.DataFrame: +def mocked_get_outputs( + measure: str, + entity: ModelableEntity, + locations: list[int], + years: list[int], + **entity_specific_metadata, +) -> pd.DataFrame: """Mocked vivarium_gbd_access get_outputs() data for testing.""" # Get the common data for the specific measure (regardless of entity type) df = { "incidence_rate": get_mocked_incidence_rate_get_outputs, + "raw_incidence_rate": get_mocked_incidence_rate_get_outputs, # Same as non-raw "prevalence": get_mocked_prevalence_get_outputs, - }[measure]() + }[measure](locations, years) # Add on common metadata (note that this may overwrite existing columns, e.g. # from loading a population static file) - df["location_id"] = utility_data.get_location_id(LOCATION) - df["year_id"] = YEAR df["expected"] = False - df["location_name"] = "India" - df["location_type"] = "admin0" + location_ids = ( + get_mocked_location_ids()[["location_id", "location_name"]] + .set_index("location_id") + .squeeze() + ) + df["location_name"] = df["location_id"].map(location_ids) + df["location_type"] = "admin0" # brittle age_bins = get_mocked_age_bins() df["age_group_name"] = df["age_group_id"].map( dict(age_bins[["age_group_id", "age_group_name"]].values) @@ -636,15 +739,17 @@ def mocked_get_outputs(measure: str, entity, **entity_specific_metadata) -> pd.D return df -def get_mocked_incidence_rate_get_outputs() -> pd.DataFrame: +def get_mocked_incidence_rate_get_outputs( + locations: list[int], years: list[int] +) -> pd.DataFrame: age_bins = get_mocked_age_bins() age_group_ids = list(age_bins["age_group_id"]) sex_ids = [1, 2] # Initiate df with all possible combinations of variable metadata columns df = pd.DataFrame( - list(itertools.product(age_group_ids, sex_ids)), - columns=["age_group_id", "sex_id"], + list(itertools.product(age_group_ids, sex_ids, locations, years)), + columns=["age_group_id", "sex_id", "location_id", "year_id"], ) # Add on other metadata columns @@ -659,15 +764,15 @@ def get_mocked_incidence_rate_get_outputs() -> pd.DataFrame: return df -def get_mocked_prevalence_get_outputs() -> pd.DataFrame: +def get_mocked_prevalence_get_outputs(locations: list[int], years: list[int]) -> pd.DataFrame: age_bins = get_mocked_age_bins() age_group_ids = list(age_bins["age_group_id"]) sex_ids = [1, 2] # Initiate df with all possible combinations of variable metadata columns df = pd.DataFrame( - list(itertools.product(age_group_ids, sex_ids)), - columns=["age_group_id", "sex_id"], + list(itertools.product(age_group_ids, sex_ids, locations, years)), + columns=["age_group_id", "sex_id", "location_id", "year_id"], ) # Add on other metadata columns diff --git a/tests/unit/test_core.py b/tests/unit/test_core.py index afc83be1..57e67170 100644 --- a/tests/unit/test_core.py +++ b/tests/unit/test_core.py @@ -1,40 +1,20 @@ import pytest from gbd_mapping import ModelableEntity, causes, covariates, risk_factors +from pytest_mock import MockerFixture -from tests.conftest import NO_GBD_ACCESS -from vivarium_inputs import core, utility_data -from vivarium_inputs.utilities import DataType, DataTypeNotImplementedError - -pytestmark = pytest.mark.skipif( - NO_GBD_ACCESS, reason="Cannot run these tests without vivarium_gbd_access" +from tests.conftest import is_not_implemented +from tests.mocked_gbd import ( + MOST_RECENT_YEAR, + get_mocked_location_ids, + mock_vivarium_gbd_access, ) +from vivarium_inputs import core +from vivarium_inputs.globals import Population +from vivarium_inputs.utilities import DataType, DataTypeNotImplementedError +LOCATION = "India" -def check_year_in_data(entity, measure, location, years, data_type): - if _is_not_implemented(data_type, measure): - with pytest.raises(DataTypeNotImplementedError): - data_type = DataType(measure, data_type) - core.get_data(entity, measure, location, years, data_type) - else: - data_type = DataType(measure, data_type) - if isinstance(years, list): - df = core.get_data(entity, measure, location, years, data_type) - assert set(df.index.get_level_values("year_id")) == set(years) - # years expected to be 1900, 2019, None, or "all" - elif years != 1900: - df = core.get_data(entity, measure, location, years, data_type) - if years == None: - assert set(df.index.get_level_values("year_id")) == set([2021]) - elif years == "all": - assert set(df.index.get_level_values("year_id")) == set(range(1990, 2023)) - else: # a single (non-1900) year - assert set(df.index.get_level_values("year_id")) == set([years]) - else: - with pytest.raises(ValueError, match="years must be in"): - core.get_data(entity, measure, location, years, data_type) - - -ENTITIES_C = [ +CAUSES = [ ( causes.measles, [ @@ -73,7 +53,7 @@ def check_year_in_data(entity, measure, location, years, data_type): ], ), ] -MEASURES_C = [ +CAUSE_MEASURES = [ "incidence_rate", "raw_incidence_rate", "prevalence", @@ -84,25 +64,36 @@ def check_year_in_data(entity, measure, location, years, data_type): "excess_mortality_rate", "deaths", ] -LOCATIONS_C = ["India"] -@pytest.mark.parametrize("entity_details", ENTITIES_C, ids=lambda x: x[0].name) -@pytest.mark.parametrize("measure", MEASURES_C, ids=lambda x: x) -@pytest.mark.parametrize("location", LOCATIONS_C, ids=lambda x: x) +@pytest.mark.parametrize("entity_details", CAUSES, ids=lambda x: x[0].name) +@pytest.mark.parametrize("measure", CAUSE_MEASURES, ids=lambda x: x) @pytest.mark.parametrize( "years", [None, 2019, 1900, [2019], [2019, 2020], "all"], ids=lambda x: str(x) ) @pytest.mark.parametrize( "data_type", ["draws", "means", ["draws", "means"]], ids=lambda x: str(x) ) -def test_year_id_causelike(entity_details, measure, location, years, data_type): +def test_year_id_causelike( + entity_details: tuple[ModelableEntity, list[str]], + measure: str, + years: int | list[int] | str | None, + data_type: str | list[str], + mocker: MockerFixture, +): + if years == "all" and measure == "remission_rate": + # remission rates and relative risks have binned years and for + # "all" years will use central comp's `core-maths` library which, like + # vivarium_gbd_access, is hosted on bitbucket and so cannot be accessed + # from github-actions. + pytest.skip("Expected to fail - see test_xfailed test") + entity, entity_expected_measures = entity_details if measure in entity_expected_measures: - check_year_in_data(entity, measure, location, years, data_type) + check_year_in_data(entity, measure, LOCATION, years, data_type, mocker) -ENTITIES_R = [ +RISKS = [ ( risk_factors.high_systolic_blood_pressure, [ @@ -122,51 +113,95 @@ def test_year_id_causelike(entity_details, measure, location, years, data_type): ], ), ] -MEASURES_R = [ +RISK_MEASURES = [ "exposure", "exposure_standard_deviation", "exposure_distribution_weights", "relative_risk", "population_attributable_fraction", ] -LOCATIONS_R = ["India"] -@pytest.mark.parametrize("entity_details", ENTITIES_R, ids=lambda x: x[0].name) -@pytest.mark.parametrize("measure", MEASURES_R, ids=lambda x: x) -@pytest.mark.parametrize("location", LOCATIONS_R, ids=lambda x: x) +@pytest.mark.parametrize("entity_details", RISKS, ids=lambda x: x[0].name) +@pytest.mark.parametrize("measure", RISK_MEASURES, ids=lambda x: x) @pytest.mark.parametrize( "years", [None, 2019, 1900, [2019], [2019, 2020], "all"], ids=lambda x: str(x) ) @pytest.mark.parametrize( "data_type", ["draws", "means", ["draws", "means"]], ids=lambda x: str(x) ) -def test_year_id_risklike(entity_details, measure, location, years, data_type): - if measure in ["relative_risk", "population_attributable_fraction"]: - pytest.skip("TODO: mic-5245 punting until later b/c these are soooo slow") +def test_year_id_risklike( + entity_details: tuple[ModelableEntity, list[str]], + measure: str, + years: int | list[int] | str | None, + data_type: str | list[str], + mocker: MockerFixture, +): + if years == "all" and measure == "relative_risk": + # remission rates and relative risks have binned years and for + # "all" years will use central comp's `core-maths` library which, like + # vivarium_gbd_access, is hosted on bitbucket and so cannot be accessed + # from github-actions. + pytest.skip("Expected to fail - see test_xfailed test") + if ( + measure == "relative_risk" + and entity_details[0].name == "high_systolic_blood_pressure" + and data_type == "draws" + ): + pytest.skip("FIXME: [mic-5542] continuous rrs cannot validate") entity, entity_expected_measures = entity_details if measure in entity_expected_measures: - check_year_in_data(entity, measure, location, years, data_type) + check_year_in_data(entity, measure, LOCATION, years, data_type, mocker) + + +@pytest.mark.xfail(raises=ModuleNotFoundError, reason="Cannot import core-maths", strict=True) +@pytest.mark.parametrize( + "entity, measure", + [ + [causes.diarrheal_diseases, "remission_rate"], + [risk_factors.high_systolic_blood_pressure, "relative_risk"], + [risk_factors.low_birth_weight_and_short_gestation, "relative_risk"], + ], +) +def test_xfailed( + entity: ModelableEntity, + measure: str, + mocker: MockerFixture, +): + """We expect failures when trying to interpolate 'all' years for binned measures + + Notes + ----- + These test parameterizations are a subset of others in this test module + that are simply marked as 'skip'. + """ + if measure == "relative_risk" and entity.name == "high_systolic_blood_pressure": + pytest.skip("FIXME: [mic-5542] continuous rrs cannot validate") + check_year_in_data(entity, measure, LOCATION, "all", "draws", mocker) -ENTITIES_COV = [ +COVARIATES = [ covariates.systolic_blood_pressure_mmhg, ] -MEASURES_COV = ["estimate"] -LOCATIONS_COV = ["India"] +COVARIATE_MEASURES = ["estimate"] -@pytest.mark.parametrize("entity", ENTITIES_COV, ids=lambda x: x.name) -@pytest.mark.parametrize("measure", MEASURES_COV, ids=lambda x: x) -@pytest.mark.parametrize("location", LOCATIONS_COV, ids=lambda x: x) +@pytest.mark.parametrize("entity", COVARIATES, ids=lambda x: x.name) +@pytest.mark.parametrize("measure", COVARIATE_MEASURES, ids=lambda x: x) @pytest.mark.parametrize( "years", [None, 2019, 1900, [2019], [2019, 2020], "all"], ids=lambda x: str(x) ) @pytest.mark.parametrize( "data_type", ["draws", "means", ["draws", "means"]], ids=lambda x: str(x) ) -def test_year_id_covariatelike(entity, measure, location, years, data_type): - check_year_in_data(entity, measure, location, years, data_type) +def test_year_id_covariatelike( + entity: ModelableEntity, + measure: str, + years: int | list[int] | str | None, + data_type: str | list[str], + mocker: MockerFixture, +): + check_year_in_data(entity, measure, LOCATION, years, data_type, mocker) @pytest.mark.parametrize("measure", ["structure", "demographic_dimensions"], ids=lambda x: x) @@ -176,14 +211,18 @@ def test_year_id_covariatelike(entity, measure, location, years, data_type): @pytest.mark.parametrize( "data_type", ["draws", "means", ["draws", "means"]], ids=lambda x: str(x) ) -def test_year_id_population(measure, years, data_type): - pop = ModelableEntity("ignored", "population", None) - location = utility_data.get_location_id("India") - check_year_in_data(pop, measure, location, years, data_type) +def test_year_id_population( + measure: str, + years: int | list[int] | str | None, + data_type: str | list[str], + mocker: MockerFixture, +): + pop = Population() + check_year_in_data(pop, measure, LOCATION, years, data_type, mocker) -@pytest.mark.parametrize("entity_details", ENTITIES_C, ids=lambda x: x[0].name) -@pytest.mark.parametrize("measure", MEASURES_C, ids=lambda x: x) +@pytest.mark.parametrize("entity_details", CAUSES, ids=lambda x: x[0].name) +@pytest.mark.parametrize("measure", CAUSE_MEASURES, ids=lambda x: x) @pytest.mark.parametrize( "locations", [ @@ -198,24 +237,40 @@ def test_year_id_population(measure, years, data_type): @pytest.mark.parametrize( "data_type", ["draws", "means", ["draws", "means"]], ids=lambda x: str(x) ) -def test_multiple_locations_causelike(entity_details, measure, locations, data_type): - year = 2021 +def test_multiple_locations_causelike( + entity_details: tuple[ModelableEntity, list[str]], + measure: str, + locations: str | int | list[str | int], + data_type: str | list[str], + mocker: MockerFixture, +): + year = MOST_RECENT_YEAR location_id_mapping = { "Ethiopia": 179, "Nigeria": 214, } entity, entity_expected_measures = entity_details - if _is_not_implemented(data_type, measure): + if is_not_implemented(data_type, measure): with pytest.raises(DataTypeNotImplementedError): data_type = DataType(measure, data_type) + mocker.patch( + "vivarium_inputs.utility_data.get_raw_location_ids", + return_value=get_mocked_location_ids(), + ) core.get_data(entity, measure, locations, year, data_type) else: - data_type = DataType(measure, data_type) if measure not in entity_expected_measures: with pytest.raises(Exception): + data_type = DataType(measure, data_type) core.get_data(entity, measure, locations, year, data_type) else: + mocked_funcs = mock_vivarium_gbd_access( + entity, measure, locations, year, data_type, mocker + ) + data_type = DataType(measure, data_type) df = core.get_data(entity, measure, locations, year, data_type) + for mocked_func in mocked_funcs: + assert mocked_func.called_once() if not isinstance(locations, list): locations = [locations] location_ids = { @@ -232,21 +287,48 @@ def test_multiple_locations_causelike(entity_details, measure, locations, data_t # TODO: Should we add the location tests for other entity types? -def _is_not_implemented(data_type: str | list[str], measure: str) -> bool: - return isinstance(data_type, list) or ( - data_type == "means" - and measure - in [ - "disability_weight", - "remission_rate", - "cause_specific_mortality_rate", - "excess_mortality_rate", - "deaths", - "exposure", - "low_birth_weight_and_short_gestation", - "exposure_standard_deviation", - "exposure_distribution_weights", - "estimate", - "structure", - ] - ) +#################### +# Helper functions # +#################### + + +def check_year_in_data( + entity: ModelableEntity, + measure: str, + location: str, + years: int | list[int] | str | None, + data_type: str | list[str], + mocker: MockerFixture, +): + if is_not_implemented(data_type, measure): + with pytest.raises(DataTypeNotImplementedError): + data_type = DataType(measure, data_type) + mocker.patch( + "vivarium_inputs.utility_data.get_raw_location_ids", + return_value=get_mocked_location_ids(), + ) + core.get_data(entity, measure, location, years, data_type) + else: + mocked_funcs = mock_vivarium_gbd_access( + entity, measure, location, years, data_type, mocker + ) + data_type = DataType(measure, data_type) + if isinstance(years, list): + df = core.get_data(entity, measure, location, years, data_type) + assert set(df.index.get_level_values("year_id")) == set(years) + for mocked_func in mocked_funcs: + assert mocked_func.called_once() + # years expected to be 1900, 2019, None, or "all" + elif years != 1900: + df = core.get_data(entity, measure, location, years, data_type) + if years == None: + assert set(df.index.get_level_values("year_id")) == set([MOST_RECENT_YEAR]) + elif years == "all": + assert set(df.index.get_level_values("year_id")) == set(range(1990, 2023)) + else: # a single (non-1900) year + assert set(df.index.get_level_values("year_id")) == set([years]) + for mocked_func in mocked_funcs: + assert mocked_func.called_once() + else: + with pytest.raises(ValueError, match="years must be in"): + core.get_data(entity, measure, location, years, data_type)