Skip to content

Commit

Permalink
Add jax.random.multinomial.
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgmartin committed Dec 27, 2024
1 parent 6dbda90 commit de5c506
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
{func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support
transforms in more than 3 dimensions, which was previously the limit. See
{jax-issue}`#25606` for more details.
* Added {func}`jax.random.multinomial`.

* Deprecations
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
Expand Down
1 change: 1 addition & 0 deletions docs/jax.random.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Random Samplers
logistic
lognormal
maxwell
multinomial
multivariate_normal
normal
orthogonal
Expand Down
51 changes: 51 additions & 0 deletions jax/_src/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2627,6 +2627,57 @@ def binomial(
batching.defvectorized(random_clone_p)
mlir.register_lowering(random_clone_p, lambda _, k: [k])


def multinomial(
key: Array,
n: RealArray,
p: RealArray,
axis: int = -1,
):
r"""Sample from a multinomial distribution.
The probability mass function is
.. math::
f(x;n,p) = \frac{n!}{x_1! \ldots x_k!} p_1^{x_1} \ldots p_k^{x_k}
Args:
key: a PRNG key used as the random key.
n: a float array-like representing the number of trials.
p: a float array-like representing the probabilities of each outcome.
axis: axis along which probabilities are defined for each outcome.
Returns:
An array of counts for each outcome.
"""

key, _ = _check_prng_key("multinomial", key)
check_arraylike("multinomial", n, p)

def f(remaining, p_r_key):
p, r, key = p_r_key
count = binomial(key, remaining, p / r)
count = jnp.where(r == 0, 0, count)
return remaining - count, count

p = jnp.moveaxis(p, axis, 0)

# remaining probability
r = jnp.flip(jnp.cumsum(jnp.flip(p, 0), 0), 0)

keys = split(key, jnp.shape(p)[0])

shape = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p)[1:])
n = jnp.broadcast_to(n, shape)
p = jnp.broadcast_to(p, (jnp.shape(p)[0],) + shape)

remaining, counts = lax.scan(f, n, (p, r, keys), unroll=True)

counts = jnp.moveaxis(counts, 0, axis)

return counts


def clone(key):
"""Clone a key for reuse
Expand Down
1 change: 1 addition & 0 deletions jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@
loggamma as loggamma,
lognormal as lognormal,
maxwell as maxwell,
multinomial as multinomial,
multivariate_normal as multivariate_normal,
normal as normal,
orthogonal as orthogonal,
Expand Down
8 changes: 8 additions & 0 deletions tests/random_lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,6 +1243,14 @@ def testBinomialCornerCases(self):
self.assertArraysAllClose(samples2, jnp.array([jnp.nan, 0., jnp.nan, jnp.nan]), check_dtypes=False)
self.assertArraysAllClose(samples3, jnp.array([jnp.nan, jnp.nan, jnp.nan]), check_dtypes=False)

def testMultinomial(self):
key = random.key(0)
probs = jnp.array([[0.5, 0.2, 0.3], [0.1, 0.2, 0.7], [0.0, 1.0, 0.0]])
trials = jnp.array(10**8).astype(float)
counts = random.multinomial(key, trials, probs)
freqs = counts / trials
self.assertAllClose(freqs, probs, atol=1e-3)

def test_batched_key_errors(self):
keys = lambda: jax.random.split(self.make_key(0))
msg = "{} accepts a single key, but was given a key array of shape.*"
Expand Down

0 comments on commit de5c506

Please sign in to comment.