Giter Site home page Giter Site logo

guanghaoliang / prog Goto Github PK

View Code? Open in Web Editor NEW

This project forked from sheldonresearch/prog

0.0 0.0 0.0 91.66 MB

All in One: Multi-task Prompting for Graph Neural Networks, KDD 2023.

Home Page: https://arxiv.org/abs/2307.01504

License: MIT License

Python 100.00%

prog's Introduction


Testing Status Testing Status Testing Status Testing Status Testing Status

| Website | Paper | Video | Raw Code |

ProG (Prompt Graph) is a library built upon PyTorch to easily conduct single or multiple task prompting for a pre-trained Graph Neural Networks (GNNs). The idea is derived from the paper: Xiangguo Sun, Hong Cheng, JIa Li, etc. All in One: Multi-task Prompting for Graph Neural Networks. KDD2023, in which they released the raw codes at Click. This repository is a polished version of the raw codes with Extremely Huge Changes and Updates:

Call for Contributors!

Once you are invited as the contributor, you would be asked to follow the following steps:

  • step 1. create a temp branch (e.g. xgTemp) from the latest xgsun. (xgsun branch is a beta branch and only xgsun can be merged to the main branch.)
  • step 2. fetch origin/xgTemp to your local xgTemp, and make your own changes via PyCharm etc.
  • step 3. push your changes from local xgTemp to your github cloud branch: origin/xgTemp.
  • step 4. open a full request to merge from your branch to xgsun.

When you finished all these jobs. I will get a notification and I will approve to merge your branch to xgsun. Once I finished, I will delete your branch, and next time you will repeat the above jobs.

It would be greatly appreciated if you could finish all these jobs during week days (Monday-Friday, Beijing Timezone). I will handle with the conflict issues during weekends and update the latest xgsun branch before Sunday (Beijing Timezone)

A widely tested xgsun branch will then be merged to the main branch and new version will be released one or two times per month.

Quick Start

Package Dependencies

  • PyTorch 1.13.1
  • torchmetrics 0.11.4
  • torch_geometric 2.2.0

Pre-train your GNN model

The following codes present a simple example on how to pre-train a GNN model via GraphCL. You can also find a integrated function pretrain() in no_meta_demo.py.

from ProG.utils import mkdir, load_data4pretrain
from ProG import PreTrain

mkdir('./pre_trained_gnn/')

pretext = 'GraphCL'  # 'GraphCL', 'SimGRACE'
gnn_type = 'TransformerConv'  # 'GAT', 'GCN'
dataname, num_parts, batch_size = 'CiteSeer', 200, 10

print("load data...")
graph_list, input_dim, hid_dim = load_data4pretrain(dataname, num_parts)

print("create PreTrain instance...")
pt = PreTrain(pretext, gnn_type, input_dim, hid_dim, gln=2)

print("pre-training...")
pt.train(dataname, graph_list, batch_size=batch_size,
         aug1='dropN', aug2="permE", aug_ratio=None,
         lr=0.01, decay=0.0001, epochs=100)

Create Relative Models

from ProG.prompt import GNN, LightPrompt
from torch import nn, optim
import torch

# load pre-trained GNN
gnn = GNN(100, hid_dim=100, out_dim=100, gcn_layer_num=2, gnn_type="TransformerConv")
pre_train_path = './pre_trained_gnn/{}.GraphCL.{}.pth'.format("CiteSeer", "TransformerConv")
gnn.load_state_dict(torch.load(pre_train_path))
print("successfully load pre-trained weights for gnn! @ {}".format(pre_train_path))
for p in gnn.parameters():
    p.requires_grad = False

# prompt with hand-crafted answering template (no answering head tuning)
PG = LightPrompt(token_dim=100, token_num_per_group=100, group_num=6, inner_prune=0.01)

opi = optim.Adam(filter(lambda p: p.requires_grad, PG.parameters()),
                 lr=0.001, weight_decay=0.00001)

lossfn = nn.CrossEntropyLoss(reduction='mean')

The above codes are also integrated as a function model_create(dataname, gnn_type, num_class, task_type) in this project.

Prompt learning with hand-crafted answering template

from ProG.data import multi_class_NIG
import torch

train, test,_,_ = multi_class_NIG(dataname, num_class)
gnn, PG, opi, lossfn, _, _ = model_create(dataname, gnn_type, num_class, task_type)
prompt_epoch = 200  # 200
# training stage
PG.train()
emb0 = gnn(train.x, train.edge_index, train.batch)
for j in range(prompt_epoch):
    pg_batch = PG.inner_structure_update()
    pg_emb = gnn(pg_batch.x, pg_batch.edge_index, pg_batch.batch)
    dot = torch.mm(emb0, torch.transpose(pg_emb, 0, 1))
    sim = torch.softmax(dot, dim=1)
    train_loss = lossfn(sim, train.y)
    print('{}/{} training loss: {:.8f}'.format(j, prompt_epoch, train_loss.item()))
    opi.zero_grad()
    train_loss.backward()
    opi.step()

More Detailed Tutorial

For more detailed usage examples w.r.t prompt with answer tuning, prompt with meta-learning etc. Please check the demo in:

  • no_meta_demo.py
  • meta_demo.py

Compare this new implementation with the raw code

Multi-class node classification (100-shots)

                      |      CiteSeer     |
                      |  ACC  | Macro-F1  |
==========================================|
reported in the paper | 80.50 |   80.05   |
(Prompt)              |                   |
------------------------------------------|
this version code     | 81.00 |   81.23   |
(Prompt)              |   (run one time)  | 
==========================================|
reported in the paper | 80.00 |  80.05   |
(Prompt w/o h)        |                   |
------------------------------------------|
this version code     | 79.78 |  80.01   |
(Prompt w/o h)        |   (run one time)  |
==========================================|

Note:

Citation

bibtex

@inproceedings{sun2023all,
  title={All in One: Multi-Task Prompting for Graph Neural Networks},
  author={Sun, Xiangguo and Cheng, Hong and Li, Jia and Liu, Bo and Guan, Jihong},
  booktitle={Proceedings of the 26th ACM SIGKDD international conference on knowledge discovery \& data mining (KDD'23)},
  year={2023}
}

Contact

  • For More Information, Further discussion, Contact: Website
  • Email: xiangguosun at cuhk dot edu dot hk

prog's People

Contributors

sheldonresearch avatar barristen avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.