-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add inverse_and_gradient method to bijection #191
base: main
Are you sure you want to change the base?
Conversation
7299fad
to
68e42ab
Compare
flowjax/bijections/bijection.py
Outdated
x_grad: The gradient of the log density at `x`. | ||
x_logp: The log density on the untransformed space at `x`. | ||
""" | ||
x, logdet = self.inverse_and_log_det(y, condition) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume you don't need the log_det in the inverse?
x = bijection.inverse(y, condition)
(_, fwd_log_det), pull_grad_fn = jax.vjp(...)
and add fwd_log_det
instead of subtracting the inverse log det.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, that is better.
Nice, thanks for the contribution. I'll have to have a think to understand it a bit better and decide whether this is in the scope of the library, as it adds a reasonable amount of code complexity for a somewhat niche feature (at least it's not something I have come across in other similar libraries before). Without really understanding the use case, I wonder how frequently this is needed, i.e. over x = dist.bijection.inverse(y)
log_prob, grad = jax.value_and_grad(dist.base_dist.log_prob)(x) If you get a chance to describe your use case a bit, maybe it would help me understand a bit better. |
At least on CPU, whilst the specialized methods are faster, the difference seems to be pretty small. Did you find situations where the specialized versions are substantially faster? import flowjax.bijections as bij
from flowjax.bijections import Coupling, Affine
import equinox as eqx
import jax.numpy as jnp
dim = 500
has_specialized = [
bij.Coupling(
jr.key(0),
transformer=Affine(),
dim=dim,
untransformed_dim=dim//2,
nn_width=50,
nn_depth=1,
),
bij.Affine(jnp.ones(dim), jnp.ones(dim)),
bij.Permute(jr.permutation(jr.key(0), jnp.arange(dim))),
bij.Chain([bij.Affine(jnp.ones(dim)), bij.Affine(jnp.ones(dim))])
]
for bijection in has_specialized:
print(bijection.__class__.__name__)
y = jnp.linspace(0, 1, dim)
y_grad = jnp.linspace(0, 1, dim)
y_logp = jnp.array(0)
specialized = eqx.filter_jit(bijection.inverse_gradient_and_val)
result = specialized(y, y_grad, y_logp)[1].block_until_ready()
%timeit specialized(y, y_grad, y_logp)[1].block_until_ready()
naive = eqx.filter_jit(bij.AbstractBijection.inverse_gradient_and_val)
naive_result = naive(bijection, y, y_grad, y_logp)[1].block_until_ready()
%timeit naive(bijection, y, y_grad, y_logp)[1].block_until_ready()
|
Thank you for having a look. I was seeing a much larger improvement in my earlier benchmarks, but turns out I made a mistake in those, and funnily enough there is essentially the same issue in the code you just posted: In the import flowjax.bijections as bij
from flowjax.bijections import Coupling, Affine
import equinox as eqx
import jax.numpy as jnp
import jax.random as jr
dim = 500
has_specialized = [
bij.Coupling(
jr.key(0),
transformer=Affine(),
dim=dim,
untransformed_dim=dim//2,
nn_width=50,
nn_depth=1,
),
bij.Affine(jnp.ones(dim), jnp.ones(dim)),
bij.Permute(jr.permutation(jr.key(0), jnp.arange(dim))),
bij.Chain([bij.Affine(jnp.ones(dim)), bij.Affine(jnp.ones(dim))] * 10)
]
for bijection in has_specialized:
print(bijection.__class__.__name__)
y = jnp.linspace(0, 1, dim)
y_grad = jnp.linspace(0, 1, dim)
y_logp = jnp.array(0)
specialized = eqx.filter_jit(lambda bijection, *args: bijection.inverse_gradient_and_val(*args))
result = specialized(bijection, y, y_grad, y_logp)[1].block_until_ready()
%timeit specialized(bijection, y, y_grad, y_logp)[1].block_until_ready()
naive = eqx.filter_jit(bij.AbstractBijection.inverse_gradient_and_val)
naive_result = naive(bijection, y, y_grad, y_logp)[1].block_until_ready()
%timeit naive(bijection, y, y_grad, y_logp)[1].block_until_ready()
In a more involved benchmark that is closer to what I'm actually interested in, I still see a speedup, but not anywhere as big as before: import flowjax
import flowjax.train
import jax
import flowjax.bijections
import jax.numpy as jnp
import equinox as eqx
from paramax import Parameterize
import numpy as np
scale = Parameterize(
lambda x: jnp.exp(jnp.arcsinh(x)), jnp.array(0.0)
)
affine = eqx.tree_at(
where=lambda aff: aff.scale,
pytree=flowjax.bijections.Affine(),
replace=scale,
)
layer = flowjax.bijections.Chain([
flowjax.bijections.coupling.Coupling(
jax.random.PRNGKey(i),
dim=5000,
untransformed_dim=2000,
nn_width=1000,
nn_depth=1,
transformer=affine,
)
for i in range(10)
])
np.random.seed(1)
y = jnp.array(np.random.randn(5000))
y_grad = jnp.array(np.random.randn(5000))
y_logp = jnp.array(1.5)
@eqx.filter_jit
def non_naive(layer, y, y_grad, y_logp):
def cost(y, y_grad):
y, y_grad, _ = layer.inverse_gradient_and_val(y, y_grad, y_logp)
return ((y + y_grad) ** 2).sum()
return jax.value_and_grad(cost)(y, y_grad)
non_naive(layer, y, y_grad, y_logp)
%timeit non_naive(layer, y, y_grad, y_logp)[1].block_until_ready()
@eqx.filter_jit
def naive(layer, y, y_grad, y_logp):
def cost(y, y_grad):
y, y_grad, _ = flowjax.bijections.AbstractBijection.inverse_gradient_and_val(layer, y, y_grad, y_logp)
return ((y + y_grad) ** 2).sum()
return jax.value_and_grad(cost)(y, y_grad)
naive(layer, y, y_grad, y_logp)
%timeit naive(layer, y, y_grad, y_logp)[1].block_until_ready()
This is not nothing, but right now I'm not sure if that is worth the effort. I'll have another look tomorrow (and I really hope I can still find a way to make it faster, because it's no fun if it is as slow as this), but if the whole thing was just based on some incorrect benchmarks, I'm sorry I wasted your time... |
Thanks. Ah yes, I think you are right about the benchmark, good catch! |
This includes an extra method to bijections that I use for a project of mine. I'm not sure if this is of enough general interest to include into the library, but including it would certainly make my life easier, and there might be other use cases around.
If there is a known density defined on the transformed space of some bijection, we can compute the logp, and the gradient of that logp on this transformed space. In the code I call those
y
(the position),y_grad
(the gradient of the log density aty
) andy_logp
(the log-density aty
).We can then compute these quantities on the untransformed space, using some autodiff where we need to take into account the jacobian determinant.
The PR contains a general implementation in the
AbstractBijection
class. The code that this generates is not always the best however, so for some bijections I also implement special cases (Affine
,Chain
,Permutation
andCoupling
).The special implementations are tested against the general implementation, and the general implementation is tested agains analytical solutions for
Affine
andExp
transformations.