Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Negative cost in OTTOutput #769

Open
matthieuheitz opened this issue Dec 6, 2024 · 4 comments
Open

Negative cost in OTTOutput #769

matthieuheitz opened this issue Dec 6, 2024 · 4 comments
Assignees

Comments

@matthieuheitz
Copy link

matthieuheitz commented Dec 6, 2024

I'm computing all pair distances between 6 spatial transcriptomics slices, and some of them have negative costs, and even absurdly large numbers, even though they have converged.

I prepared the problem like this:
stp = stp.prepare(time_key="Batch_idx",spatial_key="spatial",joint_attr=joint_attr,cost='sq_euclidean',policy="triu")
where joint_attr="X_pca" is a global PCA.

Here's my call to the solver, and the output:

stp = stp.solve(epsilon=epsilon_scheduler.Epsilon(target=1e-3, init=100, decay=0.99), 
                           alpha=0.5, 
                           linear_solver_kwargs={"momentum":acceleration.Momentum(start=300)})

[((0, 1), OTTOutput[shape=(147, 292), cost=0.9862, converged=True]),
 ((2, 4), OTTOutput[shape=(441, 824), cost=0.9677, converged=True]),
 ((1, 2), OTTOutput[shape=(292, 441), cost=0.6782, converged=True]),
 ((0, 4), OTTOutput[shape=(147, 824), cost=0.9633, converged=True]),
 ((3, 4), OTTOutput[shape=(1169, 824), cost=0.9322, converged=True]),
 ((1, 5), OTTOutput[shape=(292, 744), cost=-0.0029, converged=True]),
 ((0, 3), OTTOutput[shape=(147, 1169), cost=-56.5436, converged=True]),
 ((1, 4), OTTOutput[shape=(292, 824), cost=-28217723322368.0, converged=True]),
 ((2, 3), OTTOutput[shape=(441, 1169), cost=0.9884, converged=True]),
 ((0, 2), OTTOutput[shape=(147, 441), cost=-3458.1289, converged=True]),
 ((4, 5), OTTOutput[shape=(824, 744), cost=-0.4292, converged=True]),
 ((0, 5), OTTOutput[shape=(147, 744), cost=1.1268, converged=True]),
 ((2, 5), OTTOutput[shape=(441, 744), cost=0.9649, converged=True]),
 ((1, 3), OTTOutput[shape=(292, 1169), cost=0.8918, converged=True]),
 ((3, 5), OTTOutput[shape=(1169, 744), cost=0.8036, converged=True])]

I then tried to change some parameters of the solvers to see if it would matter, and it seems it does.
Removing the epsilon decay, we only get 3 negative values (and a nan).

stp = stp.solve(epsilon=1e-3, alpha=0.5, linear_solver_kwargs={"momentum":acceleration.Momentum(start=200)}) 

[((0, 1), OTTOutput[shape=(147, 292), cost=0.9831, converged=True]),
 ((2, 4), OTTOutput[shape=(441, 824), cost=nan, converged=False]),
 ((1, 2), OTTOutput[shape=(292, 441), cost=0.678, converged=True]),
 ((0, 4), OTTOutput[shape=(147, 824), cost=-11.4962, converged=True]),
 ((3, 4), OTTOutput[shape=(1169, 824), cost=0.9322, converged=True]),
 ((1, 5), OTTOutput[shape=(292, 744), cost=1.0901, converged=True]),
 ((0, 3), OTTOutput[shape=(147, 1169), cost=-55.8705, converged=True]),
 ((1, 4), OTTOutput[shape=(292, 824), cost=0.9495, converged=True]),
 ((2, 3), OTTOutput[shape=(441, 1169), cost=0.9883, converged=True]),
 ((0, 2), OTTOutput[shape=(147, 441), cost=-2565.3809, converged=True]),
 ((4, 5), OTTOutput[shape=(824, 744), cost=0.671, converged=True]),
 ((0, 5), OTTOutput[shape=(147, 744), cost=1.1225, converged=True]),
 ((2, 5), OTTOutput[shape=(441, 744), cost=0.9571, converged=True]),
 ((1, 3), OTTOutput[shape=(292, 1169), cost=1.0639, converged=True]),
 ((3, 5), OTTOutput[shape=(1169, 744), cost=0.7674, converged=True])]

Removing the epsilon decay, and choosing a fixed momentum, we still have 4 negative values, but still a few (and many problems didn't converge, but I'm not sure that's relevant).

stp = stp.solve(epsilon=1e-3, alpha=0.5, linear_solver_kwargs={"momentum":acceleration.Momentum(value=1.6)})

[((0, 1), OTTOutput[shape=(147, 292), cost=0.9871, converged=True]),
 ((2, 4), OTTOutput[shape=(441, 824), cost=0.9025, converged=False]),
 ((1, 2), OTTOutput[shape=(292, 441), cost=0.6788, converged=False]),
 ((0, 4), OTTOutput[shape=(147, 824), cost=1.0402, converged=True]),
 ((3, 4), OTTOutput[shape=(1169, 824), cost=0.9117, converged=False]),
 ((1, 5), OTTOutput[shape=(292, 744), cost=1.0907, converged=True]),
 ((0, 3), OTTOutput[shape=(147, 1169), cost=-43.8229, converged=False]),
 ((1, 4), OTTOutput[shape=(292, 824), cost=0.9567, converged=True]),
 ((2, 3), OTTOutput[shape=(441, 1169), cost=0.9883, converged=False]),
 ((0, 2), OTTOutput[shape=(147, 441), cost=-1995.7041, converged=True]),
 ((4, 5), OTTOutput[shape=(824, 744), cost=0.6685, converged=False]),
 ((0, 5), OTTOutput[shape=(147, 744), cost=-9.9394, converged=False]),
 ((2, 5), OTTOutput[shape=(441, 744), cost=0.9292, converged=True]),
 ((1, 3), OTTOutput[shape=(292, 1169), cost=0.9441, converged=True]),
 ((3, 5), OTTOutput[shape=(1169, 744), cost=-2.0191, converged=False])]

What's strange is that it's not always the same problems that have a negative cost. Though (0,3) and (0,2) seem to always be negative, and others are on and off.
Any idea what could be causing this?
Thanks!

@MUCDK
Copy link
Collaborator

MUCDK commented Dec 6, 2024

Hi @matthieuheitz ,

Thanks for opening this issue.

A negative cost is possible because we (following ott-jax) report the entropy-regularized optimal transport cost.

The cost=nan (with converged=False) can happen unfortunately.

For the case

((1, 4), OTTOutput[shape=(292, 824), cost=-28217723322368.0, converged=True]), could you try plotting the cost (https://moscot.readthedocs.io/en/latest/genapi/moscot.backends.ott.OTTOutput.plot_costs.html) . About the errors, let's discuss that in issue #768

@MUCDK MUCDK self-assigned this Dec 6, 2024
@matthieuheitz
Copy link
Author

The cost plot doesn't seem crazy...
stp.solutions[(1,4)].plot_costs()
Image

@matthieuheitz
Copy link
Author

Yeah, I don't know where this -28217723322368.0 value comes from, it's nowhere in stp.solutions[(1,4)]._costs:

Array([ 1.7440577,  1.2362485,  0.9642093,  0.9561533,  0.9549437,
        0.9557142, -1.       , -1.       , -1.       , -1.       ,
       -1.       , -1.       , -1.       , -1.       , -1.       ,
       -1.       , -1.       , -1.       , -1.       , -1.       ,
       -1.       , -1.       , -1.       , -1.       , -1.       ,
       -1.       , -1.       , -1.       , -1.       , -1.       ,
       -1.       , -1.       , -1.       , -1.       , -1.       ,
       -1.       , -1.       , -1.       , -1.       , -1.       ,
       -1.       , -1.       , -1.       , -1.       , -1.       ,
       -1.       , -1.       , -1.       , -1.       , -1.       ],      dtype=float32)

It looks like the setting of the cost attribute is not always happening like it should.

@MUCDK
Copy link
Collaborator

MUCDK commented Dec 9, 2024

I guess that would be some numerical overflow. Are you running it with dtype 32 or 64? If 32, could you try with dtype 64?

@michalk8 , do you have any idea where this might come from?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants