Skip to content

Commit

Permalink
Sbachmei/mic 5668/bugfix handle empty years for means (#387)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevebachmeier authored Dec 26, 2024
1 parent 19d920e commit 69687db
Show file tree
Hide file tree
Showing 6 changed files with 283 additions and 20 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
**5.2.3 - 12/26/24**

- Bugfix: better handle 'all' years when requesting mean data

**5.2.2 - 11/13/24**

- Bugfix to implement data type requests for remaining interface tests
Expand Down
27 changes: 13 additions & 14 deletions src/vivarium_inputs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ def get_raw_incidence_rate(
data = utilities.filter_data_by_restrictions(
data, restrictions_entity, "yld", utility_data.get_age_group_ids()
)
data = utilities.normalize(data, data_type.value_columns, fill_value=0)
data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns)
data = utilities.normalize(data, data_type.value_columns, fill_value=0)
return data


Expand All @@ -197,8 +197,8 @@ def get_prevalence(
data = utilities.filter_data_by_restrictions(
data, restrictions_entity, "yld", utility_data.get_age_group_ids()
)
data = utilities.normalize(data, data_type.value_columns, fill_value=0)
data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns)
data = utilities.normalize(data, data_type.value_columns, fill_value=0)
return data


Expand Down Expand Up @@ -267,13 +267,13 @@ def get_disability_weight(
data = extract.extract_data(
entity, "disability_weight", location_id, years, data_type
)
data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns)
data = utilities.normalize(data, data_type.value_columns)

cause = [c for c in causes if c.sequelae and entity in c.sequelae][0]
data = utilities.clear_disability_weight_outside_restrictions(
data, cause, 0.0, utility_data.get_age_group_ids()
)
data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns)
except (IndexError, DataDoesNotExistError):
logger.warning(
f"{entity.name.capitalize()} has no disability weight data. All values will be 0."
Expand All @@ -300,8 +300,8 @@ def get_remission_rate(
data = utilities.filter_data_by_restrictions(
data, entity, "yld", utility_data.get_age_group_ids()
)
data = utilities.normalize(data, data_type.value_columns, fill_value=0)
data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns)
data = utilities.normalize(data, data_type.value_columns, fill_value=0)
return data


Expand Down Expand Up @@ -381,8 +381,8 @@ def get_deaths(
data = utilities.filter_data_by_restrictions(
data, entity, "yll", utility_data.get_age_group_ids()
)
data = utilities.normalize(data, data_type.value_columns, fill_value=0)
data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns)
data = utilities.normalize(data, data_type.value_columns, fill_value=0)
return data


Expand Down Expand Up @@ -413,6 +413,8 @@ def get_exposure(
data, entity, "outer", utility_data.get_age_group_ids()
)

data = data.filter(DEMOGRAPHIC_COLUMNS + value_columns + ["parameter"])

if entity.distribution in ["dichotomous", "ordered_polytomous", "unordered_polytomous"]:
tmrel_cat = utility_data.get_tmrel_category(entity)
exposed = data[data.parameter != tmrel_cat]
Expand All @@ -437,7 +439,6 @@ def get_exposure(
)
else:
data = utilities.normalize(data, value_columns, fill_value=0)
data = data.filter(DEMOGRAPHIC_COLUMNS + value_columns + ["parameter"])
return data


Expand All @@ -462,8 +463,8 @@ def get_exposure_standard_deviation(
valid_age_groups = utilities.get_exposure_and_restriction_ages(exposure, entity)
data = data[data.age_group_id.isin(valid_age_groups)]

data = utilities.normalize(data, data_type.value_columns, fill_value=0)
data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns)
data = utilities.normalize(data, data_type.value_columns, fill_value=0)
return data


Expand Down Expand Up @@ -493,6 +494,7 @@ def get_exposure_distribution_weights(
copied["age_group_id"] = age_id
df.append(copied)
data = pd.concat(df)
data = data.filter(DEMOGRAPHIC_COLUMNS + DISTRIBUTION_COLUMNS)
data = utilities.normalize(data, DISTRIBUTION_COLUMNS, fill_value=0)
if years != "all":
if years:
Expand All @@ -501,7 +503,6 @@ def get_exposure_distribution_weights(
else:
most_recent_year = utility_data.get_most_recent_year()
data = data.query(f"year_id=={most_recent_year}")
data = data.filter(DEMOGRAPHIC_COLUMNS + DISTRIBUTION_COLUMNS)
data = utilities.wide_to_long(data, DISTRIBUTION_COLUMNS, var_name="parameter")
return data

Expand All @@ -512,7 +513,6 @@ def get_relative_risk(
years: int | str | list[int] | None,
data_type: utilities.DataType,
) -> pd.DataFrame:

if data_type.type != "draws":
raise utilities.DataTypeNotImplementedError(
f"Data type(s) {data_type.type} are not supported for this function."
Expand All @@ -525,7 +525,6 @@ def get_relative_risk(
)

data = extract.extract_data(entity, "relative_risk", location_id, years, data_type)

# FIXME: we don't currently support yll-only causes so I'm dropping them because the data in some cases is
# very messed up, with mort = morb = 1 (e.g., aortic aneurysm in the RR data for high systolic bp) -
# 2/8/19 K.W.
Expand Down Expand Up @@ -629,14 +628,14 @@ def get_population_attributable_fraction(
data["measure_id"] == MEASURES["YLLs"], "affected_measure"
] = "excess_mortality_rate"
data.loc[data["measure_id"] == MEASURES["YLDs"], "affected_measure"] = "incidence_rate"
data = data.filter(
DEMOGRAPHIC_COLUMNS + ["affected_entity", "affected_measure"] + value_columns
)
data = (
data.groupby(["affected_entity", "affected_measure"])
.apply(utilities.normalize, cols_to_fill=value_columns, fill_value=0)
.reset_index(drop=True)
)
data = data.filter(
DEMOGRAPHIC_COLUMNS + ["affected_entity", "affected_measure"] + value_columns
)
return data


Expand Down Expand Up @@ -674,8 +673,8 @@ def get_utilization_rate(
data_type: utilities.DataType,
) -> pd.DataFrame:
data = extract.extract_data(entity, "utilization_rate", location_id, years, data_type)
data = utilities.normalize(data, data_type.value_columns, fill_value=0)
data = data.filter(DEMOGRAPHIC_COLUMNS + data_type.value_columns)
data = utilities.normalize(data, data_type.value_columns, fill_value=0)
return data


Expand Down
13 changes: 10 additions & 3 deletions src/vivarium_inputs/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def extract_data(

validation.check_metadata(entity, measure)

year_id = _get_year_id(years)
year_id = _get_year_id(years, data_type)

try:
main_extractor, additional_extractors = extractors[measure]
Expand Down Expand Up @@ -146,6 +146,10 @@ def extract_data(
extra_draw_cols = [col for col in existing_draw_cols if col not in DRAW_COLUMNS]
data = data.drop(columns=extra_draw_cols, errors="ignore")

# drop get_outputs data earlier than the estimation years
if data_type.type == "means":
data = data.loc[data["year_id"] >= min(utility_data.get_estimation_years())]

if validate:
additional_data = {
name: extractor(entity, location_id, year_id, data_type)
Expand All @@ -167,11 +171,14 @@ def extract_data(
####################


def _get_year_id(years):
def _get_year_id(years, data_type):
if years is None: # default to most recent year
year_id = utility_data.get_most_recent_year()
elif years == "all":
year_id = None
if data_type.type == "draws":
year_id = None
else: # means
year_id = "full"
else:
year_id = years if isinstance(years, list) else [years]
estimation_years = utility_data.get_estimation_years()
Expand Down
5 changes: 3 additions & 2 deletions src/vivarium_inputs/validation/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ class RawValidationContext:
def __init__(self, location_id, **additional_data):
self.context_data = {"location_id": location_id}
self.context_data.update(additional_data)

if "estimation_years" not in self.context_data:
if "estimation_years" not in self.context_data or self.context_data[
"estimation_years"
] == ["full"]:
self.context_data["estimation_years"] = utility_data.get_estimation_years()
if "age_group_ids" not in self.context_data:
self.context_data["age_group_ids"] = utility_data.get_age_group_ids()
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/test_get_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def test_get_measure_risklike(
and not mock_gbd
and data_type == "draws"
):
pytest.skip("FIXME: [mic-5543] continuous rrs cannot validate")
pytest.skip("FIXME: [mic-5542] continuous rrs cannot validate")

# Handle not implemented
is_unimplemented = isinstance(data_type, list) or data_type == "means"
Expand Down
Loading

0 comments on commit 69687db

Please sign in to comment.