diff --git a/tests/backends/ott/test_backend.py b/tests/backends/ott/test_backend.py index 72f6455a..9bdc7f4e 100644 --- a/tests/backends/ott/test_backend.py +++ b/tests/backends/ott/test_backend.py @@ -24,7 +24,7 @@ from moscot.base.output import BaseDiscreteSolverOutput from moscot.base.solver import O, OTSolver from moscot.utils.tagged_array import Tag, TaggedArray -from tests._utils import ATOL, RTOL, Geom_t, create_lr_initializer +from tests._utils import ATOL, RTOL, Geom_t from tests.plotting.conftest import PlotTester, PlotTesterMeta @@ -52,7 +52,6 @@ def test_matches_ott(self, x: Geom_t, eps: Optional[float], jit: bool): def test_solver_rank(self, y: Geom_t, rank: Optional[int], initializer: str): eps = 1e-2 default_gamma_lr_sinhorn = 500 - initializer = create_lr_initializer(initializer, rank=rank) lr_sinkhorn = LRSinkhorn(rank=rank, initializer=initializer, gamma=default_gamma_lr_sinhorn) problem = LinearProblem(PointCloud(y, epsilon=eps)) gt = lr_sinkhorn(problem) diff --git a/tests/problems/space/test_alignment_problem.py b/tests/problems/space/test_alignment_problem.py index 7b720b43..31e3de91 100644 --- a/tests/problems/space/test_alignment_problem.py +++ b/tests/problems/space/test_alignment_problem.py @@ -13,7 +13,6 @@ from moscot.backends.ott._utils import alpha_to_fused_penalty from moscot.problems.space import AlignmentProblem from moscot.utils.tagged_array import Tag, TaggedArray -from tests._utils import create_lr_initializer from tests.problems.conftest import ( fgw_args_1, fgw_args_2, @@ -96,7 +95,6 @@ def test_solve_balanced( should_raise: bool, ): kwargs = {} - initializer = create_lr_initializer(initializer, rank=rank) if initializer is not None else None if rank > -1: kwargs["initializer"] = initializer if initializer == "random": diff --git a/tests/problems/space/test_mapping_problem.py b/tests/problems/space/test_mapping_problem.py index d01e9726..b0645e28 100644 --- a/tests/problems/space/test_mapping_problem.py +++ b/tests/problems/space/test_mapping_problem.py @@ -17,7 +17,7 @@ from moscot.backends.ott._utils import alpha_to_fused_penalty from moscot.problems.space import MappingProblem from moscot.utils.tagged_array import Tag, TaggedArray -from tests._utils import _adata_spatial_split, create_lr_initializer +from tests._utils import _adata_spatial_split from tests.problems.conftest import ( fgw_args_1, fgw_args_2, @@ -114,7 +114,6 @@ def test_solve_balanced( ): adataref, adatasp = _adata_spatial_split(adata_mapping) kwargs = {} - initializer = create_lr_initializer(initializer, rank) if initializer is not None else None if rank > -1: kwargs["initializer"] = initializer if initializer == "random":