Giter Site home page Giter Site logo

victoeywilly / da-rnn-in-tensorflow-2-and-pytorch Goto Github PK

View Code? Open in Web Editor NEW

This project forked from kaelzhang/da-rnn-in-tensorflow-2-and-pytorch

0.0 0.0 0.0 7.68 MB

A Tensorflow 2 (Keras) implementation of DA-RNN (A Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction, arXiv:1704.02971)

License: MIT License

Dockerfile 0.02% Makefile 0.06% Python 2.78% Jupyter Notebook 97.14%

da-rnn-in-tensorflow-2-and-pytorch's Introduction

Tensorflow 2 / Torch DA-RNN

A Tensorflow 2 (Keras) and pytorch implementation of the Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction

Paper: https://arxiv.org/abs/1704.02971

Run notebook demo

Install dependencies (It is recommended to use anaconda to manage environments):

make install

Run notebook:

cd notebook
jupyter lab

# Run `pytorch.ipynb`

Install

For Tensorflow 2

pip install da-rnn[keras]

For PyTorch

pip install da-rnn[torch]

Usage

For Tensorflow 2 (Still buggy for now)

from da_rnn.keras import DARNN

model = DARNN(T=10, m=128)

# Train
model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=100,
    verbose=1
)

# Predict
y_hat = model(inputs)

For PyTorch (Tested. Works)

import torch
from poutyne import Model
from da_rnn.torch import DARNN

darnn = DARNN(n=50, T=10, m=128)
model = Model(darnn)

# Train
model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=100,
    verbose=1
)

# Predict
with torch.no_grad():
    y_hat = model(inputs)

Python Docstring Notations

In docstrings of the methods of this project, we have the following notation convention:

variable_{subscript}__{superscript}

For example:

  • y_T__i means y_T__i, the i-th prediction value at time T.
  • alpha_t__k means alpha_t__k, the attention weight measuring the importance of the k-th input feature (driving series) at time t.

DARNN(T, m, p, y_dim=1)

DARNN(n, T, m, p, y_dim=1)

The naming of the following (hyper)parameters is consistent with the paper, except y_dim which is not mentioned in the paper.

  • n (torch only) int input size, the number of features of a single driving series
  • T int the length (time steps) of the window
  • m int the number of the encoder hidden states
  • p int the number of the decoder hidden states
  • y_dim int=1 the prediction dimention. Defaults to 1.

Return the DA-RNN model instance.

Data Processing

Each feature item of the dataset should be of shape (batch_size, T, length_of_driving_series + y_dim)

And each label item of the dataset should be of shape (batch_size, y_dim)

TODO

  • no hardcoding (1 for now) for prediction dimentionality

License

MIT

da-rnn-in-tensorflow-2-and-pytorch's People

Contributors

kaelzhang avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.