Giter Site home page Giter Site logo

Negative loss in unbalanced Sinkhorn about ott HOT 4 CLOSED

MUCDK avatar MUCDK commented on June 3, 2024
Negative loss in unbalanced Sinkhorn

from ott.

Comments (4)

marcocuturi avatar marcocuturi commented on June 3, 2024

Hi Dominik, thanks for this.

If you look at the primal cost, which is returned by reg_ot_cost, this includes terms in KLs (those should >0) but also a contribution term corresponding to minus the entropy.

In OTT, we compute the objective using the dual formulation

def ent_reg_cost(

and that also includes some negative contributions, notably in the div_a and div_b terms, and the total_sum term substratcted at the end.

So, having a negative ent_reg_cost is not a contradiction per se. However, this does not preclude a bug. Have you looked at the transportation matrix that is returned? do you see entries that are unusual?

from ott.

MUCDK avatar MUCDK commented on June 3, 2024

Hi Marco,

Thanks for your response. I might be wrong but reading equation 24 in https://arxiv.org/pdf/1910.12958.pdf the term multiplied by epsilon in the entropic penalization should be non-positive because f_i + g_j <= C_{ij} and hence every element of the right matrix in the inner product is non-positive whereas each term in the left matrix is non-negative.

This holds true for total_sum in the code, in the example above total_sum = -11.217885.

On the other hand, we add this term in the return statement, i.e. the return value (

return div_a + div_b + ot_prob.epsilon * (
) is div_a + div_b + ot_prob.epsilon * (jnp.sum(ot_prob.a) * jnp.sum(ot_prob.b) - total_sum). Hence, I don't know where the minus before the epsilon in equation (24) is incorporated.

Hence, I would assume that the return statement starting in line 198 should read

div_a + div_b - ot_prob.epsilon * (
      jnp.sum(ot_prob.a) * jnp.sum(ot_prob.b) - total_sum
  )

Apologies if I am wrong.

The transport matrix seems to be not completely wrong at least as for multiple scenarios the PCC with implementations from WOT and POT are larger than 0.98. If I am not mistaken they use a stabilized version whereas OTT-JAX does not.

Moreover, if I understand correctly, OTT-JAX does not incorporate the scale of the marginals into the stopping criterion. This could be easily adapted and might help the user.

from ott.

marcocuturi avatar marcocuturi commented on June 3, 2024

Thanks for checking this more closely, I must admit that part of the code was not extensively reviewed, so i am very grateful for you taking a closer look!!

here total_sum corresponds to $\langle\alpha \otimes \beta, e^{\frac{f\oplus g-C}{\varepsilon}}\rangle$, since the elementwise product of these two matrices, i.e. $[\alpha_i \beta_j e^{\frac{f_i+ g_j-C_{ij}}{\varepsilon}}]_{ij}$ corresponds to the transport.

Since that term has a minus in front of it in equation 24, I think this should still be a $-\varepsilon \times$ total_sum.

Similarly the jnp.sum(ot_prob.a) * jnp.sum(ot_prob.b) corresponds to $\langle\alpha \otimes \beta, 1\rangle$ and that should be added (with +) because of the two -. So with a quick check I think that's still a + there in front of ot_prob.epsilon, anything mistaken on my end? Thanks again for checking!

As for marginals, I think you are right, there is definitely some scaling factor needed here in the stopping criterion. I think that's a very good idea. For instance, we could have a tolerance equal by default to jnp.sum(a) * tolerance. I hope this does not have complicated side effects when optimizing over $a$ in an unbalanced setting, but I think that's reasonable.

from ott.

MUCDK avatar MUCDK commented on June 3, 2024

Hi Marco,

Thanks, yeah, I spotted my mistake and your reasoning makes perfect sense!

from ott.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.