Giter Site home page Giter Site logo

allensmile / pytorch-ntm Goto Github PK

View Code? Open in Web Editor NEW

This project forked from loudinthecloud/pytorch-ntm

0.0 3.0 0.0 4.41 MB

Neural Turing Machines (NTM) - PyTorch Implementation

License: BSD 3-Clause "New" or "Revised" License

Jupyter Notebook 81.30% Python 18.70%

pytorch-ntm's Introduction

PyTorch Neural Turing Machine (NTM)

PyTorch implementation of Neural Turing Machines (NTM).

An NTM is a memory augumented neural network (attached to external memory) where the interactions with the external memory (address, read, write) are done using differentiable transformations. Overall, the network is end-to-end differentiable and thus trainable by a gradient based optimizer.

The NTM is processing input in sequences, much like an LSTM, but with additional benfits: (1) The external memory allows the network to learn algorithmic tasks easier (2) Having a larger capacity without increasing the network's trainable parameters.

The external memory allows the NTM to learn algorithmic tasks, that are much harder for LSTM to learn, and to maintain an internal state much longer than traditional LSTMs.

A PyTorch Implementation

This repository implements a vanilla NTM in a straight forward way. The following architecture is used:

NTM Architecture

  • Support for batch leanring
  • Any read or write head configuration is supported (for example, 5 read heads and 2 write heads), the order of operation is specified by the user

Example of training convergence for the copy task using 4 different seeds.

NTM Convergence

The following plot shows the cost per sequence length during training, the network was trained with seed=7 and shows a fast convergence. Other seeds may not perform as well but should converge in less than 30K iterations.

NTM Convergence

Here is an animated GIF that shows how the model generalize. The model was evaluated after every 500 training samples, using the target sequence shown in the upper part of the image. The bottom part shows the network output at any given training stage.

NTM Convergence

The following is the same, but with sequence length = 80. Note that the network was trained with sequences of lengths 1 to 20.

NTM Convergence

Installation

The NTM can be used as a reusable module, currently not packaged though.

  1. Clone repository
  2. Install PyTorch
  3. pip install -r requirements.txt

Usage

python train.py

pytorch-ntm's People

Contributors

loudinthecloud avatar

Watchers

 avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.