Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing issue Samples are outside the support for DiscreteUniform dist… #1835

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

Deathn0t
Copy link

This fixes issue #1834 for MixedHMC sampling with DiscreteUniform distribution sampling outside the support without using the enumerate_support.

lambda idx, support: support[idx],
z_discrete,
self._support_enumerates,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing this might return in-support values but I worry that the algorithms are wrong. To compute potential energy correctly in the algorithm, we need to work with in-support values. I think you can pass support_enumerates into self._discrete_proposal_fn and change the proposal logic there.

    proposal = random.randint(rng_proposal, (), minval=0, maxval=support_size)
    # z_new_flat = z_discrete_flat.at[idx].set(proposal)
    z_new_flat = z_discrete_flat.at[idx].set(support_enumerate[proposal])

or for modified rw proposal

    i = random.randint(rng_proposal, (), minval=0, maxval=support_size - 1)
    # proposal = jnp.where(i >= z_discrete_flat[idx], i + 1, i)
    # proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, proposal)
    proposal_index = jnp.where(support_size[i] == z_discrete_flat[idx], support_size - 1, i)
    proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, support_size[proposal_index])
    z_new_flat = z_discrete_flat.at[idx].set(proposal)

or at discrete gibbs proposal

    proposal_index = jnp.where(support_enumerate[i] == z_init_flat[idx], support_size - 1, i)
    z_new_flat = z_init_flat.at[idx].set(support_enumerate[proposal_index])

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, thank you for the feedback. I will try this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi how do you debug in numpyro? I tried jax.debug. but nothing happens.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use print most of the time. When actual values are needed, I sometimes use jax.disable_jit()

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi I have issues with passing enumerate supports and traced values as the support arrays can have different sizes. I was thinking maybe to just pass the "lower bound of the support" as offset and combined with support_sizes it should make the trick. Are there discrete variables where the support is not a simple discrete range with step 1 between values?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for modified_rw_proposal I think you used support_size in place of support_enumerate, shouldn't it be:

    i = random.randint(rng_proposal, (), minval=0, maxval=support_size - 1)
    # proposal = jnp.where(i >= z_discrete_flat[idx], i + 1, i)
    # proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, proposal)
    proposal_index = jnp.where(support_enumerate[i] == z_discrete_flat[idx], support_size - 1, i)
    proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, support_enumerate[proposal_index])
    z_new_flat = z_discrete_flat.at[idx].set(proposal)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! your solutions are super cool! I haven't thought of different support sizes previously.

self._support_enumerates = np.zeros(
(len(self._support_sizes), max_length_support_enumerates), dtype=int
)
for i, (name, site) in enumerate(self._prototype_trace.items()):
Copy link
Member

@fehiepsi fehiepsi Jul 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great solution! I just have a couple of comments:

  • it might be better to loop over names in support_sizes and get site via site = self._prototype_trace[name]
  • we use ravel_pytree to flatten support_sizes. so we might want to keep the same behavior here. I don't have a great solution for this, maybe
support_enumerates = {}
for name, support_size in self._support_sizes.items():
    site = self._prototype_trace[name]
    enumerate_support = site["fn"].enumerate_support(False)
    padded_enumerate_support = np.pad(enumerate_support, (0, max_length_support_enumerates - enumerate_support.shape[0]))
    padded_enumerate_support = np.broadcast_to(padded_enumerate_support, support_size.shape + (max_length_support_enumerates,))
    support_enumerates[name] = padded_enumerate_support

self._support_enumerates = jax.vmap(lambda x: ravel_pytree(x)[0], in_axes=1, out_axes=1)(support_enumerates)

@Deathn0t
Copy link
Author

@fehiepsi it worked fine with ravel_pytree as well, I just had to adapt the in_axes=0.

@fehiepsi
Copy link
Member

I think we need to ravel along the first axis. The second axis (corresponds to max_length_support_enumerates) is the batch dimension. The current code might run but I guess things are mixed up.

for site in self._prototype_trace.values()
if site["type"] == "sample"
and site["fn"].has_enumerate_support
and not site["is_observed"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it is better to loop over support_sizes: for name, site in self._prototype_trace.items() if name in support_sizes

@Deathn0t
Copy link
Author

I think we need to ravel along the first axis. The second axis (corresponds to max_length_support_enumerates) is the batch dimension. The current code might run but I guess things are mixed up.

the first axis is in_axes=0?

@fehiepsi
Copy link
Member

fehiepsi commented Jul 29, 2024

we vmap over the batch axis, which is the second axis, i.e. in_axes=1

@fehiepsi
Copy link
Member

Could you also add a simple test (as in the issue) for this? you can run make lint and make format to fix lint issues.

@Deathn0t
Copy link
Author

Deathn0t commented Jul 30, 2024

I applied the lint/format and I added a test.

we vmap over the batch axis, which is the second axis, i.e. in_axes=1

ok, but the support_size values have shape ().

So the following line:

support_size.shape + (max_length_support_enumerates,),

is just equivalent to (max_length_support_enumerates,), and therefore in_axes=1 fails. What format should the enumerate_size and enumerate_support have?

Maybe you have an example where enumerate_size values have a shape different than a scalar ()? I can't think of one.

@fehiepsi
Copy link
Member

fehiepsi commented Jul 30, 2024

That is a good point. I thought support sizes contain flatten arrays. Sorry for the confusion. I guess we need to move the enumerate dimension to the first axis before vmapping like you did

support_enumerates[name] = np.moveaxis(padded_enumerate_support, -1, 0)

@Deathn0t
Copy link
Author

Deathn0t commented Jul 30, 2024

I tried the following direction:

        max_length_support_enumerates = np.max(
            [size for size in self._support_sizes.values()]
        )

        support_enumerates = {}
        for name, support_size in self._support_sizes.items():
            site = self._prototype_trace[name]
            enumerate_support = site["fn"].enumerate_support(True).T
            # Only the last dimension that corresponds to support size is padded
            pad_width = [(0, 0) for _ in range(len(enumerate_support.shape) - 1)] + [
                (0, max_length_support_enumerates - enumerate_support.shape[-1])
            ]
            padded_enumerate_support = np.pad(enumerate_support, pad_width)

            support_enumerates[name] = padded_enumerate_support

        self._support_enumerates = jax.vmap(
            lambda x: ravel_pytree(x)[0], in_axes=len(support_size.shape), out_axes=1
        )(support_enumerates)

which work with the following cases:

def model_1():
    numpyro.sample("x0", dist.DiscreteUniform(10, 12))
    numpyro.sample("x1", dist.Categorical(np.asarray([0.25, 0.25, 0.25, 0.25])))

def model_2():
    numpyro.sample("x0", dist.Categorical(0.25 * jnp.ones((4,))))
    numpyro.sample("x1", dist.Categorical(0.1 * jnp.ones((10,))))

def model_3():
    numpyro.sample("x0", dist.Categorical(0.25 * jnp.ones((3, 4))))
    numpyro.sample("x1", dist.Categorical(0.1 * jnp.ones((3, 10))))

But fails when I try to batch DiscreteUniform:

def model_4():
    numpyro.sample("x1", dist.DiscreteUniform(10 * jnp.ones((3,)), 19 * jnp.ones((3,))))

with the following exception which comes before the code I added (when the self._support_sizes is created):

Traceback (most recent call last):
  File "/Users/romainegele/Documents/Argonne/numpyro/test/test_distributions.py", line 3512, in <module>
    test_discrete_uniform_with_mixedhmc()
  File "/Users/romainegele/Documents/Argonne/numpyro/test/test_distributions.py", line 3501, in test_discrete_uniform_with_mixedhmc
    samples = sample_mixedhmc(model_4, num_samples, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/romainegele/Documents/Argonne/numpyro/test/test_distributions.py", line 3438, in sample_mixedhmc
    mcmc.run(key)
  File "/Users/romainegele/Documents/Argonne/numpyro/numpyro/infer/mcmc.py", line 682, in run
    states_flat, last_state = partial_map_fn(map_args)
                              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/romainegele/Documents/Argonne/numpyro/numpyro/infer/mcmc.py", line 443, in _single_chain_mcmc
    new_init_state = self.sampler.init(
                     ^^^^^^^^^^^^^^^^^^
  File "/Users/romainegele/Documents/Argonne/numpyro/numpyro/infer/mixed_hmc.py", line 88, in init
    state = super().init(rng_key, num_warmup, init_params, model_args, model_kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/romainegele/Documents/Argonne/numpyro/numpyro/infer/hmc_gibbs.py", line 467, in init
    site["fn"].enumerate_support(False).shape[0], jnp.shape(site["value"])
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/romainegele/Documents/Argonne/numpyro/numpyro/distributions/discrete.py", line 472, in enumerate_support
    values = (self.low + jnp.arange(np.amax(self.high - self.low) + 1)).reshape(
              ~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  File "/Users/romainegele/miniforge3/envs/dh-3.12-240724/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py", line 265, in deferring_binary_op
    return binary_op(*args)
           ^^^^^^^^^^^^^^^^
  File "/Users/romainegele/miniforge3/envs/dh-3.12-240724/lib/python3.12/site-packages/jax/_src/numpy/ufuncs.py", line 102, in fn
    return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)
           ^^^^^^^^^^^^^^
TypeError: add got incompatible shapes for broadcasting: (3,), (10,).

@fehiepsi
Copy link
Member

The in_axes=len(support_size.shape) might not be the same across different latent variables. I think you can move the batch dimension to the front like in my last comment.

By the way, maybe we need to use size.reshape(-1)[0] instead of size

max_length_support_enumerates = np.max(
            [size for size in self._support_sizes.values()]
        )

@fehiepsi
Copy link
Member

Hmm, there seems to have a bug at DiscreteUniform.enumerate_support. self.low should be jnp.reshape(self.low, -1)[0]

@fehiepsi
Copy link
Member

fehiepsi commented Sep 5, 2024

@Deathn0t The fix is in #1859. Could you test whether the change works now?

@Deathn0t
Copy link
Author

Deathn0t commented Sep 5, 2024

@fehiepsi sorry for the delay... other things happened I couldn't follow up. Yes, let me test this now!

@Deathn0t
Copy link
Author

Deathn0t commented Sep 5, 2024

@fehiepsi the 4 cases I put in the test are now passing, assuming changes from #1859 are used!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants