Skip to content

Commit

Permalink
use list comprehension in test_infectionsrtfeedback
Browse files Browse the repository at this point in the history
  • Loading branch information
sbidari committed Sep 12, 2024
1 parent 396c2d7 commit 8197059
Showing 1 changed file with 43 additions and 11 deletions.
54 changes: 43 additions & 11 deletions test/test_infectionsrtfeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,55 @@ def _infection_w_feedback_alt(
T, -1
) # coerce from jax to use numpy-like operations
len_gen = len(gen_int)
I_vec = np.concatenate([I0.reshape(T, -1), np.zeros(Rt.shape)])
infs = np.concatenate([I0.reshape(T, -1), np.zeros(Rt.shape)])
Rt_adj = np.zeros(Rt.shape)
inf_feedback_strength = np.array(inf_feedback_strength).reshape(T, -1)

for n in range(Rt.shape[1]):
for t in range(Rt.shape[0]):
Rt_adj[t, n] = Rt[t, n] * np.exp(
inf_feedback_strength[t, n]
* np.dot(I_vec[t : t + len_gen, n], np.flip(inf_feedback_pmf))
)
def compute_Rt_adj(
Rt, inf_feedback_strength, infs, inf_feedback_pmf, len_gen, t, n
): # numpydoc ignore=GL08
return Rt[t, n] * np.exp(
inf_feedback_strength[t, n]
* np.dot(infs[t : t + len_gen, n], np.flip(inf_feedback_pmf))
)

Rt_adj = np.array(
[
[
compute_Rt_adj(
Rt,
inf_feedback_strength,
infs,
inf_feedback_pmf,
len_gen,
t,
n,
)
for n in range(Rt.shape[1])
]
for t in range(Rt.shape[0])
]
)

I_vec[t + len_gen, n] = Rt_adj[t, n] * np.dot(
I_vec[t : t + len_gen, n], np.flip(gen_int)
)
def compute_infections(
Rt_adj, infs, len_gen, gen_int, t, n
): # numpydoc ignore=GL08
return Rt_adj[t, n] * np.dot(
infs[t : t + len_gen, n], np.flip(gen_int)
)

infs[len_gen : T + len_gen] = np.array(
[
[
compute_infections(Rt_adj, infs, len_gen, gen_int, t, n)
for n in range(Rt.shape[1])
]
for t in range(Rt.shape[0])
]
)

return {
"post_initialization_infections": np.squeeze(I_vec[I0.shape[0] :]),
"post_initialization_infections": np.squeeze(infs[I0.shape[0] :]),
"rt": np.squeeze(Rt_adj),
}

Expand Down

0 comments on commit 8197059

Please sign in to comment.