Giter Site home page Giter Site logo

tpa-lstm's Introduction

TPA-LSTM

Original Implementation of ''Temporal Pattern Attention for Multivariate Time Series Forecasting''.

Dependencies

  • python3.6.6

You can check and install other dependencies in requirements.txt.

$ pip install -r requirements.txt
# to install TensorFlow, you can refer to https://www.tensorflow.org/install/

Usage

The following example usage shows how to train and test a TPA-LSTM model on MuseData with settings used in this work.

Training

$ python main.py --mode train \
    --attention_len 16 \
    --batch_size 32 \
    --data_set muse \
    --dropout 0.2 \
    --learning_rate 1e-5 \
    --model_dir ./models/model \
    --num_epochs 40 \
    --num_layers 3 \
    --num_units 338

Testing

$ python main.py --mode test \
    --attention_len 16 \
    --batch_size 32 \
    --data_set muse \
    --dropout 0.2 \
    --learning_rate 1e-5 \
    --model_dir ./models/model \
    --num_epochs 40 \
    --num_layers 3 \
    --num_units 338

tpa-lstm's People

Contributors

shunyaoshih avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

tpa-lstm's Issues

How to deal with various max_len for different batch?

Vice nice project!

But I'm confused with:

  1. Why use CNN with the output of RNN, instead of original features to extract temporal
    information for each variables? Maybe you do some experiments, but forget to mention it in paper?

  2. With different batch, the maximum length of temporal series may be different. How can we deal this issue, if we use static kernel size?

Cheers!

Extracting Weights from Model?

Hi! I'm interested in extracting the attention weights on each variable (I have 75 different variables) to determine the relative importance of each variable in predicting the output. How would I go about extracting these weights?

What's the assert self.para.highway == self.para.attention_len means in data_generator

when I change the data set to electricity and train, it will report a error by the assert self.para.highway == self.para.attention_len , but when I delete this line, it will run but report the input can not be empty?
I don't understand the relationship between the parameter highway and attention_len , could you please give me some help? thanks a lot.

data

数据怎么换成自己数据?

About the normalization in the "TimeSeriesDataGenerator" class

In the "_preprocess" function of the "TimeSeriesDataGenerator" class, why are there two types of normalization methods (specifically, the minmax normalization method for the "electricity" dataset; the other normalization method for the other time series datasets)?
Is the second normalization method reasonable? Because it involves the other variables' information when scaling a single variable.

Model Adaptation to PyTorch

Introduction

Hi, we are a group of students from the University of Toronto. We read the research paper and really like the model. Our motivation for this adaptation stems from the shift in TensorFlow versions from 1.x to 2.x, which posed compatibility challenges for the original model. This project aims to make the model more accessible and maintainable by leveraging the flexibility and user-friendliness of PyTorch.

Project Goals

Recreate the Original Model: Faithfully adapt the model's architecture and functionality from TensorFlow to PyTorch.
Community Collaboration: Encourage contributions and feedback from the community to improve and evolve the model.

How to Use This Repository

Installation: Instructions on setting up the environment and installing necessary dependencies.
Model Architecture: Detailed explanation of the model's architecture, including differences from the original version, if any.
Training and Evaluation: Step-by-step guide on how to train and evaluate the model using provided datasets or custom data.
Contributing: Guidelines for contributing to the project, including coding standards, submitting pull requests, and reporting issues.

Acknowledgements

Original Authors: Recognition of the authors of the original research paper and model.
Community Contributors: @shunyaoshih @Daikon-Sun

Help: AttributeError

AttributeError: module 'tensorflow.python.ops.rnn_cell_impl' has no attribute '_like_rnncell'

AttributeError: 'NoneType' object has no attribute 'inputs'

Traceback (most recent call last):
File "main.py", line 37, in
main()
File "main.py", line 16, in main
graph, model, data_generator = create_graph(para)
File "D:\study\keyan\paper\TPA-LSTM-master\lib\model_utils.py", line 32, in create_graph
model = PolyRNN(para, data_generator)
File "D:\study\keyan\paper\TPA-LSTM-master\lib\model.py", line 14, in init
self._build_graph()
File "D:\study\keyan\paper\TPA-LSTM-master\lib\model.py", line 26, in _build_graph
self.rnn_inputs, self.rnn_inputs_len, self.target_outputs = self.data_generator.inputs(
AttributeError: 'NoneType' object has no attribute 'inputs'

Musedataset cannot download

urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='docs.google.com', port=443): Max retries exceeded with url: /uc?export=download&id=1a5361IfxxEY1mmTfqAviiIkq6u2OYFJ7 (Caused by NewConnectionError('<urllib3.connection.VerifiedHTTPSConnection object at 0x7f992b7dfa58>: Failed to establish a new connection: [Errno 111] Connection refused',))

I can't get musedata from docs.google.com by url. how to solve it ?

Question about attention_wrapper.py

May the code "new_state = (new_state, new_attns, new_attn_states)" should be written to "new_state = (new_state, output, new_attn_states)"?

TRAIN ERROR

when runs the train.py till lines 29
[loss, global_step, _] = sess.run(
fetches=[model.loss, model.global_step, model.update])
it shows error and turns to except tf.errors.OutOfRangeError:
why?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.