-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update log_posterior (and log_likelihood) functions to also return auxiliary information #10
Conversation
My general preference is for a return class / named dictionary. I am not a fan of returning an optional tuple as it can break code in surprising ways if it's usually not used and requires user cognitive load to understand what the order is. |
The tuple here is compulsory and the aux can be any class, dict whatever the user likes (as long as all leaves are tensors). I'm not sure how an API like
|
@SamDuffield's previous answer is correct. I tested it myself. I think it is perfectly fine to assume that the return values are up to the user.
Something like this extends flexibility and named dictionary returns. We can always assume the first return value is the scalar loss, the second is user specified. This addresses the concern of returning a structured output. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you just provide an example with a dictionary return type with tensors as an auxiliary just to show flexibility. Otherwise everything else lgtm.
The yelp examples actually now have the model output as the auxiliary information which is a dict Although this is kinda hidden lol I'll rewrite the continual regression example to showcase this API |
I've rewritten the continual regression notebook to use I've also put the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small changes proposed, everything else looks good otherwise. Feel free to merge after making changes.
README.md
Outdated
Further the output of `log_posterior` is a tuple containing the evaluation and | ||
an additional argument containing any auxiliary information we'd like to retain from | ||
the model call, here the model predictions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make this more explicit that the returns must be a tensor? I was personally mislead by the torch docs actually and had to discover that by testing it myself.
tests/vi/test_diag.py
Outdated
@@ -10,8 +10,8 @@ | |||
|
|||
def batch_normal_log_prob( | |||
p: dict, batch: Any, mean: dict, sd_diag: dict | |||
) -> torch.Tensor: | |||
return diag_normal_log_prob(p, mean, sd_diag) | |||
) -> Tuple[torch.Tensor, None]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type is not None
, I believe it's Optional[torch.Tensor, None]
tests/vi/test_diag.py
Outdated
@@ -10,8 +10,8 @@ | |||
|
|||
def batch_normal_log_prob( | |||
p: dict, batch: Any, mean: dict, sd_diag: dict | |||
) -> torch.Tensor: | |||
return diag_normal_log_prob(p, mean, sd_diag) | |||
) -> Tuple[torch.Tensor, None]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type is not None
, I believe it's Optional[torch.Tensor, None]
This PR changes the API for
log_posterior
andlog_likelihood
functions to be of the formThis means that the user can retain any useful information in the
aux
tensor tree e.g. model predictions or additional metrics.