Giter Site home page Giter Site logo

Comments (6)

marcocuturi avatar marcocuturi commented on June 14, 2024 1

That would be an excellent contribution for sure. I will leave the issue open in case there's some news on your side.

from ott.

ersisimou avatar ersisimou commented on June 14, 2024

sorry, my bad. For the implicit differentiation I saw that there are two passes happening through the custom_linear_solve,(which you had mentioned before in fact). So I guess you don't need the two passes from _iterations_implicit_bwd and also I suppose custom_vjp is seen during rederivation. But what is still unclear is why there are these initial passes through the scan_jvp and the _while_loop_jvp.. (Because, in fact, even for jax.jacfwd(jax.jacrev) it goes through both scan_jvp and _while_loop_jvp.) If there is an explanation to that it would be very helpful. Do you think there is any possibility that going through the while_loop_jvp should not be happening? Like, maybe there should be a stop_gradient somewhere? Also, if the forced scan is inevitable, maybe it would make sense to have different number of iterations for Pxy than for Pxx, Pyy? many thanks

from ott.

ersisimou avatar ersisimou commented on June 14, 2024

Actually, it seems that the two passes (in the case of the implicit differentiation) through the custom_linear_solve are because of the solution of the two linear systems corresponding to the two Schur complements. Also the pass through the scan_jvp and the _while_loop_jvp I think are due to the fact that both scan and while are used in the fixpoint_iter implementation. Therefore, I am under the impression that the custom_vjp is in fact not being used during rederivation.

Although the case of unrolling maybe falls into the case that is being described in the JAX documentation (in my first comment above) because the checkpointed states ("intermediate values of f") are used in the vjp rule, this should not be the case with implicit diff, right? Because in that case only the optimal (final) state is used in the vjp rule.

Thanks in advance and apologies for the multiple comments :)

from ott.

marcocuturi avatar marcocuturi commented on June 14, 2024

Hi Ersi, thanks a lot for your comments, and apologies for the late reply. You are right indeed to count those 2 linear solves in the solve function in implicit_differentiation.py. These two small linear solves are there to solve the larger system more efficiently. When derivating again (computing the Hessian) I expect the solutions to these linear systems to be differentiated again, but this would call the custom differentiation rules for linear system (as described here). Have you found anything else that's suspicious or buggy in that pipeline? I do expect some numerical instabilities to arise (implicit diff of those linear systems will be obviously impacted by bad conditioning of these systems), have you experienced those?

from ott.

ersisimou avatar ersisimou commented on June 14, 2024

Hi @marcocuturi and thanks for the reply!

re: custom_vjp and re-derivation: It seems that by adding an @custom_vjp decorator at _iterations_implicit and then defining the custom_vjp rule as _iterations_implicit.defvjp(_iterations_taped, _iterations_implicit_bwd) the rederivation is possible using the custom_vjp rule of implicit diff and one does not get the lax.scan error. A similar approach can be used for unrolling. One can define an iterations function at sinkhorn.py with an @custom_vjp decorator. In that case, a bit more care is needed in the iterations_bwd definition because you need two more pull-backs (due to the inputs-outputs of fixpoint_iter_fwd and fixpoint_iter_bwd). However, for the unrolling (even if the custom_vjp is seen in rederivation) one gets again the lax.scan error. So, I think that for unrolling it is not possible to really use the custom_vjp in re-derivation. This could be related to the saving of itermediate states that I mentioned before. However, I would not call this a bug. I mean, the Hessian computation is correct (as shown also in the tests). It simply would be nice to have a lighter way of computing higher order derivatives.

re: coditioning of the linear system: Since in fact one needs to tune the ridge parameters in order to ensure that the linear system is well-conditioned in implicit diff (both for gradient and higher order), I am using for the computation of the higher order derivatives the unrolling to be on the safe side :) .

I might make a PR for the re-derivation with implicit diff once I look into it more carefully. I do think it would be nice to have a more computationally efficient way to compute higher order derivatives :)

from ott.

marcocuturi avatar marcocuturi commented on June 14, 2024

Closing this for now, can reopen later.

from ott.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.