Giter Site home page Giter Site logo

stgmt-tensorflow-implementation's Introduction

STGMT-Tensorflow2-implementation

Traffic prediction based on -temporal guided multi graph Sandwich-Transformer(STGMT)

Contents

Background

The ability of spatial-temporal traffic prediction is crucial for urban computing, traffic management and future autonomous driving. In this paper, a novel spatial-temporal guided multi graph Sandwich-Transformer model (STGMT) is proposed, which is composed of the encoder, decoder and attention mechanism modules. The three modules are responsible for feature extraction of historical traffic data, autoregressive prediction and capture the features of spatial-temporal dimension respectively. Compared to the original transformer framework, the spatial-temporal embedding layer's output is introduced to guide the attention mechanism through meta learning which considers the heterogeneity of spatial nodes and temporal nodes. The temporal and spatial features are encoded through Time2Vec(T2V) and Node2Vec(N2V), and coupled into spatial-temporal embedding blocks. In addition, the multi graph is adopted to perform multi-head spatial self attention(MSA). Finally, the attention module and the feed forward layer are recombined to form the Sandwich-Transformer.

Preliminary

Before entering this project, you may need to configure the environment based on Tensorflow2.x-gpu.

!pip install node2vec

Dataset

If you want to run this project, please download the datasets and weight file from the Google. Then put the checkpoints_NYC and checkpoints_pems08 into the project as named 'checkpoints' .After some tossing, you can run data_fac.py to generate data files in pkl format for your training and testing, which may be a long wait. The pkl flie consists of 5 parts->traind data, validation data, test data, multi graph, node2vec results, and inverse_transform scalar

Weight

If you just want to inference and not train your own datasets, you can modify any dataset and name it checkpoints, for example checkpoints_pems08->checkpoints

Training

The backbone STGMT image

The operations.py,layer.py and framework.py are the most important componets in this project. Moerover, You can come up with some innovative and great ideas and you can also can change the hyperparmetes in the Hyperparameters.py if you like .Before train your own datasets, you can just change the train.py, line 18 you can change your datasets path from Hyperparameters.py, line 52, l1 loss is used. So you can finally train the model by running the following command:

python train.py

You will get a new file of your own trained weights saved in checkpoints folders.Don't worry about getting an error, even if there are weight files in the folder, they will be overwritten during training. CheckpointManager in the code can guarantee continuous training or future training line 62 to line 68.

Testing

If you only want to inferrence on our dataset, it doesn't matter. Take the dataset in New York as an example, PEMS08 performs the same operation The test.py is the kernel, before testing, the operation as follows

change the data path-> line 19
change error path and compare path -> line78, line 79
python test.py

We provide three metrics: MAE, RMSE, and MAPE

In the end, the terminate will show the results of 3,6,9,12 steps errors and average errors of each steps. Three tables will saved into your project gap.csv , pred.csv, and ana.xlsx

Results

The result of the NYC prediction:

image

More details please see the paper!

Contributing

At last, thank you very much for the contribution of the co-author in the article, and also thank my girlfriend for giving me the courage to pursue for a Ph.d.

License

MIT © YanjieWen

stgmt-tensorflow-implementation's People

Contributors

yanjiewen avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

stgmt-tensorflow-implementation's Issues

train.py运行时间

您好,我想请问一下train.py运行TEMS08这堆数据需要多久呢?我的硬件设备为NVIDIA GeForce RTX 3050 Ti LapTop GPU,但是我运行了两天还没运行结束,pycharm甚至闪退了,请问这会是什么原因呢?请问有没有什么办法重进pycharm后找回之前的运行记录呢?

数据集下载

您好,下载数据集和权重文件的那个链接无法访问,是您的个人云端硬盘吗?冒昧问下是否方便重新提供一个下载链接,感谢

对于代码的一些疑问

您好,我想请问一下为什么data_fac.py中的函数data_preprocessing对归一化函数minmaxsca进行了调用,但是代码块中没有对minmaxsca进行定义;虽然代码块的前面部分引入了一些库函数之类的,但是显示未使用 import 语句 'from sklearn.preprocessing import MinMaxScaler',这是为什么呢?data_fac.py中的函数data_preprocessing对归一化函数minmaxsca的调用是不是没有起作用呢?希望能够得到作者的解答!十分感谢!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.