diff --git a/bayesflow/distributions/diagonal_student_t.py b/bayesflow/distributions/diagonal_student_t.py index 8c6a3a7ee..a02798c4d 100644 --- a/bayesflow/distributions/diagonal_student_t.py +++ b/bayesflow/distributions/diagonal_student_t.py @@ -4,9 +4,10 @@ import math import numpy as np -from scipy.stats import t as scipy_student_t from bayesflow.types import Shape, Tensor +from bayesflow.utils import expand_tile + from .distribution import Distribution @@ -20,6 +21,7 @@ def __init__( loc: int | float | np.ndarray | Tensor = 0.0, scale: int | float | np.ndarray | Tensor = 1.0, use_learnable_parameters: bool = False, + seed_generator: keras.random.SeedGenerator = None, **kwargs, ): super().__init__(**kwargs) @@ -33,6 +35,11 @@ def __init__( self.use_learnable_parameters = use_learnable_parameters + if seed_generator is None: + seed_generator = keras.random.SeedGenerator() + + self.seed_generator = seed_generator + def build(self, input_shape: Shape) -> None: self.dim = int(input_shape[-1]) @@ -78,9 +85,15 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor: return result def sample(self, batch_shape: Shape) -> Tensor: - # TODO: use reparameterization trick instead of scipy - # TODO: use the seed generator state - dist = scipy_student_t(df=self.df, loc=self.loc, scale=self.scale) - samples = dist.rvs(size=batch_shape + (self.dim,)) + # As of writing this code, keras does not support the chi-square distribution + # nor does it support a scale or rate parameter in Gamma. Hence, we use the relation: + # chi-square(df) = Gamma(shape = 0.5 * df, scale = 2) = Gamma(shape = 0.5 * df, scale = 1) * 2 + chi2_samples = keras.random.gamma(batch_shape, alpha=0.5 * self.df, seed=self.seed_generator) * 2.0 + + # The chi-quare samples need to be repeated across self.dim + # since for each element of batch_shape only one sample is created. + chi2_samples = expand_tile(chi2_samples, n=self.dim, axis=-1) + + normal_samples = keras.random.normal(batch_shape + (self.dim,), seed=self.seed_generator) - return keras.ops.convert_to_tensor(samples) + return self.loc + self.scale * normal_samples * keras.ops.sqrt(self.df / chi2_samples)