Giter Site home page Giter Site logo

eagle's Introduction

EAGLE: Environment-Aware Dynamic Graph Learning for Out-of-Distribution Generalization

This repository is the official implementation of "Environment-Aware Dynamic Graph Learning for Out-of-Distribution Generalization (EAGLE)" accepted by the 37th Conference on Neural Information Processing Systems (NeurIPS 2023).

Black Logo

License     Code Style


0. Abstract

Dynamic graph neural networks (DGNNs) are increasingly pervasive in exploiting spatio-temporal patterns on dynamic graphs. However, existing works fail to generalize under distribution shifts, which are common in real-world scenarios. As the generation of dynamic graphs is heavily influenced by latent environments, investigating their impacts on the out-of-distribution (OOD) generalization is critical. However, it remains unexplored with the following two challenges: 1) How to properly model and infer the complex environments on dynamic graphs with distribution shifts? 2) How to discover invariant patterns given inferred spatio-temporal environments? To solve these challenges, we propose a novel Environment-Aware dynamic Graph LEarning (EAGLE) framework for OOD generalization by modeling complex coupled environments and exploiting spatio-temporal invariant patterns. Specifically, we first design the environment-aware EA-DGNN to model environments by multi-channel environments disentangling. Then, we propose an environment instantiation mechanism for environment diversification with inferred distributions. Finally, we discriminate spatio-temporal invariant patterns for out-of-distribution prediction by the invariant pattern recognition mechanism and perform fine-grained causal interventions node-wisely with a mixture of instantiated environment samples. Experiments on real-world and synthetic dynamic graph datasets demonstrate the superiority of our method against state-of-the-art baselines under distribution shifts. To the best of our knowledge, we are the first to study OOD generalization on dynamic graphs from the environment learning perspective.

1. Requirements

Main package requirements:

  • CUDA == 10.1
  • Python == 3.8.12
  • PyTorch == 1.9.1
  • PyTorch-Geometric == 2.0.1

To install the complete requiring packages, use following command at the root directory of the repository:

pip install -r requirements.txt

2. Quick Start

Training

To train the EAGLE, run the following command in the directory ./scripts:

python main.py --mode=train --use_cfg=1 --dataset=<dataset_name>

Explanations for the arguments:

  • use_cfg: if training with the preset configurations.
  • dataset: name of the datasets. collab, yelp and act are for Table 1, while collab_04, collab_06, and collab_08 are for Table 2.

Evaluation

To evaluate the EAGLE with trained models, run the following command in the directory ./scripts:

python main.py --mode=eval --use_cfg=1 --dataset=<dataset_name>

Please move the trained model in the directory ./saved_model. Note that, we have already provided all the pre-trained models in the directory for quick re-evaluation.

Reproductivity

To reproduce the main results in Table 1 and Table 2, we have already provided all experiment logs in the directory ./logs/history. Run the following command in the directory ./scripts to reproduce the results in results.txt:

python show_result.py

3. Citation

If you find this repository helpful, please consider citing the following paper. We welcome any discussions with [email protected].

@inproceedings{yuan2023environmentaware,
  title={Environment-Aware Dynamic Graph Learning for Out-of-Distribution Generalization},
  author={Yuan, Haonan and Sun, Qingyun and Fu, Xingcheng and Zhang, Ziwei and Ji, Cheng and Peng, Hao and Li, Jianxin},
  booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
  year={2023},
  url={https://openreview.net/forum?id=n8JWIzYPRz}
}

4. Acknowledgements

Part of this code is inspired by Zeyang Zhang et al.'s DIDA. We owe sincere thanks to their valuable efforts and contributions.

eagle's People

Contributors

haonan-yuan avatar suchun-sv avatar

Watchers

 avatar

Forkers

ringbdstack

eagle's Issues

About the model running

Hello, thank you for your excellent work! I've encountered some issues while running the code on the ACT dataset. I've run it multiple times with the default hyperparameters you set, but each time it achieves the same effect at epoch 52. I set the min_epoch to 200, but it's still the same. Below is the running process. It can be seen that the loss always increases dramatically at a regular interval at epochs 50, 100, and 150. Could you provide some help for me?

using gpu:2 to train the model
Loading dataset act
[2024-07-09 10:27:02,570 INFO] Note: NumExpr detected 40 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
[2024-07-09 10:27:02,570 INFO] NumExpr defaulting to 8 threads.
total length: 30, test length: 8
0%| | 0/1000 [00:00<?, ?it/s]Epoch:1, Loss: 4.4040, Time: 16.574
Current: Epoch:1, Train AUC:0.7657, Val AUC: 0.7697, Test AUC: 0.7674
Train: Epoch:1, Train AUC:0.7657, Val AUC: 0.7697, Test AUC: 0.7674
Test: Epoch:1, Train AUC:0.7600, Val AUC: 0.7496, Test AUC: 0.7388
1%|▎ | 9/1000 [02:22<4:22:17, 15.88s/it]Epoch:10, Loss: 2.8576, Time: 16.647
Current: Epoch:10, Train AUC:0.8257, Val AUC: 0.8378, Test AUC: 0.8327
Train: Epoch:10, Train AUC:0.8257, Val AUC: 0.8378, Test AUC: 0.8327
Test: Epoch:10, Train AUC:0.7906, Val AUC: 0.7742, Test AUC: 0.7571
2%|▋ | 19/1000 [05:01<4:16:47, 15.71s/it]Epoch:20, Loss: 1.4233, Time: 15.612
Current: Epoch:20, Train AUC:0.8825, Val AUC: 0.8955, Test AUC: 0.8960
Train: Epoch:20, Train AUC:0.8825, Val AUC: 0.8955, Test AUC: 0.8960
Test: Epoch:20, Train AUC:0.8283, Val AUC: 0.8109, Test AUC: 0.7844
3%|█▏ | 29/1000 [07:37<4:15:07, 15.76s/it]Epoch:30, Loss: 0.9628, Time: 16.965
Current: Epoch:30, Train AUC:0.8947, Val AUC: 0.9126, Test AUC: 0.9125
Train: Epoch:30, Train AUC:0.8947, Val AUC: 0.9126, Test AUC: 0.9125
Test: Epoch:30, Train AUC:0.8295, Val AUC: 0.8270, Test AUC: 0.7905
4%|█▌ | 39/1000 [10:12<4:04:24, 15.26s/it]Epoch:40, Loss: 0.7599, Time: 15.123
Current: Epoch:40, Train AUC:0.9005, Val AUC: 0.9101, Test AUC: 0.9212
Train: Epoch:32, Train AUC:0.8976, Val AUC: 0.9166, Test AUC: 0.9154
Test: Epoch:32, Train AUC:0.8278, Val AUC: 0.8300, Test AUC: 0.7970
5%|█▉ | 49/1000 [12:45<4:04:45, 15.44s/it]Epoch:50, Loss: 63.3256, Time: 16.026
Current: Epoch:50, Train AUC:0.9080, Val AUC: 0.9197, Test AUC: 0.9234
Train: Epoch:50, Train AUC:0.9080, Val AUC: 0.9197, Test AUC: 0.9234
Test: Epoch:50, Train AUC:0.8223, Val AUC: 0.8309, Test AUC: 0.8028
6%|██▎ | 59/1000 [15:20<3:59:56, 15.30s/it]Epoch:60, Loss: 0.7564, Time: 15.065
Current: Epoch:60, Train AUC:0.9014, Val AUC: 0.9090, Test AUC: 0.9153
Train: Epoch:52, Train AUC:0.9110, Val AUC: 0.9217, Test AUC: 0.9248
Test: Epoch:52, Train AUC:0.8208, Val AUC: 0.8367, Test AUC: 0.8107
7%|██▋ | 69/1000 [17:54<4:01:16, 15.55s/it]Epoch:70, Loss: 0.7080, Time: 15.075
Current: Epoch:70, Train AUC:0.9050, Val AUC: 0.9131, Test AUC: 0.9225
Train: Epoch:52, Train AUC:0.9110, Val AUC: 0.9217, Test AUC: 0.9248
Test: Epoch:52, Train AUC:0.8208, Val AUC: 0.8367, Test AUC: 0.8107
8%|███ | 79/1000 [20:25<3:54:53, 15.30s/it]Epoch:80, Loss: 0.8027, Time: 15.724
Current: Epoch:80, Train AUC:0.9033, Val AUC: 0.9172, Test AUC: 0.9215
Train: Epoch:52, Train AUC:0.9110, Val AUC: 0.9217, Test AUC: 0.9248
Test: Epoch:52, Train AUC:0.8208, Val AUC: 0.8367, Test AUC: 0.8107
9%|███▍ | 89/1000 [23:00<3:54:27, 15.44s/it]Epoch:90, Loss: 0.7757, Time: 15.098
Current: Epoch:90, Train AUC:0.9051, Val AUC: 0.9129, Test AUC: 0.9211
Train: Epoch:52, Train AUC:0.9110, Val AUC: 0.9217, Test AUC: 0.9248
Test: Epoch:52, Train AUC:0.8208, Val AUC: 0.8367, Test AUC: 0.8107
10%|███▊ | 99/1000 [25:33<3:48:09, 15.19s/it]Epoch:100, Loss: 60.5369, Time: 14.992
Current: Epoch:100, Train AUC:0.9034, Val AUC: 0.9151, Test AUC: 0.9222
Train: Epoch:52, Train AUC:0.9110, Val AUC: 0.9217, Test AUC: 0.9248
Test: Epoch:52, Train AUC:0.8208, Val AUC: 0.8367, Test AUC: 0.8107
11%|████▏ | 109/1000 [28:04<3:45:16, 15.17s/it]Epoch:110, Loss: 1.1229, Time: 14.996
Current: Epoch:110, Train AUC:0.8894, Val AUC: 0.9036, Test AUC: 0.9088
Train: Epoch:52, Train AUC:0.9110, Val AUC: 0.9217, Test AUC: 0.9248
Test: Epoch:52, Train AUC:0.8208, Val AUC: 0.8367, Test AUC: 0.8107
12%|████▌ | 119/1000 [30:37<3:48:15, 15.54s/it]Epoch:120, Loss: 0.8284, Time: 15.144
Current: Epoch:120, Train AUC:0.8918, Val AUC: 0.9062, Test AUC: 0.9105
Train: Epoch:52, Train AUC:0.9110, Val AUC: 0.9217, Test AUC: 0.9248
Test: Epoch:52, Train AUC:0.8208, Val AUC: 0.8367, Test AUC: 0.8107
13%|████▉ | 129/1000 [33:07<3:37:54, 15.01s/it]Epoch:130, Loss: 0.8237, Time: 15.151
Current: Epoch:130, Train AUC:0.9033, Val AUC: 0.9166, Test AUC: 0.9220
Train: Epoch:52, Train AUC:0.9110, Val AUC: 0.9217, Test AUC: 0.9248
Test: Epoch:52, Train AUC:0.8208, Val AUC: 0.8367, Test AUC: 0.8107
14%|█████▎ | 139/1000 [35:40<3:39:43, 15.31s/it]Epoch:140, Loss: 0.7556, Time: 15.184
Current: Epoch:140, Train AUC:0.9045, Val AUC: 0.9181, Test AUC: 0.9224
Train: Epoch:52, Train AUC:0.9110, Val AUC: 0.9217, Test AUC: 0.9248
Test: Epoch:52, Train AUC:0.8208, Val AUC: 0.8367, Test AUC: 0.8107
15%|█████▋ | 149/1000 [38:13<3:35:59, 15.23s/it]Epoch:150, Loss: 60.8410, Time: 15.222
Current: Epoch:150, Train AUC:0.9074, Val AUC: 0.9215, Test AUC: 0.9241
Train: Epoch:52, Train AUC:0.9110, Val AUC: 0.9217, Test AUC: 0.9248
Test: Epoch:52, Train AUC:0.8208, Val AUC: 0.8367, Test AUC: 0.8107
16%|██████ | 159/1000 [40:45<3:34:45, 15.32s/it]Epoch:160, Loss: 1.0731, Time: 15.109
Current: Epoch:160, Train AUC:0.8933, Val AUC: 0.9030, Test AUC: 0.9126
Train: Epoch:52, Train AUC:0.9110, Val AUC: 0.9217, Test AUC: 0.9248
Test: Epoch:52, Train AUC:0.8208, Val AUC: 0.8367, Test AUC: 0.8107
17%|██████▍ | 169/1000 [43:18<3:32:52, 15.37s/it]Epoch:170, Loss: 1.0648, Time: 15.693
Current: Epoch:170, Train AUC:0.8975, Val AUC: 0.9027, Test AUC: 0.9145
Train: Epoch:52, Train AUC:0.9110, Val AUC: 0.9217, Test AUC: 0.9248
Test: Epoch:52, Train AUC:0.8208, Val AUC: 0.8367, Test AUC: 0.8107
18%|██████▊ | 179/1000 [45:49<3:26:08, 15.07s/it]Epoch:180, Loss: 1.1164, Time: 15.134
Current: Epoch:180, Train AUC:0.8979, Val AUC: 0.9104, Test AUC: 0.9179
Train: Epoch:52, Train AUC:0.9110, Val AUC: 0.9217, Test AUC: 0.9248
Test: Epoch:52, Train AUC:0.8208, Val AUC: 0.8367, Test AUC: 0.8107
19%|███████▏ | 189/1000 [48:20<3:23:20, 15.04s/it]Epoch:190, Loss: 1.2160, Time: 15.072
Current: Epoch:190, Train AUC:0.9023, Val AUC: 0.9136, Test AUC: 0.9202
Train: Epoch:52, Train AUC:0.9110, Val AUC: 0.9217, Test AUC: 0.9248
Test: Epoch:52, Train AUC:0.8208, Val AUC: 0.8367, Test AUC: 0.8107
20%|███████▌ | 199/1000 [50:54<3:31:23, 15.83s/it]Epoch:200, Loss: 59.7033, Time: 15.058
Current: Epoch:200, Train AUC:0.9054, Val AUC: 0.9215, Test AUC: 0.9227
Train: Epoch:52, Train AUC:0.9110, Val AUC: 0.9217, Test AUC: 0.9248
Test: Epoch:52, Train AUC:0.8208, Val AUC: 0.8367, Test AUC: 0.8107
20%|███████▌ | 200/1000 [51:24<3:25:38, 15.42s/it]
train_auc val_auc test_auc ... test_val_auc test_test_auc epoch_time
0 0.910956 0.921692 0.924762 ... 0.836722 0.810747 15.42341

Additionally, on the COLLAB dataset, the computation encounters NaN issues after running for a while(with the default hyperparameters), which leads to the failure of calculating AUC. Could you provide some assistance? Thank you again for your wonderful work.
[2024-07-09 12:02:35,339 INFO] Note: NumExpr detected 40 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
[2024-07-09 12:02:35,339 INFO] NumExpr defaulting to 8 threads.
total length: 16, test length: 5
0%| | 0/1000 [00:00<?, ?it/s]Epoch:1, Loss: 3.7839, Time: 16.402
Current: Epoch:1, Train AUC:0.6514, Val AUC: 0.6667, Test AUC: 0.6666
Train: Epoch:1, Train AUC:0.6514, Val AUC: 0.6667, Test AUC: 0.6666
Test: Epoch:1, Train AUC:0.7670, Val AUC: 0.7898, Test AUC: 0.7506
1%|▎ | 9/1000 [02:22<4:23:53, 15.98s/it]Epoch:10, Loss: 2.3704, Time: 15.174
Current: Epoch:10, Train AUC:0.6912, Val AUC: 0.7103, Test AUC: 0.7075
Train: Epoch:10, Train AUC:0.6912, Val AUC: 0.7103, Test AUC: 0.7075
Test: Epoch:10, Train AUC:0.7833, Val AUC: 0.8110, Test AUC: 0.7552
2%|▋ | 19/1000 [04:55<4:09:15, 15.25s/it]Epoch:20, Loss: 2.4864, Time: 15.275
Current: Epoch:20, Train AUC:0.7333, Val AUC: 0.7546, Test AUC: 0.7527
Train: Epoch:20, Train AUC:0.7333, Val AUC: 0.7546, Test AUC: 0.7527
Test: Epoch:20, Train AUC:0.8021, Val AUC: 0.8389, Test AUC: 0.7667
3%|█▏ | 29/1000 [07:29<4:11:06, 15.52s/it]Epoch:30, Loss: 2.3170, Time: 15.400
Current: Epoch:30, Train AUC:0.7735, Val AUC: 0.8012, Test AUC: 0.7981
Train: Epoch:30, Train AUC:0.7735, Val AUC: 0.8012, Test AUC: 0.7981
Test: Epoch:30, Train AUC:0.8201, Val AUC: 0.8556, Test AUC: 0.7804
3%|█▎ | 33/1000 [08:46<4:16:54, 15.94s/it]
Traceback (most recent call last):
File "/home/jpf/xx/EAGLE/scripts/main.py", line 35, in
results = runner.run()
File "/home/jpf/xx/EAGLE/EAGLE/runner.py", line 529, in run
test_results = self.test(epoch, self.data["test"])
File "/home/jpf/xx/EAGLE/EAGLE/runner.py", line 605, in test
auc, ap = self.loss.predict(z, pos_edge, neg_edge, self.model.edge_decoder)
File "/home/jpf/xx/EAGLE/EAGLE/utils/loss.py", line 56, in predict
return roc_auc_score(y, pred), 0
File "/home/jpf/.local/lib/python3.8/site-packages/sklearn/metrics/_ranking.py", line 551, in roc_auc_score
y_score = check_array(y_score, ensure_2d=False)
File "/home/jpf/.local/lib/python3.8/site-packages/sklearn/utils/validation.py", line 921, in check_array
_assert_all_finite(
File "/home/jpf/.local/lib/python3.8/site-packages/sklearn/utils/validation.py", line 161, in _assert_all_finite
raise ValueError(msg_err)
ValueError: Input contains NaN.

question about the decoder of the ecvae

Hello, I read your nips 2023 paper and i am very interested, but I am confused about equation 8. The decoder of z, pω(z | e, y), is not in the loss function of equation (8), how can it be updated? Actually, the equation 8 seems to reconstruct y,then how the ecvae generate latent z?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.