Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements SampledValue #262

Merged
merged 54 commits into from
Jul 26, 2024
Merged

Implements SampledValue #262

merged 54 commits into from
Jul 26, 2024

Conversation

gvegayon
Copy link
Member

@gvegayon gvegayon commented Jul 11, 2024

  • Introduces the SampledValye class with value, t_start, and t_unit attributes (time-related attributes are optional).
  • Updates all RandomVariable.sample() calls to return tuples/namedtuples with SampledValue instead of ArrayLike objects.
  • Tweaks one of the tutorials so that the figures show reasonable values.

image

The latter is the new version of the last plot, the previous figure looks like the following:

image

- It fixes a hidden bug in the hospital admissions model tutorial where the day-of-the-week effect was not properly used.

For later

Also, in the extending pyrenew tutorial

image
New version of the hospital admissions plot. vs:
image

This is not implemented as the section on using the day-of-week effect was removed from the tutorial.

Copy link

codecov bot commented Jul 11, 2024

Codecov Report

Attention: Patch coverage is 98.52941% with 1 line in your changes missing coverage. Please review.

Project coverage is 92.80%. Comparing base (7fd138d) to head (c1111fe).

Files Patch % Lines
model/src/pyrenew/metaclass.py 90.90% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #262      +/-   ##
==========================================
- Coverage   92.80%   92.80%   -0.01%     
==========================================
  Files          39       39              
  Lines         904      917      +13     
==========================================
+ Hits          839      851      +12     
- Misses         65       66       +1     
Flag Coverage Δ
unittests 92.80% <98.52%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@gvegayon gvegayon mentioned this pull request Jul 15, 2024
@gvegayon
Copy link
Member Author

Just found a bug in the code. Particularly the definition of the day of the week effect in the hospital_admissions_model.qmd tutorial. The current implementation does not use all the day-of-the-week effect samples but rather the first observation (Monday/Sunday, depending on where you live). Here is the code:

https://github.com/CDCgov/multisignal-epi-inference/blob/ce055afcaa2998d152ce39b3efd248d98eb3ddef/docs/source/tutorials/hospital_admissions_model.qmd#L502

In other words, instead of returning a tuple/namedtuple, it returns a jax array, which then, when used within latent.HospitalAdmissions.sample() fails to use the full vector:

https://github.com/CDCgov/multisignal-epi-inference/blob/ce055afcaa2998d152ce39b3efd248d98eb3ddef/model/src/pyrenew/latent/hospitaladmissions.py#L192-L195

The example should instead return a tuple:

         return (jnp.tile(ans, self.nweeks)[: self.len],)

But then, the code fails because the day-of-the-week RV samples a vector of length dat.shape[0] = 90 and tries multiplying it by a vector of infections of length dat.shape[0] + padding (21) + seeding (15) = 126.

For this PR, I'll fix the example by returning a tuple with a TimeArray (as all RVs in the PR), and specifying the sample size to be 126. A better solution will be to refactor latent.HospitalAdmissions to use a TimeArray instead of a ArrayLike for the day-of-the-week effect, passing the latent_admissions as a TimeArray so it will figure out what is the proper length (and alignment) automatically.

This could have been identified if we checked that all calls to RandomVariable.sample() return a namedtuple or a tuple of TimeArrays. I suggest leveraging we now have the RandomVariable.__call__() method to introduce a check there and have RandomVariable.sample() be renamed to RandomVariable._sample_() so it is now considered to be internal.

attn @damonbayer @dylanhmorris

@gvegayon gvegayon marked this pull request as ready for review July 18, 2024 22:09
@damonbayer
Copy link
Collaborator

damonbayer commented Jul 22, 2024

@gvegayon did you consider implementing this with __jax_array__? I think you did, but I couldn’t remember.

@gvegayon
Copy link
Member Author

@gvegayon did you consider implementing this with __jax_array__? I think you did, but I couldn’t remember.

I have not. I was checking, and there's not enough documentation to make good use of it. Looking at the source code of jax, it seems we would have to implement __jax_array__ as a function call. I suggest adding this as an issue and moving forward; mostly because it is not properly documented yet.

@dylanhmorris dylanhmorris mentioned this pull request Jul 26, 2024
@dylanhmorris
Copy link
Collaborator

Ready for re-review, @damonbayer

@dylanhmorris dylanhmorris requested a review from damonbayer July 26, 2024 19:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants