Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
hussain-jafari committed Feb 17, 2024
1 parent 904c340 commit 18fe82a
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 40 deletions.
76 changes: 57 additions & 19 deletions src/vivarium_inputs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import pandas as pd
from gbd_mapping import Cause, Covariate, Etiology, RiskFactor, Sequela, causes
from loguru import logger

from vivarium_gbd_access.constants import MOST_RECENT_YEAR

from vivarium_inputs import extract, utilities, utility_data
from vivarium_inputs.globals import (
COVARIATE_VALUE_COLUMNS,
Expand Down Expand Up @@ -96,8 +96,12 @@ def get_data(entity, measure: str, location: Union[str, int], get_all_years: boo
return data


def get_raw_incidence_rate(entity: Union[Cause, Sequela], location_id: int, get_all_years: bool = False) -> pd.DataFrame:
data = extract.extract_data(entity, "incidence_rate", location_id, validate=True, get_all_years=get_all_years)
def get_raw_incidence_rate(
entity: Union[Cause, Sequela], location_id: int, get_all_years: bool = False
) -> pd.DataFrame:
data = extract.extract_data(
entity, "incidence_rate", location_id, validate=True, get_all_years=get_all_years
)
if entity.kind == "cause":
restrictions_entity = entity
else: # sequela
Expand All @@ -112,16 +116,22 @@ def get_raw_incidence_rate(entity: Union[Cause, Sequela], location_id: int, get_
return data


def get_incidence_rate(entity: Union[Cause, Sequela], location_id: int, get_all_years: bool = False) -> pd.DataFrame:
def get_incidence_rate(
entity: Union[Cause, Sequela], location_id: int, get_all_years: bool = False
) -> pd.DataFrame:
data = get_data(entity, "raw_incidence_rate", location_id, get_all_years)
prevalence = get_data(entity, "prevalence", location_id)
# Convert from "True incidence" to the incidence rate among susceptibles
data /= 1 - prevalence
return data.fillna(0)


def get_prevalence(entity: Union[Cause, Sequela], location_id: int, get_all_years: bool = False) -> pd.DataFrame:
data = extract.extract_data(entity, "prevalence", location_id, validate=True, get_all_years=get_all_years)
def get_prevalence(
entity: Union[Cause, Sequela], location_id: int, get_all_years: bool = False
) -> pd.DataFrame:
data = extract.extract_data(
entity, "prevalence", location_id, validate=True, get_all_years=get_all_years
)
if entity.kind == "cause":
restrictions_entity = entity
else: # sequela
Expand All @@ -136,14 +146,20 @@ def get_prevalence(entity: Union[Cause, Sequela], location_id: int, get_all_year
return data


def get_birth_prevalence(entity: Union[Cause, Sequela], location_id: int, get_all_years: bool = False) -> pd.DataFrame:
data = extract.extract_data(entity, "birth_prevalence", location_id, validate=True, get_all_years=get_all_years)
def get_birth_prevalence(
entity: Union[Cause, Sequela], location_id: int, get_all_years: bool = False
) -> pd.DataFrame:
data = extract.extract_data(
entity, "birth_prevalence", location_id, validate=True, get_all_years=get_all_years
)
data = data.filter(["year_id", "sex_id", "location_id"] + DRAW_COLUMNS)
data = utilities.normalize(data, fill_value=0)
return data


def get_disability_weight(entity: Union[Cause, Sequela], location_id: int, get_all_years: bool = False) -> pd.DataFrame:
def get_disability_weight(
entity: Union[Cause, Sequela], location_id: int, get_all_years: bool = False
) -> pd.DataFrame:
if entity.kind == "cause":
data = utility_data.get_demographic_dimensions(location_id, draws=True, value=0.0)
if not get_all_years:
Expand All @@ -154,23 +170,35 @@ def get_disability_weight(entity: Union[Cause, Sequela], location_id: int, get_a
if entity.sequelae:
for sequela in entity.sequelae:
try:
prevalence = get_data(sequela, "prevalence", location_id, get_all_years=get_all_years)
prevalence = get_data(
sequela, "prevalence", location_id, get_all_years=get_all_years
)
except DataDoesNotExistError:
# sequela prevalence does not exist so no point continuing with this sequela
continue
disability = get_data(sequela, "disability_weight", location_id, get_all_years=get_all_years)
disability = get_data(
sequela, "disability_weight", location_id, get_all_years=get_all_years
)
disability.index = disability.index.set_levels(
[location_id], level="location_id"
)
data += prevalence * disability
cause_prevalence = get_data(entity, "prevalence", location_id, get_all_years=get_all_years)
cause_prevalence = get_data(
entity, "prevalence", location_id, get_all_years=get_all_years
)
data = (data / cause_prevalence).fillna(0).reset_index()
else: # entity.kind == 'sequela'
try:
data = extract.extract_data(entity, "disability_weight", location_id, validate=True, get_all_years=get_all_years)
data = extract.extract_data(
entity,
"disability_weight",
location_id,
validate=True,
get_all_years=get_all_years,
)
# add year id with single year so normalization doesn't fill in all years
if not get_all_years:
data['year_id'] = MOST_RECENT_YEAR
data["year_id"] = MOST_RECENT_YEAR
data = utilities.normalize(data)

cause = [c for c in causes if c.sequelae and entity in c.sequelae][0]
Expand All @@ -186,8 +214,12 @@ def get_disability_weight(entity: Union[Cause, Sequela], location_id: int, get_a
return data


def get_remission_rate(entity: Cause, location_id: int, get_all_years: bool = False) -> pd.DataFrame:
data = extract.extract_data(entity, "remission_rate", location_id, validate=True, get_all_years=get_all_years)
def get_remission_rate(
entity: Cause, location_id: int, get_all_years: bool = False
) -> pd.DataFrame:
data = extract.extract_data(
entity, "remission_rate", location_id, validate=True, get_all_years=get_all_years
)
data = utilities.filter_data_by_restrictions(
data, entity, "yld", utility_data.get_age_group_ids()
)
Expand All @@ -196,15 +228,21 @@ def get_remission_rate(entity: Cause, location_id: int, get_all_years: bool = Fa
return data


def get_cause_specific_mortality_rate(entity: Cause, location_id: int, get_all_years: bool = False) -> pd.DataFrame:
deaths = get_data(entity, "deaths", location_id, get_all_years) # population isn't by draws
def get_cause_specific_mortality_rate(
entity: Cause, location_id: int, get_all_years: bool = False
) -> pd.DataFrame:
deaths = get_data(
entity, "deaths", location_id, get_all_years
) # population isn't by draws
pop = get_data(Population(), "structure", location_id, get_all_years)
data = deaths.join(pop, lsuffix="_deaths", rsuffix="_pop")
data[DRAW_COLUMNS] = data[DRAW_COLUMNS].divide(data.value, axis=0)
return data.drop(["value"], axis="columns")


def get_excess_mortality_rate(entity: Cause, location_id: int, get_all_years: bool = False) -> pd.DataFrame:
def get_excess_mortality_rate(
entity: Cause, location_id: int, get_all_years: bool = False
) -> pd.DataFrame:
csmr = get_data(entity, "cause_specific_mortality_rate", location_id, get_all_years)
prevalence = get_data(entity, "prevalence", location_id, get_all_years)
data = (csmr / prevalence).fillna(0)
Expand Down
47 changes: 35 additions & 12 deletions src/vivarium_inputs/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import pandas as pd
from gbd_mapping import Cause, Covariate, Etiology, RiskFactor, Sequela

from vivarium_gbd_access.constants import MOST_RECENT_YEAR

import vivarium_inputs.validation.raw as validation
from vivarium_inputs.globals import (
MEASURES,
Expand Down Expand Up @@ -126,43 +126,64 @@ def extract_data(
for name, extractor in additional_extractors.items()
}
if not get_all_years:
additional_data['estimation_years'] = [MOST_RECENT_YEAR]
additional_data["estimation_years"] = [MOST_RECENT_YEAR]
validation.validate_raw_data(data, entity, measure, location_id, **additional_data)

return data


def extract_prevalence(entity: Union[Cause, Sequela], location_id: int, get_all_years: bool = False) -> pd.DataFrame:
def extract_prevalence(
entity: Union[Cause, Sequela], location_id: int, get_all_years: bool = False
) -> pd.DataFrame:
data = gbd.get_incidence_prevalence(
entity_id=entity.gbd_id, location_id=location_id, entity_type=entity.kind, get_all_years=get_all_years,
entity_id=entity.gbd_id,
location_id=location_id,
entity_type=entity.kind,
get_all_years=get_all_years,
)
data = data[data.measure_id == MEASURES["Prevalence"]]
return data


def extract_incidence_rate(entity: Union[Cause, Sequela], location_id: int, get_all_years: bool = False) -> pd.DataFrame:
def extract_incidence_rate(
entity: Union[Cause, Sequela], location_id: int, get_all_years: bool = False
) -> pd.DataFrame:
data = gbd.get_incidence_prevalence(
entity_id=entity.gbd_id, location_id=location_id, entity_type=entity.kind, get_all_years=get_all_years,
entity_id=entity.gbd_id,
location_id=location_id,
entity_type=entity.kind,
get_all_years=get_all_years,
)
data = data[data.measure_id == MEASURES["Incidence rate"]]
return data


def extract_birth_prevalence(entity: Union[Cause, Sequela], location_id: int, get_all_years: bool = False) -> pd.DataFrame:
def extract_birth_prevalence(
entity: Union[Cause, Sequela], location_id: int, get_all_years: bool = False
) -> pd.DataFrame:
data = gbd.get_birth_prevalence(
entity_id=entity.gbd_id, location_id=location_id, entity_type=entity.kind, get_all_years=get_all_years,
entity_id=entity.gbd_id,
location_id=location_id,
entity_type=entity.kind,
get_all_years=get_all_years,
)
data = data[data.measure_id == MEASURES["Incidence rate"]]
return data


def extract_remission_rate(entity: Cause, location_id: int, get_all_years: bool = False) -> pd.DataFrame:
data = gbd.get_modelable_entity_draws(entity.me_id, location_id, get_all_years=get_all_years)
def extract_remission_rate(
entity: Cause, location_id: int, get_all_years: bool = False
) -> pd.DataFrame:
data = gbd.get_modelable_entity_draws(
entity.me_id, location_id, get_all_years=get_all_years
)
data = data[data.measure_id == MEASURES["Remission rate"]]
return data


def extract_disability_weight(entity: Sequela, location_id: int, get_all_years: bool = False) -> pd.DataFrame:
def extract_disability_weight(
entity: Sequela, location_id: int, get_all_years: bool = False
) -> pd.DataFrame:
disability_weights = gbd.get_auxiliary_data(
"disability_weight", entity.kind, "all", location_id
)
Expand All @@ -172,7 +193,9 @@ def extract_disability_weight(entity: Sequela, location_id: int, get_all_years:
return data


def extract_deaths(entity: Cause, location_id: int, get_all_years: bool = False) -> pd.DataFrame:
def extract_deaths(
entity: Cause, location_id: int, get_all_years: bool = False
) -> pd.DataFrame:
data = gbd.get_codcorrect_draws(entity.gbd_id, location_id, get_all_years=get_all_years)
data = data[data.measure_id == MEASURES["Deaths"]]
return data
Expand Down
4 changes: 3 additions & 1 deletion src/vivarium_inputs/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from vivarium_inputs.globals import Population


def get_measure(entity: ModelableEntity, measure: str, location: str, get_all_years: bool = False) -> pd.DataFrame:
def get_measure(
entity: ModelableEntity, measure: str, location: str, get_all_years: bool = False
) -> pd.DataFrame:
"""Pull GBD data for measure and entity and prep for simulation input,
including scrubbing all GBD conventions to replace IDs with meaningful
values or ranges and expanding over all demographic dimensions. To pull data
Expand Down
2 changes: 1 addition & 1 deletion src/vivarium_inputs/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def filter_data_by_restrictions(

start, end = get_age_group_ids_by_restriction(entity, which_age)
ages = get_restriction_age_ids(start, end, age_group_ids)
#ages = ages + [238, 388, 389]
# ages = ages + [238, 388, 389]
data = data[data.age_group_id.isin(ages)]
return data

Expand Down
2 changes: 1 addition & 1 deletion src/vivarium_inputs/utility_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_age_bins(*_, **__) -> pd.DataFrame:
age_bins = gbd.get_age_bins()[
["age_group_id", "age_group_name", "age_group_years_start", "age_group_years_end"]
].rename(columns={"age_group_years_start": "age_start", "age_group_years_end": "age_end"})
age_bins = age_bins.sort_values('age_start')
age_bins = age_bins.sort_values("age_start")
return age_bins


Expand Down
2 changes: 1 addition & 1 deletion src/vivarium_inputs/validation/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
causes,
)
from loguru import logger

from vivarium_gbd_access.gbd import get_age_bins

from vivarium_inputs import utility_data
from vivarium_inputs.globals import (
DEMOGRAPHIC_COLUMNS,
Expand Down
12 changes: 7 additions & 5 deletions src/vivarium_inputs/validation/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
Sequela,
causes,
)

from vivarium_gbd_access.constants import MOST_RECENT_YEAR

from vivarium_inputs import utilities, utility_data
from vivarium_inputs.globals import (
BOUNDARY_SPECIAL_CASES,
Expand Down Expand Up @@ -148,7 +148,9 @@ def validate_for_simulation(
if measure not in validators:
raise NotImplementedError()

context_args['years'] = pd.DataFrame({'year_start': MOST_RECENT_YEAR, 'year_end': MOST_RECENT_YEAR+1}, index=[0])
context_args["years"] = pd.DataFrame(
{"year_start": MOST_RECENT_YEAR, "year_end": MOST_RECENT_YEAR + 1}, index=[0]
)
context = SimulationValidationContext(location, **context_args)
validators[measure](data, entity, context)

Expand Down Expand Up @@ -1250,9 +1252,9 @@ def validate_theoretical_minimum_risk_life_expectancy(
error=DataTransformationError,
)
if not data.sort_values(by="age", ascending=False).value.is_monotonic:
raise DataTransformationError(
"Life expectancy data is not monotonically decreasing over age."
)
raise DataTransformationError(
"Life expectancy data is not monotonically decreasing over age."
)


def validate_age_bins(
Expand Down

0 comments on commit 18fe82a

Please sign in to comment.