Comments (12)
Yes, I understand you,but you just implement the maximization procedure,and I think you lost the minimization in the code. In my opinion,you just implement this
max EV[log Dφ(y)] + EP[log(1 − Dφ(Eψ(x)))] this is the PRIOR in the code
but you lost this
min EV[log Dφ(y)]
because the GAN is a minimax two-play game. Is it right?
from deepinfomaxpytorch.
Thanks for the comments. I'll take a look this weekend. Translation below for those that don't speak Chinese.
zhaoxingwudy: It should be correct, the original two E (logD (y)) maximization and E (1-D (E (x))) maximize, min-max is the two are maximized, and finally take a negative plus To total loss while maximizing three items
SuJingZhi:
I think maybe not. As with the original generation of the anti-network, when optimizing the discriminator, the generator parameters are fixed, maximizing E(logD(y))+E(1-D(E(x))), in order to make the discriminator easier. Distinguish between true and false samples; when optimizing the generator, fix the discriminator parameters and minimize E(1-D(E(x))) in order to make the sample generated by the generator easier to confuse the discriminator. Here are the two processes. , one maximizes, one minimizes, so this code should be problematic.
zhaoxingwidy:
Here G is E, this formula is training D, train.py 97 98 lines are the code to train E and D respectively. But I am very curious for this loss general... all three are added up This loss training encoder and deepMInetwork
from deepinfomaxpytorch.
I understand your point, and it could be the case, as the results are not as expected. The research team has new code out, so I'll take a look and compare.
from deepinfomaxpytorch.
应该是对的, 原始两项E(logD(y))最大化和E(1-D(E(x)))最大化, min-max就是两项都最大化, 最后取了个负号加到总的loss,同时最大化三项
from deepinfomaxpytorch.
我觉得可能不是。和原始的生成对抗网络一样,优化判别器的时候,固定生成器参数,最大化E(logD(y))+E(1-D(E(x))),是为了使得判别器可以更容易区分真假样本;优化生成器时,固定判别器参数,最小化E(1-D(E(x))),是为了使生成器所生成的样本更容易迷惑判别器,这里是两个过程,一个最大化,一个最小化,所以这个代码应该是有问题的。
from deepinfomaxpytorch.
这里的G就是E, 这个式子就是训练D, train.py 97 98行就分别是训练E和D的代码了. 不过我很好奇为啥这个loss通用...三项加起来的全部都是这个loss训练encoder和deepMInetwork
from deepinfomaxpytorch.
OK,thank you very much.
from deepinfomaxpytorch.
emmmm.....in the orignal paper , the appendix said:
Here we provide architectural details for our experiments. Example code for running Deep Infomax (DIM) can be found at
REDACTED.
what's the meaning of REDACTED
?thanks
from deepinfomaxpytorch.
Redacted means the information was removed before publication. It's a bit weird, normally redaction would be done for security or legal reasons, but I don't see that applying here.
from deepinfomaxpytorch.
Ok, so I finally got around to looking at this.
From the paper..
DIM imposes statistical constraints onto learned
representations by implicitly training the encoder so that the push-forward distribution, Uψ,P, matches
a prior, V. This is done (see Figure 6 in the App. A.1) by training a discriminator, Dφ : Y → R, to
estimate the divergence, D(V||Uψ,P), then training the encoder to minimize this estimate:
(ˆω, ψˆ)P = arg min ψ arg max φ Dbφ(V||Uψ,P) = EV[log Dφ(y)] + EP[log(1 − Dφ(Eψ(x)))] (7)
And the code..
prior = torch.rand_like(y)
term_a = torch.log(self.prior_d(prior)).mean()
term_b = torch.log(1.0 - self.prior_d(y)).mean()
PRIOR = - (term_a + term_b) * self.gamma
In words..
Make a uniform random bunch of numbers of same shape as the output.
Feed them through a discriminator network to estimate the probability of noise.
Feed the output, y through the same network.
Take the inverse probability of the output.
Take the expectation(mean) of the log probability of both.
Add them together and multiply by -1 and the gamma hyper parameter.
I think this works. The maximization would occur because we have multiplied by -1 at the end.
Thoughts?
from deepinfomaxpytorch.
I think that you can have a look at the paper 'Adversarial Autoencoder' and it may be helpful.
from deepinfomaxpytorch.
Thank you for such a clean and neat code. I have learned a lot from it!
Back to this prior question, I think @SuJingZhi is correct. However I am not familiar with PyTorch, but I guess we probably need to take these line into a separate loss function (e.g. prior_loss), which is for updating PriorDiscriminator first.
And then come back to DeepInfoMaxLoss, replace the PRIOR with something like
class DeepInfoMaxLoss(nn.Module):
...
y_rescaled = nn.Sigmoid(y) # original author has this 'nonlinearity' set
ENCLOSS = -prior_loss(y_rescaled) * self.gamma # minus sign is to "trick" PriorDiscriminator
return LOCAL + GLOBAL + ENCLOSS
This is how the original author implements it (which is a bit complicated, because he is trying to make everything fits into his 'cortext' framework).
https://github.com/rdevon/DIM/blob/bac4765a8126746675f517c7bfa1b04b88044d51/cortex_DIM/models/prior_matching.py#L70-L75
from deepinfomaxpytorch.
Related Issues (18)
- Great Code! HOT 2
- Question about the loss function HOT 3
- Questions about loss functions HOT 2
- FileNotFoundError: [Errno 2] No such file or directory: 'c:\\data\\deepinfomax\\models\\run5/encoder860.wgt' HOT 3
- Why does prior distribution have no encoder loss? HOT 10
- Why the experimental results are inferior to those reported in paper? HOT 2
- Why do you fine-tune the encoder when training a classifier? HOT 1
- loss score become negative infinity and nan HOT 5
- RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation HOT 1
- The loss value is negative HOT 1
- epoch restart? HOT 1
- Why this code don't apply adversarial training? HOT 1
- How was the result with full loss used?
- model_path = Path(r'c:\data\deepinfomax\models\run1\encoder' + str(epoch))
- Matching representations to a prior distribution is wrong? HOT 3
- How would I apply this to non-image (1-dimensional) data? HOT 2
- some potential bugs w.r.t BN layers HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from deepinfomaxpytorch.