Giter Site home page Giter Site logo

emilemathieu / pvae Goto Github PK

View Code? Open in Web Editor NEW
119.0 5.0 42.0 519 KB

code for "Continuous Hierarchical Representations with Poincaré Variational Auto-Encoders".

Home Page: https://arxiv.org/abs/1901.06033

License: MIT License

Python 100.00%
vae pytorch hyperbolic-geometry poincare-embeddings hierarchical-data

pvae's Introduction

demonstrative figure

Code for reproducing the experiments in the paper:

@inproceedings{mathieu2019poincare,
  title={Continuous Hierarchical Representations with Poincar\'e Variational Auto-Encoders},
  author={Mathieu, Emile and Le Lan, Charline and Maddison, Chris J. and Tomioka, Ryota and Whye Teh, Yee},
  booktitle={Advances in Neural Information Processing Systems},
  year={2019}
}

Prerequisites

pip install -r -U requirements.txt or python3 setup.py install --user

Models

VAE (--manifold Euclidean):

  • Prior distribution (--prior): Normal (WrappedNormal is theoretically equivalent)
  • Posterior distribution (--posterior): Normal (WrappedNormal is theoretically equivalent)
  • Decoder architecture (--dec): Linear (MLP) (Wrapped is theoretically equivalent)
  • Encoder architecture (--enc): Linear (MLP) (Wrapped is theoretically equivalent)

PVAE (--manifold PoincareBall):

  • Curvature (--c): 1.0
  • Prior distribution (--prior): WrappedNormal or RiemannianNormal
  • Posterior distribution (--posterior): WrappedNormal or RiemannianNormal
  • Decoder architecture (--dec):
    • Linear (MLP)
    • Wrapped (logarithm map followed by MLP),
    • Geo (first layer is based on geodesic distance to hyperplanes, followed by MLP)
    • Mob (based on Hyperbolic feed-forward layers from Ganea et al (2018))
  • Encoder architecture (--enc): Wrapped or Mob

Run experiments

Synthetic dataset

python3 pvae/main.py --model tree --manifold PoincareBall --latent-dim 2 --hidden-dim 200 --prior-std 1.7 --c 1.2 --data-size 50 --data-params 6 2 1 1 5 5 --dec Wrapped --enc Wrapped  --prior RiemannianNormal --posterior RiemannianNormal --epochs 1000 --save-freq 1000 --lr 1e-3 --batch-size 64 --iwae-samples 5000

MNIST dataset

python3 pvae/main.py --model mnist --manifold Euclidean             --latent-dim 2 --hidden-dim 600 --prior Normal        --posterior Normal        --dec Wrapped --enc Wrapped --lr 5e-4 --epochs 80 --save-freq 80 --batch-size 128 --iwae-samples 5000
python3 pvae/main.py --model mnist --manifold PoincareBall --c 0.7  --latent-dim 2 --hidden-dim 600 --prior WrappedNormal --posterior WrappedNormal --dec Geo     --enc Wrapped --lr 5e-4 --epochs 80 --save-freq 80 --batch-size 128 --iwae-samples 5000

Custom dataset via csv file (placed in /data, no header, integer labels on last column)

python3 pvae/main.py --model csv --data-param CSV_NAME --data-size NB_FEATURES

pvae's People

Contributors

dependabot[bot] avatar emilemathieu avatar macio232 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

pvae's Issues

Negative KL loss

When I tested the Poincaré version on MNIST with latent_dim> 32 I had a negative KL loss, is this normal?

replicate the Graph embeddings experiment

I've been trying to replicate the Graph embeddings experiment present in section 5.3 of the paper. The other two experiments of section 5 are explicit in the readme with their input parameters. Could anybody help me to do that? I downloaded the data sets used in the graph embedding experiment from their sources, but those are in a format different to CSV, and I am not sure if I just need convert their to CSV and use the code with "-model" parameter equal to "csv" or make other things.

normdist2plane potential implementation error?

Hi! It looks like when norm=True, normdist2plane is multiplied by the incorrect term?

In normdist2plane, there is a line
a_norm = a.norm(dim=dim, keepdim=keepdim, p=2).clamp_min(MIN_NORM)
this is the euclidean norm

and subsequently, when norm=True,
res = res * a_norm

However, res, I believe, should be multiplied by the Riemannian norm. From equation 24 in Hyperbolic Neural Network, the distance is computed by
$$sign(\langle -p_k \oplus_c x, a_k\rangle) \sqrt{g^c_{p_k}(a_k, a_k)} d_c(x, \tilde{H}_{a_k, p_k})$$

And $\sqrt{g^c_{p_k}(a_k, a_k)}$ shall take the place of a_norm in the multiplication.

Please correct me if I am wrong. I appreciate your help in advance.

Question about the implementation of wrapped normal

Hello! My name is Seunghyuk Cho and I'm now proceeding my research based on your code.

I have a question on the implementation of wrapped normal (pvae/distributions/wrapped_normal.py).

In your paper, Algorithm 1 shows that sampling from wrapped normal distribution is proceeded as

  1. sample from normal distribution
  2. divide with lambda
  3. apply exponential map

However, your code includes a transportation between 2) and 3). In the original paper that suggested wrapped normal used transportation, but your paper didn't used it and I think ignoring it make sense. So can you answer my question that using transportation is a correct implementation or it is just a mistake?

Thank you for providing the implementation of your awesome paper!

Why is the decoder's output range required to be from 0 to 1 in MNIST?

While inspecting the VAE structure, I noticed that the model applies an additional operation to the encoder's output and, furthermore, it also applies an operation to the decoder's output in the final step. This leads to a requirement for the decoder's output to be in the range of 0 to 1. However, I encounter an error when training the model. Could you please explain the reasons and purposes behind these two operations, and how to address this error? The training command I am using is: 'python3 pvae/main.py --model mnist --manifold PoincareBall --c 0.7 --latent-dim 2 --hidden-dim 600 --prior WrappedNormal --posterior WrappedNormal --dec Geo --enc Wrapped --lr 5e-4 --epochs 80 --save-freq 80 --batch-size 128 --iwae-samples 5000'.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.