Comments (6)
That would be an excellent contribution for sure. I will leave the issue open in case there's some news on your side.
from ott.
sorry, my bad. For the implicit differentiation I saw that there are two passes happening through the custom_linear_solve
,(which you had mentioned before in fact). So I guess you don't need the two passes from _iterations_implicit_bwd
and also I suppose custom_vjp
is seen during rederivation. But what is still unclear is why there are these initial passes through the scan_jvp
and the _while_loop_jvp
.. (Because, in fact, even for jax.jacfwd(jax.jacrev)
it goes through both scan_jvp
and _while_loop_jvp
.) If there is an explanation to that it would be very helpful. Do you think there is any possibility that going through the while_loop_jvp
should not be happening? Like, maybe there should be a stop_gradient
somewhere? Also, if the forced scan is inevitable, maybe it would make sense to have different number of iterations for Pxy than for Pxx, Pyy? many thanks
from ott.
Actually, it seems that the two passes (in the case of the implicit differentiation) through the custom_linear_solve
are because of the solution of the two linear systems corresponding to the two Schur complements. Also the pass through the scan_jvp
and the _while_loop_jvp
I think are due to the fact that both scan and while are used in the fixpoint_iter
implementation. Therefore, I am under the impression that the custom_vjp
is in fact not being used during rederivation.
Although the case of unrolling maybe falls into the case that is being described in the JAX documentation (in my first comment above) because the checkpointed states ("intermediate values of f") are used in the vjp rule, this should not be the case with implicit diff, right? Because in that case only the optimal (final) state is used in the vjp rule.
Thanks in advance and apologies for the multiple comments :)
from ott.
Hi Ersi, thanks a lot for your comments, and apologies for the late reply. You are right indeed to count those 2 linear solves in the solve
function in implicit_differentiation.py
. These two small linear solves are there to solve the larger system more efficiently. When derivating again (computing the Hessian) I expect the solutions to these linear systems to be differentiated again, but this would call the custom differentiation rules for linear system (as described here). Have you found anything else that's suspicious or buggy in that pipeline? I do expect some numerical instabilities to arise (implicit diff of those linear systems will be obviously impacted by bad conditioning of these systems), have you experienced those?
from ott.
Hi @marcocuturi and thanks for the reply!
re: custom_vjp and re-derivation: It seems that by adding an @custom_vjp
decorator at _iterations_implicit
and then defining the custom_vjp rule as _iterations_implicit.defvjp(_iterations_taped, _iterations_implicit_bwd)
the rederivation is possible using the custom_vjp rule of implicit diff and one does not get the lax.scan
error. A similar approach can be used for unrolling. One can define an iterations
function at sinkhorn.py
with an @custom_vjp
decorator. In that case, a bit more care is needed in the iterations_bwd
definition because you need two more pull-backs (due to the inputs-outputs of fixpoint_iter_fwd
and fixpoint_iter_bwd
). However, for the unrolling (even if the custom_vjp
is seen in rederivation) one gets again the lax.scan
error. So, I think that for unrolling it is not possible to really use the custom_vjp
in re-derivation. This could be related to the saving of itermediate states that I mentioned before. However, I would not call this a bug. I mean, the Hessian computation is correct (as shown also in the tests). It simply would be nice to have a lighter way of computing higher order derivatives.
re: coditioning of the linear system: Since in fact one needs to tune the ridge parameters in order to ensure that the linear system is well-conditioned in implicit diff (both for gradient and higher order), I am using for the computation of the higher order derivatives the unrolling to be on the safe side :) .
I might make a PR for the re-derivation with implicit diff once I look into it more carefully. I do think it would be nice to have a more computationally efficient way to compute higher order derivatives :)
from ott.
Closing this for now, can reopen later.
from ott.
Related Issues (20)
- enable geodesic geometry in fgw problems HOT 5
- Throw an error if `scale_cost!=1.0` in `geometry.Geodesic` HOT 2
- `rank2` lr initializer not reproducible between 0.4.3 and 0.4.4 HOT 2
- Implement the principled initialisation HOT 4
- AttributeError -> with: sinkhorn_divergence - when: passing in sinkhorn_kwargs={''rank"=#someInteger} HOT 2
- AttributeError: module 'jax.random' has no attribute 'PRNGKeyArray' HOT 2
- how to compute the (bures) wasserstein distance between gaussians of different dimensions HOT 10
- Remove legacy `bool` option when scaling the cost matrix
- Add CITATION.cff HOT 1
- `converged` flag compatibility with `min_iterations` logic HOT 3
- Sinkhorn iteration is not converging in A100 GPU HOT 5
- bug in documentation
- Effects not supported in `custom_vjp` error when using GromovWasserstein HOT 14
- Increased GPU memory usage when using a cost_fn different from costs.SqEuclidean() HOT 10
- `compute_sparse_laplacian` gives int32 vs int64 index mismatch when input is from scipy csr
- Modify neural methods's `__call__` to do just one step
- Misnumbered equation HOT 1
- linear and quadratic part get mixed up in genot HOT 4
- Unbalanced FGW doesn't converge when margins are provided HOT 2
- Add new geometry class for triangulated meshes
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from ott.