Skip to content

Commit

Permalink
fix alpha tests
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen committed Dec 10, 2024
1 parent f15c4e3 commit e13665e
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 5 deletions.
2 changes: 0 additions & 2 deletions src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,6 @@ def _prepare(
geom_kwargs["cost_matrix_rank"] = cost_matrix_rank
geom_xx = self._create_geometry(x, t=time_scales_heat_kernel.x, is_linear_term=False, **geom_kwargs)
geom_yy = self._create_geometry(y, t=time_scales_heat_kernel.y, is_linear_term=False, **geom_kwargs)
if alpha is None:
alpha = 1.0 if xy is None else 0.5 # set defaults according to the data provided
if alpha <= 0.0:
raise ValueError(f"Expected `alpha` to be in interval `(0, 1]`, found `{alpha}`.")
if (alpha == 1.0 and xy is not None) or (alpha != 1.0 and xy is None):
Expand Down
4 changes: 2 additions & 2 deletions tests/problems/base/test_general_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def test_set_graph_x_y(self, adata_x: AnnData, adata_y: AnnData, ts: Tuple[Optio
assert ta2.tag == Tag.GRAPH
assert ta2.cost == "geodesic"

prob1 = prob1.solve(epsilon=10.0)
prob1 = prob1.solve(epsilon=10.0, alpha=1.0)

prob2 = OTProblem(adata_x, adata_y)
prob2 = prob2.prepare(
Expand All @@ -313,7 +313,7 @@ def test_set_graph_x_y(self, adata_x: AnnData, adata_y: AnnData, ts: Tuple[Optio
assert ta2.tag == Tag.GRAPH
assert ta2.cost == "geodesic"

prob2 = prob2.solve(epsilon=10.0)
prob2 = prob2.solve(epsilon=10.0, alpha=1.0)

assert not np.allclose(prob1.solution._output.geom.cost_matrix, prob2.solution._output.geom.cost_matrix)

Expand Down
2 changes: 1 addition & 1 deletion tests/problems/time/test_temporal_base_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_posterior_growth_rates(self, adata_time_marginal_estimations: AnnData):
b=True,
marginal_kwargs={"proliferation_key": "proliferation"},
)
prob = prob.solve(max_iterations=10)
prob = prob.solve(max_iterations=10, alpha=1.0)
assert prob.delta == (t2 - t1)

gr = prob.posterior_growth_rates
Expand Down

0 comments on commit e13665e

Please sign in to comment.