gwtest's People
gwtest's Issues
[Pytorch Question] How to differentiate GW wrt C1 and C2?
Hi @tvayer ,
I have been following some of your works and the POT framework lately. I noticed that there isn't an autograd.Function available for Gromov-Wasserstein distance available so I started off creating it myself.
Can you let me know if/how you compute the Loss of gromov_wasserstein2 (given here) with respect to C1 and C2 cost matrices of source and target spaces.
I would need that because the backward() func of the corresponding autograd.Function I mentioned above would need that.
A look at what I did as of now :
import numpy as np
import torch
from torch.autograd import Function
from ot.gromov import gromov_wasserstein2
class GromovWassersteinLossFunction(Function):
"""Return GW Loss for input (C1,C2,p,q) """
@staticmethod
def forward(ctx, C1,C2,p,q):
# convert to numpy
C1 = C1.detach().cpu().numpy().astype(np.float64)
C2 = C2.detach().cpu().numpy().astype(np.float64)
p = p.detach().cpu().numpy().astype(np.float64)
q = q.detach().cpu().numpy().astype(np.float64)
p /= p.sum()
q /= q.sum()
T,log= gromov_wasserstein2(C1,C2,p,q,loss_fun='kl_loss',log=True)
T = torch.from_numpy(np.asarray(T))
grad_C1 = #TODO
grad_C2 = #TODO
mark_non_differentiable(p,q)
ctx.save_for_backward(grad_C1,grad_C2)
return torch.sum(T)
@staticmethod
def backward(ctx, grad_output):
grad_C10,grad_C20 = ctx.saved_tensors
grad_C1,grad_C2 = None,None
if ctx.needs_input_grad[0]:
grad_C1 = grad_C10
if ctx.needs_input_grad[1]:
grad_C2 = grad_C20
return grad_C1,grad_C2
def GW(C1,C2,p,q):
"""loss=gromov_wasserstein(C1,C2,p,q)"""
return GromovWassersteinLossFunction.apply(C1,C2,p,q)
It would be a huge help if you could just let me know in regards to whether or how you are computing the gradients (someone on the POT slack suggested that I compute them by hand but I couldn't fathom how to find the original equation for this func as I am a beginner in Optimal Transport).
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. ๐๐๐
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google โค๏ธ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.