Giter Site home page Giter Site logo

jianningli / skullvae Goto Github PK

View Code? Open in Web Editor NEW
3.0 2.0 0.0 38.53 MB

training β-VAE by Aggregating a Learned Gaussian Posterior with a Decoupled Decoder

Python 100.00%
disentangled-representations disentanglement skull-reconstruction variational-autoencoder variational-inference shape-completion vae monai unsupervised-learning

skullvae's Introduction

Main Take Away Message

  • Unsupervised skull shape completion using a variational autoencoder (poster).

  • The reconstruction (Dice) loss does not decrease given a large beta in a regular beta-VAE(e.g., beta=100). Note: the initial decrease is due to random initialization of the network before training. The loss does not decrease to a desired small value, as in the following curve in red.

alt text

  • The latent variables from beta=100 can be used for reconstruction by using an independent decoder, and the reconstruction (dice) loss can decrease to a desirable small value. alt text

  • The encoder of the VAE trained using a large beta and the independently trained decoder can be aggragated to form a new VAE that satisfies the latent Gaussian assumption and can produce good reconstruction.

Code

zcrco = zts + γDEVcr
zfaco = zts + γDEVfa

alt text

(1) train the initial VAE using beta=100 or beta=0.0001

python monaiSkullVAE.py --phase train
#python monaiSkullVAE.py --phase test

(2) train a decoder using the latent variables from the previously trained VAE (beta=100)

python VAEDecoderRetrain.py --phase train
#python VAEDecoderRetrain.py --phase test

the decoupled decoder 'newDecoder' takes as input the latent variables 'z' from Step (1) and outputs a reconstruction, using only the reconstruction (dice) loss

# model is the trained VAE with beta=100. z is the latent variable corresponding to an 'input'.
_,_,_,z=model.forward(inputs)
z=torch.tensor(z.cpu().detach().numpy())
# 'newDecoder' is the decoupled decoder
recon_batch = newDecoder(z)

(3) make predictions using the aggregated VAE (encoder from beta=100 + decoupled decoder)

python AggreegateVAE.py

Dataset

Download the dataset here. The dataset is extended from the AutoImplant Challenge. There are 100 healthy skulls, 100 skulls with facial and craial defects:

Dataset

Latent Distributions of the skull variables (Dimension of latent variables reduced from 32 to 2 for illustrative purposes) Latent Distributions


References:

Dataset (SkullFix)

@inproceedings{li2020dataset,
  title={Dataset descriptor for the AutoImplant cranial implant design challenge},
  author={Li, Jianning and Egger, Jan},
  booktitle={Cranial Implant Design Challenge},
  pages={10--15},
  year={2020},
  organization={Springer}
}

Methods

@article{li2022training,
  title={Training β-VAE by Aggregating a Learned Gaussian Posterior with a Decoupled Decoder},
  author={Li, Jianning and Fragemann, Jana and Ahmadi, Seyed-Ahmad and Kleesiek, Jens and Egger, Jan},
  journal={arXiv preprint arXiv:2209.14783},
  year={2022}
}

⭐ Check out our other skull-reconstruction project with MONAI at SkullRec

📧 For questions about the codes, feel free to contact [email protected]

skullvae's People

Contributors

jianningli avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar  avatar

skullvae's Issues

Where is the "000.nrrd" created?

I have been trying to repeat your experiment using your code for academic purposes. As I'm on I reach the third file (AggregateVAE.py), I am having trouble on line 257.

temp,h=nrrd.read('000.nrrd')

I've been wondering if you or anyone could help me understand where the "000.nrrd" file is created?

Thank you, Eric Li, for your amazing work, and also thank you in advanced!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.