Giter Site home page Giter Site logo

hkuds / flashst Goto Github PK

View Code? Open in Web Editor NEW
50.0 3.0 2.0 20.62 MB

[ICML'2024] "FlashST: A Simple and Universal Prompt-Tuning Framework for Traffic Prediction"

Home Page: https://arxiv.org/abs/2405.17898

License: Apache License 2.0

Python 100.00%
pre-training prompt-tuning smart-cities spatio-temporal-prediction traffic-flow-prediction urban-computing

flashst's Introduction

FlashST: A Simple and Universal Prompt-Tuning Framework for Traffic Prediction

A pytorch implementation for the paper: FlashST: A Simple and Universal Prompt-Tuning Framework for Traffic Prediction

Zhonghang Li, Lianghao Xia, Yong Xu, Chao Huang* (*Correspondence)

Data Intelligence Lab@University of Hong Kong, South China University of Technology, PAZHOU LAB


Introduction

In this work, we introduce a simple and universal spatio-temporal prompt-tuning framework, which addresses the significant challenge posed by distribution shift in this field. To achieve this objective, we present FlashST, a framework that adapts pretrained models to the specific characteristics of diverse downstream datasets, thereby improving generalization across various prediction scenarios. We begin by utilizing a lightweight spatio-temporal prompt network for in-context learning, capturing spatio-temporal invariant knowledge and facilitating effective adaptation to diverse scenarios. Additionally, we incorporate a distribution mapping mechanism to align the data distributions of pre-training and downstream data, facilitating effective knowledge transfer in spatio-temporal forecasting. Empirical evaluations demonstrate the effectiveness of our FlashST across different spatio-temporal prediction tasks.

The detailed framework of the proposed FlashST.


Getting Started

Table of Contents:


1. Code Structure [Back to Top]

  • conf: This folder includes parameter settings for FlashST (config.conf) as well as all other baseline models.
  • data: The documentation encompasses all the datasets utilized in our work, alongside prefabricated files and the corresponding file generation codes necessary for certain baselines.
  • lib: Including a series of initialization methods for data processing, as follows:
    • data_process.py: Load, split, generate data, normalization method, slicing, etc.
    • logger.py: For output printing.
    • metrics.py: Method for calculating evaluation indicators.
    • predifineGraph.py: Predefined graph generation method.
    • TrainInits.py: Training initialization, including settings of optimizer, device, random seed, etc.
  • model: Includes the implementation of FlashST and all baseline models, along with the necessary code to support the framework's execution. The args.py script is utilized to generate the required prefabricated data and parameter configurations for different baselines. Additionally, the SAVE folder serves as the storage location for saving the pre-trained models.
  • SAVE: This folder serves as the storage location for saving the trained models, including pretrain, eval and ori.
│  README.md
│  requirements.txt
│
├─conf
│  ├─AGCRN
│  ├─ASTGCN
│  ├─FlashST
│  │  │  config.conf
│  │  │  Params_pretrain.py
│  ├─GWN
│  ├─MSDR
│  ├─MTGNN
│  ├─PDFormer
│  ├─ST-WA
│  ├─STFGNN
│  ├─STGCN
│  ├─STSGCN
│  └─TGCN
│
├─data
│  ├─CA_District5
│  ├─chengdu_didi
│  ├─NYC_BIKE
│  ├─PEMS03
│  ├─PEMS04
│  ├─PEMS07
│  ├─PEMS07M
│  ├─PEMS08
│  ├─PDFormer
│  ├─STFGNN
│  └─STGODE
│
├─lib
│  │  data_process.py
│  │  logger.py
│  │  metrics.py
│  │  predifineGraph.py
│  │  TrainInits.py
│
├─model
│  │  FlashST.py
│  │  PromptNet.py
│  │  Run.py
│  │  Trainer.py
│  │
│  ├─AGCRN
│  ├─ASTGCN
│  ├─DMSTGCN
│  ├─GWN
│  ├─MSDR
│  ├─MTGNN
│  ├─PDFormer
│  ├─STFGNN
│  ├─STGCN
│  ├─STGODE
│  ├─STSGCN
│  ├─ST_WA
│  └─TGCN
│
└─SAVE
    └─pretrain
        ├─GWN
        │      GWN_P8437.pth
        │
        ├─MTGNN
        │      MTGNN_P8437.pth
        │
        ├─PDFormer
        │      PDFormer_P8437.pth
        │
        └─STGCN
                STGCN_P8437.pth
            

2.Environment [Back to Top]

The code can be run in the following environments, other version of required packages may also work.

  • python==3.9.12
  • numpy==1.23.1
  • pytorch==1.9.0
  • cudatoolkit==11.1.1

Or you can install the required environment, which can be done by running the following commands:

# cteate new environmrnt
conda create -n FlashST python=3.9.12

# activate environmrnt
conda activate FlashST

# Torch with CUDA 11.1
pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html

# Install required libraries
pip install -r requirements.txt

3. Run the codes [Back to Top]

cd model
  • To test different models in various modes, you can execute the Run.py code. There are some examples:
# Evaluate the performance of MTGNN enhanced by FlashST on the PEMS07M dataset
python Run.py -dataset_test PEMS07M -mode eval -model MTGNN

# Evaluate the performance of STGCN enhanced by FlashST on the CA_District5 dataset
python Run.py -dataset_test CA_District5 -mode eval -model STGCN

# Evaluate the original performance of STGCN on the chengdu_didi dataset
python Run.py -dataset_test chengdu_didi -mode ori -model STGCN

# Pretrain from scratch with MTGNN model, checkpoint will be saved in FlashST-main/SAVE/pretrain/MTGNN(model name)/xxx.pth
python Run.py -mode pretrain -model MTGNN
  • Parameter setting instructions. The parameter settings consist of two parts: the pre-training model and the baseline model. To avoid any confusion arising from potential overlapping parameter names, we employ a hyphen (-) to specify the parameters of FlashST and use a double hyphen (--) to specify the parameters of the baseline model. Here is an example:
# Set first_layer_embedding_size and out_layer_dim to 32 in STFGNN
python Run.py -model STFGNN -mode ori -dataset_test PEMS08 --first_layer_embedding_size 32 --out_layer_dim 32

Citation

If you find FlashST useful in your research or applications, please kindly cite:

@misc{li2024flashst,
      title={FlashST: A Simple and Universal Prompt-Tuning Framework for Traffic Prediction}, 
      author={Zhonghang Li and Lianghao Xia and Yong Xu and Chao Huang},
      year={2024},
      eprint={2405.17898},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Acknowledgements

We developed our code framework drawing inspiration from AGCRN and GPT-ST. Furthermore, the implementation of the baselines primarily relies on a combination of the code released by the original author and the code from LibCity. We extend our heartfelt gratitude for their remarkable contribution.

flashst's People

Contributors

hkuds avatar lzh-ys1998 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

flashst's Issues

Reproduction of results on Table4

How the w/o fine-tuning results in Table 4 are derived?
I tried to use -mode ori to generate a pre-trained model of stgcn on PEMS07M, but the pretrain mode in the code only yields results after enhancement with FlashST.

DSU parameter

In FlashST.py why did you use DSU parameter which unused!

dimension dismatch problem

I run the provided examples:

# Evaluate the performance of MTGNN enhanced by FlashST on the PEMS07M dataset
python Run.py -dataset_test PEMS07M -mode eval -model MTGNN

# Pretrain from scratch with MTGNN model, checkpoint will be saved in FlashST-main/SAVE/pretrain/MTGNN(model name)/xxx.pth
python Run.py -mode pretrain -model MTGNN

But both examples encounter errors:

  File "/home/liyang/Developer/FlashST/model/PromptNet.py", line 118, in forward
    hidden = torch.cat([time_series_emb] + node_emb + tem_emb, dim=-1).transpose(1, 3)
RuntimeError: Sizes of tensors must match except in dimension 3. Expected size 228 but got size 57 for tensor number 1 in the list.

  File "/home/liyang/Developer/FlashST/model/PromptNet.py", line 118, in forward
    hidden = torch.cat([time_series_emb] + node_emb + tem_emb, dim=-1).transpose(1, 3)
RuntimeError: Sizes of tensors must match except in dimension 3. Expected size 358 but got size 90 for tensor number 1 in the list.

I have prepared the data according to the README.md

Reproduce results

Hi,

python Run.py -dataset_test PEMS07M -mode eval -model MTGNN

produce:

============================scaler_mae_loss
Applying learning rate decay.
2024-08-10 16:46: Experiment log path in: /home/seyed/PycharmProjects/step/FlashST/model/../SAVE/eval/MTGNN
0%| | 0/20 [00:01<?, ?it/s]
Traceback (most recent call last):
File "/home/seyed/PycharmProjects/step/FlashST/model/Run.py", line 173, in
trainer.train_eval()
File "/home/seyed/PycharmProjects/step/FlashST/model/Trainer.py", line 128, in train_eval
train_epoch_loss, loss_pre = self.eval_trn_eps()
File "/home/seyed/PycharmProjects/step/FlashST/model/Trainer.py", line 180, in eval_trn_eps
out, q = self.model(data, data, self.args.dataset_test, self.batch_seen, nadj=nadj, lpls=lpls, useGNN=True, DSU=True)
File "/home/seyed/miniconda3/envs/FlashST/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/seyed/miniconda3/envs/FlashST/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/home/seyed/miniconda3/envs/FlashST/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 178, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/home/seyed/miniconda3/envs/FlashST/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
output.reraise()
File "/home/seyed/miniconda3/envs/FlashST/lib/python3.9/site-packages/torch/_utils.py", line 425, in reraise
raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/home/seyed/miniconda3/envs/FlashST/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/home/seyed/miniconda3/envs/FlashST/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/seyed/PycharmProjects/step/FlashST/model/FlashST.py", line 152, in forward
return self.forward_pretrain(source, label, select_dataset, batch_seen, nadj, lpls, useGNN, DSU)
File "/home/seyed/PycharmProjects/step/FlashST/model/FlashST.py", line 155, in forward_pretrain
x_prompt_return = self.pretrain_model(source[..., :self.input_base_dim], source, None, nadj, lpls, useGNN)
File "/home/seyed/miniconda3/envs/FlashST/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/seyed/PycharmProjects/step/FlashST/model/PromptNet.py", line 118, in forward
hidden = torch.cat([time_series_emb] + node_emb + tem_emb, dim=-1).transpose(1, 3)
RuntimeError: Sizes of tensors must match except in dimension 2. Got 228 and 114 (The offending index is 1)

What should I do?

No data.

Where is the data? NO data folder according to the code structure.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.