From b3c536a647bfbc9a6acb398c741c428b69a5b84b Mon Sep 17 00:00:00 2001 From: "Lumberbot (aka Jack)" <39504233+meeseeksmachine@users.noreply.github.com> Date: Wed, 17 Jan 2024 23:56:27 +0100 Subject: [PATCH] Backport PR #2395: Fix repr for custom distributions, add optional constraint (#2398) Co-authored-by: Martin Kim <46072231+martinkim0@users.noreply.github.com> --- .gitignore | 1 + docs/release_notes/index.md | 6 + scvi/distributions/_constraints.py | 20 +++ scvi/distributions/_negative_binomial.py | 142 ++++++++++++------ tests/distributions/test_negative_binomial.py | 15 ++ 5 files changed, 136 insertions(+), 48 deletions(-) create mode 100644 scvi/distributions/_constraints.py diff --git a/.gitignore b/.gitignore index da15580f13..1fffe607a3 100644 --- a/.gitignore +++ b/.gitignore @@ -124,3 +124,4 @@ ENV/ poetry.lock .mypy_cache/ .ruff_cache/ +.vscode/ diff --git a/docs/release_notes/index.md b/docs/release_notes/index.md index 352fa90e5c..8638c843a8 100644 --- a/docs/release_notes/index.md +++ b/docs/release_notes/index.md @@ -66,6 +66,12 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/ - Fix bug in {class}`scvi.external.GIMVI` where `batch_size` was not properly used in inference methods {pr}`2366`. - Fix error message formatting in {meth}`scvi.data.fields.LayerField.transfer_field` {pr}`2368`. +- Fix ambiguous error raised in {meth}`scvi.distributions.NegativeBinomial.log_prob` and + {meth}`scvi.distributions.ZeroInflatedNegativeBinomial.log_prob` when `scale` not passed in + and value not in support {pr}`2395`. +- Fix initialization of {class}`scvi.distributions.NegativeBinomial` and + {class}`scvi.distributions.ZeroInflatedNegativeBinomial` when `validate_args=True` and + optional parameters not passed in {pr}`2395`. #### Changed diff --git a/scvi/distributions/_constraints.py b/scvi/distributions/_constraints.py new file mode 100644 index 0000000000..cdc5dab89d --- /dev/null +++ b/scvi/distributions/_constraints.py @@ -0,0 +1,20 @@ +import torch +from torch.distributions.constraints import Constraint + + +class _Optional(Constraint): + def __init__(self, constraint: Constraint): + self.constraint = constraint + + def check(self, value: torch.Tensor) -> torch.Tensor: + if value is None: + return torch.ones(1, dtype=torch.bool) + return self.constraint.check(value) + + def __repr__(self) -> str: + return f"Optional({self.constraint})" + + +def optional_constraint(constraint: Constraint) -> Constraint: + """Returns a wrapped constraint that allows optional values.""" + return _Optional(constraint) diff --git a/scvi/distributions/_negative_binomial.py b/scvi/distributions/_negative_binomial.py index 8108a68679..bfa3837b83 100644 --- a/scvi/distributions/_negative_binomial.py +++ b/scvi/distributions/_negative_binomial.py @@ -1,5 +1,6 @@ +from __future__ import annotations + import warnings -from typing import Optional, Union import jax import jax.numpy as jnp @@ -19,10 +20,16 @@ from scvi import settings +from ._constraints import optional_constraint + def log_zinb_positive( - x: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, pi: torch.Tensor, eps=1e-8 -): + x: torch.Tensor, + mu: torch.Tensor, + theta: torch.Tensor, + pi: torch.Tensor, + eps: float = 1e-8, +) -> torch.Tensor: """Log likelihood (scalar) of a minibatch according to a zinb model. Parameters @@ -73,13 +80,13 @@ def log_zinb_positive( def log_nb_positive( - x: Union[torch.Tensor, jnp.ndarray], - mu: Union[torch.Tensor, jnp.ndarray], - theta: Union[torch.Tensor, jnp.ndarray], + x: torch.Tensor | jnp.ndarray, + mu: torch.Tensor | jnp.ndarray, + theta: torch.Tensor | jnp.ndarray, eps: float = 1e-8, log_fn: callable = torch.log, lgamma_fn: callable = torch.lgamma, -): +) -> torch.Tensor | jnp.ndarray: """Log likelihood (scalar) of a minibatch according to a nb model. Parameters @@ -118,8 +125,8 @@ def log_mixture_nb( theta_1: torch.Tensor, theta_2: torch.Tensor, pi_logits: torch.Tensor, - eps=1e-8, -): + eps: float = 1e-8, +) -> torch.Tensor: """Log likelihood (scalar) of a minibatch according to a mixture nb model. pi_logits is the probability (logits) to be in the first component. @@ -183,7 +190,11 @@ def log_mixture_nb( return log_mixture_nb -def _convert_mean_disp_to_counts_logits(mu, theta, eps=1e-6): +def _convert_mean_disp_to_counts_logits( + mu: torch.Tensor, + theta: torch.Tensor, + eps: float = 1e-6, +) -> tuple[torch.Tensor, torch.Tensor]: r"""NB parameterizations conversion. Parameters @@ -210,7 +221,9 @@ def _convert_mean_disp_to_counts_logits(mu, theta, eps=1e-6): return total_count, logits -def _convert_counts_logits_to_mean_disp(total_count, logits): +def _convert_counts_logits_to_mean_disp( + total_count: torch.Tensor, logits: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: """NB parameterizations conversion. Parameters @@ -231,7 +244,7 @@ def _convert_counts_logits_to_mean_disp(total_count, logits): return mu, theta -def _gamma(theta, mu): +def _gamma(theta: torch.Tensor, mu: torch.Tensor) -> Gamma: concentration = theta rate = theta / mu # Important remark: Gamma is parametrized by the rate = 1/scale! @@ -258,12 +271,23 @@ class Poisson(PoissonTorch): def __init__( self, rate: torch.Tensor, - validate_args: Optional[bool] = None, - scale: Optional[torch.Tensor] = None, + validate_args: bool | None = None, + scale: torch.Tensor = None, ): super().__init__(rate=rate, validate_args=validate_args) self.scale = scale + def __repr__(self) -> str: + param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__] + args_string = ", ".join( + [ + f"{p}: {self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].size()}" + for p in param_names + if self.__dict__[p] is not None + ] + ) + return self.__class__.__name__ + "(" + args_string + ")" + class NegativeBinomial(Distribution): r"""Negative binomial distribution. @@ -297,20 +321,20 @@ class NegativeBinomial(Distribution): """ arg_constraints = { - "mu": constraints.greater_than_eq(0), - "theta": constraints.greater_than_eq(0), - "scale": constraints.greater_than_eq(0), + "mu": optional_constraint(constraints.greater_than_eq(0)), + "theta": optional_constraint(constraints.greater_than_eq(0)), + "scale": optional_constraint(constraints.greater_than_eq(0)), } support = constraints.nonnegative_integer def __init__( self, - total_count: Optional[torch.Tensor] = None, - probs: Optional[torch.Tensor] = None, - logits: Optional[torch.Tensor] = None, - mu: Optional[torch.Tensor] = None, - theta: Optional[torch.Tensor] = None, - scale: Optional[torch.Tensor] = None, + total_count: torch.Tensor | None = None, + probs: torch.Tensor | None = None, + logits: torch.Tensor | None = None, + mu: torch.Tensor | None = None, + theta: torch.Tensor | None = None, + scale: torch.Tensor | None = None, validate_args: bool = False, ): self._eps = 1e-8 @@ -335,17 +359,17 @@ def __init__( super().__init__(validate_args=validate_args) @property - def mean(self): + def mean(self) -> torch.Tensor: return self.mu @property - def variance(self): + def variance(self) -> torch.Tensor: return self.mean + (self.mean**2) / self.theta @torch.inference_mode() def sample( self, - sample_shape: Optional[Union[torch.Size, tuple]] = None, + sample_shape: torch.Size | tuple | None = None, ) -> torch.Tensor: """Sample from the distribution.""" sample_shape = sample_shape or torch.Size() @@ -373,9 +397,20 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor: return log_nb_positive(value, mu=self.mu, theta=self.theta, eps=self._eps) - def _gamma(self): + def _gamma(self) -> Gamma: return _gamma(self.theta, self.mu) + def __repr__(self) -> str: + param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__] + args_string = ", ".join( + [ + f"{p}: {self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].size()}" + for p in param_names + if self.__dict__[p] is not None + ] + ) + return self.__class__.__name__ + "(" + args_string + ")" + class ZeroInflatedNegativeBinomial(NegativeBinomial): r"""Zero-inflated negative binomial distribution. @@ -411,22 +446,22 @@ class ZeroInflatedNegativeBinomial(NegativeBinomial): """ arg_constraints = { - "mu": constraints.greater_than_eq(0), - "theta": constraints.greater_than_eq(0), - "zi_logits": constraints.real, - "scale": constraints.greater_than_eq(0), + "mu": optional_constraint(constraints.greater_than_eq(0)), + "theta": optional_constraint(constraints.greater_than_eq(0)), + "zi_logits": optional_constraint(constraints.real), + "scale": optional_constraint(constraints.greater_than_eq(0)), } support = constraints.nonnegative_integer def __init__( self, - total_count: Optional[torch.Tensor] = None, - probs: Optional[torch.Tensor] = None, - logits: Optional[torch.Tensor] = None, - mu: Optional[torch.Tensor] = None, - theta: Optional[torch.Tensor] = None, - zi_logits: Optional[torch.Tensor] = None, - scale: Optional[torch.Tensor] = None, + total_count: torch.Tensor | None = None, + probs: torch.Tensor | None = None, + logits: torch.Tensor | None = None, + mu: torch.Tensor | None = None, + theta: torch.Tensor | None = None, + zi_logits: torch.Tensor | None = None, + scale: torch.Tensor | None = None, validate_args: bool = False, ): super().__init__( @@ -443,12 +478,12 @@ def __init__( ) @property - def mean(self): + def mean(self) -> torch.Tensor: pi = self.zi_probs return (1 - pi) * self.mu @property - def variance(self): + def variance(self) -> None: raise NotImplementedError @lazy_property @@ -463,7 +498,7 @@ def zi_probs(self) -> torch.Tensor: @torch.inference_mode() def sample( self, - sample_shape: Optional[Union[torch.Size, tuple]] = None, + sample_shape: torch.Size | tuple | None = None, ) -> torch.Tensor: """Sample from the distribution.""" sample_shape = sample_shape or torch.Size() @@ -522,7 +557,7 @@ def __init__( mu2: torch.Tensor, theta1: torch.Tensor, mixture_logits: torch.Tensor, - theta2: Optional[torch.Tensor] = None, + theta2: torch.Tensor = None, validate_args: bool = False, ): ( @@ -540,7 +575,7 @@ def __init__( self.theta2 = None @property - def mean(self): + def mean(self) -> torch.Tensor: pi = self.mixture_probs return pi * self.mu1 + (1 - pi) * self.mu2 @@ -551,7 +586,7 @@ def mixture_probs(self) -> torch.Tensor: @torch.inference_mode() def sample( self, - sample_shape: Optional[Union[torch.Size, tuple]] = None, + sample_shape: torch.Size | tuple | None = None, ) -> torch.Tensor: """Sample from the distribution.""" sample_shape = sample_shape or torch.Size() @@ -593,6 +628,17 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor: eps=1e-08, ) + def __repr__(self) -> str: + param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__] + args_string = ", ".join( + [ + f"{p}: {self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].size()}" + for p in param_names + if self.__dict__[p] is not None + ] + ) + return self.__class__.__name__ + "(" + args_string + ")" + class JaxNegativeBinomialMeanDisp(dist.NegativeBinomial2): """Negative binomial parameterized by mean and inverse dispersion.""" @@ -607,7 +653,7 @@ def __init__( self, mean: jnp.ndarray, inverse_dispersion: jnp.ndarray, - validate_args: Optional[bool] = None, + validate_args: bool | None = None, eps: float = 1e-8, ): self._inverse_dispersion, self._mean = promote_shapes(inverse_dispersion, mean) @@ -615,15 +661,15 @@ def __init__( super().__init__(mean, inverse_dispersion, validate_args=validate_args) @property - def mean(self): + def mean(self) -> jnp.ndarray: return self._mean @property - def inverse_dispersion(self): + def inverse_dispersion(self) -> jnp.ndarray: return self._inverse_dispersion @validate_sample - def log_prob(self, value): + def log_prob(self, value) -> jnp.ndarray: """Log probability.""" # theta is inverse_dispersion theta = self._inverse_dispersion diff --git a/tests/distributions/test_negative_binomial.py b/tests/distributions/test_negative_binomial.py index 524816c086..e0a72146d3 100644 --- a/tests/distributions/test_negative_binomial.py +++ b/tests/distributions/test_negative_binomial.py @@ -69,3 +69,18 @@ def test_zinb_distribution(): dist1.log_prob(-x) # ensures neg values raise warning with pytest.warns(UserWarning): dist2.log_prob(0.5 * x) # ensures float values raise warning + + # test with no scale + dist1 = ZeroInflatedNegativeBinomial( + mu=mu, theta=theta, zi_logits=pi, validate_args=True + ) + dist2 = NegativeBinomial(mu=mu, theta=theta, validate_args=True) + dist1.__repr__() + dist2.__repr__() + assert dist1.log_prob(x).shape == size + assert dist2.log_prob(x).shape == size + + with pytest.warns(UserWarning): + dist1.log_prob(-x) + with pytest.warns(UserWarning): + dist2.log_prob(0.5 * x)