Skip to content

Commit

Permalink
undo non-cause changes
Browse files Browse the repository at this point in the history
  • Loading branch information
hussain-jafari committed Feb 16, 2024
1 parent f2a4204 commit 0e9c28d
Showing 1 changed file with 24 additions and 32 deletions.
56 changes: 24 additions & 32 deletions src/vivarium_inputs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,11 @@ def get_deaths(entity: Cause, location_id: int, get_all_years: bool = False) ->
return data



def get_exposure(
entity: Union[RiskFactor, AlternativeRiskFactor], location_id: int, get_all_years: bool = False
entity: Union[RiskFactor, AlternativeRiskFactor], location_id: int
) -> pd.DataFrame:
data = extract.extract_data(entity, "exposure", location_id, validate=True, get_all_years=get_all_years)
draw_cols_to_drop = [f"draw_{i}" for i in range(500,1000)]
data = data.drop(draw_cols_to_drop, axis=1)
data = extract.extract_data(entity, "exposure", location_id)
data = data.drop("modelable_entity_id", "columns")

if entity.name in EXTRA_RESIDUAL_CATEGORY:
Expand Down Expand Up @@ -269,27 +268,26 @@ def get_exposure(


def get_exposure_standard_deviation(
entity: Union[RiskFactor, AlternativeRiskFactor], location_id: int, get_all_years: bool = False
entity: Union[RiskFactor, AlternativeRiskFactor], location_id: int
) -> pd.DataFrame:
data = extract.extract_data(entity, "exposure_standard_deviation", location_id, validate=True, get_all_years=get_all_years)
data = extract.extract_data(entity, "exposure_standard_deviation", location_id)
data = data.drop("modelable_entity_id", "columns")

exposure = extract.extract_data(entity, "exposure", location_id, validate=True, get_all_years=get_all_years)
exposure = extract.extract_data(entity, "exposure", location_id)
valid_age_groups = utilities.get_exposure_and_restriction_ages(exposure, entity)
data = data[data.age_group_id.isin(valid_age_groups)]
draw_cols_to_drop = [f"draw_{i}" for i in range(500,1000)]
data = data.drop(draw_cols_to_drop, axis=1)

data = utilities.normalize(data, fill_value=0)
data = data.filter(DEMOGRAPHIC_COLUMNS + DRAW_COLUMNS)
return data


def get_exposure_distribution_weights(
entity: Union[RiskFactor, AlternativeRiskFactor], location_id: int, get_all_years: bool = False
entity: Union[RiskFactor, AlternativeRiskFactor], location_id: int
) -> pd.DataFrame:
data = extract.extract_data(entity, "exposure_distribution_weights", location_id, validate=True, get_all_years=get_all_years)
data = extract.extract_data(entity, "exposure_distribution_weights", location_id)

exposure = extract.extract_data(entity, "exposure", location_id, validate=True, get_all_years=get_all_years)
exposure = extract.extract_data(entity, "exposure", location_id)
valid_ages = utilities.get_exposure_and_restriction_ages(exposure, entity)

data.drop("age_group_id", axis=1, inplace=True)
Expand All @@ -300,8 +298,6 @@ def get_exposure_distribution_weights(
df.append(copied)
data = pd.concat(df)
data = utilities.normalize(data, fill_value=0, cols_to_fill=DISTRIBUTION_COLUMNS)
if not get_all_years:
data = data.query("year_id==@MOST_RECENT_YEAR")
data = data.filter(DEMOGRAPHIC_COLUMNS + DISTRIBUTION_COLUMNS)
data = utilities.wide_to_long(data, DISTRIBUTION_COLUMNS, var_name="parameter")
return data
Expand Down Expand Up @@ -330,8 +326,8 @@ def filter_relative_risk_to_cause_restrictions(data: pd.DataFrame) -> pd.DataFra
return data


def get_relative_risk(entity: RiskFactor, location_id: int, get_all_years: bool = False) -> pd.DataFrame:
data = extract.extract_data(entity, "relative_risk", location_id, validate=True, get_all_years=get_all_years)
def get_relative_risk(entity: RiskFactor, location_id: int) -> pd.DataFrame:
data = extract.extract_data(entity, "relative_risk", location_id)

# FIXME: we don't currently support yll-only causes so I'm dropping them because the data in some cases is
# very messed up, with mort = morb = 1 (e.g., aortic aneurysm in the RR data for high systolic bp) -
Expand All @@ -352,7 +348,6 @@ def get_relative_risk(entity: RiskFactor, location_id: int, get_all_years: bool
+ ["affected_entity", "affected_measure", "parameter"]
+ DRAW_COLUMNS
)

data = (
data.groupby(["affected_entity", "parameter"])
.apply(utilities.normalize, fill_value=1)
Expand All @@ -379,12 +374,12 @@ def filter_by_relative_risk(df: pd.DataFrame, relative_risk: pd.DataFrame) -> pd


def get_population_attributable_fraction(
entity: Union[RiskFactor, Etiology], location_id: int, get_all_years: bool = False
entity: Union[RiskFactor, Etiology], location_id: int
) -> pd.DataFrame:
causes_map = {c.gbd_id: c for c in causes}
if entity.kind == "risk_factor":
data = extract.extract_data(entity, "population_attributable_fraction", location_id, validate=True, get_all_years=get_all_years)
relative_risk = extract.extract_data(entity, "relative_risk", validate=True, get_all_years=get_all_years)
data = extract.extract_data(entity, "population_attributable_fraction", location_id)
relative_risk = extract.extract_data(entity, "relative_risk", location_id)

# FIXME: we don't currently support yll-only causes so I'm dropping them because the data in some cases is
# very messed up, with mort = morb = 1 (e.g., aortic aneurysm in the RR data for high systolic bp) -
Expand Down Expand Up @@ -412,7 +407,7 @@ def get_population_attributable_fraction(

else: # etiology
data = extract.extract_data(
entity, "etiology_population_attributable_fraction", location_id, validate=True, get_all_years=get_all_years
entity, "etiology_population_attributable_fraction", location_id
)
cause = [c for c in causes if entity in c.etiologies][0]
data = utilities.filter_data_by_restrictions(
Expand Down Expand Up @@ -442,8 +437,8 @@ def get_population_attributable_fraction(
return data


def get_estimate(entity: Covariate, location_id: int, get_all_years: bool = False) -> pd.DataFrame:
data = extract.extract_data(entity, "estimate", location_id, validate=True, get_all_years=get_all_years)
def get_estimate(entity: Covariate, location_id: int) -> pd.DataFrame:
data = extract.extract_data(entity, "estimate", location_id)

key_columns = ["location_id", "year_id"]
if entity.by_age:
Expand All @@ -457,40 +452,37 @@ def get_estimate(entity: Covariate, location_id: int, get_all_years: bool = Fals
return data


#TODO: Remove?
def get_utilization_rate(entity: HealthcareEntity, location_id: int) -> pd.DataFrame:
data = extract.extract_data(entity, "utilization_rate", location_id)
data = utilities.normalize(data, fill_value=0)
data = data.filter(DEMOGRAPHIC_COLUMNS + DRAW_COLUMNS)
return data


def get_structure(entity: Population, location_id: int, get_all_years: bool = False) -> pd.DataFrame:
data = extract.extract_data(entity, "structure", location_id, validate=True, get_all_years=get_all_years)
def get_structure(entity: Population, location_id: int) -> pd.DataFrame:
data = extract.extract_data(entity, "structure", location_id)
data = data.drop("run_id", axis="columns").rename(columns={"population": "value"})
data = utilities.normalize(data)
return data


def get_theoretical_minimum_risk_life_expectancy(
entity: Population, location_id: int, get_all_years: bool = False
entity: Population, location_id: int
) -> pd.DataFrame:
data = extract.extract_data(
entity, "theoretical_minimum_risk_life_expectancy", location_id, validate=True, get_all_years=get_all_years
entity, "theoretical_minimum_risk_life_expectancy", location_id
)
data = data.rename(columns={"age": "age_start", "life_expectancy": "value"})
data["age_end"] = data.age_start.shift(-1).fillna(125.0)
return data


def get_age_bins(entity: Population, location_id: int, get_all_years: bool = False) -> pd.DataFrame:
def get_age_bins(entity: Population, location_id: int) -> pd.DataFrame:
age_bins = utility_data.get_age_bins()[["age_group_name", "age_start", "age_end"]]
return age_bins


def get_demographic_dimensions(entity: Population, location_id: int, get_all_years: bool = False) -> pd.DataFrame:
def get_demographic_dimensions(entity: Population, location_id: int) -> pd.DataFrame:
demographic_dimensions = utility_data.get_demographic_dimensions(location_id)
if not get_all_years:
demographic_dimensions = demographic_dimensions.query("year_id==@MOST_RECENT_YEAR")
demographic_dimensions = utilities.normalize(demographic_dimensions)
return demographic_dimensions

0 comments on commit 0e9c28d

Please sign in to comment.