yongfeiyan / gumbel_softmax_vae Goto Github PK
View Code? Open in Web Editor NEWPyTorch implementation of a Variational Autoencoder with Gumbel-Softmax Distribution
PyTorch implementation of a Variational Autoencoder with Gumbel-Softmax Distribution
Hi, what does the latent_dim mean in your code? Could it be changed to other numbers? I can understand that categorical_dim means 10 categories for 10 digits, but I'm confused about the latent_dim. Thanks!
In your code the KL divergence is calculated by:
KLD = torch.sum(qy * (log_qy - 1. / categorical_dim), dim=-1).mean()
I think, for the 1. / categorical_dim
, it should be replaced by the torch.log(1. / categorical_dim),
otherwise, it is not the KL divergence.
The gumbel_softmax_sample function is logits + gumbel_sample. But it should have been F.log_softmax(logits) + gumbel_sample according to the paper. Is this not a mistake?
Hi @YongfeiYan, thanks for sharing your project with us.
I would like to know how do I modify the implementation to use Bernoulli variables. I need the network to generate codes consisting of 0s and 1s.
Thanks.
Gumbel_Softmax_VAE/gumbel_softmax_vae.py
Line 83 in 7d3df6c
Is this temp parameter needed? Isn't temp passed in during the forward pass?
I was wondering what exactly this line in the KLD calculation does:
log_ratio = torch.log(qy * categorical_dim + 1e-20)
In the definition of the ELBO loss, the KLD should be computed between the variational distribution q(z|x) and the prior p(z). How come you did not simply use the pytorch implementation of KLD (kl_div)?
Hey, sorry I haven't run your code to check, but here does this reset the temperature parameter after each epoch?
I'm using your loss function code in my project and getting negative values...
Is it normal to get negative KLD value?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.