Giter Site home page Giter Site logo

ott-jax / ott Goto Github PK

View Code? Open in Web Editor NEW
458.0 10.0 78.0 78.4 MB

Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.

Home Page: https://ott-jax.readthedocs.io

License: Apache License 2.0

Python 100.00%
optimal-transport automatic-differentiation jax sinkhorn gromov-wasserstein

ott's Introduction

logo

Optimal Transport Tools (OTT)

Downloads Tests Docs Coverage

See the full documentation.

What is OTT-JAX?

A JAX powered library to compute optimal transport at scale and on accelerators, OTT-JAX includes the fastest implementation of the Sinkhorn algorithm you will find around. We have implemented all tweaks (scheduling, momentum, acceleration, initializations) and extensions (low-rank, entropic maps). They can be used directly between two datasets, or within more advanced problems (Gromov-Wasserstein, barycenters). Some of JAX features, including JIT, auto-vectorization and implicit differentiation work towards the goal of having end-to-end differentiable outputs. OTT-JAX is led by a team of researchers at Apple, with contributions from Google and Meta researchers, as well as many academic partners, including TU München, Oxford, ENSAE/IP Paris, ENS Paris and the Hebrew University.

Installation

Install OTT-JAX from PyPI as:

pip install ott-jax

or with conda via conda-forge as:

conda install -c conda-forge ott-jax

What is optimal transport?

Optimal transport can be loosely described as the branch of mathematics and optimization that studies matching problems: given two families of points, and a cost function on pairs of points, find a "good" (low cost) way to associate bijectively to every point in the first family another in the second.

Such problems appear in all areas of science, are easy to describe, yet hard to solve. Indeed, while matching optimally two sets of $n$ points using a pairwise cost can be solved with the Hungarian algorithm, solving it costs an order of $O(n^3)$ operations, and lacks flexibility, since one may want to couple families of different sizes.

Optimal transport extends all of this, through faster algorithms (in $n^2$ or even linear in $n$) along with numerous generalizations that can help it handle weighted sets of different size, partial matchings, and even more evolved so-called quadratic matching problems.

In the simple toy example below, we compute the optimal coupling matrix between two point clouds sampled randomly (2D vectors, compared with the squared Euclidean distance):

Example

import jax
import jax.numpy as jnp

from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

# sample two point clouds and their weights.
rngs = jax.random.split(jax.random.PRNGKey(0), 4)
n, m, d = 12, 14, 2
x = jax.random.normal(rngs[0], (n,d)) + 1
y = jax.random.uniform(rngs[1], (m,d))
a = jax.random.uniform(rngs[2], (n,))
b = jax.random.uniform(rngs[3], (m,))
a, b = a / jnp.sum(a), b / jnp.sum(b)
# Computes the couplings using the Sinkhorn algorithm.
geom = pointcloud.PointCloud(x, y)
prob = linear_problem.LinearProblem(geom, a, b)

solver = sinkhorn.Sinkhorn()
out = solver(prob)

The call to solver(prob) above works out the optimal transport solution. The out object contains a transport matrix (here of size $12\times 14$) that quantifies the association strength between each point of the first point cloud, to one or more points from the second, as illustrated in the plot below. We provide more flexibility to define custom cost functions, objectives, and solvers, as detailed in the full documentation.

obtained coupling

Citation

If you have found this work useful, please consider citing this reference:

@article{cuturi2022optimal,
  title={Optimal Transport Tools (OTT): A JAX Toolbox for all things Wasserstein},
  author={Cuturi, Marco and Meng-Papaxanthos, Laetitia and Tian, Yingtao and Bunne, Charlotte and
          Davis, Geoff and Teboul, Olivier},
  journal={arXiv preprint arXiv:2201.12324},
  year={2022}
}

See also

The moscot package for OT analysis of multi-omics data also uses OTT as a backbone.

ott's People

Contributors

adrhill avatar alantian avatar antoinebelloir avatar awehenkel avatar bamos avatar bosr avatar bunnech avatar daniel-packer avatar ersisimou avatar geoff-davis avatar giovp avatar gjhuizing avatar guillaumehu avatar jtt94 avatar laetitiapapaxanthos avatar lucaeyring avatar marcocuturi avatar meyerscetbon avatar michalk8 avatar mucdk avatar nvesseron avatar olivierteboul avatar othmanesebbouh avatar pierreablin avatar pkassraie avatar sauravmaheshkar avatar selmanozleyen avatar soerenab avatar theouscidda6 avatar zoepiran avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

ott's Issues

Jiting low-rank Sinkhorn doesn't speed up computations

It seems that the low-rank alternative of Sinkhorn (https://ott-jax.readthedocs.io/en/latest/notebooks/LRSinkhorn.html) does not take advantage of jit. Even so, it seems that using jit does slow down the calculations.

Here is a colab notebook that illustrates this aspect: https://colab.research.google.com/drive/1QAlY9x5BfJhM_DBD7WOFrf_FuNnWeAek?usp=sharing. This notebook also shows that unlike low-rank Sinkhorn, Sinkhorn is well accelerated when jited, via experiments on the Sinkhorn divergence.

Do you have an explanation?

Thank you very much for your help.

PointCloud `apply_lse_kernel` for large scale optimal transport requires too much memory

Running on CPU, tries to allocate ~122GiB. This happens for online=True, my current hypothesis is that the 2 nested vmaps here and here (for the online case) cause some large intermediate storage to be created that's unfortunately not optimized away.
If that's the case, trying jax.lax.scan instead of the outer vmap might be an alternative approach.

Minimal reproducible example:

import numpy as np
import jax.numpy as jnp
import ott

exp = 16  # 2 ** 16
x = np.random.normal(size=(2 ** exp, 30))
y = np.random.normal(size=(2 ** (exp - 1), 30))
f = x.sum(1)
g = y.sum(1)
eps = 1e-2

pc = ott.geometry.PointCloud(jnp.asarray(x), jnp.asarray(y), epsilon=eps, online=True)
print(pc.shape)  # (65536, 32768)

pc.apply_lse_kernel(f, g, eps)  # raises the RuntimeError (OOM) below

Traceback:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [43], in <module>
----> 1 pc.apply_lse_kernel(f, g, eps)

File /vol/storage/MK/repos/ott/ott/geometry/pointcloud.py:136, in PointCloud.apply_lse_kernel(self, f, g, eps, vec, axis)
    129 app = jax.vmap(
    130     _apply_lse_kernel_xy,
    131     in_axes=[
    132         None, 0, None, self._axis_norm, None, 0, None, None, None, None
    133     ])
    135 if axis == 0:
--> 136   h_res, h_sgn = app(self.x, self.y, self._norm_x, self._norm_y, f, g, eps,
    137                      vec, self._cost_fn, self.power)
    138   h_res = eps * h_res - jnp.where(jnp.isfinite(g), g, 0)
    139 if axis == 1:

    [... skipping hidden 3 frame]

File /vol/storage/MK/repos/ott/ott/geometry/pointcloud.py:314, in _apply_lse_kernel_xy(x, y, norm_x, norm_y, f, g, eps, vec, cost_fn, cost_pow)
    312 def _apply_lse_kernel_xy(x, y, norm_x, norm_y, f, g, eps,
    313                          vec, cost_fn, cost_pow):
--> 314   c = _cost(x, y, norm_x, norm_y, cost_fn, cost_pow)
    315   return ops.logsumexp((f + g - c) / eps, b=vec, return_sign=True, axis=-1)

File /vol/storage/MK/repos/ott/ott/geometry/pointcloud.py:335, in _cost(x, y, norm_x, norm_y, cost_fn, cost_pow)
    333 def _cost(x, y, norm_x, norm_y, cost_fn, cost_pow):
    334   one_line_pairwise = jax.vmap(cost_fn.pairwise, in_axes=[0, None])
--> 335   return (norm_x + norm_y + one_line_pairwise(x, y)) ** (0.5 * cost_pow)

    [... skipping hidden 3 frame]

File /vol/storage/MK/repos/ott/ott/geometry/costs.py:96, in Euclidean.pairwise(self, x, y)
     95 def pairwise(self, x, y):
---> 96   return -2 * dot(x, y)

File /vol/storage/MK/repos/ott/ott/geometry/costs.py:30, in dot(x, y)
     27 """Accelerator dependent dot. Implemented to avoid OOMs with online mode."""
     28 platform = xla_bridge.get_backend().platform
     29 return jnp.where(platform == 'gpu',
---> 30                  jnp.sum(x * y),
     31                  jnp.vdot(x, y))

    [... skipping hidden 1 frame]

File /vol/storage/.miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:6747, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
   6745 if not isinstance(other, _accepted_binop_types):
   6746   return NotImplemented
-> 6747 return binary_op(self, other)

    [... skipping hidden 12 frame]

File /vol/storage/.miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/_src/dispatch.py:444, in _execute_compiled(name, compiled, output_buffer_counts, result_handlers, kept_var_idx, *args)
    441 device, = compiled.local_devices()
    442 input_bufs = util.flatten(
    443     device_put(x, device) for i, x in enumerate(args) if i in kept_var_idx)
--> 444 out_bufs = compiled.execute(input_bufs)
    445 check_special(name, out_bufs)
    446 if output_buffer_counts is None:

RuntimeError: RESOURCE_EXHAUSTED: Out of memory allocating 257698037760 bytes.

Version: current master (fed7dfd)

Cost without norm and PointCloud in `online` mode does not work

To reproduce:

import ott
import numpy as np
import jax.numpy as jnp

x = np.random.normal(size=(100, 30))
y = np.random.normal(size=(90, 30))

pc = ott.geometry.pointcloud.PointCloud(jnp.asarray(x), jnp.asarray(y), scale_cost=None,
                                        cost_fn=ott.geometry.costs.Cosine(),
                                        batch_size=32)

pc.apply_cost(jnp.ones((100,)))

Problem is that pc._axis_norm == None, which causes not to vmap across _norm_y here
Since _norm_{x,y} is correctly set to all 0s, simple fix is to always vmap across this dimension.

Traceback:

TypeError                                 Traceback (most recent call last)
Input In [65], in <cell line: 1>()
----> 1 pc.apply_cost(jnp.ones((100,)))

File /opt/projects/helmholtz/ott_mk/ott/geometry/pointcloud.py:373, in PointCloud.apply_cost(self, arr, axis, fn, is_linear)
    370 if self.is_squared_euclidean and (fn is None or is_linear):
    371   return self.vec_apply_cost(arr, axis, fn=fn)
--> 373 return self._apply_cost(arr, axis, fn=fn)

File /opt/projects/helmholtz/ott_mk/ott/geometry/pointcloud.py:390, in PointCloud._apply_cost(self, arr, axis, fn)
    388 if axis == 0:
    389   print(self._norm_x.shape, self._norm_y.shape)
--> 390   return app(
    391       self.x, self.y, self._norm_x, self._norm_y, arr, self._cost_fn,
    392       self.power, self.inv_scale_cost, fn
    393   )
    394 if axis == 1:
    395   return app(
    396       self.y, self.x, self._norm_y, self._norm_x, arr, self._cost_fn,
    397       self.power, self.inv_scale_cost, fn
    398   )

    [... skipping hidden 3 frame]

File /opt/projects/helmholtz/ott_mk/ott/geometry/pointcloud.py:712, in _apply_cost_xy(x, y, norm_x, norm_y, vec, cost_fn, cost_pow, scale_cost, fn)
    689 def _apply_cost_xy(
    690     x, y, norm_x, norm_y, vec, cost_fn, cost_pow, scale_cost, fn=None
    691 ):
    692   """Apply [num_b, num_a] fn(cost) matrix (or transpose) to vector.
    693 
    694   Applies [num_b, num_a] ([num_a, num_b] if axis=1 from `apply_cost`)
   (...)
    710     A jnp.ndarray corresponding to cost x vector
    711   """
--> 712   c = _cost(x, y, norm_x, norm_y, cost_fn, cost_pow, scale_cost)
    713   return jnp.dot(c, vec) if fn is None else jnp.dot(fn(c), vec)

File /opt/projects/helmholtz/ott_mk/ott/geometry/pointcloud.py:685, in _cost(x, y, norm_x, norm_y, cost_fn, cost_pow, scale_cost)
    683 one_line_pairwise = jax.vmap(cost_fn.pairwise, in_axes=[0, None])
    684 print(norm_x.shape, norm_y.shape, "NX x NY")
--> 685 return ((norm_x + norm_y + one_line_pairwise(x, y)) ** (0.5 * cost_pow) *
    686         scale_cost)

File ~/.miniconda3/envs/ott/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:4586, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
   4584 if not isinstance(other, _accepted_binop_types):
   4585   return NotImplemented
-> 4586 return binary_op(self, other)

    [... skipping hidden 14 frame]

File ~/.miniconda3/envs/ott/lib/python3.9/site-packages/jax/_src/numpy/ufuncs.py:81, in _maybe_bool_binop.<locals>.fn(x1, x2)
     79 def fn(x1, x2):
     80   x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
---> 81   return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)

    [... skipping hidden 7 frame]

File ~/.miniconda3/envs/ott/lib/python3.9/site-packages/jax/_src/lax/lax.py:1443, in _broadcasting_shape_rule(name, *avals)
   1441     non_1s = {d for d in ds if not core.symbolic_equal_dim(d, 1)}
   1442     if len(non_1s) > 1:
-> 1443       raise TypeError(f'{name} got incompatible shapes for broadcasting: '
   1444                       f'{", ".join(map(str, map(tuple, shapes)))}.')
   1445     result_shape.append(non_1s.pop() if non_1s else 1)
   1446 return tuple(result_shape)

TypeError: add got incompatible shapes for broadcasting: (100,), (90,).

catching `online=True` in `gromov_wasserstein`

hi,
we noticed that one can accidentally pass online=True using gromov_wasserstein and this will propagate through **kwargs. Eventually mapping will fail at some point due to dimensionality assertion but we thought it may be useful to catch it earlier and warn.
What do you think?

Remove `jit` in solver classes

In light of #169 and #170, I advocate for removal of jit from the solvers (like Sinkhorn), since user most likely wants to jit the outermost function anyways.

```scale_cost``` throws an error if used with ```online=False```

from ott.core.sinkhorn import sinkhorn
from ott.geometry.pointcloud import PointCloud
import jax.numpy as jnp
import jax

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, shape=(10,2))
y = jax.random.normal(key, shape=(10,2))
geom = PointCloud(x, y, online=False, scale_cost="max_cost")
out = sinkhorn(geom)

throws an error because

@property
  def is_online(self) -> bool:
    return self._online is not None

returns online==True.

Will open a PR soon.

put adaptive momentum as default for Sinkhorn

This did not generate much comments (#77) but I think that switching to an adaptive momentum schedule by default could make a fairly big difference across many applications (and propagate to GW) for everything using Sinkhorn. What about putting this as a default? I think initializations are nice, and complementary to this (at least this is what we find) but it seems that adaptive momentum is really a game changer that works consistently.

Anyone has an opinion on this?

run tests locally with pytest

Hi,

In order to run tests locally with pytest, the pytest, pytest-xdist, pytest-memray, pytest-cov are currently included in setup.cfg.

However, it is not possible to install pytest-memray on a Mac. So, we might want to remove that.

Also (very minor) we might want to change the path in contributing.md for running tests, to be a valid path to a test in ott

`fixpoint_iter_backprop` doesn't correctly handle integer constants

import jax
import jax.numpy as jnp
import ott

def cond_fn(*_, **__): return True

def body_fn(it: int, const: int, state: float, compute_error):
    return state + const

def foo(x: jnp.ndarray):
    res = ott.core.fixed_point_loop.fixpoint_iter_backprop(
        cond_fn,
        body_fn,
        min_iterations=1,
        max_iterations=20,
        inner_iterations=1,
        constants=2,
        state=x
    )
    return jnp.sum(res)

v, g = jax.value_and_grad(foo)(jnp.zeros((10,)))

raises

---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File ~/.miniconda3/envs/cellrank/lib/python3.8/runpy.py:194, in _run_module_as_main(***failed resolving arguments***)
    193     sys.argv[0] = mod_spec.origin
--> 194 return _run_code(code, main_globals, None,
    195                  "__main__", mod_spec)

File ~/.miniconda3/envs/cellrank/lib/python3.8/runpy.py:87, in _run_code(***failed resolving arguments***)
     80 run_globals.update(__name__ = mod_name,
     81                    __file__ = fname,
     82                    __cached__ = cached,
   (...)
     85                    __package__ = pkg_name,
     86                    __spec__ = mod_spec)
---> 87 exec(code, run_globals)
     88 return run_globals

File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ipykernel_launcher.py:17, in <module>
     15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()

File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/traitlets/config/application.py:976, in Application.launch_instance(***failed resolving arguments***)
    975 app.initialize(argv)
--> 976 app.start()

File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ipykernel/kernelapp.py:712, in IPKernelApp.start(***failed resolving arguments***)
    711 try:
--> 712     self.io_loop.start()
    713 except KeyboardInterrupt:

File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/tornado/platform/asyncio.py:199, in BaseAsyncIOLoop.start(***failed resolving arguments***)
    198     asyncio.set_event_loop(self.asyncio_loop)
--> 199     self.asyncio_loop.run_forever()
    200 finally:

File ~/.miniconda3/envs/cellrank/lib/python3.8/asyncio/base_events.py:570, in BaseEventLoop.run_forever(***failed resolving arguments***)
    569 while True:
--> 570     self._run_once()
    571     if self._stopping:

File ~/.miniconda3/envs/cellrank/lib/python3.8/asyncio/base_events.py:1859, in BaseEventLoop._run_once(***failed resolving arguments***)
   1858     else:
-> 1859         handle._run()
   1860 handle = None

File ~/.miniconda3/envs/cellrank/lib/python3.8/asyncio/events.py:81, in Handle._run(***failed resolving arguments***)
     80 try:
---> 81     self._context.run(self._callback, *self._args)
     82 except (SystemExit, KeyboardInterrupt):

File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ipykernel/kernelbase.py:508, in Kernel.dispatch_queue(***failed resolving arguments***)
    507 try:
--> 508     await self.process_one()
    509 except Exception:

File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ipykernel/kernelbase.py:497, in Kernel.process_one(***failed resolving arguments***)
    496         return None
--> 497 await dispatch(*args)

File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ipykernel/kernelbase.py:404, in Kernel.dispatch_shell(***failed resolving arguments***)
    403     if inspect.isawaitable(result):
--> 404         await result
    405 except Exception:

File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ipykernel/kernelbase.py:728, in Kernel.execute_request(***failed resolving arguments***)
    727 if inspect.isawaitable(reply_content):
--> 728     reply_content = await reply_content
    730 # Flush output before sending the reply.

File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ipykernel/ipkernel.py:383, in IPythonKernel.do_execute(***failed resolving arguments***)
    382 if with_cell_id:
--> 383     res = shell.run_cell(
    384         code,
    385         store_history=store_history,
    386         silent=silent,
    387         cell_id=cell_id,
    388     )
    389 else:

File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ipykernel/zmqshell.py:528, in ZMQInteractiveShell.run_cell(***failed resolving arguments***)
    527 self._last_traceback = None
--> 528 return super().run_cell(*args, **kwargs)

File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2881, in InteractiveShell.run_cell(***failed resolving arguments***)
   2880 try:
-> 2881     result = self._run_cell(
   2882         raw_cell, store_history, silent, shell_futures, cell_id
   2883     )
   2884 finally:

File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2936, in InteractiveShell._run_cell(***failed resolving arguments***)
   2935 try:
-> 2936     return runner(coro)
   2937 except BaseException as e:

File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner(***failed resolving arguments***)
    128 try:
--> 129     coro.send(None)
    130 except StopIteration as exc:

File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3135, in InteractiveShell.run_cell_async(***failed resolving arguments***)
   3133 interactivity = "none" if silent else self.ast_node_interactivity
-> 3135 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   3136        interactivity=interactivity, compiler=compiler, result=result)
   3138 self.last_execution_succeeded = not has_raised

File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3338, in InteractiveShell.run_ast_nodes(***failed resolving arguments***)
   3337     asy = compare(code)
-> 3338 if await self.run_code(code, result, async_=asy):
   3339     return True

File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3398, in InteractiveShell.run_code(***failed resolving arguments***)
   3397     else:
-> 3398         exec(code_obj, self.user_global_ns, self.user_ns)
   3399 finally:
   3400     # Reset our crash handler in place

Input In [1], in <cell line: 22>()
     20     return jnp.sum(res)
---> 22 jax.value_and_grad(foo)(jnp.zeros((10,)))

Input In [1], in foo(***failed resolving arguments***)
     10 def foo(x: jnp.ndarray):
---> 11     res = ott.core.fixed_point_loop.fixpoint_iter_backprop(
     12         cond_fn,
     13         body_fn,
     14         min_iterations=1,
     15         max_iterations=20,
     16         inner_iterations=1,
     17         constants=2,
     18         state=x
     19     )
     20     return jnp.sum(res)

JaxStackTraceBeforeTransformation: TypeError: Called add with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array. 
If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
Input In [1], in <cell line: 22>()
     11     res = ott.core.fixed_point_loop.fixpoint_iter_backprop(
     12         cond_fn,
     13         body_fn,
   (...)
     18         state=x
     19     )
     20     return jnp.sum(res)
---> 22 jax.value_and_grad(foo)(jnp.zeros((10,)))

    [... skipping hidden 11 frame]

File /opt/projects/ott_jt/ott/core/fixed_point_loop.py:225, in fixpoint_iter_bwd(***failed resolving arguments***)
    219   (_, g_state, g_constants), _ = jax.lax.scan(
    220       lambda carry, x: unrolled_body_fn(carry), (0, g, g_constants),
    221       None,
    222       length=max_iterations // inner_iterations
    223   )
    224 else:
--> 225   _, g_state, g_constants = jax.lax.while_loop(
    226       bwd_cond_fn, unrolled_body_fn,
    227       (iteration - inner_iterations, g, g_constants)
    228   )
    230 return g_constants, g_state

    [... skipping hidden 11 frame]

File /opt/projects/ott_jt/ott/core/fixed_point_loop.py:212, in fixpoint_iter_bwd.<locals>.unrolled_body_fn(iteration_g_gconst)
    208 _, pullback = jax.vjp(
    209     unrolled_body_fn_no_errors, iteration, constants, state
    210 )
    211 _, gi_constants, g_state = pullback(g)
--> 212 g_constants = jax.tree_util.tree_map(
    213     lambda x, y: x + y, g_constants, gi_constants
    214 )
    215 out = (iteration - inner_iterations, g_state, g_constants)
    216 return (out, None) if force_scan else out

    [... skipping hidden 2 frame]

File /opt/projects/ott_jt/ott/core/fixed_point_loop.py:213, in fixpoint_iter_bwd.<locals>.unrolled_body_fn.<locals>.<lambda>(x, y)
    208 _, pullback = jax.vjp(
    209     unrolled_body_fn_no_errors, iteration, constants, state
    210 )
    211 _, gi_constants, g_state = pullback(g)
    212 g_constants = jax.tree_util.tree_map(
--> 213     lambda x, y: x + y, g_constants, gi_constants
    214 )
    215 out = (iteration - inner_iterations, g_state, g_constants)
    216 return (out, None) if force_scan else out

    [... skipping hidden 1 frame]

File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:4630, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
   4628 args = (other, self) if swap else (self, other)
   4629 if isinstance(other, _accepted_binop_types):
-> 4630   return binary_op(*args)
   4631 if isinstance(other, _rejected_binop_types):
   4632   raise TypeError(f"unsupported operand type(s) for {opchar}: "
   4633                   f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")

    [... skipping hidden 7 frame]

File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/_src/numpy/ufuncs.py:80, in _maybe_bool_binop.<locals>.fn(x1, x2)
     79 def fn(x1, x2):
---> 80   x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
     81   return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)

File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/_src/numpy/util.py:343, in _promote_args(fun_name, *args)
    341 """Convenience function to apply Numpy argument shape and dtype promotion."""
    342 _check_arraylike(fun_name, *args)
--> 343 _check_no_float0s(fun_name, *args)
    344 return _promote_shapes(fun_name, *_promote_dtypes(*args))

File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/_src/numpy/util.py:330, in _check_no_float0s(fun_name, *args)
    328 """Check if none of the args have dtype float0."""
    329 if any(dtypes.dtype(arg) == dtypes.float0 for arg in args):
--> 330   raise TypeError(
    331       f"Called {fun_name} with a float0 array. "
    332       "float0s do not support any operations by design because they "
    333       "are not compatible with non-trivial vector spaces. No implicit dtype "
    334       "conversion is done. You can use np.zeros_like(arr, dtype=np.float) "
    335       "to cast a float0 array to a regular zeros array. \n"
    336       "If you didn't expect to get a float0 you might have accidentally "
    337       "taken a gradient with respect to an integer argument.")

TypeError: Called add with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array. 
If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument.

Add CI for notebooks

There were 2 notebooks which I found were not runnable in #99 . Think it's a good idea to test/regenerate notebooks on the CI to make sure they are always up-to-date.

application of `low-rank` to single-cell data

TL;DR in this colab we provide an example for our failure in obtaining a valid mapping using low-rank.

problem setup: In this example data set we look into mapping spatial transcriptomics at single cell resolution from mouse embryonic tissues across two time-points. so the quadratic term accounts for distances in spatial coordinates and the linear captures distances in gene-expression.

evaluation of the mapping: As an initial sanity check we look at the cell-transition table, that is the transition matrix with entries grouped by cell types (we asses the row-stochastic, forward, setting). the naive assumption is that cells of the same type, e.g. brain will be mapped mainly mapped to themselves. Evaluating the regular FGW and FGW (unbalanced) this is indeed what we observe. However, for low-rank we get a matrix with constant columns. We observed a similar phenomena at different time-points. Comparing the results we can see hints for the constant columns as they are cell-types also favored in the regular regime.

image
image
image

Follow up on issue #6 of archived repo https://github.com/google-research/ott

Hello,

This is a follow up to this issue (currently read-only) from the previous development branch of this toolbox.

It was suggested there that if we would like to check whether Sinkhorn has run without numerical issues, we should check the out.converged and/or the out.errors.

In the case however, where the call to the sinkhorn function is included in another function (which we would like to jax.jit), checking out.converged and/or out.errors raises a ConcretizationTypeError.

I share here a collab demonstrating the issue.

Are there any pointers as to how this could be avoided?

Thanks in advance.

`nan` in gradients of a function, whose output is a itself a function of an optimal transport plan depending on the inputs: switching from `jnp.float32` to `jnp.float64` can help

I encountered a problem in optimizing a function loss whose output is itself a function of the optimal transport plan between two measurements defined from the inputs of loss.

Indeed, while the value of loss did not diverge and the norm of the gradients did not explode, I obtained from a certain moment nan in the gradients.

All these computations were done with DeviceArray containing jnp.float32. I then switched to jnp.float64 and there were no more nans, all calculations were done correctly.

I'm not sure why this change of type solved the problem (maybe you can help me with this?) but I think that switching from jnp.float32 to jnp.float64 can help some users!

Remark: jax by default enforces single-precision numbers to mitigate the Numpy API’s tendency to aggressively promote operands to double, you need to run th following code to allow the switching to jnp.float64:

from jax.config import config
config.update("jax_enable_x64", True)

Ref: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html

handling padding when computing scaling for cost matrices.

The segmented approaches to pad efficiently point clouds rely on padding with "default" (e.g. zero) vectors. These padded vectors create spurious entries in cost matrices, but when running Sinkhorn, these entries are ignored thanks to the fact that these entries have 0 weights (in the respective a and b weight vectors). However, these entries would play a role when computing scalings (such as mean cost).

A possible approach would be to redefine some of those statistics (e.g. max, mean etc..) to only apply to entries that have a positive weight (e.g. C_ij is considered iff a_i b_j >0). For the mean cost, the mean might be weighted by a_i b_j as well.

Bug in the way GW handles a KL loss

At the moment, the KL loss in the GW solver applies directly to the cost matrices.

This is not really the way we envisioned it in the original paper http://proceedings.mlr.press/v48/peyre16.pdf, in which we were applying KL to the Kernel matrices, and not the cost matrices.

At the moment, the only way to get around this is quite clumsy, since it involves instantiating a Geometry whose cost_matrix is the kernel_matrix of another, e.g. if I use the old API using the make:

geom_1 = pointcloud.PointCloud(x)
geom_xx = geometry.Geometry(geom_1.kernel_matrix)
_ = gromov_wasserstein.gromov_wasserstein(geom_xx,...

I think that by default using the kl loss on geometries should mean that the kernel matrices are directly considered (and not costs). This can also impact other areas (e.g. (F)GW barycenters which at this moment take as input cost matrices and not geometries IIUC).

pre-commit

I'd like to suggest using a pre-commit to mostly unify formatting/add ome simple checks.

For starters, I'd recommend the following:

Example file I like using is here, though there are other checks, such as for typing and documentation: https://github.com/theislab/squidpy/blob/master/.pre-commit-config.yaml
black/isort are configured here: https://github.com/theislab/squidpy/blob/master/pyproject.toml#L5-L58

For new contributors, the workflow would then be:

git clone https://github.com/ott-jax/ott && cd ott
pip install -e'.[dev]'  # extra `dev` requirements contains `pre-commit` package
pre-commit install

Restructuring the codebase

As discussed previously, we have a pretty flat API at the moment.
One proposal would be to split it based on the problems:

- linear
    - initializer.py
    - problem.py
    - solver.py
    - output.py  # not sure if necessary, can be in `solver.py`
    - barycenter.py  # contains both continuous and discrete barycenters
    - math
        - implicit_diff.py
        - momentum.py 
    - low_rank
        - initializer.py
        - solver.py
        - output.py
- quadratic  # same as above, except no `low_rank/math`
- nn  # alt. would be to have `ott/linear/nn` or `ott/nn/linear`, though might be too deep
    - neuraldual.py
    - icnn.py
    - layers.py
- math
    - decomposition.py
    - fixpoint_iter.py
    - matrix_square_root.py
- utils
    - dataclasses.py
    - segment.py
- geometry
    - geometry.py
    - point_cloud.py
    - low_rank.py
    - grid.py
    - graph.py
    - epsilon.py  # not sure
    - costs.py  # we could also split them in separate files in `costs/`, but imho not necessary

Or alternatively as:

- problems
    - linear
        - problem.py
        - discrete_barycenter.py
        - continuous_barycenter.py
    - quadratic
        - problem.py
        - gw_barycenter.py
    - something for `nn` traning?
- solvers
    - linear
        - sinkhorn.py  # would also contain the output
        - sinkhorn_lr.py
    - quadratic
        - gromov_wasserstein.py
    - nn
        - neuraldual.py
        - icnn.py
        - layers.py
- initializers
    - linear
    - quadratic
    - nn
# same as above
- geometry
- utils
- math
- geometry

TBD: tools, esp. gaussians, see #124.
My personal preference is to go with something which resembles the 2nd suggestion, but would appreciate comments from both @marcocuturi @ersisimou .

type not understood error when calling ott.tools.plot.Plot() in high dimensional data

Hello,

I have realized that there is the following error happening when calling ott.tools.plot.Plot() for jax arrays with dimension higher than 2.

While searching, I realized that the error is happening in the bidimensional() function called when the data are high dimensional in order to project them in 2 dimensions.

More specifically the error is happening when we use scipy.sparse.linalg.svds with jnp.concatenate (jax device arrays). If the input vectors are passed like nupmy arrays the function works properly.

To Reproduce

For reproducing the error you can run the initial provided example

  import jax
  import jax.numpy as jnp
  from ott.tools import transport
  # Samples two point clouds and their weights.
  rngs = jax.random.split(jax.random.PRNGKey(0),4)
  n, m, d = 12, 14, 20
  x = jax.random.normal(rngs[0], (n,d)) + 1
  y = jax.random.uniform(rngs[1], (m,d))
  a = jax.random.uniform(rngs[2], (n,))
  b = jax.random.uniform(rngs[3], (m,))
  a, b = a / jnp.sum(a), b / jnp.sum(b)
  # Computes the couplings via Sinkhorn algorithm.
  ot = transport.solve(x, y, a=a, b=b)
  P = ot.matrix

where d = 20 is the dimension of the vectors instead of 2 in the original example.

Then run the following:

plott = ott.tools.plot.Plot()
_ = plott(ot)

The same result happens when we run

scipy.sparse.linalg.svds(y, k=2)
# y is jax DeviceArray here 

But is solved if we run

scipy.sparse.linalg.svds(np.array(y), k=2)
# here it is converted to numpy array

Marking fast tests

Currently, running all tests takes about 1h, we should mark some of them as fast and add a new CI job for them.

`power=1.0` differentiation is unstable due to instability of differentiating distance when points are nearby

Describe the bug
We chose originally to implement the squared Euclidean distance as jnp.sum(x**2,axis=-1) + jnp.sum(y**2,axis=-1) - 2 * jnp.vdot(x,y). Although this works with power=2.0 (which leaves it unchanged) this fails to differentiate elegantly when x~=y and power=1.0 because of 0/0 mishandling. Some custom differentiation is needed in that case, and it's already implemented in jnp.linalg.norm.

To Reproduce

g1 = jax.grad(lambda x,y: jnp.linalg.norm(x-y))
g2 = jax.grad(lambda x,y: (x**2 + y**2 - 2*x*y) ** 0.5)
for eps in range(-3,-9,-1):
  print(g1(1.+1*10**eps,1.))
  print(g2(1.+1*10**eps,1.))
  print('--')

Expected behavior
Here we would have expected similar behaviour, but what I see is

1.0
1.0240479
--
1.0
nan
--
1.0
nan
--
1.0
nan
--
1.0
nan
--
nan
nan
--

As a result, I propose to call the current (misnamed) Euclidean cost function as SqEuclidean and leave it as it is, default to power=1.0 for all PointCloud, and introduce the Euclidean cost function, that will use jnp.linalg.norm. This will also clean up the contradiction that power=2.0 be the current default, which is messy for anything that's not SqEuclidean.

Negative loss in unbalanced Sinkhorn

Unbalanced Sinkhorn results in negative reg_ot_cost.

import jax
from ott.core.sinkhorn import Sinkhorn
from ott.geometry.pointcloud import PointCloud
import jax.numpy as jnp
from ott.core import LinearProblem

n = 1000
dim = 30

rng = jax.random.PRNGKey(0)
rng, *rngs = jax.random.split(rng, 5)
x = jax.random.uniform(rngs[0], (n, dim))
y = jax.random.uniform(rngs[1], (n, dim)) + 0.1
a = jax.random.uniform(rngs[2], (n,))
b = jax.random.uniform(rngs[3], (n,))
a_scaled = a / jnp.sum(a)
b_scaled = b / jnp.sum(b)

lp = LinearProblem(PointCloud(x, y, epsilon=1), a_scaled, b_scaled, tau_a=0.5, tau_b=0.6)
s = Sinkhorn()(lp)

print(s.reg_ot_cost)

The regularised OT cost is -8.937. The more unbalanced, the more negative the OT cost is, which should not be the case in general.

Installation ott-jax: from source.

Vectorization memory issues with `PointCloud.apply_lse_kernel` with `online=True`

We've noticed that when online=True, the memory performance got worse when online=False,
After checking the traceback (attached below), we've found that the 2 netsted vmaps in apply_lse_kernel vectorize the code such that n x m matrix is fully materialized (n/m being the number of points the in PointCloud).

I have a small fix in here which uses online: Optional[int] = None as a batch size. Then in apply_lse_kernel, I use lax.scan and within the loop body, the fully vectorized computation for that batch is used - this reduces the memory complexity from O(n * m) -> O(max(n, m) * batch_size).
So far, can only report that I can run sinkhorn on a PointCloud of shape (65536, 32768) using 8516MiB memory (ohline=16384), which previously raised OOM on 16GiB GPU (Tesla T4).
It took 1157s to run sinkhorn with epsilon=1e-2 ; will try to do more comprehensive benchmarks
later.
As for tests, they all pass, except 4:

FAILED tests/core/sinkhorn_bures_test.py::SinkhornTest::test_bures_point_cloud_ker-batch
FAILED tests/geometry/geometry_pointcloud_apply_test.py::ApplyTest::test_apply_cost_and_kernel
FAILED tests/core/sinkhorn_test.py::SinkhornTest::test_apply_transport_geometry_from_potentials
FAILED tests/core/sinkhorn_test.py::SinkhornTest::test_apply_transport_geometry_from_scalings

Think there might be a more efficient approach.

For benchmarking, I've used the same code as in #9
jax_online_err.txt

Documentation side panel is not entirely visible

Hi,
Thanks for this amazing library, while exploring the API docs I noticed that the side panel is too full and not very helpful in the readability of docs. I think it will be great if you can use text wrapping around it to make sure the API element is visible.
Screenshot 2022-09-26 at 18 43 51

Unifying passing PRNG seeds

Sometimes, we pass the seed: int, other times we directly use the PRNGKeyArray.
We should be more consistent with the jax ecosystem and pass the PRNGKeyArray directly in most cases.

Instability with GW

Hi, I am noticing that adding a small noise of the order of 1e-4 can significantly affect GW performance. When I try to align two vectors that only differ by this noise, then GW can a) produce widely different results with different epsilons and b) across different runs, the results differ significantly (i.e. there doesn't seem to be a clear pattern by which I can set the epsilon value).

Steps to reproduce the behavior:
Here is a colab notebook with minimal code.
https://colab.research.google.com/drive/101IiOthicKKG0TLMKobFURIpNVrdXmZi?usp=sharing (Please feel free to modify anything)

Expected behavior
I expected near 100% accuracy across different runs and epsilon values.

Additional context
I am new to ott-jax and it might be that I am using it incorrectly. Any pointers/tips regarding best practices would be very helpful!

custom_vjp seems to not be used during computation of higher order derivatives of Sinkhorn

Hello,

It seems that the custom_vjp (defined with either _iterations_implicit.defvjp(_iterations_taped, _iterations_implicit_bwd) in sinkhorn.py for the implicit differentiation or fixpoint_iter_backprop.defvjp(fixpoint_iter_fwd, fixpoint_iter_bwd) in fixed_point_loop for the unrolling) does not seem to be used for the computation of higher order derivatives. I believe that this is also related to the error discussed previously in this issue. This is not evident with the computation of jax.hessian (which is implemented as jax.jacfwd(jax.jacrev), as the second derivation uses forward mode automatic differentiation, which is compatible with the jax.lax.while_loop. Therefore, no error is raised even if the custom_vjp is ignored the second time.

I think that this can be seen by adding a breakpoint in the _while_loop_jvp of JAX's control_flow.py. For instance, computation of jax.hessian passes through _while_loop_jvp first, then one time throughfixpoint_iter for implicit diff (or fixpoint_iter_fwd for unrolling) and then one time through _iterations_implicit_bwd for implicit diff (or fixpoint_iter_bwd for unrolling). In the case where a jax.lax.scan is forced by setting min_iterations equal to max_iterations (and both Jacobians are computed with reverse mode), instead of the initial pass through _while_loop_jvp, one gets in the end a pass through _scan_transpose of JAX's control_flow.py.

In either case, I think that if the custom_vjp was not ignored during rederivation, one should get two passes through _iterations_implicit_bwd (equivalently fixpoint_iter_bwd), right?

I am not sure how this could be fixed. The only relevant info that I could find in JAX's documentation for custom_vjp was this:

"Notice that f_jvp calls f to compute the primal outputs. In the context of higher-order differentiation, each application of a differentiation transform will use the custom JVP rule if and only if the rule calls the original f to compute the primal outputs. (This represents a kind of fundamental tradeoff, where we can't make use of intermediate values from the evaluation of f in our rule and also have the rule apply in all orders of higher-order differentiation.)"

I hope I am not missing something. If there is a solution to this, it would be great, as what I need to compute is jax.jacrev(jax.grad) for the sinkhorn divergence and the forced scan option causes a significant computational overhead, especially for the two autocorrelation terms Pxx, Pyy.

Many thanks!

Bug in ICNN neural dual notebook

Hello,

I believe there is a bug in the ICNN neural dual notebook while computing the inverse transport map in plot_ot_map :

def plot_ot_map(neural_dual, source, target, inverse=False):

    if not inverse:
      grad_state_s = neural_dual.transport(source)
    else:
      grad_state_s = neural_dual.inverse_transport(target)

This gets called by:
plot_ot_map(neural_dual, data_target, data_source, inverse=True)

Here, the inverse transport gets computed with data_source as input. However, I think it should be computed using data_target instead, as it is also done during computation of the Sinkhorn loss for potential f:
pred_source = neural_dual.inverse_transport(data_target)

After changing target to source in the plot_ot_map function, I get a different transport map seen on the bottom here:

Bildschirmfoto 2022-05-12 um 17 34 41

How to build an iterable DataLoader for a pre-existing dataset

Dear the authors of OTT,
I am interested in learning a Kantorovich dual using ICNN (mentioned here https://ott-jax.readthedocs.io/en/latest/notebooks/neural_dual.html). However, in the given example, the dataloaders were created for a data generator, which maybe different from the common usage in which data are read from some existing sources. In such case (the data for training are pre-existed), how should the dataloaders be defined and input into the solver to get the neural dual?
Looking forward to your explanation.
Thank you very much!

Manipulation of cost matrix and infinite cost

Problem:
While trying to reproduce the codes from the Single-cell genomics page on my own data, I wondered if it was possible to manipulate the cost matrix before using the sinkhorn function.

If we know that some couplings are impossible, independently of their distance between the x and y arrays, is it possible to define these costs as infinite to penalize them when running the sinkhorn function ? Is it theoretically acceptable to have infinite costs when using OT with the sinkhorn function or others OT algorithms? Are there alternatives to penalize couplings that do not involve infinite values? Should i just run sinkhorn function for several separated PointCloud objects with only possible coupling observations (no infinite values then)?

Solutions considered:
In case infinite values is acceptable, how should I create this new cost matrix?

  • Replacing the cost matrix with geom.cost_matrix = new_cost_matrix returns an error (AttributeError).
  • Create a custom CostFn that performs this step directly, in a custom pairwise method for example?
  • Or a new PointCloud method/parameter to multiply by elements the cost matrix by a matrix of 1.0 or infinite values (something similar to scale_cost but for matrix multiplication)?

Deprecating the functional API

We should add warnings that it will be removed in the future release and then remove the sinkhorn/gromov_wasserstein/make/... functions, since they provide too little (in terms of a wrapper) for too much maintenance.

Sinkhorn barycenter with custom cost function

I'm trying to compute the Sinkhorn barycenter of a set of histograms on [0, 1] with a custom cost function.

@register_pytree_node_class
class MyDist(ott.geometry.costs.CostFn):
  def pairwise(self, x, y):
    return (x - y) ** 4

The following code is used to instantiate a Geometry and compute the barycenter

xgrid = np.linspace(0, 1, 1000)
geometry = PointCloud(xgrid.reshape(-1, 1), cost_fn= MyDist())

a = np.random.uniform(size=(10, 1000))
a = a / np.sum(a, axis=1, keepdims=True)
barycenter = discrete_barycenter.discrete_barycenter(geometry, a)

However, I get the following error:

File /opt/homebrew/Caskroom/miniconda/base/envs/scipy-dev/lib/python3.10/site-packages/ott/core/discrete_barycenter.py:79, in discrete_barycenter(geom, a, weights, dual_initialization, threshold, norm_error, inner_iterations, min_iterations, max_iterations, lse_mode, debiased)
     75   raise ValueError(f'weights must have positive values and size {batch_size}')
     77 if dual_initialization is None:
     78   # initialization strategy from https://arxiv.org/pdf/1503.02533.pdf, (3.6)
---> 79   dual_initialization = geom.apply_cost(a.T, axis=0).T
     80   dual_initialization -= jnp.average(dual_initialization,
     81                                      weights=weights,
     82                                      axis=0)[jnp.newaxis, :]
     84 if debiased and not geom.is_symmetric:

File /opt/homebrew/Caskroom/miniconda/base/envs/scipy-dev/lib/python3.10/site-packages/ott/geometry/pointcloud.py:324, in PointCloud.apply_cost(self, arr, axis, fn)
    305 """Applies cost matrix to array (vector or matrix).
    306 
    307 This function applies the geometry's cost matrix, to perform either
   (...)
    321   A jnp.ndarray, [num_b, batch] if axis=0 or [num_a, batch] if axis=1
    322 """
    323 if fn is None:
--> 324   return self.vec_apply_cost(arr, axis, fn=fn)
    325 # Switch to efficient computation for the squared euclidean case.
    326 return jnp.where(jnp.logical_and(self.is_squared_euclidean,
    327                                  geometry.is_affine(fn)),
    328                  self.vec_apply_cost(arr, axis, fn=fn),
    329                  self._apply_cost(arr, axis, fn=fn))

File /opt/homebrew/Caskroom/miniconda/base/envs/scipy-dev/lib/python3.10/site-packages/ott/geometry/pointcloud.py:378, in PointCloud.vec_apply_cost(self, arr, axis, fn)
    375 nx, ny = (nx, ny) if axis == 0 else (ny, nx)
    377 applied_cost = jnp.dot(nx, arr).reshape(1, -1)
--> 378 applied_cost += ny.reshape(-1, 1) * jnp.sum(arr, axis=0).reshape(1, -1)
    379 cross_term = -2.0 * jnp.dot(y, jnp.dot(x.T, arr))
    380 applied_cost += cross_term[:, None] if rank == 1 else cross_term

File /opt/homebrew/Caskroom/miniconda/base/envs/scipy-dev/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:5252, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
   5250 if not isinstance(other, _accepted_binop_types):
   5251   return NotImplemented
-> 5252 return binary_op(self, other)

    [... skipping hidden 14 frame]

File /opt/homebrew/Caskroom/miniconda/base/envs/scipy-dev/lib/python3.10/site-packages/jax/_src/numpy/ufuncs.py:87, in _maybe_bool_binop.<locals>.fn(x1, x2)
     85 def fn(x1, x2):
     86   x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
---> 87   return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)

    [... skipping hidden 6 frame]

File /opt/homebrew/Caskroom/miniconda/base/envs/scipy-dev/lib/python3.10/site-packages/jax/_src/lax/lax.py:1443, in _broadcasting_shape_rule(name, *avals)
   1441     non_1s = {d for d in ds if not core.symbolic_equal_dim(d, 1)}
   1442     if len(non_1s) > 1:
-> 1443       raise TypeError(f'{name} got incompatible shapes for broadcasting: '
   1444                       f'{", ".join(map(str, map(tuple, shapes)))}.')
   1445     result_shape.append(non_1s.pop() if non_1s else 1)
   1446 return tuple(result_shape)

TypeError: add got incompatible shapes for broadcasting: (1, 10000), (1, 10).

After a bit of digging, it seems related to the definition of norm function in my cost. However, I don't know how to approach solving it

Duplicate Gaussian code

Duplicate code for estimating Gaussian moments and computing transport maps for Gaussian OT problems
- In tools, Gaussian


and
m = matrix_square_root.sqrtm_only(

- And in ICNN initialisers,
def compute_gaussian_map(self, inputs):
,
mo = sqrtm_only(jnp.dot(jnp.dot(covs_sqrt, covt), covs_sqrt))

It may be worth consolidating these. It may also be worth moving Gaussian computations to /core, or creating a hierarchical structure that does not cause cyclical import errors (see e.g. #98 (comment))

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.