Skip to content

Commit

Permalink
Reformat files using pre-commit :grimace-face:
Browse files Browse the repository at this point in the history
  • Loading branch information
kokbent committed Nov 3, 2023
1 parent fffc0a9 commit a2982a0
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
"editor.defaultFormatter": "ms-python.black-formatter"
},
"python.formatting.provider": "none"
}
}
47 changes: 34 additions & 13 deletions mechanistic_compartments.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,12 @@ def __init__(self, **kwargs):
self.POPULATION * self.INIT_WANING_DIST.transpose()
).transpose()

initial_infectious_count = self.INITIAL_INFECTIONS * self.INIT_INFECTED_DIST
initial_exposed_count = self.INITIAL_INFECTIONS * self.INIT_EXPOSED_DIST
initial_infectious_count = (
self.INITIAL_INFECTIONS * self.INIT_INFECTED_DIST
)
initial_exposed_count = (
self.INITIAL_INFECTIONS * self.INIT_EXPOSED_DIST
)
self.INITIAL_STATE = (
inital_suseptible_count, # s
initial_exposed_count, # e
Expand All @@ -107,7 +111,10 @@ def get_args(self, sample=False):
for example functions f() in charge of disease dynamics see the model_odes folder.
"""
if sample:
beta = utils.sample_r0(self.STRAIN_SPECIFIC_R0) / self.infectious_period
beta = (
utils.sample_r0(self.STRAIN_SPECIFIC_R0)
/ self.infectious_period
)
waning_protections = utils.sample_waning_protections(
self.WANING_PROTECTIONS
)
Expand All @@ -120,9 +127,9 @@ def get_args(self, sample=False):
waning_rate = 1 / self.WANING_TIME
# default to no cross immunity, setting diagnal to 0
# TODO use priors informed by https://www.sciencedirect.com/science/article/pii/S2352396423002992
suseptibility_matrix = jnp.ones((self.NUM_STRAINS, self.NUM_STRAINS)) * (
1 - jnp.diag(jnp.array([1] * self.NUM_STRAINS))
)
suseptibility_matrix = jnp.ones(
(self.NUM_STRAINS, self.NUM_STRAINS)
) * (1 - jnp.diag(jnp.array([1] * self.NUM_STRAINS)))
# if your model expects added parameters, add them here
args = {
"beta": beta,
Expand Down Expand Up @@ -244,7 +251,9 @@ def run(
----------
Diffrax.Solution object as described by https://docs.kidger.site/diffrax/api/solution/
"""
term = ODETerm(lambda t, state, parameters: model(state, t, parameters))
term = ODETerm(
lambda t, state, parameters: model(state, t, parameters)
)
solver = Tsit5()
t0 = 0.0
dt0 = 0.1
Expand Down Expand Up @@ -301,7 +310,9 @@ def plot_diffrax_solution(
]
get_indexes.append(index_slice)
else:
get_indexes.append(self.IDX.__getitem__(compartment.strip().upper()))
get_indexes.append(
self.IDX.__getitem__(compartment.strip().upper())
)

fig, ax = plt.subplots(1)
for compartment, idx in zip(plot_compartments, get_indexes):
Expand Down Expand Up @@ -345,7 +356,9 @@ def load_waning_and_recovered_distributions(self):
sero_data = pd.read_csv(download_link)
os.makedirs(self.SEROLOGICAL_DATA, exist_ok=True)
sero_data.to_csv(sero_path, index=False)
pop_path = self.DEMOGRAPHIC_DATA + "population_rescaled_age_distributions/"
pop_path = (
self.DEMOGRAPHIC_DATA + "population_rescaled_age_distributions/"
)
(
self.INIT_RECOVERED_DIST,
self.INIT_WANING_DIST,
Expand Down Expand Up @@ -411,7 +424,9 @@ def load_init_infection_infected_and_exposed_dist(self):
# TODO initialize infections by age based on the seroprevalence by age.
# since we are assuming similar dynamics in short time frames
# we expect to see similar proportions of each age bin in new infections as recovered
self.INIT_INFECTION_DIST = self.INIT_RECOVERED_DIST[:, self.STRAIN_IDX.omicron]
self.INIT_INFECTION_DIST = self.INIT_RECOVERED_DIST[
:, self.STRAIN_IDX.omicron
]
# eig_data = np.linalg.eig(self.CONTACT_MATRIX)
# max_index = np.argmax(eig_data[0])
# self.INIT_INFECTION_DIST = abs(eig_data[1][:, max_index])
Expand All @@ -425,8 +440,12 @@ def load_init_infection_infected_and_exposed_dist(self):
# [0.27707683 0.45785665 0.1815728 0.08349373] contact matrix method, 4 bins

# ratio of gamma / sigma defines our infected to exposed ratio at any given time
exposed_to_infected_ratio = self.EXPOSED_TO_INFECTIOUS / self.INFECTIOUS_PERIOD
self.INIT_EXPOSED_DIST = exposed_to_infected_ratio * self.INIT_INFECTION_DIST
exposed_to_infected_ratio = (
self.EXPOSED_TO_INFECTIOUS / self.INFECTIOUS_PERIOD
)
self.INIT_EXPOSED_DIST = (
exposed_to_infected_ratio * self.INIT_INFECTION_DIST
)
# INIT_EXPOSED_DIST is not strain stratified, put infected into the omicron strain via indicator vec
self.INIT_EXPOSED_DIST = self.INIT_EXPOSED_DIST[:, None] * np.array(
[0] * self.STRAIN_IDX.omicron + [1]
Expand Down Expand Up @@ -456,7 +475,9 @@ def default(self, obj):
if isinstance(obj, np.ndarray) or isinstance(obj, jnp.ndarray):
return obj.tolist()
if isinstance(obj, EnumMeta):
return {str(e): idx for e, idx in zip(obj, range(len(obj)))}
return {
str(e): idx for e, idx in zip(obj, range(len(obj)))
}
return json.JSONEncoder.default(self, obj)

return json.dump(self.__dict__, file, indent=4, cls=CustomEncoder)
Expand Down
12 changes: 9 additions & 3 deletions model_odes/seir_model_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def _seirw_ode(state, _, parameters):
p = Parameters(parameters)

# TODO when adding birth and deaths just create it as a compartment
force_of_infection = p.beta * p.contact_matrix.dot(i) / p.population[:, None]
force_of_infection = (
p.beta * p.contact_matrix.dot(i) / p.population[:, None]
)
ds_to_e = force_of_infection * s[:, None]

ds_to_w = s * p.vax_rate # vaccination of suseptibles
Expand All @@ -65,7 +67,9 @@ def _seirw_ode(state, _, parameters):
effective_ws_by_age = ws_by_age * (
1 - (p.waning_protections * (1 - partial_susceptibility))
)
ws_exposed = force_of_infection_strain[:, None] * effective_ws_by_age
ws_exposed = (
force_of_infection_strain[:, None] * effective_ws_by_age
)
# element wise subtraction of exposed w_s from strain_target dw
dw = dw.at[:, strain_target_idx, :].add(-ws_exposed)
# element wise addition of exposed w_s into de
Expand Down Expand Up @@ -111,7 +115,9 @@ def _seirw_ode(state, _, parameters):
# only top waning compartment receives people from "r"

# sum ds_to_e since s does not split by subtype
ds = jnp.add(jnp.zeros(s.shape), jnp.add(-jnp.sum(ds_to_e, axis=1), -ds_to_w))
ds = jnp.add(
jnp.zeros(s.shape), jnp.add(-jnp.sum(ds_to_e, axis=1), -ds_to_w)
)
de = jnp.add(de, -de_to_i + ds_to_e)
di = jnp.add(jnp.zeros(i.shape), jnp.add(de_to_i, -di_to_r))
dr = jnp.add(jnp.zeros(r.shape), jnp.add(di_to_r, -dr_to_w))
Expand Down

0 comments on commit a2982a0

Please sign in to comment.