Comments (4)
Hi Dominik, thanks for this.
If you look at the primal cost, which is returned by reg_ot_cost
, this includes terms in KLs (those should >0) but also a contribution term corresponding to minus the entropy.
In OTT, we compute the objective using the dual formulation
Line 142 in d7521fd
and that also includes some negative contributions, notably in the
div_a
and div_b
terms, and the total_sum
term substratcted at the end.
So, having a negative ent_reg_cost
is not a contradiction per se. However, this does not preclude a bug. Have you looked at the transportation matrix that is returned? do you see entries that are unusual?
from ott.
Hi Marco,
Thanks for your response. I might be wrong but reading equation 24 in https://arxiv.org/pdf/1910.12958.pdf the term multiplied by epsilon in the entropic penalization should be non-positive because f_i + g_j <= C_{ij}
and hence every element of the right matrix in the inner product is non-positive whereas each term in the left matrix is non-negative.
This holds true for total_sum
in the code, in the example above total_sum = -11.217885
.
On the other hand, we add this term in the return statement, i.e. the return value (
Line 198 in d7521fd
div_a + div_b + ot_prob.epsilon * (jnp.sum(ot_prob.a) * jnp.sum(ot_prob.b) - total_sum)
. Hence, I don't know where the minus before the epsilon in equation (24) is incorporated.
Hence, I would assume that the return statement starting in line 198 should read
div_a + div_b - ot_prob.epsilon * (
jnp.sum(ot_prob.a) * jnp.sum(ot_prob.b) - total_sum
)
Apologies if I am wrong.
The transport matrix seems to be not completely wrong at least as for multiple scenarios the PCC with implementations from WOT and POT are larger than 0.98. If I am not mistaken they use a stabilized version whereas OTT-JAX
does not.
Moreover, if I understand correctly, OTT-JAX
does not incorporate the scale of the marginals into the stopping criterion. This could be easily adapted and might help the user.
from ott.
Thanks for checking this more closely, I must admit that part of the code was not extensively reviewed, so i am very grateful for you taking a closer look!!
here total_sum
corresponds to
Since that term has a minus in front of it in equation 24, I think this should still be a total_sum
.
Similarly the jnp.sum(ot_prob.a) * jnp.sum(ot_prob.b)
corresponds to +
there in front of ot_prob.epsilon
, anything mistaken on my end? Thanks again for checking!
As for marginals, I think you are right, there is definitely some scaling factor needed here in the stopping criterion. I think that's a very good idea. For instance, we could have a tolerance equal by default to jnp.sum(a) * tolerance
. I hope this does not have complicated side effects when optimizing over
from ott.
Hi Marco,
Thanks, yeah, I spotted my mistake and your reasoning makes perfect sense!
from ott.
Related Issues (20)
- `rank2` lr initializer not reproducible between 0.4.3 and 0.4.4 HOT 2
- Implement the principled initialisation HOT 4
- AttributeError -> with: sinkhorn_divergence - when: passing in sinkhorn_kwargs={''rank"=#someInteger} HOT 2
- AttributeError: module 'jax.random' has no attribute 'PRNGKeyArray' HOT 2
- how to compute the (bures) wasserstein distance between gaussians of different dimensions HOT 10
- Remove legacy `bool` option when scaling the cost matrix
- Add CITATION.cff HOT 1
- `converged` flag compatibility with `min_iterations` logic HOT 3
- Sinkhorn iteration is not converging in A100 GPU HOT 5
- bug in documentation
- Effects not supported in `custom_vjp` error when using GromovWasserstein HOT 14
- Increased GPU memory usage when using a cost_fn different from costs.SqEuclidean() HOT 10
- `compute_sparse_laplacian` gives int32 vs int64 index mismatch when input is from scipy csr
- Modify neural methods's `__call__` to do just one step
- Misnumbered equation
- linear and quadratic part get mixed up in genot HOT 4
- Unbalanced FGW doesn't converge when margins are provided HOT 2
- Add new geometry class for triangulated meshes
- Potential bug in ott.geometry.segment._segment_interface HOT 1
- Role of `fused_penalty`
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from ott.