Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient of the log probability #244

Open
arrjon opened this issue Nov 6, 2024 · 2 comments
Open

Gradient of the log probability #244

arrjon opened this issue Nov 6, 2024 · 2 comments
Labels
feature New feature or request
Milestone

Comments

@arrjon
Copy link
Contributor

arrjon commented Nov 6, 2024

It would be nice to be able to compute the gradient of the log probability (in the new BayesFlow version).

For example, with the jax-backend, I would like to be able to do something along this lines:

from jax import grad

def partial_objective(theta):
    log_prop = approximator.log_prob(
        data={'theta': theta, 'x': x},
        batch_size=1,
    )
    return log_prop[0]

grad(partial_objective)(theta)

The main reason this is not working at the moment, is the adapter:
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[1].
Any idea, how this could be done?

Edit: also just evaluating the log_prop at the moment works only if one does not use standardize in the adapter. I think this is related to #233. However, calling grad(approximator._log_prob(...)) circumventing the adapter works.

@stefanradev93
Copy link
Contributor

Agreed that this is an important feature to have! We may want to implement an internal function log_prob_and_grad that does not trace the adapter transformations, but the naive implementation would give gradients wrt the transformed inputs, which may not be what one wants (e.g., gradients would also have a different scale).

@LarsKue
Copy link
Contributor

LarsKue commented Nov 7, 2024

I will look into what options we have to allow something like this. For now, you can use the protected _log_prob as a work-around. This returns a keras.backend.Tensor so the gradient should be enabled. The complete equivalent code should be:

def partial_objective(theta):
    data = approximator.adapter(data={"theta": theta, "x": x}, strict=False)
    data = keras.tree.map_structure(keras.ops.convert_to_tensor, data)
    log_prob = approximator._log_prob(**data)
    return log_prob[0]

@paul-buerkner paul-buerkner added the feature New feature or request label Nov 8, 2024
@paul-buerkner paul-buerkner added this to the BayesFlow 2.0 milestone Nov 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants