Giter Site home page Giter Site logo

Comments (12)

taldatech avatar taldatech commented on June 1, 2024 1

Are you calculating the FID as described in the code?
Anyway, here is a checkpoint, you can load it on Colab:

https://mega.nz/file/FF0G0QLb#JKskRwfFUt4gfnAakbbzxk_BnGdoEDnjr35rd1Yun5A

image

from soft-intro-vae-pytorch.

GloryyrolG avatar GloryyrolG commented on June 1, 2024 1

Hi, maybe not really... Previously, I trained using the script of train_soft_intro_vae.py as well as soft_intro_vae_image_code_tutorial.ipynb, and they work well both with or without model.eval(), which I think it is of the expectation.

Btw, running the checkpoint you provided at an early time, I got 4.1 FID under eval mode and even 3.9 under train mode, which is much lower than the reported results (4.6). This makes me a bit confused.

from soft-intro-vae-pytorch.

taldatech avatar taldatech commented on June 1, 2024

Hi,
Can you please run the Colab example for 180-220 epochs with the recommended hyper-parameters? You should be able to reach the reported FID (+-std) from the paper.

https://colab.research.google.com/github/taldatech/soft-intro-vae-pytorch/blob/main/soft_intro_vae_tutorial/soft_intro_vae_image_code_tutorial.ipynb

from soft-intro-vae-pytorch.

GloryyrolG avatar GloryyrolG commented on June 1, 2024

Hi Daniel,

Thanks for your instant reply. I have tried your suggestions. The outcome is, running soft_intro_vae_image_code_tutorial.ipynb, I got 5.8 FID while I got 6.6 FID by running train_soft_intro_vae.py. The result reported in the paper is 4.6, smaller. For both, I ran 250 epochs with the recommended hyperparams here.

from soft-intro-vae-pytorch.

GloryyrolG avatar GloryyrolG commented on June 1, 2024

Thanks for your providing the checkpoint!

To calculate FID, I directly adopt the code here after calling model.eval()

if with_fid and ((epoch == 0) or (epoch >= 100 and epoch % 20 == 0) or epoch == num_epochs - 1):

Then I got 4.1 FID. However, the reconstructed images look more whiten than real-look images.

from soft-intro-vae-pytorch.

GloryyrolG avatar GloryyrolG commented on June 1, 2024

Are you calculating the FID as described in the code?
Anyway, here is a checkpoint, you can load it on Colab:

https://mega.nz/file/FF0G0QLb#JKskRwfFUt4gfnAakbbzxk_BnGdoEDnjr35rd1Yun5A

image

Did you use extra datasets for training? Since the reconstruction looks much more normal if I set the model at train mode. Or is there any point I missed? Thanks.

from soft-intro-vae-pytorch.

taldatech avatar taldatech commented on June 1, 2024

I think the whitening artifacts are due to the BatchNorm in IntroVAE's decoder. Try without the model.eval(), just wrap the code with with torch,no_grad():

from soft-intro-vae-pytorch.

GloryyrolG avatar GloryyrolG commented on June 1, 2024
  • With normal model.eval()

Reconstruction

cifa10_grid_reconstructions

Generation

cifa10_grid_generated

  • Commenting model.eval(), i.e., in model.train()

Reconstruction

cifa10_grid_reconstructions

Generation

image

We can see the model works well in the mode of training but evaluation.

Both the encoder as well as the decoder use BatchNorm, hence both can be possible causes? Besides, I think since the model is trained on the training data, it is conventional and reasonable to set eval mode and expect the model to work well?

I also updated the code used for FID calculation:

model.eval()

import sys; sys.path.append('../soft_intro_vae/')
from metrics.fid_score import calculate_fid_given_dataset

train_set = CIFAR10(root='./cifar10_ds', train=True, download=True, transform=transforms.ToTensor())
train_data_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8)
print("calculating fid...")
fid = calculate_fid_given_dataset(train_data_loader, model, batch_size, cuda=True, dims=2048,
                                  device=device, num_images=50000)
print("fid:", fid)

Thanks & Regards,

from soft-intro-vae-pytorch.

taldatech avatar taldatech commented on June 1, 2024

We used the original IntroVAE's architecture. The other types of Soft-IntroVAE (2D, 3D, Style-SIntro-VAE) don't use BatchNorm in the decoder. I believe that if you remove it from the decoder, you can use model.eval() as usual (let me know if you try that, probably also requires hyper-param tuning). BatchNorm doesn't always behave in generative models (especially in the generator/decoder).

from soft-intro-vae-pytorch.

taldatech avatar taldatech commented on June 1, 2024

Well, I don't have an explanation for this, the only difference I can think of is the version of PyTorch assuming you are indeed using the same code for the FID calculation...

from soft-intro-vae-pytorch.

GloryyrolG avatar GloryyrolG commented on June 1, 2024

Well, thanks for your kind reply. Actually, I am using the environment.yml in the repo and the calculation function is abovementioned which is directly borrowed from the training script. Anyway, really thanks again:)...

from soft-intro-vae-pytorch.

taldatech avatar taldatech commented on June 1, 2024

Yeah no problem :).

from soft-intro-vae-pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.