Giter Site home page Giter Site logo

Comments (5)

zoepiran avatar zoepiran commented on June 3, 2024

@LaetitiaPapaxanthos here you can also observe the performance of unbalanced with $\tau_a = \tau_b$.
As I show the problem is indeed in gw_unbalanced_correction=True (with False it works).
you can obviously play there with everything :)
image
image

from ott.

michalk8 avatar michalk8 commented on June 3, 2024

For the tau_a = tau_b = 0.9, I noticed that the total mass transported is very low (1e-5), whereas if only 1 is unbalanced, it's fairly high (0.9). Code to reproduce:

from jax.config import config
config.update("jax_enable_x64", True)
import ott
import jax

import numpy as np
from ott.geometry.pointcloud import PointCloud

np.random.seed(0)
x = np.random.normal(size=(64, 3))
y = np.random.normal(size=(128, 3))
xx = np.random.normal(size=(64, 3))
yy = np.random.normal(size=(128, 3))

o, scale_cost = True, 'max_cost'
geom_x = PointCloud(x, online=o, scale_cost=scale_cost)
geom_y = PointCloud(y, online=o, scale_cost=scale_cost)
geom_xy = PointCloud(xx, yy, online=o, scale_cost=scale_cost)

solver = ott.core.gromov_wasserstein.GromovWasserstein(jit=False, epsilon=1e-2, lse_mode=False)
prob = ott.core.quad_problems.QuadraticProblem(geom_x, geom_y,
                                               geom_xy,
                                               tau_a=0.8, tau_b=0.8,
                                               gw_unbalanced_correction=True)

iteration = 0
state = solver.init_state(prob, -1)
linear_pb = prob.update_linearization(state.linear_state, solver.epsilon, state.old_transport_mass)

out = solver.linear_ot_solver(linear_pb)
old_transport_mass = jax.lax.stop_gradient(
    state.linear_state.transport_mass()
)
state = state.update(
    iteration, out, linear_pb, solver.store_inner_errors, old_transport_mass
)
print(state.linear_state.marginal(0).sum())  # 1.883714535238546e-05

from ott.

michalk8 avatar michalk8 commented on June 3, 2024

In the next iteration, the solution to the linearized problems contains infs; this also causes the transport mass sum to be 0 (and makes the scale between the old and the new transport mass NaN). This only happens when gw_unbalanced_correction=True.

from ott.

marcocuturi avatar marcocuturi commented on June 3, 2024

maybe we can close now?

from ott.

michalk8 avatar michalk8 commented on June 3, 2024

completed via #128

from ott.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.