---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation Traceback (most recent call last)
File ~/.miniconda3/envs/cellrank/lib/python3.8/runpy.py:194, in _run_module_as_main(***failed resolving arguments***)
193 sys.argv[0] = mod_spec.origin
--> 194 return _run_code(code, main_globals, None,
195 "__main__", mod_spec)
File ~/.miniconda3/envs/cellrank/lib/python3.8/runpy.py:87, in _run_code(***failed resolving arguments***)
80 run_globals.update(__name__ = mod_name,
81 __file__ = fname,
82 __cached__ = cached,
(...)
85 __package__ = pkg_name,
86 __spec__ = mod_spec)
---> 87 exec(code, run_globals)
88 return run_globals
File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ipykernel_launcher.py:17, in <module>
15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()
File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/traitlets/config/application.py:976, in Application.launch_instance(***failed resolving arguments***)
975 app.initialize(argv)
--> 976 app.start()
File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ipykernel/kernelapp.py:712, in IPKernelApp.start(***failed resolving arguments***)
711 try:
--> 712 self.io_loop.start()
713 except KeyboardInterrupt:
File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/tornado/platform/asyncio.py:199, in BaseAsyncIOLoop.start(***failed resolving arguments***)
198 asyncio.set_event_loop(self.asyncio_loop)
--> 199 self.asyncio_loop.run_forever()
200 finally:
File ~/.miniconda3/envs/cellrank/lib/python3.8/asyncio/base_events.py:570, in BaseEventLoop.run_forever(***failed resolving arguments***)
569 while True:
--> 570 self._run_once()
571 if self._stopping:
File ~/.miniconda3/envs/cellrank/lib/python3.8/asyncio/base_events.py:1859, in BaseEventLoop._run_once(***failed resolving arguments***)
1858 else:
-> 1859 handle._run()
1860 handle = None
File ~/.miniconda3/envs/cellrank/lib/python3.8/asyncio/events.py:81, in Handle._run(***failed resolving arguments***)
80 try:
---> 81 self._context.run(self._callback, *self._args)
82 except (SystemExit, KeyboardInterrupt):
File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ipykernel/kernelbase.py:508, in Kernel.dispatch_queue(***failed resolving arguments***)
507 try:
--> 508 await self.process_one()
509 except Exception:
File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ipykernel/kernelbase.py:497, in Kernel.process_one(***failed resolving arguments***)
496 return None
--> 497 await dispatch(*args)
File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ipykernel/kernelbase.py:404, in Kernel.dispatch_shell(***failed resolving arguments***)
403 if inspect.isawaitable(result):
--> 404 await result
405 except Exception:
File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ipykernel/kernelbase.py:728, in Kernel.execute_request(***failed resolving arguments***)
727 if inspect.isawaitable(reply_content):
--> 728 reply_content = await reply_content
730 # Flush output before sending the reply.
File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ipykernel/ipkernel.py:383, in IPythonKernel.do_execute(***failed resolving arguments***)
382 if with_cell_id:
--> 383 res = shell.run_cell(
384 code,
385 store_history=store_history,
386 silent=silent,
387 cell_id=cell_id,
388 )
389 else:
File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ipykernel/zmqshell.py:528, in ZMQInteractiveShell.run_cell(***failed resolving arguments***)
527 self._last_traceback = None
--> 528 return super().run_cell(*args, **kwargs)
File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2881, in InteractiveShell.run_cell(***failed resolving arguments***)
2880 try:
-> 2881 result = self._run_cell(
2882 raw_cell, store_history, silent, shell_futures, cell_id
2883 )
2884 finally:
File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2936, in InteractiveShell._run_cell(***failed resolving arguments***)
2935 try:
-> 2936 return runner(coro)
2937 except BaseException as e:
File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner(***failed resolving arguments***)
128 try:
--> 129 coro.send(None)
130 except StopIteration as exc:
File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3135, in InteractiveShell.run_cell_async(***failed resolving arguments***)
3133 interactivity = "none" if silent else self.ast_node_interactivity
-> 3135 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
3136 interactivity=interactivity, compiler=compiler, result=result)
3138 self.last_execution_succeeded = not has_raised
File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3338, in InteractiveShell.run_ast_nodes(***failed resolving arguments***)
3337 asy = compare(code)
-> 3338 if await self.run_code(code, result, async_=asy):
3339 return True
File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3398, in InteractiveShell.run_code(***failed resolving arguments***)
3397 else:
-> 3398 exec(code_obj, self.user_global_ns, self.user_ns)
3399 finally:
3400 # Reset our crash handler in place
Input In [1], in <cell line: 22>()
20 return jnp.sum(res)
---> 22 jax.value_and_grad(foo)(jnp.zeros((10,)))
Input In [1], in foo(***failed resolving arguments***)
10 def foo(x: jnp.ndarray):
---> 11 res = ott.core.fixed_point_loop.fixpoint_iter_backprop(
12 cond_fn,
13 body_fn,
14 min_iterations=1,
15 max_iterations=20,
16 inner_iterations=1,
17 constants=2,
18 state=x
19 )
20 return jnp.sum(res)
JaxStackTraceBeforeTransformation: TypeError: Called add with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array.
If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument.
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
Input In [1], in <cell line: 22>()
11 res = ott.core.fixed_point_loop.fixpoint_iter_backprop(
12 cond_fn,
13 body_fn,
(...)
18 state=x
19 )
20 return jnp.sum(res)
---> 22 jax.value_and_grad(foo)(jnp.zeros((10,)))
[... skipping hidden 11 frame]
File /opt/projects/ott_jt/ott/core/fixed_point_loop.py:225, in fixpoint_iter_bwd(***failed resolving arguments***)
219 (_, g_state, g_constants), _ = jax.lax.scan(
220 lambda carry, x: unrolled_body_fn(carry), (0, g, g_constants),
221 None,
222 length=max_iterations // inner_iterations
223 )
224 else:
--> 225 _, g_state, g_constants = jax.lax.while_loop(
226 bwd_cond_fn, unrolled_body_fn,
227 (iteration - inner_iterations, g, g_constants)
228 )
230 return g_constants, g_state
[... skipping hidden 11 frame]
File /opt/projects/ott_jt/ott/core/fixed_point_loop.py:212, in fixpoint_iter_bwd.<locals>.unrolled_body_fn(iteration_g_gconst)
208 _, pullback = jax.vjp(
209 unrolled_body_fn_no_errors, iteration, constants, state
210 )
211 _, gi_constants, g_state = pullback(g)
--> 212 g_constants = jax.tree_util.tree_map(
213 lambda x, y: x + y, g_constants, gi_constants
214 )
215 out = (iteration - inner_iterations, g_state, g_constants)
216 return (out, None) if force_scan else out
[... skipping hidden 2 frame]
File /opt/projects/ott_jt/ott/core/fixed_point_loop.py:213, in fixpoint_iter_bwd.<locals>.unrolled_body_fn.<locals>.<lambda>(x, y)
208 _, pullback = jax.vjp(
209 unrolled_body_fn_no_errors, iteration, constants, state
210 )
211 _, gi_constants, g_state = pullback(g)
212 g_constants = jax.tree_util.tree_map(
--> 213 lambda x, y: x + y, g_constants, gi_constants
214 )
215 out = (iteration - inner_iterations, g_state, g_constants)
216 return (out, None) if force_scan else out
[... skipping hidden 1 frame]
File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:4630, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
4628 args = (other, self) if swap else (self, other)
4629 if isinstance(other, _accepted_binop_types):
-> 4630 return binary_op(*args)
4631 if isinstance(other, _rejected_binop_types):
4632 raise TypeError(f"unsupported operand type(s) for {opchar}: "
4633 f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")
[... skipping hidden 7 frame]
File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/_src/numpy/ufuncs.py:80, in _maybe_bool_binop.<locals>.fn(x1, x2)
79 def fn(x1, x2):
---> 80 x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
81 return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)
File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/_src/numpy/util.py:343, in _promote_args(fun_name, *args)
341 """Convenience function to apply Numpy argument shape and dtype promotion."""
342 _check_arraylike(fun_name, *args)
--> 343 _check_no_float0s(fun_name, *args)
344 return _promote_shapes(fun_name, *_promote_dtypes(*args))
File ~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/_src/numpy/util.py:330, in _check_no_float0s(fun_name, *args)
328 """Check if none of the args have dtype float0."""
329 if any(dtypes.dtype(arg) == dtypes.float0 for arg in args):
--> 330 raise TypeError(
331 f"Called {fun_name} with a float0 array. "
332 "float0s do not support any operations by design because they "
333 "are not compatible with non-trivial vector spaces. No implicit dtype "
334 "conversion is done. You can use np.zeros_like(arr, dtype=np.float) "
335 "to cast a float0 array to a regular zeros array. \n"
336 "If you didn't expect to get a float0 you might have accidentally "
337 "taken a gradient with respect to an integer argument.")
TypeError: Called add with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array.
If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument.