diff --git a/CHANGELOG.md b/CHANGELOG.md index bb9404268a3d..ba78f94a00c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. {func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support transforms in more than 3 dimensions, which was previously the limit. See {jax-issue}`#25606` for more details. + * Added {func}`jax.random.multinomial`. * Deprecations * From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings` diff --git a/docs/jax.random.rst b/docs/jax.random.rst index 6c5427c05e66..837037a492a6 100644 --- a/docs/jax.random.rst +++ b/docs/jax.random.rst @@ -53,6 +53,7 @@ Random Samplers logistic lognormal maxwell + multinomial multivariate_normal normal orthogonal diff --git a/jax/_src/random.py b/jax/_src/random.py index 12aa5b93efbf..d1cf3a0ed649 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -2627,6 +2627,56 @@ def binomial( batching.defvectorized(random_clone_p) mlir.register_lowering(random_clone_p, lambda _, k: [k]) + +def multinomial( + key: Array, + n: RealArray, + p: RealArray, + axis: int = -1, +): + r"""Sample from a multinomial distribution. + + The probability mass function is + + .. math:: + f(x;n,p) = \frac{n!}{x_1! \ldots x_k!} p_1^{x_1} \ldots p_k^{x_k} + + Args: + key: a PRNG key used as the random key. + n: a float array-like representing the number of trials. + p: a float array-like representing the probabilities of each outcome. + axis: axis along which probabilities are defined for each outcome. + + Returns: + An array of counts for each outcome. + """ + + key, _ = _check_prng_key("multinomial", key) + check_arraylike("multinomial", n, p) + + def f(remaining, p_rho_key): + p, rho, key = p_rho_key + count = binomial(key, remaining, p / rho) + count = jnp.where(rho == 0, 0, count) + return remaining - count, count + + p = jnp.moveaxis(p, axis, 0) + + rhos = jnp.flip(jnp.cumsum(jnp.flip(p, 0), 0), 0) + + keys = split(key, p.shape[0]) + + shape = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p)[1:]) + n = jnp.broadcast_to(n, shape) + p = jnp.broadcast_to(p, (jnp.shape(p)[0],) + shape) + + remaining, counts = lax.scan(f, n, (p, rhos, keys), unroll=True) + + counts = jnp.moveaxis(counts, 0, axis) + + return counts + + def clone(key): """Clone a key for reuse diff --git a/jax/random.py b/jax/random.py index b99cd531f18c..f83ad808cd77 100644 --- a/jax/random.py +++ b/jax/random.py @@ -232,6 +232,7 @@ loggamma as loggamma, lognormal as lognormal, maxwell as maxwell, + multinomial as multinomial, multivariate_normal as multivariate_normal, normal as normal, orthogonal as orthogonal, diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index af7a3fb7a3b7..469de0175568 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -1243,6 +1243,14 @@ def testBinomialCornerCases(self): self.assertArraysAllClose(samples2, jnp.array([jnp.nan, 0., jnp.nan, jnp.nan]), check_dtypes=False) self.assertArraysAllClose(samples3, jnp.array([jnp.nan, jnp.nan, jnp.nan]), check_dtypes=False) + def testMultinomial(self): + key = random.key(0) + probs = jnp.array([[0.5, 0.2, 0.3], [0.1, 0.2, 0.7], [0.0, 1.0, 0.0]]) + trials = jnp.array(10**8).astype(float) + counts = random.multinomial(key, trials, probs) + freqs = counts / trials + self.assertAllClose(freqs, probs, atol=1e-3) + def test_batched_key_errors(self): keys = lambda: jax.random.split(self.make_key(0)) msg = "{} accepts a single key, but was given a key array of shape.*"