Giter Site home page Giter Site logo

gwtest's People

Contributors

tvayer avatar

Stargazers

 avatar

Watchers

 avatar  avatar  avatar

gwtest's Issues

[Pytorch Question] How to differentiate GW wrt C1 and C2?

Hi @tvayer ,
I have been following some of your works and the POT framework lately. I noticed that there isn't an autograd.Function available for Gromov-Wasserstein distance available so I started off creating it myself.
Can you let me know if/how you compute the Loss of gromov_wasserstein2 (given here) with respect to C1 and C2 cost matrices of source and target spaces.
I would need that because the backward() func of the corresponding autograd.Function I mentioned above would need that.

A look at what I did as of now :

import numpy as np
import torch
from torch.autograd import Function
from ot.gromov import gromov_wasserstein2
class GromovWassersteinLossFunction(Function):
    """Return GW Loss for input (C1,C2,p,q) """

    @staticmethod
    def forward(ctx, C1,C2,p,q):

        # convert to numpy
        C1 = C1.detach().cpu().numpy().astype(np.float64)
        C2 = C2.detach().cpu().numpy().astype(np.float64)
        p = p.detach().cpu().numpy().astype(np.float64)
        q = q.detach().cpu().numpy().astype(np.float64)
        p /= p.sum()
        q /= q.sum()
        T,log= gromov_wasserstein2(C1,C2,p,q,loss_fun='kl_loss',log=True)
        T = torch.from_numpy(np.asarray(T))
        grad_C1 = #TODO
        grad_C2 = #TODO
        mark_non_differentiable(p,q)
        ctx.save_for_backward(grad_C1,grad_C2)
        return torch.sum(T) 

    @staticmethod
    def backward(ctx, grad_output):

        grad_C10,grad_C20 = ctx.saved_tensors
        grad_C1,grad_C2 = None,None
        if ctx.needs_input_grad[0]:
          grad_C1 = grad_C10
        if ctx.needs_input_grad[1]:
          grad_C2 = grad_C20
        
        return grad_C1,grad_C2



def GW(C1,C2,p,q):
    """loss=gromov_wasserstein(C1,C2,p,q)"""
    return GromovWassersteinLossFunction.apply(C1,C2,p,q)


It would be a huge help if you could just let me know in regards to whether or how you are computing the gradients (someone on the POT slack suggested that I compute them by hand but I couldn't fathom how to find the original equation for this func as I am a beginner in Optimal Transport).

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.