diff --git a/torchcfm/conditional_flow_matching.py b/torchcfm/conditional_flow_matching.py index 3e518ef..c93c4e2 100644 --- a/torchcfm/conditional_flow_matching.py +++ b/torchcfm/conditional_flow_matching.py @@ -6,6 +6,7 @@ # License: MIT License import math +from typing import Union import torch @@ -31,7 +32,7 @@ def pad_t_like_x(t, x): t: Vector (bs) pad_t_like_x(t, x): Tensor (bs, 1, 1, 1) """ - if isinstance(t, float): + if isinstance(t, (float, int)): return t return t.reshape(-1, *([1] * (x.dim() - 1))) @@ -47,12 +48,12 @@ class ConditionalFlowMatcher: - score function $\nabla log p_t(x|x0, x1)$ """ - def __init__(self, sigma: float = 0.0): + def __init__(self, sigma: Union[float, int] = 0.0): r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$. Parameters ---------- - sigma : float + sigma : Union[float, int] """ self.sigma = sigma @@ -215,15 +216,15 @@ class ExactOptimalTransportConditionalFlowMatcher(ConditionalFlowMatcher): It overrides the sample_location_and_conditional_flow. """ - def __init__(self, sigma: float = 0.0): + def __init__(self, sigma: Union[float, int] = 0.0): r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$. Parameters ---------- - sigma : float + sigma : Union[float, int] ot_sampler: exact OT method to draw couplings (x0, x1) (see Eq.(17) [1]). """ - self.sigma = sigma + super().__init__(sigma) self.ot_sampler = OTPlanSampler(method="exact") def sample_location_and_conditional_flow(self, x0, x1, return_noise=False): @@ -382,16 +383,16 @@ class SchrodingerBridgeConditionalFlowMatcher(ConditionalFlowMatcher): sample_location_and_conditional_flow functions. """ - def __init__(self, sigma: float = 1.0, ot_method="exact"): + def __init__(self, sigma: Union[float, int] = 1.0, ot_method="exact"): r"""Initialize the SchrodingerBridgeConditionalFlowMatcher class. It requires the hyper- parameter $\sigma$ and the entropic OT map. Parameters ---------- - sigma : float + sigma : Union[float, int] ot_sampler: exact OT method to draw couplings (x0, x1) (see Eq.(17) [1]). """ - self.sigma = sigma + super().__init__(sigma) self.ot_method = ot_method self.ot_sampler = OTPlanSampler(method=ot_method, reg=2 * self.sigma**2) diff --git a/torchcfm/optimal_transport.py b/torchcfm/optimal_transport.py index 4f563ba..ebb54d1 100644 --- a/torchcfm/optimal_transport.py +++ b/torchcfm/optimal_transport.py @@ -19,6 +19,21 @@ def __init__( normalize_cost=False, **kwargs, ): + r"""Initialize the OTPlanSampler class. + + Parameters + ---------- + method : str + The method used to compute the OT plan. Can be one of "exact", "sinkhorn", + "unbalanced", or "partial". + reg : float (default : 0.05) + Entropic regularization coefficients. + reg_m : float (default : 1.0) + Marginal relaxation term for unbalanced OT (`method='unbalanced'`). + normalize_cost : bool (default : False) + Whether to normalize the cost matrix by its maximum value. + It should be set to `False` when using minibatches. + """ # ot_fn should take (a, b, M) as arguments where a, b are marginals and # M is a cost matrix if method == "exact":