Skip to content

Commit

Permalink
sigma type & doc (#72)
Browse files Browse the repository at this point in the history
* docstrings OTPlanSampler

* ensure sigma type to be float
  • Loading branch information
guillaumehu authored Nov 14, 2023
1 parent 21cd0c8 commit 7cb209d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
19 changes: 10 additions & 9 deletions torchcfm/conditional_flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# License: MIT License

import math
from typing import Union

import torch

Expand All @@ -31,7 +32,7 @@ def pad_t_like_x(t, x):
t: Vector (bs)
pad_t_like_x(t, x): Tensor (bs, 1, 1, 1)
"""
if isinstance(t, float):
if isinstance(t, (float, int)):
return t
return t.reshape(-1, *([1] * (x.dim() - 1)))

Expand All @@ -47,12 +48,12 @@ class ConditionalFlowMatcher:
- score function $\nabla log p_t(x|x0, x1)$
"""

def __init__(self, sigma: float = 0.0):
def __init__(self, sigma: Union[float, int] = 0.0):
r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$.
Parameters
----------
sigma : float
sigma : Union[float, int]
"""
self.sigma = sigma

Expand Down Expand Up @@ -215,15 +216,15 @@ class ExactOptimalTransportConditionalFlowMatcher(ConditionalFlowMatcher):
It overrides the sample_location_and_conditional_flow.
"""

def __init__(self, sigma: float = 0.0):
def __init__(self, sigma: Union[float, int] = 0.0):
r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$.
Parameters
----------
sigma : float
sigma : Union[float, int]
ot_sampler: exact OT method to draw couplings (x0, x1) (see Eq.(17) [1]).
"""
self.sigma = sigma
super().__init__(sigma)
self.ot_sampler = OTPlanSampler(method="exact")

def sample_location_and_conditional_flow(self, x0, x1, return_noise=False):
Expand Down Expand Up @@ -382,16 +383,16 @@ class SchrodingerBridgeConditionalFlowMatcher(ConditionalFlowMatcher):
sample_location_and_conditional_flow functions.
"""

def __init__(self, sigma: float = 1.0, ot_method="exact"):
def __init__(self, sigma: Union[float, int] = 1.0, ot_method="exact"):
r"""Initialize the SchrodingerBridgeConditionalFlowMatcher class. It requires the hyper-
parameter $\sigma$ and the entropic OT map.
Parameters
----------
sigma : float
sigma : Union[float, int]
ot_sampler: exact OT method to draw couplings (x0, x1) (see Eq.(17) [1]).
"""
self.sigma = sigma
super().__init__(sigma)
self.ot_method = ot_method
self.ot_sampler = OTPlanSampler(method=ot_method, reg=2 * self.sigma**2)

Expand Down
15 changes: 15 additions & 0 deletions torchcfm/optimal_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@ def __init__(
normalize_cost=False,
**kwargs,
):
r"""Initialize the OTPlanSampler class.
Parameters
----------
method : str
The method used to compute the OT plan. Can be one of "exact", "sinkhorn",
"unbalanced", or "partial".
reg : float (default : 0.05)
Entropic regularization coefficients.
reg_m : float (default : 1.0)
Marginal relaxation term for unbalanced OT (`method='unbalanced'`).
normalize_cost : bool (default : False)
Whether to normalize the cost matrix by its maximum value.
It should be set to `False` when using minibatches.
"""
# ot_fn should take (a, b, M) as arguments where a, b are marginals and
# M is a cost matrix
if method == "exact":
Expand Down

0 comments on commit 7cb209d

Please sign in to comment.