Comments (6)
Yeah, that's what I mean about the shuffling, thanks for confirming, so that is not the cause of it. I am still confused because AFAIK this repository and the author's code appear to be doing the exact same thing when calculating MIG (on dSprites at least) but this repository is giving much lower MIG when loading the same models. I'll look more into it over this weekend.
I also looked into the discrete estimation of MIG used in [1], appendix C. (Essentially discretize samples from z in ~20 bins and use sklearn to estimate the discrete MI b/w latents z and generative factors v.) Unfortunately it does not agree with the MIG computed using this sampling-based estimation (consistently lower, reasonably insensitive to number of bins) irrespective of whether we use the mean of the latent representation or we sample from q(z|x), so the jury is still out on how to best estimate MIG I suppose.
Edit: It seems like the MIG scores reported in [1] are consistently lower anyway around 0.2 for the best models, so perhaps this is expected.
'
[1]: https://arxiv.org/abs/1811.12359
from disentangling-vae.
The small MIG is definitely (and unfortunately) something we always had in our experiments. Importantly, I got the same results when using the author's implementation. This is one of the reason we introduced AAM, which measures only the disentanglement rather than disentanglement + amount of information of v about z. I am surprised you get small AAM though.
Here are the results we were getting :
We see that when increasing β by a small amount (from 1 to 4), highly increases axis alignment (from20% to 65%) due to the regularisation of the total correlation, while increasing β by a large amount (from 4 to 50) decreases axis alignment due to the penalisation of the dimension wise KL. I.e. it is not monotonic.
from disentangling-vae.
from disentangling-vae.
Yes it is, if you get an answer / insights please post it here. I would be interested + other people might be.
And just to be clear, I have not tried rerunning the authors code. I only tried using their MIG code to compute the MIG for our results :/ . I.e. it does not seem that the issue comes from the computation of MIG, but to be honest I have not spent too much time on MIG as this was a late addition before a deadline.
from disentangling-vae.
After some digging I am getting better results using the author's MIG calculation code - around 0.3-0.8 for most of my trained models. Perhaps the problem lies in shuffling the dataloader? I notice when I shuffle the dataloader I get a very low MIG (on dSprites).
# Load dataloader
all_loader = (..., shuffle=False)
vae = model
N = len(all_loader.dataset) # number of data samples - don't shuffle
K = vae.latent_dim # number of latent variables
nparams = 2
vae.eval()
qzCx_params = torch.Tensor(N, K, nparams)
n = 0
with torch.no_grad():
for x, gen_factors in all_loader:
batch_size = x.size(0)
x, gen_factors = x.to(device, dtype=torch.float), gen_factors.to(device)
qzCx_stats = torch.stack(vae.encoder(x)['continuous'], dim=2)
qzCx_params[n:n + batch_size] = qzCx_stats.data
n += batch_size
# Reshape to get known generative factors
qzCx_params = qzCx_params.view(3, 6, 40, 32, 32, K, nparams).to(device)
# Sample from diagonal Gaussian posterior q(z|x) using given parameters (mu, logvar)
qzCx_samples = qzCx_sample(params=qzCx_params)
I think the reshape on the second last line requires the dataset to be in the native order so that the generative factors are in the correct order - it's not obvious that they should be though, this is a quirk of the dSprites dataset.
from disentangling-vae.
Thanks for digging into it. What exactly do you mean by shuffling ? We do no shuffle the test loader (
Line 235 in a54b794
BTW : I'm more than happy to accept PRs
from disentangling-vae.
Related Issues (20)
- negetive total correlation loss for btc-vae HOT 5
- Error in losses explanation? HOT 2
- Readme Losses Explanation
- Training not using GPU HOT 2
- TC-BetaVAE's MSS Question
- Doubts about the calculation of H_z HOT 6
- Why is tc_loss in bTCVAE negative? HOT 11
- FashionMNIST background_color not set
- Inplace error when running FactorVAE HOT 6
- Computing MIG and AAM for other datasets HOT 3
- Low MIG values bug found & solution HOT 3
- PlotNeuralNet Code HOT 2
- evaluate.py compute_losses?
- Minor bug in loss logging HOT 1
- Which paper does Axis Alignment Metric publish? HOT 1
- How can I apply it to my own dataset? HOT 2
- Duplicating hyperparameters when training a FactorVAE
- imageio.mimsave error HOT 1
- Dataset with incomplete combinations
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from disentangling-vae.