Giter Site home page Giter Site logo

jamescao2048 / complex-gated-recurrent-neural-networks Goto Github PK

View Code? Open in Web Editor NEW

This project forked from v0lta/complex-gated-recurrent-neural-networks

0.0 1.0 0.0 31.96 MB

Complex domain recurrent neural network gating and Stiefel-manifold optimization in TensorFlow, NeurIPS 2018

License: Apache License 2.0

Python 99.48% Shell 0.52%

complex-gated-recurrent-neural-networks's Introduction

Code for the paper on complex gated recurrent neural networks (https://arxiv.org/pdf/1806.08267v2.pdf). This project was developed using python 3.6 and Tensorflow 1.10.0 on NVIDIA Titan Xp cards, but does not require 12GB of card memory to run.

To recreate the results in table 1 run bonn_eval_gate_diff.py, once for the adding and one more time for the memory problem. Adjust ./eval/eval.py with the proper log-directories and it will do the evaluation for you.

In order to re-run the human-motion prediction and music transcription experiments in the paper take a look at the human_motion_exp and music_exp directories.

Use the montreal_eval.py file to recreate our experiments on the memory and adding problem shown in figures 2 and 3 of the paper.

This repository contains tensorflow ports of the Theano code at: https://github.com/amarshah/complex_RNN and https://github.com/stwisdom/urnn

The custom optimizers class contains the Stiefel-Manifold optimizer proposed in "Full-Capacity Unitary Recurrent Neural Networks" by Wisdom et al. (https://arxiv.org/abs/1611.00035) this is the default. You can turn off Stiefel-manifold optimization by setting stiefel=false when creating the cell. Please note that you will require a bounded cell activation function such as the Hirose non-linearity for the cell optimization to be stable in this case. In order to work with the basis proposed by Arjovski, Shah et al in "Unitary Evolution Recurrent Neural Networks" (https://arxiv.org/abs/1511.06464) you can set arjovski_basis=True, for the complex cells implemented in custom_cells.py, this setting will work with the default ModRelu-activation.

You don't have to work in the complex domain. To create real valued cells simply set the real argument in the constructor to True and choose a real valued activation such as the relu. The Stiefel manifold optimizer will also work in the real domain.

If you find the code in this repository useful please consider citing:

@inproceedings{wolter2018complex,
     author = {Wolter, Moritz and Yao, Angela},
      title = {Complex Gated Recurrent Neural Networks},
  booktitle = {Advances in Neural Information Processing Systems 31},
       year = {2018},
   abstract = {Complex numbers have long been favoured for digital signal processing, yet
               complex representations rarely appear in deep learning architectures. RNNs, widely
               used to process time series and sequence information, could greatly benefit from
               complex representations. We present a novel complex gated recurrent cell, which
               is a hybrid cell combining complex-valued and norm-preserving state transitions
               with a gating mechanism. The resulting RNN exhibits excellent stability and
               convergence properties and performs competitively on the synthetic memory and
               adding task, as well as on the real-world tasks of human motion prediction.}
}

complex-gated-recurrent-neural-networks's People

Contributors

jamescao2048 avatar v0lta avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.