Skip to content

Commit

Permalink
Backport PR #2395: Fix repr for custom distributions, add optional co…
Browse files Browse the repository at this point in the history
…nstraint (#2398)

Co-authored-by: Martin Kim <[email protected]>
  • Loading branch information
meeseeksmachine and martinkim0 authored Jan 17, 2024
1 parent fc75cfe commit b3c536a
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 48 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,4 @@ ENV/
poetry.lock
.mypy_cache/
.ruff_cache/
.vscode/
6 changes: 6 additions & 0 deletions docs/release_notes/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/
- Fix bug in {class}`scvi.external.GIMVI` where `batch_size` was not
properly used in inference methods {pr}`2366`.
- Fix error message formatting in {meth}`scvi.data.fields.LayerField.transfer_field` {pr}`2368`.
- Fix ambiguous error raised in {meth}`scvi.distributions.NegativeBinomial.log_prob` and
{meth}`scvi.distributions.ZeroInflatedNegativeBinomial.log_prob` when `scale` not passed in
and value not in support {pr}`2395`.
- Fix initialization of {class}`scvi.distributions.NegativeBinomial` and
{class}`scvi.distributions.ZeroInflatedNegativeBinomial` when `validate_args=True` and
optional parameters not passed in {pr}`2395`.

#### Changed

Expand Down
20 changes: 20 additions & 0 deletions scvi/distributions/_constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
from torch.distributions.constraints import Constraint


class _Optional(Constraint):
def __init__(self, constraint: Constraint):
self.constraint = constraint

def check(self, value: torch.Tensor) -> torch.Tensor:
if value is None:
return torch.ones(1, dtype=torch.bool)
return self.constraint.check(value)

def __repr__(self) -> str:
return f"Optional({self.constraint})"


def optional_constraint(constraint: Constraint) -> Constraint:
"""Returns a wrapped constraint that allows optional values."""
return _Optional(constraint)
142 changes: 94 additions & 48 deletions scvi/distributions/_negative_binomial.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from typing import Optional, Union

import jax
import jax.numpy as jnp
Expand All @@ -19,10 +20,16 @@

from scvi import settings

from ._constraints import optional_constraint


def log_zinb_positive(
x: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, pi: torch.Tensor, eps=1e-8
):
x: torch.Tensor,
mu: torch.Tensor,
theta: torch.Tensor,
pi: torch.Tensor,
eps: float = 1e-8,
) -> torch.Tensor:
"""Log likelihood (scalar) of a minibatch according to a zinb model.
Parameters
Expand Down Expand Up @@ -73,13 +80,13 @@ def log_zinb_positive(


def log_nb_positive(
x: Union[torch.Tensor, jnp.ndarray],
mu: Union[torch.Tensor, jnp.ndarray],
theta: Union[torch.Tensor, jnp.ndarray],
x: torch.Tensor | jnp.ndarray,
mu: torch.Tensor | jnp.ndarray,
theta: torch.Tensor | jnp.ndarray,
eps: float = 1e-8,
log_fn: callable = torch.log,
lgamma_fn: callable = torch.lgamma,
):
) -> torch.Tensor | jnp.ndarray:
"""Log likelihood (scalar) of a minibatch according to a nb model.
Parameters
Expand Down Expand Up @@ -118,8 +125,8 @@ def log_mixture_nb(
theta_1: torch.Tensor,
theta_2: torch.Tensor,
pi_logits: torch.Tensor,
eps=1e-8,
):
eps: float = 1e-8,
) -> torch.Tensor:
"""Log likelihood (scalar) of a minibatch according to a mixture nb model.
pi_logits is the probability (logits) to be in the first component.
Expand Down Expand Up @@ -183,7 +190,11 @@ def log_mixture_nb(
return log_mixture_nb


def _convert_mean_disp_to_counts_logits(mu, theta, eps=1e-6):
def _convert_mean_disp_to_counts_logits(
mu: torch.Tensor,
theta: torch.Tensor,
eps: float = 1e-6,
) -> tuple[torch.Tensor, torch.Tensor]:
r"""NB parameterizations conversion.
Parameters
Expand All @@ -210,7 +221,9 @@ def _convert_mean_disp_to_counts_logits(mu, theta, eps=1e-6):
return total_count, logits


def _convert_counts_logits_to_mean_disp(total_count, logits):
def _convert_counts_logits_to_mean_disp(
total_count: torch.Tensor, logits: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""NB parameterizations conversion.
Parameters
Expand All @@ -231,7 +244,7 @@ def _convert_counts_logits_to_mean_disp(total_count, logits):
return mu, theta


def _gamma(theta, mu):
def _gamma(theta: torch.Tensor, mu: torch.Tensor) -> Gamma:
concentration = theta
rate = theta / mu
# Important remark: Gamma is parametrized by the rate = 1/scale!
Expand All @@ -258,12 +271,23 @@ class Poisson(PoissonTorch):
def __init__(
self,
rate: torch.Tensor,
validate_args: Optional[bool] = None,
scale: Optional[torch.Tensor] = None,
validate_args: bool | None = None,
scale: torch.Tensor = None,
):
super().__init__(rate=rate, validate_args=validate_args)
self.scale = scale

def __repr__(self) -> str:
param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__]
args_string = ", ".join(
[
f"{p}: {self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].size()}"
for p in param_names
if self.__dict__[p] is not None
]
)
return self.__class__.__name__ + "(" + args_string + ")"


class NegativeBinomial(Distribution):
r"""Negative binomial distribution.
Expand Down Expand Up @@ -297,20 +321,20 @@ class NegativeBinomial(Distribution):
"""

arg_constraints = {
"mu": constraints.greater_than_eq(0),
"theta": constraints.greater_than_eq(0),
"scale": constraints.greater_than_eq(0),
"mu": optional_constraint(constraints.greater_than_eq(0)),
"theta": optional_constraint(constraints.greater_than_eq(0)),
"scale": optional_constraint(constraints.greater_than_eq(0)),
}
support = constraints.nonnegative_integer

def __init__(
self,
total_count: Optional[torch.Tensor] = None,
probs: Optional[torch.Tensor] = None,
logits: Optional[torch.Tensor] = None,
mu: Optional[torch.Tensor] = None,
theta: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
total_count: torch.Tensor | None = None,
probs: torch.Tensor | None = None,
logits: torch.Tensor | None = None,
mu: torch.Tensor | None = None,
theta: torch.Tensor | None = None,
scale: torch.Tensor | None = None,
validate_args: bool = False,
):
self._eps = 1e-8
Expand All @@ -335,17 +359,17 @@ def __init__(
super().__init__(validate_args=validate_args)

@property
def mean(self):
def mean(self) -> torch.Tensor:
return self.mu

@property
def variance(self):
def variance(self) -> torch.Tensor:
return self.mean + (self.mean**2) / self.theta

@torch.inference_mode()
def sample(
self,
sample_shape: Optional[Union[torch.Size, tuple]] = None,
sample_shape: torch.Size | tuple | None = None,
) -> torch.Tensor:
"""Sample from the distribution."""
sample_shape = sample_shape or torch.Size()
Expand Down Expand Up @@ -373,9 +397,20 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor:

return log_nb_positive(value, mu=self.mu, theta=self.theta, eps=self._eps)

def _gamma(self):
def _gamma(self) -> Gamma:
return _gamma(self.theta, self.mu)

def __repr__(self) -> str:
param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__]
args_string = ", ".join(
[
f"{p}: {self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].size()}"
for p in param_names
if self.__dict__[p] is not None
]
)
return self.__class__.__name__ + "(" + args_string + ")"


class ZeroInflatedNegativeBinomial(NegativeBinomial):
r"""Zero-inflated negative binomial distribution.
Expand Down Expand Up @@ -411,22 +446,22 @@ class ZeroInflatedNegativeBinomial(NegativeBinomial):
"""

arg_constraints = {
"mu": constraints.greater_than_eq(0),
"theta": constraints.greater_than_eq(0),
"zi_logits": constraints.real,
"scale": constraints.greater_than_eq(0),
"mu": optional_constraint(constraints.greater_than_eq(0)),
"theta": optional_constraint(constraints.greater_than_eq(0)),
"zi_logits": optional_constraint(constraints.real),
"scale": optional_constraint(constraints.greater_than_eq(0)),
}
support = constraints.nonnegative_integer

def __init__(
self,
total_count: Optional[torch.Tensor] = None,
probs: Optional[torch.Tensor] = None,
logits: Optional[torch.Tensor] = None,
mu: Optional[torch.Tensor] = None,
theta: Optional[torch.Tensor] = None,
zi_logits: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
total_count: torch.Tensor | None = None,
probs: torch.Tensor | None = None,
logits: torch.Tensor | None = None,
mu: torch.Tensor | None = None,
theta: torch.Tensor | None = None,
zi_logits: torch.Tensor | None = None,
scale: torch.Tensor | None = None,
validate_args: bool = False,
):
super().__init__(
Expand All @@ -443,12 +478,12 @@ def __init__(
)

@property
def mean(self):
def mean(self) -> torch.Tensor:
pi = self.zi_probs
return (1 - pi) * self.mu

@property
def variance(self):
def variance(self) -> None:
raise NotImplementedError

@lazy_property
Expand All @@ -463,7 +498,7 @@ def zi_probs(self) -> torch.Tensor:
@torch.inference_mode()
def sample(
self,
sample_shape: Optional[Union[torch.Size, tuple]] = None,
sample_shape: torch.Size | tuple | None = None,
) -> torch.Tensor:
"""Sample from the distribution."""
sample_shape = sample_shape or torch.Size()
Expand Down Expand Up @@ -522,7 +557,7 @@ def __init__(
mu2: torch.Tensor,
theta1: torch.Tensor,
mixture_logits: torch.Tensor,
theta2: Optional[torch.Tensor] = None,
theta2: torch.Tensor = None,
validate_args: bool = False,
):
(
Expand All @@ -540,7 +575,7 @@ def __init__(
self.theta2 = None

@property
def mean(self):
def mean(self) -> torch.Tensor:
pi = self.mixture_probs
return pi * self.mu1 + (1 - pi) * self.mu2

Expand All @@ -551,7 +586,7 @@ def mixture_probs(self) -> torch.Tensor:
@torch.inference_mode()
def sample(
self,
sample_shape: Optional[Union[torch.Size, tuple]] = None,
sample_shape: torch.Size | tuple | None = None,
) -> torch.Tensor:
"""Sample from the distribution."""
sample_shape = sample_shape or torch.Size()
Expand Down Expand Up @@ -593,6 +628,17 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
eps=1e-08,
)

def __repr__(self) -> str:
param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__]
args_string = ", ".join(
[
f"{p}: {self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].size()}"
for p in param_names
if self.__dict__[p] is not None
]
)
return self.__class__.__name__ + "(" + args_string + ")"


class JaxNegativeBinomialMeanDisp(dist.NegativeBinomial2):
"""Negative binomial parameterized by mean and inverse dispersion."""
Expand All @@ -607,23 +653,23 @@ def __init__(
self,
mean: jnp.ndarray,
inverse_dispersion: jnp.ndarray,
validate_args: Optional[bool] = None,
validate_args: bool | None = None,
eps: float = 1e-8,
):
self._inverse_dispersion, self._mean = promote_shapes(inverse_dispersion, mean)
self._eps = eps
super().__init__(mean, inverse_dispersion, validate_args=validate_args)

@property
def mean(self):
def mean(self) -> jnp.ndarray:
return self._mean

@property
def inverse_dispersion(self):
def inverse_dispersion(self) -> jnp.ndarray:
return self._inverse_dispersion

@validate_sample
def log_prob(self, value):
def log_prob(self, value) -> jnp.ndarray:
"""Log probability."""
# theta is inverse_dispersion
theta = self._inverse_dispersion
Expand Down
15 changes: 15 additions & 0 deletions tests/distributions/test_negative_binomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,18 @@ def test_zinb_distribution():
dist1.log_prob(-x) # ensures neg values raise warning
with pytest.warns(UserWarning):
dist2.log_prob(0.5 * x) # ensures float values raise warning

# test with no scale
dist1 = ZeroInflatedNegativeBinomial(
mu=mu, theta=theta, zi_logits=pi, validate_args=True
)
dist2 = NegativeBinomial(mu=mu, theta=theta, validate_args=True)
dist1.__repr__()
dist2.__repr__()
assert dist1.log_prob(x).shape == size
assert dist2.log_prob(x).shape == size

with pytest.warns(UserWarning):
dist1.log_prob(-x)
with pytest.warns(UserWarning):
dist2.log_prob(0.5 * x)

0 comments on commit b3c536a

Please sign in to comment.