Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inverse_and_gradient method to bijection #191

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

aseyboldt
Copy link

This includes an extra method to bijections that I use for a project of mine. I'm not sure if this is of enough general interest to include into the library, but including it would certainly make my life easier, and there might be other use cases around.

If there is a known density defined on the transformed space of some bijection, we can compute the logp, and the gradient of that logp on this transformed space. In the code I call those y (the position), y_grad (the gradient of the log density at y) and y_logp (the log-density at y).

We can then compute these quantities on the untransformed space, using some autodiff where we need to take into account the jacobian determinant.

The PR contains a general implementation in the AbstractBijection class. The code that this generates is not always the best however, so for some bijections I also implement special cases (Affine, Chain, Permutation and Coupling).

The special implementations are tested against the general implementation, and the general implementation is tested agains analytical solutions for Affine and Exp transformations.

x_grad: The gradient of the log density at `x`.
x_logp: The log density on the untransformed space at `x`.
"""
x, logdet = self.inverse_and_log_det(y, condition)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume you don't need the log_det in the inverse?

x = bijection.inverse(y, condition)
(_, fwd_log_det), pull_grad_fn = jax.vjp(...)

and add fwd_log_det instead of subtracting the inverse log det.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that is better.

@danielward27
Copy link
Owner

Nice, thanks for the contribution. I'll have to have a think to understand it a bit better and decide whether this is in the scope of the library, as it adds a reasonable amount of code complexity for a somewhat niche feature (at least it's not something I have come across in other similar libraries before). Without really understanding the use case, I wonder how frequently this is needed, i.e. over

x = dist.bijection.inverse(y)
log_prob, grad = jax.value_and_grad(dist.base_dist.log_prob)(x)

If you get a chance to describe your use case a bit, maybe it would help me understand a bit better.

@danielward27
Copy link
Owner

danielward27 commented Oct 22, 2024

At least on CPU, whilst the specialized methods are faster, the difference seems to be pretty small. Did you find situations where the specialized versions are substantially faster?

import flowjax.bijections as bij
from flowjax.bijections import Coupling, Affine
import equinox as eqx
import jax.numpy as jnp


dim = 500

has_specialized = [
    bij.Coupling(
        jr.key(0),
        transformer=Affine(),
        dim=dim,
        untransformed_dim=dim//2,
        nn_width=50,
        nn_depth=1,
    ),
    bij.Affine(jnp.ones(dim), jnp.ones(dim)),
    bij.Permute(jr.permutation(jr.key(0), jnp.arange(dim))),
    bij.Chain([bij.Affine(jnp.ones(dim)), bij.Affine(jnp.ones(dim))])
]
for bijection in has_specialized:
    print(bijection.__class__.__name__)
    y = jnp.linspace(0, 1, dim)
    y_grad = jnp.linspace(0, 1, dim)
    y_logp = jnp.array(0)

    specialized = eqx.filter_jit(bijection.inverse_gradient_and_val)
    result = specialized(y, y_grad, y_logp)[1].block_until_ready()
    %timeit specialized(y, y_grad, y_logp)[1].block_until_ready()

    naive = eqx.filter_jit(bij.AbstractBijection.inverse_gradient_and_val)
    naive_result = naive(bijection, y, y_grad, y_logp)[1].block_until_ready()
    %timeit naive(bijection, y, y_grad, y_logp)[1].block_until_ready()
Coupling
306 μs ± 32.1 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
361 μs ± 13.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Affine
306 μs ± 10.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
372 μs ± 39.6 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Permute
302 μs ± 7.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
354 μs ± 9.18 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Chain
330 μs ± 6 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
417 μs ± 36.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

@aseyboldt
Copy link
Author

Thank you for having a look.

I was seeing a much larger improvement in my earlier benchmarks, but turns out I made a mistake in those, and funnily enough there is essentially the same issue in the code you just posted:

In the specialized function the bijection is a compile time constant, while it is an input in the general function. If I fix that I see practically no performance difference between the general and the specialized implementations anymore in that particular benchmark:

import flowjax.bijections as bij
from flowjax.bijections import Coupling, Affine
import equinox as eqx
import jax.numpy as jnp
import jax.random as jr

dim = 500

has_specialized = [
    bij.Coupling(
        jr.key(0),
        transformer=Affine(),
        dim=dim,
        untransformed_dim=dim//2,
        nn_width=50,
        nn_depth=1,
    ),
    bij.Affine(jnp.ones(dim), jnp.ones(dim)),
    bij.Permute(jr.permutation(jr.key(0), jnp.arange(dim))),
    bij.Chain([bij.Affine(jnp.ones(dim)), bij.Affine(jnp.ones(dim))] * 10)
]
for bijection in has_specialized:
    print(bijection.__class__.__name__)
    y = jnp.linspace(0, 1, dim)
    y_grad = jnp.linspace(0, 1, dim)
    y_logp = jnp.array(0)

    specialized = eqx.filter_jit(lambda bijection, *args: bijection.inverse_gradient_and_val(*args))
    result = specialized(bijection, y, y_grad, y_logp)[1].block_until_ready()
    %timeit specialized(bijection, y, y_grad, y_logp)[1].block_until_ready()

    naive = eqx.filter_jit(bij.AbstractBijection.inverse_gradient_and_val)
    naive_result = naive(bijection, y, y_grad, y_logp)[1].block_until_ready()
    %timeit naive(bijection, y, y_grad, y_logp)[1].block_until_ready()
Coupling
295 μs ± 3.94 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
302 μs ± 6.06 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Affine
259 μs ± 1.91 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
269 μs ± 4.88 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Permute
257 μs ± 8.25 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
260 μs ± 6.91 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Chain
594 μs ± 8.43 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
597 μs ± 9.84 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In a more involved benchmark that is closer to what I'm actually interested in, I still see a speedup, but not anywhere as big as before:

import flowjax
import flowjax.train
import jax
import flowjax.bijections
import jax.numpy as jnp
import equinox as eqx
from paramax import Parameterize
import numpy as np

scale = Parameterize(
    lambda x: jnp.exp(jnp.arcsinh(x)), jnp.array(0.0)
)
affine = eqx.tree_at(
    where=lambda aff: aff.scale,
    pytree=flowjax.bijections.Affine(),
    replace=scale,
)

layer = flowjax.bijections.Chain([
    flowjax.bijections.coupling.Coupling(
        jax.random.PRNGKey(i),
        dim=5000,
        untransformed_dim=2000,
        nn_width=1000,
        nn_depth=1,
        transformer=affine,
    )
    for i in range(10)
])

np.random.seed(1)
y = jnp.array(np.random.randn(5000))
y_grad = jnp.array(np.random.randn(5000))
y_logp = jnp.array(1.5)

@eqx.filter_jit
def non_naive(layer, y, y_grad, y_logp):
    def cost(y, y_grad):
        y, y_grad, _ = layer.inverse_gradient_and_val(y, y_grad, y_logp)
        return ((y + y_grad) ** 2).sum()
    return jax.value_and_grad(cost)(y, y_grad)

non_naive(layer, y, y_grad, y_logp)
%timeit non_naive(layer, y, y_grad, y_logp)[1].block_until_ready()

@eqx.filter_jit
def naive(layer, y, y_grad, y_logp):
    def cost(y, y_grad):
        y, y_grad, _ = flowjax.bijections.AbstractBijection.inverse_gradient_and_val(layer, y, y_grad, y_logp)
        return ((y + y_grad) ** 2).sum()
    return jax.value_and_grad(cost)(y, y_grad)

naive(layer, y, y_grad, y_logp)
%timeit naive(layer, y, y_grad, y_logp)[1].block_until_ready()
81.3 ms ± 1.96 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
95.9 ms ± 2.8 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

This is not nothing, but right now I'm not sure if that is worth the effort. I'll have another look tomorrow (and I really hope I can still find a way to make it faster, because it's no fun if it is as slow as this), but if the whole thing was just based on some incorrect benchmarks, I'm sorry I wasted your time...

@danielward27
Copy link
Owner

Thanks. Ah yes, I think you are right about the benchmark, good catch!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants