Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update log_posterior (and log_likelihood) functions to also return auxiliary information #10

Merged
merged 9 commits into from
Feb 7, 2024

Conversation

SamDuffield
Copy link
Contributor

This PR changes the API for log_posterior and log_likelihood functions to be of the form

eval, aux = log_posterior(params, batch)

This means that the user can retain any useful information in the aux tensor tree e.g. model predictions or additional metrics.

@dpsimpson
Copy link

My general preference is for a return class / named dictionary. I am not a fan of returning an optional tuple as it can break code in surprising ways if it's usually not used and requires user cognitive load to understand what the order is.

@SamDuffield
Copy link
Contributor Author

My general preference is for a return class / named dictionary. I am not a fan of returning an optional tuple as it can break code in surprising ways if it's usually not used and requires user cognitive load to understand what the order is.

The tuple here is compulsory and the aux can be any class, dict whatever the user likes (as long as all leaves are tensors).

I'm not sure how an API like arbitrary_output = log_posterior(params, batch) would work with torch.func.grad (and friends). Since it needs

func (Callable) – A Python function that takes one or more arguments. Must return a single-element Tensor. If specified has_aux equals True, function can return a tuple of single-element Tensor and other auxiliary objects: (output, aux).

@johnathanchiu
Copy link
Contributor

johnathanchiu commented Feb 2, 2024

@SamDuffield's previous answer is correct. I tested it myself. I think it is perfectly fine to assume that the return values are up to the user.

def log_posterior_n(params, batch, model, n_data):
    y_pred = functional_call(model, params, batch[0])
    return normal_log_prior(params) + normal_log_likelihood(
        batch[1], y_pred
    ) * n_data, {
        "b": torch.tensor([]),
        "c": torch.tensor([]),
        "d": torch.tensor([]),
        "e": torch.tensor([]),
    }


def log_posterior(p, b):
    a, aux_ret = log_posterior_n(p, b, model, 1)
    return a.mean(), aux_ret

Something like this extends flexibility and named dictionary returns. We can always assume the first return value is the scalar loss, the second is user specified. This addresses the concern of returning a structured output.

Copy link
Contributor

@johnathanchiu johnathanchiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just provide an example with a dictionary return type with tensors as an auxiliary just to show flexibility. Otherwise everything else lgtm.

@SamDuffield
Copy link
Contributor Author

SamDuffield commented Feb 2, 2024

The yelp examples actually now have the model output as the auxiliary information which is a dict

Although this is kinda hidden lol

I'll rewrite the continual regression example to showcase this API

@SamDuffield
Copy link
Contributor Author

I've rewritten the continual regression notebook to use uqlib.vi.diag with a log_posterior in the new format.

I've also put the uqlib API including log_posterior front and centre in the README

Copy link
Contributor

@johnathanchiu johnathanchiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small changes proposed, everything else looks good otherwise. Feel free to merge after making changes.

README.md Outdated
Comment on lines 32 to 34
Further the output of `log_posterior` is a tuple containing the evaluation and
an additional argument containing any auxiliary information we'd like to retain from
the model call, here the model predictions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this more explicit that the returns must be a tensor? I was personally mislead by the torch docs actually and had to discover that by testing it myself.

@@ -10,8 +10,8 @@

def batch_normal_log_prob(
p: dict, batch: Any, mean: dict, sd_diag: dict
) -> torch.Tensor:
return diag_normal_log_prob(p, mean, sd_diag)
) -> Tuple[torch.Tensor, None]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type is not None, I believe it's Optional[torch.Tensor, None]

@@ -10,8 +10,8 @@

def batch_normal_log_prob(
p: dict, batch: Any, mean: dict, sd_diag: dict
) -> torch.Tensor:
return diag_normal_log_prob(p, mean, sd_diag)
) -> Tuple[torch.Tensor, None]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type is not None, I believe it's Optional[torch.Tensor, None]

@SamDuffield SamDuffield merged commit 966f9be into main Feb 7, 2024
2 checks passed
@SamDuffield SamDuffield deleted the aux-log-posterior branch February 7, 2024 14:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants