Giter Site home page Giter Site logo

berlinera / dsvae-nes Goto Github PK

View Code? Open in Web Editor NEW
4.0 2.0 4.0 109 KB

This repository contains the official PyTorch implementation of the paper: "Learning Discrete Structured VAE using NES".

Python 100.00%
derivative-free-optimization structured-prediction variational-autoencoder

dsvae-nes's Introduction

Learning Discrete Structured Variational Auto-Encoder using Natural Evolution Strategies

This repository contains the official PyTorch implementation of Learning Discrete Structured VAE using NES.

In our paper, we suggest using a black-box, gradient-free method to optimize discrete structured VAEs. We experimentally demonstrate that it is as effective as popular gradient-based approximations although being more general, scalable, and simpler to implement.

alt text

Specifically, the repository contains code to:

  1. Optimize unstructured discrete VAE using NES (was used to demonstrate that NES can scale well with the network size)
  2. Reproduce the dependency parsing experiments (parsing directory)

Note: For the latent structure recovery experiments we relied on the official implementation of Gradient Estimation with Stochastic Softmax Tricks.

Installation

First, install Python 3.7.
Then, clone this repository and install the dependencies (preferably using a conda environment):

git clone https://github.com/BerlinerA/DSVAE-NES
cd DSVAE-NES
pip install -r requirements.txt

For running the dependency parsing experiments, you should also manually install the following packages:

Before running the code, configure the GPUs that NES will run on using the CUDA_VISIBLE_DEVICES environment variable:

export CUDA_VISIBLE_DEVICES=0,1,2

Unstructured discrete VAE

Optimize the unstructured VAE on one of the four supported benchmarks.

For optimization using NES, run:

python train.py --dataset [MNIST, FashionMNIST, KMNIST or Omniglot] --binarize --validate --nes --n_perturb [number of samples]

For optimization using SST, run:

python train.py --dataset [MNIST, FashionMNIST, KMNIST or Omniglot] --binarize --validate --sst

Dependency parsing

Data

The experiments were performed on datasets from the Universal Dependencies. The datasets should be organized as follows:

datasets_dir
├── dataset1_name
│     ├── train.conllu
│     ├── dev.conllu
│     └── test.conllu
└── dataset2_name
      ├── train.conllu
      ├── dev.conllu
      └── test.conllu

Running experiments

In order to run unsupervised domain adaptation experiments, begin with training the model on the source domain in a supervised manner by running:

python train.py --source [source domain name] --target [target domain name] --ext_emb [external word vectors file]
--target_epochs 0

You may include the --non_projective flag for non-projective dependency parsing.

Next, train the structured VAE on the target domain in an unsupervised manner.

For optimization using NES, run:

python train.py --source [source domain name] --target [target domain name] --ext_emb [external word vectors path]
--source_epochs 0 --pretrained_path [pretrained model weights directory] --nes --freeze_decoder

Again, you may include the --non_projective flag for non-projective dependency parsing.

For optimization using DPP, run:

python train.py --source [source domain name] --target [target domain name] --ext_emb [external word vectors path]
--source_epochs 0 --pretrained_path [pretrained model weights directory]

For optimization using sparseMAP, run:

python train.py --source [source domain name] --target [target domain name] --ext_emb [external word vectors path]
--source_epochs 0 --pretrained_path [pretrained model weights directory] --non_projective

For all methods, tune the --target_lr parameter over the [5e-4, 1e-5] interval.

Cite

If you make use of this code for your research, we'll appreciate citing our paper:

@inproceedings{berliner2021learning,
  title={Learning Discrete Structured Variational Auto-Encoder using Natural Evolution Strategies},
  author={Berliner, Alon and Rotman, Guy and Adi, Yossi and Reichart, Roi and Hazan, Tamir},
  booktitle={International Conference on Learning Representations},
  year={2021}
}

dsvae-nes's People

Contributors

berlinera avatar

Stargazers

 avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.