Skip to content

Commit

Permalink
update tests and implement str as initializer input
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen committed Dec 12, 2024
1 parent f06d4c6 commit 3ec3328
Show file tree
Hide file tree
Showing 14 changed files with 121 additions and 73 deletions.
7 changes: 5 additions & 2 deletions src/moscot/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@
Numeric_t = Union[int, float] # type of `time_key` arguments
Filter_t = Optional[Union[str, Mapping[str, Sequence[Any]]]] # type how to filter adata
Str_Dict_t = Optional[Union[str, Mapping[str, Sequence[Any]]]] # type for `cell_transition`
SinkhornInitializerTag_t = Literal["default", "gaussian", "sorting"]
LRInitializerTag_t = Literal["random", "rank2", "k-means", "generalized-k-means"]

SinkhornInitializer_t = Optional[Union[SinkhornInitializer, LRInitializer]]
LRInitializer_t = Optional[Union[LRInitializer, LRInitializerTag_t]]
SinkhornInitializer_t = Optional[Union[SinkhornInitializer, SinkhornInitializerTag_t]]
QuadInitializer_t = Optional[Union[BaseQuadraticInitializer]]

Initializer_t = Union[SinkhornInitializer_t, QuadInitializer_t]
Initializer_t = Union[SinkhornInitializer_t, QuadInitializer_t, LRInitializer_t]
ProblemStage_t = Literal["prepared", "solved"]
Device_t = Union[Literal["cpu", "gpu", "tpu"], str]

Expand Down
88 changes: 88 additions & 0 deletions src/moscot/backends/ott/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import numpy as np
import scipy.sparse as sp
from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud
from ott.initializers.linear import initializers as init_lib
from ott.initializers.linear import initializers_lr as lr_init_lib
from ott.neural import datasets
from ott.solvers import utils as solver_utils
from ott.tools.sinkhorn_divergence import sinkhorn_divergence as sinkhorn_div
Expand All @@ -21,6 +23,90 @@
__all__ = ["sinkhorn_divergence"]


class InitializerAdapter:
"""Adapter class for creating various OT solver initializers.
This class provides static methods to create and manage different types of
initializers used in optimal transport solvers, including low-rank, k-means,
and standard Sinkhorn initializers.
"""

@staticmethod
def lr_from_str(
initializer: str,
rank: int,
**kwargs: Any,
) -> lr_init_lib.LRInitializer:
"""Create a low-rank initializer from a string specification.
Parameters
----------
initializer : str
Either existing initializer instance or string specifier.
rank : int
Rank for the initialization.
**kwargs : Any
Additional keyword arguments for initializer creation.
Returns
-------
LRInitializer
Configured low-rank initializer.
Raises
------
NotImplementedError
If requested initializer type is not implemented.
"""
if isinstance(initializer, lr_init_lib.LRInitializer):
return initializer
if initializer == "k-means":
return lr_init_lib.KMeansInitializer(rank=rank, **kwargs)
if initializer == "generalized-k-means":
return lr_init_lib.GeneralizedKMeansInitializer(rank=rank, **kwargs)
if initializer == "random":
return lr_init_lib.RandomInitializer(rank=rank, **kwargs)
if initializer == "rank2":
return lr_init_lib.Rank2Initializer(rank=rank, **kwargs)
raise NotImplementedError(f"Initializer `{initializer}` is not implemented.")

@staticmethod
def from_str(
initializer: str,
**kwargs: Any,
) -> init_lib.SinkhornInitializer:
"""Create a Sinkhorn initializer from a string specification.
Parameters
----------
initializer : str
String specifier for initializer type.
**kwargs : Any
Additional keyword arguments for initializer creation.
Returns
-------
SinkhornInitializer
Configured Sinkhorn initializer.
Raises
------
NotImplementedError
If requested initializer type is not implemented.
"""
if isinstance(initializer, init_lib.SinkhornInitializer):
return initializer
if initializer == "default":
return init_lib.DefaultInitializer(**kwargs)
if initializer == "gaussian":
return init_lib.GaussianInitializer(**kwargs)
if initializer == "sorting":
return init_lib.SortingInitializer(**kwargs)
if initializer == "subsample":
return init_lib.SubsampleInitializer(**kwargs)
raise NotImplementedError(f"Initializer `{initializer}` is not yet implemented.")


def sinkhorn_divergence(
point_cloud_1: ArrayLike,
point_cloud_2: ArrayLike,
Expand All @@ -47,6 +133,8 @@ def sinkhorn_divergence(
b=b,
scale_cost=scale_cost,
epsilon=epsilon,
tau_a=tau_a,
tau_b=tau_b,
**kwargs,
)[1]
xy_conv, xx_conv, *yy_conv = output.converged
Expand Down
12 changes: 11 additions & 1 deletion src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@
from moscot._logging import logger
from moscot._types import (
ArrayLike,
LRInitializer_t,
ProblemKind_t,
QuadInitializer_t,
SinkhornInitializer_t,
)
from moscot.backends.ott._utils import (
InitializerAdapter,
Loader,
MultiLoader,
_instantiate_geodesic_cost,
Expand Down Expand Up @@ -286,8 +288,12 @@ def __init__(
eps = kwargs.get("epsilon")
if eps is not None and eps > 0.0:
logger.info(f"Found `epsilon`={eps}>0. We recommend setting `epsilon`=0 for the low-rank solver.")
if isinstance(initializer, str):
initializer = InitializerAdapter.lr_from_str(initializer, rank=rank)
self._solver = sinkhorn_lr.LRSinkhorn(rank=rank, epsilon=epsilon, initializer=initializer, **kwargs)
else:
if isinstance(initializer, str):
initializer = InitializerAdapter.from_str(initializer)
self._solver = sinkhorn.Sinkhorn(initializer=initializer, **kwargs)

def _prepare(
Expand Down Expand Up @@ -389,7 +395,7 @@ def __init__(
self,
jit: bool = True,
rank: int = -1,
initializer: QuadInitializer_t | None = None,
initializer: QuadInitializer_t | LRInitializer_t = None,
initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
**kwargs: Any,
Expand All @@ -401,13 +407,17 @@ def __init__(
eps = kwargs.get("epsilon")
if eps is not None and eps > 0.0:
logger.info(f"Found `epsilon`={eps}>0. We recommend setting `epsilon`=0 for the low-rank solver.")
if isinstance(initializer, str):
initializer = InitializerAdapter.lr_from_str(initializer, rank=rank)
self._solver = gromov_wasserstein_lr.LRGromovWasserstein(
rank=rank,
initializer=initializer,
**kwargs,
)
else:
linear_solver = sinkhorn.Sinkhorn(**linear_solver_kwargs)
if isinstance(initializer, str):
raise ValueError("Expected `initializer` to be `None` or `ott.initializers.quadratic.initializers`.")
self._solver = gromov_wasserstein.GromovWasserstein(
linear_solver=linear_solver,
initializer=initializer,
Expand Down
37 changes: 0 additions & 37 deletions tests/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import numpy as np
import pandas as pd
from ott.initializers.linear import initializers as init_lib
from ott.initializers.linear import initializers_lr as lr_init_lib
from scipy.sparse import csr_matrix

from anndata import AnnData
Expand Down Expand Up @@ -103,38 +101,3 @@ def _base_problem_type(self) -> Type[B]:
@property
def _valid_policies(self) -> Tuple[str, ...]:
return ()


def create_lr_initializer(
initializer,
rank,
**kwargs,
) -> lr_init_lib.LRInitializer: # noqa: D102
if isinstance(initializer, lr_init_lib.LRInitializer):
return initializer
if initializer == "random":
return lr_init_lib.RandomInitializer(rank=rank, **kwargs)
if initializer == "rank2":
return lr_init_lib.Rank2Initializer(rank=rank, **kwargs)
if initializer == "k-means":
return lr_init_lib.KMeansInitializer(rank=rank, **kwargs)
if initializer == "generalized-k-means":
return lr_init_lib.GeneralizedKMeansInitializer(rank=rank, **kwargs)
raise NotImplementedError(f"Initializer `{initializer}` is not yet implemented.")


def create_fr_initializer(
initializer,
**kwargs,
) -> init_lib.SinkhornInitializer: # noqa: D102
if isinstance(initializer, init_lib.SinkhornInitializer):
return initializer
if initializer == "default":
return init_lib.DefaultInitializer(**kwargs)
if initializer == "gaussian":
return init_lib.GaussianInitializer(**kwargs)
if initializer == "sorting":
return init_lib.SortingInitializer(**kwargs)
if initializer == "subsample":
return init_lib.SubsampleInitializer(**kwargs)
raise NotImplementedError(f"Initializer `{initializer}` is not yet implemented.")
4 changes: 1 addition & 3 deletions tests/backends/ott/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from ott.geometry.geometry import Geometry
from ott.geometry.low_rank import LRCGeometry
from ott.geometry.pointcloud import PointCloud
from ott.initializers.linear import initializers_lr as lr_init_lib
from ott.problems.linear.linear_problem import LinearProblem
from ott.problems.quadratic import quadratic_problem
from ott.problems.quadratic.quadratic_problem import QuadraticProblem
Expand Down Expand Up @@ -156,8 +155,7 @@ def test_epsilon(self, x_cost: jnp.ndarray, y_cost: jnp.ndarray, eps: Optional[f
def test_solver_rank(self, x: Geom_t, y: Geom_t, rank: int) -> None:
thresh, eps = 1e-2, 1e-2
if rank > -1:
initializer = lr_init_lib.RandomInitializer(rank=rank)
gt = LRGromovWasserstein(epsilon=eps, rank=rank, threshold=thresh, initializer=initializer)(
gt = LRGromovWasserstein(epsilon=eps, rank=rank, threshold=thresh, initializer="rank2")(
QuadraticProblem(PointCloud(x, epsilon=eps), PointCloud(y, epsilon=eps))
)

Expand Down
9 changes: 4 additions & 5 deletions tests/problems/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import numpy as np
import pandas as pd
from ott.initializers.linear import initializers as init_lib
from ott.initializers.linear import initializers_lr as lr_init_lib
from sklearn.metrics import pairwise_distances

import anndata as ad
Expand Down Expand Up @@ -77,7 +75,7 @@ def marginal_keys(request):
"tau_a": 1.0,
"tau_b": 1.0,
"rank": 7,
"initializer": lr_init_lib.RandomInitializer(rank=7),
"initializer": "rank2",
"initializer_kwargs": {},
"jit": False,
"threshold": 2e-3,
Expand All @@ -99,7 +97,7 @@ def marginal_keys(request):
"tau_b": 0.8,
"rank": -1,
"batch_size": 125,
"initializer": init_lib.GaussianInitializer(),
"initializer": "gaussian",
"initializer_kwargs": {},
"jit": True,
"threshold": 3e-3,
Expand Down Expand Up @@ -159,7 +157,8 @@ def marginal_keys(request):
"scale_cost": "max_cost",
"rank": 7,
"batch_size": 123,
"initializer": lr_init_lib.RandomInitializer(rank=7),
"initializer": "rank2",
"initializer_kwargs": {},
"jit": False,
"threshold": 2e-3,
"min_iterations": 2,
Expand Down
4 changes: 1 addition & 3 deletions tests/problems/cross_modality/test_translation_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,8 @@ def test_pass_arguments(self, adata_translation_split: Tuple[AnnData, AnnData],
solver = tp[key].solver.solver
args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args
for arg, val in args.items():
if arg == "initializer" and args_to_check["rank"] == -1:
if arg == "initializer":
assert isinstance(getattr(solver, val), Callable)
else:
assert getattr(solver, val) == args_to_check[arg], arg

sinkhorn_solver = solver.linear_solver if args_to_check["rank"] == -1 else solver
lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args
Expand Down
4 changes: 1 addition & 3 deletions tests/problems/generic/test_fgw_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,8 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin
solver = problem[key].solver.solver
args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args
for arg, val in args.items():
if args_to_check["rank"] == -1 and arg == "initializer":
if arg == "initializer":
assert isinstance(getattr(solver, val), Callable)
else:
assert getattr(solver, val, object()) == args_to_check[arg], arg

sinkhorn_solver = solver.linear_solver if args_to_check["rank"] == -1 else solver
lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args
Expand Down
4 changes: 1 addition & 3 deletions tests/problems/generic/test_gw_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,8 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin
args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args
for arg, val in args.items():
assert hasattr(solver, val)
if arg == "initializer" and args_to_check["rank"] == -1:
if arg == "initializer":
assert isinstance(getattr(solver, val), Callable)
else:
assert getattr(solver, val) == args_to_check[arg]

sinkhorn_solver = solver.linear_solver if args_to_check["rank"] == -1 else solver
lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args
Expand Down
7 changes: 3 additions & 4 deletions tests/problems/generic/test_sinkhorn_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,9 @@ def test_pass_arguments(self, adata_time: AnnData, args_to_check: Mapping[str, A
solver = problem[(0, 1)].solver.solver
args = sinkhorn_solver_args if args_to_check["rank"] == -1 else lr_sinkhorn_solver_args
for arg, val in args.items():
if val != "initializer_kwargs":
assert hasattr(solver, val), val
el = getattr(solver, val)[0] if isinstance(getattr(solver, val), tuple) else getattr(solver, val)
assert el == args_to_check[arg], arg
assert hasattr(solver, val), val
el = getattr(solver, val)[0] if isinstance(getattr(solver, val), tuple) else getattr(solver, val)
assert el == args_to_check[arg], arg

lin_prob = problem[(0, 1)]._solver._problem
for arg, val in lin_prob_args.items():
Expand Down
3 changes: 1 addition & 2 deletions tests/problems/space/test_alignment_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,7 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin
args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args
for arg, val in args.items():
assert hasattr(solver, val)
if arg != "initializer":
assert getattr(solver, val) == args_to_check[arg]
assert getattr(solver, val) == args_to_check[arg]

sinkhorn_solver = solver.linear_solver if args_to_check["rank"] == -1 else solver
lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args
Expand Down
4 changes: 1 addition & 3 deletions tests/problems/space/test_mapping_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,8 @@ def test_pass_arguments(self, adata_mapping: AnnData, args_to_check: Mapping[str
args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args
for arg, val in args.items():
assert hasattr(solver, val)
if arg == "initializer" and args_to_check["rank"] == -1:
if arg == "initializer":
assert isinstance(getattr(solver, val), Callable)
else:
assert getattr(solver, val) == args_to_check[arg]

sinkhorn_solver = solver.linear_solver if args_to_check["rank"] == -1 else solver
lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args
Expand Down
4 changes: 1 addition & 3 deletions tests/problems/time/test_lineage_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,8 @@ def test_pass_arguments(self, adata_time_barcodes: AnnData, args_to_check: Mappi
args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args
for arg, val in args.items():
assert hasattr(solver, val)
if arg == "initializer" and args_to_check["rank"] == -1:
if arg == "initializer":
assert isinstance(getattr(solver, val), Callable)
else:
assert getattr(solver, val) == args_to_check[arg]

sinkhorn_solver = solver.linear_solver if args_to_check["rank"] == -1 else solver
lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args
Expand Down
7 changes: 3 additions & 4 deletions tests/problems/time/test_temporal_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,10 +440,9 @@ def test_pass_arguments(self, adata_time: AnnData, args_to_check: Mapping[str, A
solver = problem[key].solver.solver
args = sinkhorn_solver_args if args_to_check["rank"] == -1 else lr_sinkhorn_solver_args
for arg, val in args.items():
if val != "initializer_kwargs":
assert hasattr(solver, val)
el = getattr(solver, val)[0] if isinstance(getattr(solver, val), tuple) else getattr(solver, val)
assert el == args_to_check[arg]
assert hasattr(solver, val)
el = getattr(solver, val)[0] if isinstance(getattr(solver, val), tuple) else getattr(solver, val)
assert el == args_to_check[arg]

lin_prob = problem[key]._solver._problem
for arg, val in lin_prob_args.items():
Expand Down

0 comments on commit 3ec3328

Please sign in to comment.