-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
masked autoregressive flow with mixed transformer types #161
Comments
I think your approach works, but it would have a bit of extra overhead as like you said the masked autoregressive network will still produce a set of (unused) parameters for the identity transformed variables. If you wanted to avoid that, here's another possibility. What I have done is wrap a masked autoregressive bijection that has dimension matching the dimensionality of the transformed variables, and from typing import ClassVar
from flowjax.bijections.bijection import AbstractBijection
from flowjax.bijections.masked_autoregressive import MaskedAutoregressive
class IdentityFirstMaskedAutoregressive(AbstractBijection):
masked_autoregressive: MaskedAutoregressive
identity_dim: int
shape: tuple[int, ...]
cond_shape: ClassVar[None] = None
def __init__(self, masked_autoregressive: MaskedAutoregressive):
self.masked_autoregressive = masked_autoregressive
self.identity_dim = masked_autoregressive.cond_shape[0]
self.shape = (self.identity_dim + self.masked_autoregressive.shape[0],)
def transform(self, x, condition=None):
y = self.masked_autoregressive.transform(
x[self.identity_dim :], condition=x[: self.identity_dim]
)
return x.at[self.identity_dim :].set(y)
def transform_and_log_det(self, x, condition=None):
y, log_det = self.masked_autoregressive.transform_and_log_det(
x[self.identity_dim :],
condition=x[: self.identity_dim],
)
return x.at[self.identity_dim :].set(y), log_det
def inverse(self, y, condition=None):
x = self.masked_autoregressive.inverse(
y[self.identity_dim :], condition=y[: self.identity_dim]
)
return y.at[self.identity_dim :].set(x)
def inverse_and_log_det(self, y, condition=None):
x, log_det = self.masked_autoregressive.inverse_and_log_det(
y[self.identity_dim :], condition=y[: self.identity_dim]
)
return y.at[self.identity_dim :].set(x), log_det If you need to support a conditional version of this, then it should be possible with some concatenating and adjusting of shapes. In general it could be possible to add support for a mix of transformer types, but e.g. if we assume we have a list of heterogeneous transformers then compilation speed might become an issue, as we can no longer just rely on vmap and would have to loop. Thanks for the support and let me know if you have any questions/issues! |
This is a bit late, but another option is to defined individual bijections for the non-transformed variables and the remaining ones, and then stack them together into a single bijection: import jax
import jax.numpy as jnp
from flowjax.bijections import Identity, RationalQuadraticSpline, MaskedAutoregressive, Concatenate
from flowjax.distributions import Uniform, Transformed
N = 5
base_dist = Uniform(minval = -jnp.ones(N), maxval = jnp.ones(N))
bijections = [
Identity(shape = (1,)),
MaskedAutoregressive(
key = jax.random.PRNGKey(0),
transformer = RationalQuadraticSpline(knots = 5, interval = 1.0),
dim = N - 1,
nn_width = 10,
nn_depth = 1,
),
]
# use Concatenate as it stacks bijections along an *existing* axis
bijection = Concatenate(bijections)
flow = Transformed(base_dist, bijection) You could wrap this in a constructor with |
You can do that, but note that the transform of the transformed dimensions will be independent of the identity transformed variables if you do |
It could be possible to support a transformer with shape/dimension matching the shape of the total bijection (rather than only scalar bijections), in which case you could |
I am looking into a modification of a regular masked autoregressive flow where the base distribution is an N-dimensional uniform and the first variable does not get transformed, while the rest of the variables get transformed via a rational quadratic spline. I have removed the shuffling in the
masked_autoregressive_flow
function via removing the_add_default_permute
, and modified the_flat_params_to_transformer
in theMaskedAutoregressive
class to apply an Identity transformer to the first dimension in the following wayMy understanding is that in this way the
masked_autoregressive_mlp
will still produce a set of spline parameters for the first variable, that then never get used, and that this should be harmless. My experiments seem to produce the expected results but I am not sure that this is the most efficient way to go about this or whether I am disregarding anything relevant, so would love to hear your opinion as to how to make the best use of your package. Thanks again for all the amazing work!The text was updated successfully, but these errors were encountered: