Giter Site home page Giter Site logo

beabevi / sun Goto Github PK

View Code? Open in Web Editor NEW
39.0 3.0 2.0 3.08 MB

Understanding and Extending Subgraph GNNs by Rethinking their Symmetries (NeurIPS 2022 Oral)

Home Page: https://arxiv.org/abs/2206.11140

License: MIT License

Python 100.00%
subgraph graph-neural-networks equivariance message-passing-neural-network

sun's Introduction

Understanding and Extending Subgraph GNNs by Rethinking Their Symmetries

This repository contains the official code of the paper Understanding and Extending Subgraph GNNs by Rethinking Their Symmetries (NeurIPS 2022 Oral).

The master branch contains the most recent version of the code, using pyg=2.2.0 and pytorch=1.13.1. For the version submitted at NeurIPS 2022, check out the neurips22 tag.

The code builds on top of the ESAN framework.

Install

First create a conda environment

conda env create -f environment.yml

and activate it

conda activate subgraph

Then, set-up wandb.

Reproduce ogbg-molhiv and ZINC results

We provide the hyperparameter configurations to obtain the reported results on ogbg-molhiv (Table 2) and ZINC (Table 1).

Prepare the data

python data.py --dataset ZINC --policies ego_nets ego_nets_plus
python data.py --dataset ogbg-molhiv --policies ego_nets_plus

Obtain a sweep id <sweep-id> by running

wandb sweep configs/deterministic/<config-name>

where configs/deterministic/<config-name> is one between configs/deterministic/SUN-ogbg-molhiv.yaml and configs/deterministic/SUN-ZINC.yaml.

Run the 10 seeds with

wandb agent <sweep-id>

and compute mean and std of Metric/test_mean over the runs in the sweep to obtain SUN results in Tables 1, 2.

Reproduce other results

First, prepare the data. Run

python data.py --dataset $DATASET --policies $POLICY

where $DATASET is one of the following:

  • TUDatasets (MUTAG, PTC, PROTEINS, NCI1, NCI109, IMDB-BINARY, IMDB-MULTI) - Table 4
  • graphproperty - Table 5
  • subgraphcount (aka counting substructures) - Table 1

and $POLICY is one of the following:

  • ego_nets
  • ego_nets_plus
  • node_marked
  • null

To perform hyperparameter tuning, make use of wandb:

  1. In configs/deterministic folder, choose the yaml file corresponding to the dataset of interest, say <config-name>. This file contains the hyperparameters grid.

  2. Run

    wandb sweep configs/deterministic/<config-name>

    to obtain a sweep id <sweep-id>

  3. Run the hyperparameter tuning with

    wandb agent <sweep-id>

    You can run the above command multiple times on each machine you would like to contribute to the grid-search

  4. Open your project in your wandb account on the browser to see the results:

    • For the TUDatasets refer to Metric/valid_mean and Metric/valid_std to obtain the results.

    • For graphproperty and subgraphcount, compute mean and std of Metric/train_mean, Metric/valid_mean, Metric/test_mean by grouping over all hyperparameters and averaging over the different seeds. Then, take the results corresponding to the configuration obtaining the best validation metric.

Note that in configs/deterministic/SUN-subgraphcount.yaml, key task_idx indicates the target, that is, 0, 1, 2, 3 indicates respectively Triangle, Tailed Tri., Star and 4-Cycle. Similarly in configs/deterministic/SUN-graphproperty.yaml, key task_idx 0, 1, 2 indicates respectively IsConnected, Diameter, Radius.

Get generalisation curves

Values for GIN and GNN-AK models are obtained with the GNN-AK code; DSS-GNN, DS-GNN and NGNN values can be obtained by running the code in this repo with the appropriate model.

We report results for these methods in the out/ folder.

SUN curves can be obtained as detailed below.

4-Cycles (EGO) (Figure 4a)

Prepare the data

python data.py --dataset subgraphcount --policies ego_nets

Run

for i in {1..10}; do python plot.py --batch_size=128 --channels=96 --dataset=subgraphcount --drop_ratio=0 --emb_dim=110 --epochs=250 --gnn_type=originalgin --jk=concat --learning_rate=0.001 --model=sun --num_layer=5 --policy=ego_nets --task_idx=3 --seed="$i"; done

Then, plot the curve in ego_nets-plot.pdf by running

python make_plot.py --policy ego_nets

4-Cycles (EGO+) (Figure 4b)

Prepare the data

python data.py --dataset subgraphcount --policies ego_nets_plus

Run

for i in {1..10}; do python plot.py --batch_size=128 --channels=96 --dataset=subgraphcount --drop_ratio=0 --emb_dim=96 --epochs=250 --gnn_type=originalgin --jk=concat --learning_rate=0.001 --model=sun --num_layer=6 --policy=ego_nets_plus --task_idx=3 --seed="$i"; done

Then, plot the curve in ego_nets_plus-plot.pdf by running

python make_plot.py --policy ego_nets_plus

ZINC (Figure 4c)

Prepare the data

python data.py --dataset ZINC --policies ego_nets

Run

for i in {1..10}; do python plot.py --batch_size=128 --channels=96 --dataset=ZINC --drop_ratio=0 --emb_dim=64 --epochs=400 --gnn_type=zincgin --learning_rate=0.001 --model=sun --num_layer=6 --patience=40 --policy=ego_nets --num_hops=3 --seed="$i"; done

Then, plot the curve in ego_nets-ZINC-plot.pdf by running

python make_plot_zinc.py

Credits

For attribution in academic contexts, please cite

@inproceedings{frasca2022understanding,
title={Understanding and Extending Subgraph GNNs by Rethinking Their Symmetries},
author={Frasca, Fabrizio and Bevilacqua, Beatrice and Bronstein, Michael M and Maron, Haggai},
booktitle={Advances in Neural Information Processing Systems},
year={2022},
}

sun's People

Contributors

beabevi avatar cptq avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.