Giter Site home page Giter Site logo

cnn-vae's Introduction

CNN-VAE

A Res-Net Style VAE with an adjustable perceptual loss using a pre-trained vgg19.
Based off of Deep Feature Consistent Variational Autoencoder

Latent space interpolation
Latent space interpolation

My Pytorch Deep Learning Series on Youtube

Whole Playlist
Pytorch VAE Basics

If you found this code useful

Buy me a Coffee

NEW!

Let me know if any other features would be useful!

1.3) Default model is now much larger, but still has a similar memory usage plus much better performance. Added some additional arguments for greater customization!
--norm_type arg to change the layer norm type between BatchNorm (bn) and GroupNorm (gn), use GroupNorm if you can only train with a small batch size.
--num_res_blocks arg defines how many Res-Identity blocks are at the BottleNeck for both the Encoder and Decoder, increase this for a deeper model while maintaining low memory footprint.
--deep_model arg will add a Res-Identity block to each of the up/down-sampling stages for both the Encoder and Decoder, use this to increase depth, but will result in a larger memory footprint + slower training.

1.2) Added Dynamic Depth Architecture, define the "blocks" parameter, a list of channel scales. Each scale will create a new Res up/down block with each block scaling up/down by a factor of 2. Default parameters will downsample a 3x64x64 image to a 256x4x4 latent space although any square image will work.

1.1) Added training script with loss logging etc. Dataset uses Pytorch "ImageFolder" dataset, code assumes there is no pre-defined train/test split and creates one if w fixed random seed so it will be the same every time the code is run.

Training Examples

Notes:
Avoid using a Bottle-Neck feature map size of less than 4x4 as all conv kernels are 3x3, if you do set --num_res_blocks to 0 to avoid adding a lot of model parameters that won't do much
If you can only train with a very small batch size consider using GroupNorm instead of BatchNorm, aka set --norm_type to gn.


Basic training command:
This will create a 51 Million Parameter VAE for a 64x64 sized image and will create a 256x4x4 latent representation.
python train_vae.py -mn test_run --dataset_root #path to dataset root#

Starting from an existing checkpoint:
The code will attempt to load a checkpoint with the name provided in the "save_dir" specified.
python train_vae.py -mn test_run --load_checkpoint --dataset_root #path to dataset root#

Train without a feature loss:
This will also stop the VGG19 model from being created and will result in faster training but lower quality image features.
python train_vae.py -mn test_run --feature_scale 0 --dataset_root #path to dataset root#

Define a Custom Architecture:
Example showing how to define each of the main parameters of the VAE Architecture.
python train_vae.py -mn test_run --latent_channels 128 --block_widths 1 2 4 8 --ch_multi 64 --dataset_root #path to dataset root#

Define Deeper Architecture for a larger image:
Example showing how to change the image size (128x128) used while keeping the same latent representation (256x4x4) by changing the number of blocks.
python train_vae.py -mn test_run --image_size 128 --block_widths 1 2 4 4 8 --dataset_root #path to dataset root#

Train with a 128x128 image with a deeper model by adding Res-Identity Blocks to each down/up-sample stage without additional downsampling.

python train_vae.py -mn test_run --image_size 128 --deep_model  --latent_channels 64 --dataset_root #path to dataset root#

Latent space representation will be 64x8x8, same number of latent variables as before, but a different shape!


Results


Results on validation images of the STL10 dataset at 64x64 with a latent vector size of 512 (images on top are the reconstruction) NOTE: RES_VAE_64_old.py was used to generate the results below

With Perception loss
VAE Trained with perception/feature loss

Without Perception loss
VAE Trained without perception/feature loss

Additional Results - celeba

The images in the STL10 have a lot of variation meaning more "features" need to be encoded in the latent space to achieve a good reconstruction. Using a data-set with less variation (and the same latent vector size) should results in a higher quality reconstructed image.

Celeba trained with perception loss

New Model Test images from VAE trained on CelebA at 128x128 resolution (latent space is therefore 512x2x2) using all layers of the VGG model for the perception loss Celeba 128x128 test images trained with perception loss

As Used in:

@article{ditria2023long,
  title={Long-Term Prediction of Natural Video Sequences with Robust Video Predictors},
  author={Ditria, Luke and Drummond, Tom},
  journal={arXiv preprint arXiv:2308.11079},
  year={2023}
}

cnn-vae's People

Contributors

dependabot[bot] avatar lukeditria avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

cnn-vae's Issues

Cool! how well does it work?

I've been thinking about implementing a resnet style vae for images. In preparation I came across your repo -- very interesting! I'm wondering if you'd be willing to share your thoughts about how well this works? In particular, I'd be interested to understand the impact of the skip connections. Also, qualitatively, what fidelity do you see in the reconstructed images?

Kind regards,
Hudson

How to modify the image size?

Thanks for your code! It did work when I trained with image_size=64, but when I tried to modify the image_size=128, some error
occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-23-cd633620fbd2> in <module>
      6 
      7         #VAE loss
----> 8         loss = vae_loss(recon_data, data[0].to(device), mu, logvar)
      9 
     10         #Perception loss

<ipython-input-4-ba4d722462c7> in vae_loss(recon, x, mu, logvar)
     14 
     15 def vae_loss(recon, x, mu, logvar):
---> 16     recon_loss = F.binary_cross_entropy_with_logits(recon, x)
     17     KL_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).mean()
     18     loss = recon_loss + 0.01 * KL_loss

~/anaconda3/envs/pt1.0/lib/python3.6/site-packages/torch/nn/functional.py in binary_cross_entropy_with_logits(input, target, weight, size_average, reduce, reduction, pos_weight)
   2096 
   2097     if not (target.size() == input.size()):
-> 2098         raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
   2099 
   2100     return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)

ValueError: Target size (torch.Size([8, 3, 128, 128])) must be the same as input size (torch.Size([8, 3, 320, 320]))

So how can I modify the input size of image correctly? Thank you very much, if you can answer my question!

How about 1-channel image reconstrustion?

Thanks for your code! It did work when I trained with image_channel=3, but when I tried to test the image with channel=1, some error occurred:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-38-0565c2bab7d7> in <module>
      9 
     10         #Perception loss
---> 11         loss_feature = feature_loss(data[0].to(device), recon_data, feature_extractor)
     12 
     13         loss += loss_feature

<ipython-input-30-2deca954bbe0> in feature_loss(img, recon_data, feature_extractor)
     26 def feature_loss(img, recon_data, feature_extractor):
     27     img_cat = torch.cat((img, torch.sigmoid(recon_data)), 0)
---> 28     out = feature_extractor(img_cat)
     29     loss = 0
     30     for i in range(len(feature_extractor)):

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
    115     def forward(self, input):
    116         for module in self:
--> 117             input = module(input)
    118         return input
    119 

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py in forward(self, input)
    421 
    422     def forward(self, input: Tensor) -> Tensor:
--> 423         return self._conv_forward(input, self.weight)
    424 
    425 class Conv3d(_ConvNd):

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight)
    417                             weight, self.bias, self.stride,
    418                             _pair(0), self.dilation, self.groups)
--> 419         return F.conv2d(input, weight, self.bias, self.stride,
    420                         self.padding, self.dilation, self.groups)
    421 

RuntimeError: Given groups=1, weight of size [64, 3, 3, 3], expected input[64, 1, 256, 256] to have 3 channels, but got 1 channels instead

So how can I modify the code correctly? Thank you very much, if you could answer my question!

Thank you

Hey I just wanted to say I really like your project. I have tried out other vae repos in the past and this one is much easier to use and performs much better. I was able to very quickly get setup training on my dataset. The VGG19 model vastly improves training performance for me compared to other network structures I've used before. Code is very clean as well. Cheers!

About training for1024*1024 imaging

Dear LukeDitria,
I am a long-time user of this repo and I have raised an issue about scaling to 256 resolutions earlier (if you remember). But now I want to train on much higher resolution images (1024), do you think it is possible or what necessary changes do I need to make?
Best regards,
Jay

Question about class ResUp

In class Decoder, the number of convolutional patterns decreaes (from z to ch8 -> ch8 -> ch4 -> ch2 -> ch).

In class ResUp, the convolution operation goes in this way: self.conv1 = nn.Conv2d(channel_in, channel_out//2, 3, 1, 1).
In this way, the number of convolutional patterns is smaller than channel_out.
Take ResUp(ch*8, ch*4) as an example, the number of convolutional patterns varies by ch*8 -> ch*2 ->ch*4.

I suggest that are there any mistakes in self.conv1 in class ResUp?
I think it should be self.conv1 = nn.Conv2d(channel_in, channel_in//2,3,1,1) or nn.Conv2d(channel_in, channel_out*2, 3, 1, 1) or something that doesn't "shrink" the convolutional patterns during the operation.

size of VAE image

Can the size of VAE image input be changed?Why must it be 64x64?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.