Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen committed Dec 12, 2024
1 parent 3ec3328 commit 2474cf7
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 6 deletions.
3 changes: 1 addition & 2 deletions tests/backends/ott/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from moscot.base.output import BaseDiscreteSolverOutput
from moscot.base.solver import O, OTSolver
from moscot.utils.tagged_array import Tag, TaggedArray
from tests._utils import ATOL, RTOL, Geom_t, create_lr_initializer
from tests._utils import ATOL, RTOL, Geom_t
from tests.plotting.conftest import PlotTester, PlotTesterMeta


Expand Down Expand Up @@ -52,7 +52,6 @@ def test_matches_ott(self, x: Geom_t, eps: Optional[float], jit: bool):
def test_solver_rank(self, y: Geom_t, rank: Optional[int], initializer: str):
eps = 1e-2
default_gamma_lr_sinhorn = 500
initializer = create_lr_initializer(initializer, rank=rank)
lr_sinkhorn = LRSinkhorn(rank=rank, initializer=initializer, gamma=default_gamma_lr_sinhorn)
problem = LinearProblem(PointCloud(y, epsilon=eps))
gt = lr_sinkhorn(problem)
Expand Down
2 changes: 0 additions & 2 deletions tests/problems/space/test_alignment_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from moscot.backends.ott._utils import alpha_to_fused_penalty
from moscot.problems.space import AlignmentProblem
from moscot.utils.tagged_array import Tag, TaggedArray
from tests._utils import create_lr_initializer
from tests.problems.conftest import (
fgw_args_1,
fgw_args_2,
Expand Down Expand Up @@ -96,7 +95,6 @@ def test_solve_balanced(
should_raise: bool,
):
kwargs = {}
initializer = create_lr_initializer(initializer, rank=rank) if initializer is not None else None
if rank > -1:
kwargs["initializer"] = initializer
if initializer == "random":
Expand Down
3 changes: 1 addition & 2 deletions tests/problems/space/test_mapping_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from moscot.backends.ott._utils import alpha_to_fused_penalty
from moscot.problems.space import MappingProblem
from moscot.utils.tagged_array import Tag, TaggedArray
from tests._utils import _adata_spatial_split, create_lr_initializer
from tests._utils import _adata_spatial_split
from tests.problems.conftest import (
fgw_args_1,
fgw_args_2,
Expand Down Expand Up @@ -114,7 +114,6 @@ def test_solve_balanced(
):
adataref, adatasp = _adata_spatial_split(adata_mapping)
kwargs = {}
initializer = create_lr_initializer(initializer, rank) if initializer is not None else None
if rank > -1:
kwargs["initializer"] = initializer
if initializer == "random":
Expand Down

0 comments on commit 2474cf7

Please sign in to comment.