Giter Site home page Giter Site logo

zhengchuanpan / gman Goto Github PK

View Code? Open in Web Editor NEW
401.0 11.0 105.0 4.84 MB

GMAN: A Graph Multi-Attention Network for Traffic Prediction (GMAN, https://fanxlxmu.github.io/publication/aaai2020/) was accepted by AAAI-2020.

License: Apache License 2.0

Python 100.00%
gman traffic-prediction aaai2020

gman's People

Contributors

fanxlxmu avatar zhengchuanpan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

gman's Issues

一些关于GMAN的问题

model.py的line142这里x和y的shape不应该一致吗? 还有请问楼主tf是啥版本的。感谢
image

去除transformAttention评估模型时,是怎么处理维度不匹配的问题的?

为了研究模型中各部分的影响,作者去除了transformAttention部分来评估模型性能,如果直接将此部分的代码注释,在执行decoder部分的spatialAttention的 X = tf.concat((X, STE), axis = -1)这一行代码会报维度不匹配错误,所以想请教作者在去除transformAttention模块来评估模型的性能时,是如何处理维度的

Time Features - Ordinality

Doesn't the way time features were encoded introduce ordinality?

For example, if Sunday is encoded as 1 and Thursday is encoded as 5 - doesn't that let the model think Thursday is more important than Sunday.

Is this understanding correct? If yes, could you help to understand why that decision was taken during model design?

group spatial attention

论文提到采用了把节点分组的方式,理乱上减少了计算的复杂度,请问在代码中计算空间注意力这一块儿,哪里体现了分组计算呢?

Validation error nan

I've been trying to run the MATR example, and from the first iteration I'm receving validation error "nan", as a consequence the model stops learning after 10 iterations. Is there are problem with the code?

求指点:如何解决AttributeError: 'numpy.bytes_' object has no attribute 'delta'

utils.py中 timeofday = (Time.hour * 3600 + Time.minute * 60 + Time.second) // Time.freq.delta.total_seconds() 这一句报错

Traceback (most recent call last):
File "/Users/crowd/PycharmProjects/GMAN/METR/train.py", line 55, in
mean, std) = utils.loadData(args)
File "/Users/crowd/PycharmProjects/GMAN/METR/utils.py", line 73, in loadData
timeofday = (Time.hour * 3600 + Time.minute * 60 + Time.second) // Time.freq.delta.total_seconds()
AttributeError: 'numpy.bytes_' object has no attribute 'delta'

我没有修改过作者源码 请问这个问题大家是怎么解决的

HELP,NotImplementedError: reshaping is not supported for Index objects

Traceback (most recent call last):
File "D:/GitHub源代码/GMAN-master/GMAN-master/METR/train.py", line 55, in
mean, std) = utils.loadData(args)
File "D:\GitHub源代码\GMAN-master\GMAN-master\METR\utils.py", line 72, in loadData
dayofweek = np.reshape(Time.weekday, newshape = (-1, 1))
File "D:\Software\Anaconda3\envs\tensorflow\lib\site-packages\numpy\core\fromnumeric.py", line 232, in reshape
return _wrapfunc(a, 'reshape', newshape, order=order)
File "D:\Software\Anaconda3\envs\tensorflow\lib\site-packages\numpy\core\fromnumeric.py", line 57, in _wrapfunc
return getattr(obj, method)(*args, **kwds)
File "D:\Software\Anaconda3\envs\tensorflow\lib\site-packages\pandas\core\indexes\base.py", line 1149, in reshape
raise NotImplementedError("reshaping is not supported "
NotImplementedError: reshaping is not supported for Index objects

Reproducing the results

Hello,
Thank you very much for sharing your code with the community.

After many attempts with different hyperparameters we have not been able to reproduce any results from the paper (or even get close). Was anyone been able to reproduce the results or do the authors have any pointers in how to achieve this?
Thank you.

请问下loadData()里面 Time = df.index报错是为什么啊?

我使用的DCRNN下载下来的METR.h5文件,使用pandas对其进行读取,生成Time Embedding时,代码中TIME = df.index报错,如下:
ssh://[email protected]:22/home/tank/anaconda3/envs/lpb/bin/python3.6 -u /home/tank/lxl/GMAN/GMAN-master/METR/analyzeData.py
Traceback (most recent call last):
File "/home/tank/lxl/GMAN/GMAN-master/METR/analyzeData.py", line 37, in
print(df.index)
File "/home/tank/anaconda3/envs/lpb/lib/python3.6/site-packages/pandas/core/indexes/base.py", line 852, in repr
attrs = self._format_attrs()
File "/home/tank/anaconda3/envs/lpb/lib/python3.6/site-packages/pandas/core/indexes/datetimelike.py", line 381, in _format_attrs
freq = self.freqstr
File "/home/tank/anaconda3/envs/lpb/lib/python3.6/site-packages/pandas/core/indexes/extension.py", line 54, in fget
result = getattr(self.data, name)
File "/home/tank/anaconda3/envs/lpb/lib/python3.6/site-packages/pandas/core/arrays/datetimelike.py", line 1104, in freqstr
return self.freq.freqstr
AttributeError: 'numpy.bytes
' object has no attribute 'freqstr'

Process finished with exit code 1

请问是我的数据集不对吗?还是我的Pandas版本(1.1.4)不对啊,为什么无法获取到这个index呢?万分感谢

数据

对这个工作非常感兴趣,请问能否提供下完整的数据,包括SE?

A doubt

Why GAMN good at long-term forecasting, but not so obvious for short-term forecasting.
it is because Temporal attention? Looking forward to your answer,thanks.

有谁复现出了PeMS上的结果吗?

我在TensorFlow2上兼容模式跑的,还把patience调成了20,测试集平均MAE为1.66,与报告的水平有差距

testing time: 36.1s
MAE RMSE MAPE
train 1.32 2.87 2.78%
val 1.59 3.72 3.61%
test 1.66 3.82 3.74%
performance in each prediction step
step: 01 0.99 1.88 1.96%
step: 02 1.21 2.47 2.50%
step: 03 1.38 2.97 2.93%
step: 04 1.52 3.35 3.30%
step: 05 1.62 3.65 3.61%
step: 06 1.71 3.90 3.86%
step: 07 1.78 4.09 4.08%
step: 08 1.85 4.25 4.26%
step: 09 1.90 4.38 4.42%
step: 10 1.95 4.49 4.55%
step: 11 1.99 4.59 4.67%
step: 12 2.03 4.67 4.78%
average: 1.66 3.72 3.74%
total time: 3.4min

The data shape is different from DCRNN, GraphWavnet.

The previous works data shape is:
train shape X(36465, 12, 325, 2) Y(36465, 12, 325, 2)
val shape X(5209, 12, 325, 2) Y(5209, 12, 325, 2)
test shape X(10419, 12, 325, 2) Y(10419, 12, 325, 2)
Your is:
trainX: (36458, 12, 325) trainY: (36458, 12, 325)
valX: (5189, 12, 325) valY: (5189, 12, 325)
testX: (10400, 12, 325) testY: (10400, 12, 325)
I'm confused about it.

about model performance

great work!
I have a question about the computation of attention coefficient. Did you ever do experience to compare the model performance with STE block and without STE block?

ZeroDivisionError: float division by zero

在生成SE时,preprocess_transition_probs()的normalized_probs = [float(u_prob)/norm_const for u_prob in unnormalized_probs]出现问题:ZeroDivisionError: float division by zero

Masking in Loss function

I have seen various masking applications in the code yet it wasn't mentioned in paper. Especially in the mae_loss(), masking is applied. What is the purpose of this application?

Inconsistencies with the paper

Hello, firstly I would like to thank you for sharing the code. I was looking at the Spatial Attention component (line 56 in model.py) and I've noticed some differences from what is presented in the paper:

  1. I can not find where you're splitting the vertices into G partitions (and doing the intra/inter group attention). As far as I can understand the spatialAttention function does only the intra-group spatial attention without any restrictions.
  2. After you're computing eq 7 (line 86 in model.py) the output is projected again using 2 FC layers, which in the paper are not described. What is the reason for it?
  3. Looking at eq 7 the input of function f3 is the previous hidden representation where in you're code you're also using the static graph embeddings (e_{v,tj})

Looking forward for your reply.

The length of the PEMS data

In the paper, it mentioned that "traffic speed prediction on the PeMS dataset (Li et al. 2018b)), which contains 6 months of data recorded by 325 traffic sensors ranging from January 1st, 2017 to June 30th, 2017 in the Bay Area." But in the referred paper, it said the data was collected from Jan 1st 2017 to May 31th 2017. Can you provide the 6 month data instead?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.