Giter Site home page Giter Site logo

divelab / dig Goto Github PK

View Code? Open in Web Editor NEW
1.8K 29.0 278.0 269.89 MB

A library for graph deep learning research

Home Page: https://diveintographs.readthedocs.io/

License: GNU General Public License v3.0

Python 88.04% Shell 0.02% Jupyter Notebook 11.29% C++ 0.36% Cuda 0.28% C 0.02%
deep-learning graph-neural-network graph-generation explainable-ml self-supervised-learning 3d-graph

dig's Introduction

logo

PyPI Version Docs Status Build Status codecov Last Commit Contributing License visitors Downloads

Documentation | Paper [JMLR] | Tutorials | Benchmarks | Examples | Colab Demo | slack community:fire:

DIG: Dive into Graphs is a turnkey library for graph deep learning research.

๐Ÿ”ฅUpdate (2022/07): We have upgraded our DIG library based on PyG 2.0.0. We recommend installing our latest version.

Why DIG?

The key difference with current graph deep learning libraries, such as PyTorch Geometric (PyG) and Deep Graph Library (DGL), is that, while PyG and DGL support basic graph deep learning operations, DIG provides a unified testbed for higher-level, research-oriented graph deep learning tasks, such as graph generation, self-supervised learning, explainability, 3D graphs, and graph out-of-distribution.

If you are working or plan to work on research in graph deep learning, DIG enables you to develop your own methods within our extensible framework, and compare with current baseline methods using common datasets and evaluation metrics without extra effort.

Overview

It includes unified implementations of data interfaces, common algorithms, and evaluation metrics for several advanced tasks. Our goal is to enable researchers to easily implement and benchmark algorithms. Currently, we consider the following research directions.

  • Graph Generation: dig.ggraph
  • Self-supervised Learning on Graphs: dig.sslgraph
  • Explainability of Graph Neural Networks: dig.xgraph
  • Deep Learning on 3D Graphs: dig.threedgraph
  • Graph OOD: dig.oodgraph
  • Graph Augmentation: dig.auggraph
  • Fair Graph Learning: dig.fairgraph

logo

Usage

Example: a few lines of code to run SphereNet on QM9 to incorporate 3D information of molecules.

from dig.threedgraph.dataset import QM93D
from dig.threedgraph.method import SphereNet
from dig.threedgraph.evaluation import ThreeDEvaluator
from dig.threedgraph.method import run

# Load the dataset and split
dataset = QM93D(root='dataset/')
target = 'U0'
dataset.data.y = dataset.data[target]
split_idx = dataset.get_idx_split(len(dataset.data.y), train_size=110000, valid_size=10000, seed=42)
train_dataset, valid_dataset, test_dataset = dataset[split_idx['train']], dataset[split_idx['valid']], dataset[split_idx['test']]

# Define model, loss, and evaluation
model = SphereNet(energy_and_force=False, cutoff=5.0, num_layers=4,
                  hidden_channels=128, out_channels=1, int_emb_size=64,
                  basis_emb_size_dist=8, basis_emb_size_angle=8, basis_emb_size_torsion=8, out_emb_channels=256,
                  num_spherical=3, num_radial=6, envelope_exponent=5,
                  num_before_skip=1, num_after_skip=2, num_output_layers=3)                 
loss_func = torch.nn.L1Loss()
evaluation = ThreeDEvaluator()

# Train and evaluate
run3d = run()
run3d.run(device, train_dataset, valid_dataset, test_dataset, model, loss_func, evaluation,
          epochs=20, batch_size=32, vt_batch_size=32, lr=0.0005, lr_decay_factor=0.5, lr_decay_step_size=15)
  1. For details of all included APIs, please refer to the documentation.
  2. We provide a hands-on tutorial for each direction to help you to get started with DIG: Graph Generation, Self-supervised Learning on Graphs, Explainability of Graph Neural Networks, Deep Learning on 3D Graphs, Graph OOD (GOOD) datasets.
  3. We also provide examples to use APIs provided in DIG. You can get started with your interested directions by clicking the following links.

Installation

Install from pip

The key dependencies of DIG: Dive into Graphs are PyTorch (>=1.10.0), PyTorch Geometric (>=2.0.0), and RDKit.

  1. Install PyTorch (>=1.10.0)
$ python -c "import torch; print(torch.__version__)"
>>> 1.10.0
  1. Install PyG (>=2.0.0)
$ python -c "import torch_geometric; print(torch_geometric.__version__)"
>>> 2.0.0
  1. Install DIG: Dive into Graphs.
pip install dive-into-graphs

After installation, you can check the version. You have successfully installed DIG: Dive into Graphs if no error occurs.

$ python
>>> from dig.version import __version__
>>> print(__version__)

Install from source

If you want to try the latest features that have not been released yet, you can install dig from source.

git clone https://github.com/divelab/DIG.git
cd DIG
pip install .

Contributing

We welcome any forms of contributions, such as reporting bugs and adding new features. Please refer to our contributing guidelines for details.

Citing DIG

Please cite our paper if you find DIG useful in your work:

@article{JMLR:v22:21-0343,
  author  = {Meng Liu and Youzhi Luo and Limei Wang and Yaochen Xie and Hao Yuan and Shurui Gui and Haiyang Yu and Zhao Xu and Jingtun Zhang and Yi Liu and Keqiang Yan and Haoran Liu and Cong Fu and Bora M Oztekin and Xuan Zhang and Shuiwang Ji},
  title   = {{DIG}: A Turnkey Library for Diving into Graph Deep Learning Research},
  journal = {Journal of Machine Learning Research},
  year    = {2021},
  volume  = {22},
  number  = {240},
  pages   = {1-9},
  url     = {http://jmlr.org/papers/v22/21-0343.html}
}

The Team

DIG: Dive into Graphs is developed by DIVE@TAMU. Contributors are Meng Liu*, Youzhi Luo*, Limei Wang*, Yaochen Xie*, Hao Yuan*, Shurui Gui*, Haiyang Yu*, Zhao Xu, Jingtun Zhang, Yi Liu, Keqiang Yan, Haoran Liu, Cong Fu, Bora Oztekin, Xuan Zhang, and Shuiwang Ji.

Acknowledgments

This work was supported in part by National Science Foundation grants IIS-2006861, IIS-1955189, IIS-1908220, IIS-1908198, DBI-2028361, and DBI-1922969.

Contact

If you have any technical questions, please submit new issues or raise it in our DIG slack community:fire:.

If you have any other questions, please contact us: Meng Liu [[email protected]] and Shuiwang Ji [[email protected]].

dig's People

Contributors

alirezadizaji avatar ameya98 avatar borao avatar cm-bf avatar congffu avatar floatlazer avatar gloria-liu avatar hannesstark avatar hongyiling avatar hydra324 avatar jacoblau0513 avatar joaquincabezas avatar kirillshmilovich avatar kruskallin avatar limei0307 avatar lyzustc avatar mengliu1998 avatar michael1015198808 avatar nate1874 avatar oceanusity avatar ordinarycrazy avatar r-kellerm avatar sairamanareddy avatar yang-han avatar ycremar avatar ykq98 avatar zoexu119 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

dig's Issues

Xgraph for graph level predictions

Thanks for the work behind this repo! I was wondering, do you have implemented the explainability methods for graph-level predictions as well? Or only node features?

Also, do you plan to develop the library to work with DGL?

Thanks!

Zero-padding in SubgraphX

In SubgraphX, it is very inspiring that zero-padding is introduced.

I have the same observation that GNN is very very sensitive to topology perturbation. Thus I have some questions:

  1. Regarding setting features to be 0 - in MUTAG dataset, I suppose the default node feature 0 represents the atom is Carbon, do you set feature value to 0 in zero-padding part in a way that they actually all turn into Carbon or any other atom?
    (but after I think it through, I think you mean setting features all to zero given the node label is encoded in a one-hot way, such that setting zeros will deactivate the weights for node features, correct me if I'm wrong.)

  2. Compared with simply removing the nodes from the graph, how does zero-padding affect the modelโ€™s prediction in empirical study?

Thanks!

Issue in GNNExplainer tutorial.

I tried to run the tutorial for GNNexplainer. However, I am consistently getting the following value error. It would be a great help if you could help me out here.:

`---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
in
8 if torch.isnan(data.y[0].squeeze()):
9 continue
---> 10 edge_masks, hard_edge_masks, related_preds = explainer(data.x, data.edge_index, sparsity=sparsity, num_classes=num_classes, node_idx=node_idx)
11
12 x_collector.collect_data(hard_edge_masks, related_preds, data.y[0].squeeze().long().item())

~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1050 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051 return forward_call(*input, **kwargs)
1052 # Do not call functions when jit is used
1053 full_backward_hooks, non_full_backward_hooks = [], []

~/.local/lib/python3.8/site-packages/dig/xgraph/method/gnnexplainer.py in forward(self, x, edge_index, mask_features, **kwargs)
143 self.clear_masks()
144 self.set_masks(x, self_loop_edge_index)
--> 145 edge_masks.append(self.control_sparsity(self.gnn_explainer_alg(x, edge_index, ex_label), sparsity=kwargs.get('sparsity')))
146 # edge_masks.append(self.gnn_explainer_alg(x, edge_index, ex_label))
147

~/.local/lib/python3.8/site-packages/dig/xgraph/method/gnnexplainer.py in gnn_explainer_alg(self, x, edge_index, ex_label, mask_features, **kwargs)
84 h = x
85 raw_preds = self.model(x=h, edge_index=edge_index, **kwargs)
---> 86 loss = self.loss(raw_preds, ex_label)
87 if epoch % 20 == 0 and debug:
88 print(f'Loss:{loss.item()}')

~/.local/lib/python3.8/site-packages/dig/xgraph/method/gnnexplainer.py in loss(self, raw_preds, x_label)
44 def loss(self, raw_preds: Tensor, x_label: Union[Tensor, int]):
45 if self.explain_graph:
---> 46 loss = cross_entropy_with_logit(raw_preds, x_label)
47 else:
48 loss = cross_entropy_with_logit(raw_preds[self.node_idx].unsqueeze(0), x_label)

~/.local/lib/python3.8/site-packages/dig/xgraph/method/gnnexplainer.py in cross_entropy_with_logit(y_pred, y_true, **kwargs)
10
11 def cross_entropy_with_logit(y_pred: torch.Tensor, y_true: torch.Tensor, **kwargs):
---> 12 return cross_entropy(y_pred, y_true.long(), **kwargs)
13
14 class GNNExplainer(ExplainerBase):

~/.local/lib/python3.8/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
2822 if size_average is not None or reduce is not None:
2823 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2824 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
2825
2826

ValueError: Expected input batch_size (700) to match target batch_size (1).`

local variable 'x' referenced before assignment

First of all, I really appreciate your work in contributing this library, it's really useful. However, while following your tutorial, I got the following error:

---------------------------------------------------------------------------
UnboundLocalError                         Traceback (most recent call last)
<ipython-input-5-7e5867836373> in <module>
     16 print(f'explain graph node {node_idx}')
     17 data.to(device)
---> 18 logits = model(data.x, data.edge_index)
     19 prediction = logits[node_idx].argmax(-1).item()
     20 

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/models/models.py in forward(self, *args, **kwargs)
    147         :return:
    148         """
--> 149         x, edge_index, batch = self.arguments_read(*args, **kwargs)
    150 
    151         post_conv = self.relu1(self.conv1(x, edge_index))

~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/models/models.py in arguments_read(self, *args, **kwargs)
     42             elif len(args) == 2:
     43                 x, edge_index, batch = args[0], args[1], \
---> 44                                        torch.zeros(args[0].shape[0], dtype=torch.int64, device=x.device)
     45             elif len(args) == 3:
     46                 x, edge_index, batch = args[0], args[1], args[2]

UnboundLocalError: local variable 'x' referenced before assignment

The corresponding code block is:

# --- Create data collector and explanation processor ---
from dig.xgraph.evaluation import XCollector, ExplanationProcessor
x_collector = XCollector()

index = -1
node_indices = torch.where(dataset[0].test_mask * dataset[0].y != 0)[0].tolist()
data = dataset[0]

from dig.xgraph.method.subgraphx import PlotUtils
from dig.xgraph.method.subgraphx import find_closest_node_result, k_hop_subgraph_with_default_whole_graph
plotutils = PlotUtils(dataset_name='ba_shapes')

# Visualization
max_nodes = 5
node_idx = node_indices[6]
print(f'explain graph node {node_idx}')
data.to(device)
logits = model(data.x, data.edge_index)
prediction = logits[node_idx].argmax(-1).item()

_, explanation_results, related_preds = \
    explainer(data.x, data.edge_index, node_idx=node_idx, max_nodes=max_nodes)
result = find_closest_node_result(explanation_results[prediction], max_nodes=max_nodes)

plotutils = PlotUtils(dataset_name='ba_shapes')
explainer.visualization(explanation_results,
                        prediction,
                        max_nodes=max_nodes,
                        plot_utils=plotutils,
                        y=data.y)

My installed packages are as follow:

# packages in environment at /home/*/anaconda3/envs/dig:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       1_gnu    conda-forge
anyio                     3.1.0            py38h578d9bd_0    conda-forge
argon2-cffi               20.1.0           py38h497a2fe_2    conda-forge
ase                       3.21.1                   pypi_0    pypi
async_generator           1.10                       py_0    conda-forge
attrs                     21.2.0             pyhd8ed1ab_0    conda-forge
babel                     2.9.1              pyh44b312d_0    conda-forge
backcall                  0.2.0              pyh9f0ad1d_0    conda-forge
backports                 1.0                        py_2    conda-forge
backports.functools_lru_cache 1.6.4              pyhd8ed1ab_0    conda-forge
blas                      1.0                         mkl
bleach                    3.3.0              pyh44b312d_0    conda-forge
boost                     1.74.0           py38hc10631b_3    conda-forge
boost-cpp                 1.74.0               hc6e9bd1_3    conda-forge
brotlipy                  0.7.0           py38h497a2fe_1001    conda-forge
bzip2                     1.0.8                h7b6447c_0
ca-certificates           2021.5.30            ha878542_0    conda-forge
cairo                     1.16.0            h6cf1ce9_1008    conda-forge
captum                    0.2.0                    pypi_0    pypi
certifi                   2021.5.30        py38h578d9bd_0    conda-forge
cffi                      1.14.5           py38ha65f79e_0    conda-forge
chardet                   4.0.0            py38h578d9bd_1    conda-forge
cilog                     1.2.0                    pypi_0    pypi
cloudpickle               1.6.0                    pypi_0    pypi
cryptography              3.4.7            py38ha5dfef3_0    conda-forge
cudatoolkit               10.1.243             h6bb024c_0
cycler                    0.10.0                     py_2    conda-forge
decorator                 4.4.2                    pypi_0    pypi
defusedxml                0.7.1              pyhd8ed1ab_0    conda-forge
dive-into-graphs          0.0.4                    pypi_0    pypi
entrypoints               0.3             pyhd8ed1ab_1003    conda-forge
et-xmlfile                1.1.0                    pypi_0    pypi
ffmpeg                    4.3                  hf484d3e_0    pytorch
fontconfig                2.13.1            hba837de_1005    conda-forge
freetype                  2.10.4               h5ab3b9f_0
gettext                   0.19.8.1          h0b5b191_1005    conda-forge
gmp                       6.2.1                h2531618_2
gnutls                    3.6.15               he1e5248_0
googledrivedownloader     0.4                      pypi_0    pypi
greenlet                  1.1.0            py38h709712a_0    conda-forge
h5py                      3.2.1                    pypi_0    pypi
icu                       68.1                 h58526e2_0    conda-forge
idna                      2.10               pyh9f0ad1d_0    conda-forge
importlib-metadata        4.5.0            py38h578d9bd_0    conda-forge
intel-openmp              2021.2.0           h06a4308_610
ipykernel                 5.5.5            py38hd0cf306_0    conda-forge
ipython                   7.24.1           py38hd0cf306_0    conda-forge
ipython-genutils          0.2.0                    pypi_0    pypi
ipython_genutils          0.2.0                      py_1    conda-forge
isodate                   0.6.0                    pypi_0    pypi
jedi                      0.18.0           py38h578d9bd_2    conda-forge
jinja2                    3.0.1              pyhd8ed1ab_0    conda-forge
joblib                    1.0.1                    pypi_0    pypi
jpeg                      9b                   h024ee3a_2
json5                     0.9.5              pyh9f0ad1d_0    conda-forge
jsonschema                3.2.0              pyhd8ed1ab_3    conda-forge
jupyter_client            6.1.12             pyhd8ed1ab_0    conda-forge
jupyter_core              4.7.1            py38h578d9bd_0    conda-forge
jupyter_server            1.8.0              pyhd8ed1ab_0    conda-forge
jupyterlab                3.0.16             pyhd8ed1ab_0    conda-forge
jupyterlab_pygments       0.1.2              pyh9f0ad1d_0    conda-forge
jupyterlab_server         2.6.0              pyhd8ed1ab_0    conda-forge
kiwisolver                1.3.1            py38h1fd1430_1    conda-forge
lame                      3.100                h7b6447c_0
lcms2                     2.12                 h3be6417_0
ld_impl_linux-64          2.33.1               h53a641e_7
libffi                    3.3                  he6710b0_2
libgcc-ng                 9.3.0               h2828fa1_19    conda-forge
libglib                   2.68.2               h3e27bee_2    conda-forge
libgomp                   9.3.0               h2828fa1_19    conda-forge
libiconv                  1.16                 h516909a_0    conda-forge
libidn2                   2.3.1                h27cfd23_0
libpng                    1.6.37               hbc83047_0
libsodium                 1.0.18               h36c2ea0_1    conda-forge
libstdcxx-ng              9.3.0               h6de172a_19    conda-forge
libtasn1                  4.16.0               h27cfd23_0
libtiff                   4.2.0                h85742a9_0
libunistring              0.9.10               h27cfd23_0
libuuid                   2.32.1            h7f98852_1000    conda-forge
libuv                     1.40.0               h7b6447c_0
libwebp-base              1.2.0                h27cfd23_0
libxcb                    1.13              h7f98852_1003    conda-forge
libxml2                   2.9.12               h72842e0_0    conda-forge
llvmlite                  0.36.0                   pypi_0    pypi
lz4-c                     1.9.3                h2531618_0
markupsafe                2.0.1            py38h497a2fe_0    conda-forge
matplotlib-base           3.4.2            py38hcc49a3a_0    conda-forge
matplotlib-inline         0.1.2              pyhd8ed1ab_2    conda-forge
mistune                   0.8.4           py38h497a2fe_1003    conda-forge
mkl                       2021.2.0           h06a4308_296
mkl-service               2.3.0            py38h27cfd23_1
mkl_fft                   1.3.0            py38h42c9631_2
mkl_random                1.2.1            py38ha9443f7_2
mypy-extensions           0.4.3                    pypi_0    pypi
nbclassic                 0.3.1              pyhd8ed1ab_1    conda-forge
nbclient                  0.5.3              pyhd8ed1ab_0    conda-forge
nbconvert                 6.0.7            py38h578d9bd_3    conda-forge
nbformat                  5.1.3              pyhd8ed1ab_0    conda-forge
ncurses                   6.2                  he6710b0_1
nest-asyncio              1.5.1              pyhd8ed1ab_0    conda-forge
nettle                    3.7.2                hbbd107a_1
networkx                  2.5.1                    pypi_0    pypi
ninja                     1.10.2               hff7bd54_1
notebook                  6.4.0              pyha770c72_0    conda-forge
numba                     0.53.1                   pypi_0    pypi
numpy                     1.20.2           py38h2d18471_0
numpy-base                1.20.2           py38hfae3a4d_0
olefile                   0.46                       py_0
openh264                  2.1.0                hd408876_0
openpyxl                  3.0.7                    pypi_0    pypi
openssl                   1.1.1k               h7f98852_0    conda-forge
packaging                 20.9               pyh44b312d_0    conda-forge
pandas                    1.2.4                    pypi_0    pypi
pandoc                    2.14.0.1             h7f98852_0    conda-forge
pandocfilters             1.4.2                      py_1    conda-forge
parso                     0.8.2              pyhd8ed1ab_0    conda-forge
pcre                      8.44                 he1b5a44_0    conda-forge
pexpect                   4.8.0              pyh9f0ad1d_2    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pillow                    8.2.0            py38he98fc37_0
pip                       21.1.1           py38h06a4308_0
pixman                    0.40.0               h36c2ea0_0    conda-forge
prometheus_client         0.11.0             pyhd8ed1ab_0    conda-forge
prompt-toolkit            3.0.18             pyha770c72_0    conda-forge
pthread-stubs             0.4               h36c2ea0_1001    conda-forge
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
pycairo                   1.20.1           py38hf61ee4a_0    conda-forge
pycparser                 2.20               pyh9f0ad1d_2    conda-forge
pygments                  2.9.0              pyhd8ed1ab_0    conda-forge
pyopenssl                 20.0.1             pyhd8ed1ab_0    conda-forge
pyparsing                 2.4.7              pyh9f0ad1d_0    conda-forge
pyrsistent                0.17.3           py38h497a2fe_2    conda-forge
pysocks                   1.7.1            py38h578d9bd_3    conda-forge
python                    3.8.10               hdb3f193_7
python-dateutil           2.8.1                      py_0    conda-forge
python-louvain            0.15                     pypi_0    pypi
python_abi                3.8                      1_cp38    conda-forge
pytorch                   1.8.1           py3.8_cuda10.1_cudnn7.6.3_0    pytorch
pytz                      2021.1             pyhd8ed1ab_0    conda-forge
pyzmq                     22.1.0           py38h2035c66_0    conda-forge
rdflib                    5.0.0                    pypi_0    pypi
rdkit                     2021.03.2        py38hf8acc3d_0    conda-forge
readline                  8.1                  h27cfd23_0
reportlab                 3.5.67           py38hadf75a6_0    conda-forge
requests                  2.25.1             pyhd3deb0d_0    conda-forge
scikit-learn              0.24.2                   pypi_0    pypi
scipy                     1.6.3                    pypi_0    pypi
send2trash                1.5.0                      py_0    conda-forge
setuptools                52.0.0           py38h06a4308_0
shap                      0.39.0                   pypi_0    pypi
six                       1.15.0           py38h06a4308_0
slicer                    0.0.7                    pypi_0    pypi
sniffio                   1.2.0            py38h578d9bd_1    conda-forge
sqlalchemy                1.4.17           py38h497a2fe_0    conda-forge
sqlite                    3.35.4               hdfb4753_0
tabulate                  0.8.9                    pypi_0    pypi
terminado                 0.10.0           py38h578d9bd_0    conda-forge
testpath                  0.5.0              pyhd8ed1ab_0    conda-forge
threadpoolctl             2.1.0                    pypi_0    pypi
tk                        8.6.10               hbc83047_0
torch-cluster             1.5.9                    pypi_0    pypi
torch-geometric           1.7.0                    pypi_0    pypi
torch-scatter             2.0.6                    pypi_0    pypi
torch-sparse              0.6.9                    pypi_0    pypi
torch-spline-conv         1.2.1                    pypi_0    pypi
torchaudio                0.8.1                      py38    pytorch
torchvision               0.9.1                py38_cu101    pytorch
tornado                   6.1              py38h497a2fe_1    conda-forge
tqdm                      4.61.0                   pypi_0    pypi
traitlets                 5.0.5                      py_0    conda-forge
typed-argument-parser     1.5.4                    pypi_0    pypi
typing-inspect            0.7.0                    pypi_0    pypi
typing_extensions         3.7.4.3            pyha847dfd_0
tzdata                    2020f                h52ac0ba_0
urllib3                   1.26.5             pyhd8ed1ab_0    conda-forge
wcwidth                   0.2.5              pyh9f0ad1d_2    conda-forge
webencodings              0.5.1                      py_1    conda-forge
websocket-client          0.57.0           py38h578d9bd_4    conda-forge
wheel                     0.36.2             pyhd3eb1b0_0
xorg-kbproto              1.0.7             h7f98852_1002    conda-forge
xorg-libice               1.0.10               h7f98852_0    conda-forge
xorg-libsm                1.2.3             hd9c2040_1000    conda-forge
xorg-libx11               1.7.2                h7f98852_0    conda-forge
xorg-libxau               1.0.9                h7f98852_0    conda-forge
xorg-libxdmcp             1.1.3                h7f98852_0    conda-forge
xorg-libxext              1.3.4                h7f98852_1    conda-forge
xorg-libxrender           0.9.10            h7f98852_1003    conda-forge
xorg-renderproto          0.11.1            h7f98852_1002    conda-forge
xorg-xextproto            7.3.0             h7f98852_1002    conda-forge
xorg-xproto               7.0.31            h7f98852_1007    conda-forge
xz                        5.2.5                h7b6447c_0
zeromq                    4.3.4                h9c3ff4c_0    conda-forge
zipp                      3.4.1              pyhd8ed1ab_0    conda-forge
zlib                      1.2.11               h7b6447c_3
zstd                      1.4.9                haebb681_0

Thanks in advanced!

fidelity score is negative in PGExplainer on BA-shapes

I run your program xgraph/PGExplainer on BA-shapes dataset using command python pipeline.py --dataset_name BA_shapes --random_split True --latent_dim 20 20 20 --concate True --adj_normlize False --emb_normlize True provided in scripts.sh. However, I get negative fidelity score -0.0706 when I set top_k as 6 because it is the number edge in motifs as paper illustrated:

fetch network parameters from the saved files training time is 0.007172066951170564s 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 42/42 [00:06<00:00, 6.76it/s] fidelity score: -0.0706, sparsity score: 0.6932

Secondly, when I add metric 'auc' under the guidance of official code of the paper, I get 0.5348. But the result reported in paper is 0.963, and I can get higher results when I run official code of the paper.

Why is that?

Run GNNExplainer

I am trying to run the GNN Explainer code. Although I have used the install.batch to generate the conda environment, I am having issues with the packages. Namely with tap (in the line "from tap import Tap") . I tried installing it myself and I can import tap but doesn't seem to find anything called Tap on it. Could you guide me here please?

Also, I was going through the example and I was wondering which settings should I change to run it on the graph level predictions. Is there an example for that? And, is it possible to get the feature importances and the important subgraphs? I didn't see any outputs in the explain functions.

Thanks in advance!

spherenet xyz_to_dat issue

Dear DIG maintainers,

First of all many thanks for this repo!

I was checking out your implementation of the SphereNet architecture, and I noticed that in some cases I cannot successfully perform a forward pass of the network. Notably, the problem appears to lie in the xyz_to_dat function that is called within that submodule. Here I provide a minimal example of the issue.

I start with 3D coordinates in cartesian space and compute edge indices with torch_geometric.nn.pool.radius_graph:

In [13]: g.coords
Out[13]: 
tensor([[ 4.5877,  1.2124,  0.9045],
        [ 3.4939,  2.0393,  0.7065],
        [ 2.5577,  1.7294, -0.2622],
        [ 2.6985,  0.5842, -1.0372],
        [ 1.6548,  0.2541, -2.0763],
        [ 0.3290,  0.2897, -1.5940],
 ...]])

In [14]: g.coords.shape
Out[14]: torch.Size([41, 3])


In [16]: edge_index = radius_graph(g.coords, r=5.0, batch=torch.zeros(g.atomids.size(0), dtype=int))

In [17]: edge_index
Out[17]: 
tensor([[20,  1, 19,  ...,  2,  4,  3],
        [ 0,  0,  0,  ..., 40, 40, 40]])

In [18]: edge_index.shape
Out[18]: torch.Size([2, 937])

I then try and call the xyz_to_dat routine present in the forward pass, and I encounter this error:

In [20]:     out = xyz_to_dat(pos=g.coords, edge_index=edge_index, num_nodes=g.atomids.size(0), use_torsion=True)
    ...: 
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-20-e5d12c7af566> in <module>
----> 1 out = xyz_to_dat(pos=g.coords, edge_index=edge_index, num_nodes=g.atomids.size(0), use_torsion=True)

~net.py in xyz_to_dat(pos, edge_index, num_nodes, use_torsion)
    114     repeat = num_triplets - 1
    115     num_triplets_t = num_triplets.repeat_interleave(repeat)
--> 116     idx_i_t = idx_i.repeat_interleave(num_triplets_t)
    117     idx_j_t = idx_j.repeat_interleave(num_triplets_t)
    118     idx_k_t = idx_k.repeat_interleave(num_triplets_t)

RuntimeError: repeats must have the same size as input along dim

Interestingly, this problem seems to be more frequent the larger the graph I consider, the calculation being successful for some:

image

I'm wondering whether the problem is on my side and I'm misunderstanding how these functions shoud be called. Any help is appreciated!

'Batch' objects has no attribute 'mask'

First, thanks for opening your wonderful code
But I have few questions.

I use the xgraph in 'main' branch, and i got AttributeError : 'Batch' object has no attribute 'mask'
What should i do to solve this problem?

Moreover, i want to utilize other GNN models such as GAT, APPNP.
How can i use some other GNN models, not the checkpoints?

Thank you, have a nice day!

Code for 3d graph reproduction of results on MD17

Thanks for this very useful library. Is it please possible to add some code similar to benchmarks/threedgraph/threedgraph.ipynb for running Spherenet on MD17 together with the necessary parameters to reproduce the results from the paper ? Thanks!

Negative JSELoss

I was initially confused by getting a negative loss using the JSELoss and by the implementation here where the Jensen Shannon Divergence is shifted:

log_2 = np.log(2.)
if positive:
score = log_2 - F.softplus(-masked_d_prime)
else:
score = F.softplus(-masked_d_prime) + masked_d_prime - log_2
return score

To avoid others stumbling over this:
Apparently, this choice was made for consistency with other divergences, and it does not affect anything in training and model performance.
This issue in the Deep Infomax library explains why the shift is included rdevon/DIM#19

TypeError: draw_networkx_nodes() got an unexpected keyword argument 'num_nodes'

I finished the Installation and tried to run the demo of GradCAM, but the error happened as below.

$ python -m benchmark.kernel.pipeline --task explain --model_name GCN_3l --dataset_name tox21 --target_idx 2 --explainer GradCAM --sparsity 0.5 --debug --vis --nolabel
Add /home/ubuntu/DIG/dig/xgraph/metrics as a system path.
DEBUG: 04/26/2021 09:26:13 AM : Parse arguments.

INFO: -----------------------------------
    Task: explain
Mon Apr 26 09:26:13 2021
INFO: Load Dataset tox21
DEBUG: 04/26/2021 09:26:13 AM :  Data(edge_attr=[22, 3], edge_index=[2, 22], smiles="CCCc1ccc(OC)cc1", x=[11, 9], y=[1])
INFO: Loading model...
INFO: GCN_3l(
  (conv1): GCNConv(9, 300)
  (convs): ModuleList(
    (0): GCNConv(300, 300)
    (1): GCNConv(300, 300)
  )
  (relu1): ReLU()
  (relus): ModuleList(
    (0): ReLU()
    (1): ReLU()
  )
  (readout): GlobalMeanPool()
  (ffn): Sequential(
    (0): Linear(in_features=300, out_features=300, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=300, out_features=2, bias=True)
  )
  (dropout): Dropout(p=0.5, inplace=False)
)
INFO: Config the model used to be explained...
INFO: Loading best Checkpoint 932...
INFO: Create explainer: GradCAM...
INFO: Begin explain
INFO: explain graph line 1667
DEBUG: 04/26/2021 09:26:16 AM : Mask Calculate...
DEBUG: 04/26/2021 09:26:16 AM : Predict...
DEBUG: 04/26/2021 09:26:16 AM : Explainer prediction time: 0.4564
Input Tensor 0 did not already require gradients, required_grads has been set automatically.
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/env/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/ubuntu/anaconda3/envs/env/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/ubuntu/DIG/dig/xgraph/GradCAM/benchmark/kernel/pipeline.py", line 88, in <module>
    sample_explain(explainer, data, explain_collector, sparsity=args['explain'].sparsity)
  File "/home/ubuntu/DIG/dig/xgraph/GradCAM/benchmark/kernel/explain.py", line 116, in sample_explain
    y=data.x[:, 0] if node_idx is None else data.y, num_nodes=data.x.shape[0])
  File "/home/ubuntu/DIG/dig/xgraph/GradCAM/benchmark/models/explainers.py", line 290, in visualize_graph
    nx.draw_networkx_nodes(G, pos, node_color=node_colors, **kwargs)
TypeError: draw_networkx_nodes() got an unexpected keyword argument 'num_nodes'

If XGNN model will be added?

Thanks for this great project.
I notice that XGNN and this project are both your team's work, I wonder if you will add XGNN to this project recently?
Thanks!

Errors in running the bash files in SubgraphX in the main branch

Describe the bug

I followed the instructions in Readme in SubgraphX, created the env and installed the requirements. However I still cannot reproduce the results. When I ran the script files, I still got error and didn't know why.

$ cd DIG/xgraph/SubgraphX
$ source ./scripts.sh
/home/liyang/anaconda3/lib/python3.7/site-packages/numba/decorators.py:146: RuntimeWarning: Caching is not available when the 'parallel' target is in use. Caching is now being disabled to allow execution to continue.
  warnings.warn(msg, RuntimeWarning)
Traceback (most recent call last):
  File "/home/liyang/anaconda3/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/liyang/anaconda3/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/liyang/Documents/csprogram/DIG-main/dig/xgraph/SubgraphX/forgraph/subgraphx.py", line 6, in <module>
    from forgraph.mcts import MCTS, reward_func
  File "/home/liyang/Documents/csprogram/DIG-main/dig/xgraph/SubgraphX/forgraph/mcts.py", line 4, in <module>
    from Configures import mcts_args
  File "/home/liyang/Documents/csprogram/DIG-main/dig/xgraph/SubgraphX/Configures.py", line 3, in <module>
    from tap import Tap
  File "/home/liyang/anaconda3/lib/python3.7/site-packages/tap.py", line 6, in <module>
    from mc_bin_client import mc_bin_client, memcacheConstants as Constants
  File "/home/liyang/anaconda3/lib/python3.7/site-packages/mc_bin_client/mc_bin_client.py", line 278
    except MemcachedError, e:
                         ^
SyntaxError: invalid syntax

The same error also occurred when I ran another bash file.

$ cd DIG/xgraph/SubgraphX
$ source ./models/train_gnns.sh 
/home/liyang/anaconda3/lib/python3.7/site-packages/numba/decorators.py:146: RuntimeWarning: Caching is not available when the 'parallel' target is in use. Caching is now being disabled to allow execution to continue.
  warnings.warn(msg, RuntimeWarning)
Traceback (most recent call last):
  File "/home/liyang/anaconda3/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/liyang/anaconda3/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/liyang/Documents/csprogram/DIG-main/dig/xgraph/SubgraphX/models/train_gnns.py", line 9, in <module>
    from Configures import data_args, train_args, model_args
  File "/home/liyang/Documents/csprogram/DIG-main/dig/xgraph/SubgraphX/Configures.py", line 3, in <module>
    from tap import Tap
  File "/home/liyang/anaconda3/lib/python3.7/site-packages/tap.py", line 6, in <module>
    from mc_bin_client import mc_bin_client, memcacheConstants as Constants
  File "/home/liyang/anaconda3/lib/python3.7/site-packages/mc_bin_client/mc_bin_client.py", line 278
    except MemcachedError, e:
                         ^
SyntaxError: invalid syntax

Configuration

OS: Release Linux Mint 20.2 Uma 64-bit
python version: 3.7.0
torch and cuda version: 1.9.0+cu102

dynamic graph

HI, I wonder if the XGNN toolkit can be applied to graphs with temporally varying features. Currently, my model was trained using GCN-LSTM algorithm, for a regression problem, not classification. Thanks ahead.

the model checkpoint not being saved and is used later on.

from dig.xgraph.models import GCN_2l
model = GCN_2l(model_level='node', dim_node=dim_node, dim_hidden=300, num_classes=num_classes)
model.to(device)
ckpt_path = osp.join('checkpoints', 'ba_shapes', 'GCN_2l', '0', 'GCN_2l_best.ckpt')
model.load_state_dict(torch.load(ckpt_path, map_location=torch.device(device))['state_dict'])

for the above code(given in tutorial for using explainability models)
the following error is produced

FileNotFoundError Traceback (most recent call last)
in
3 model.to(device)
4 ckpt_path = osp.join('checkpoints', 'ba_shapes', 'GCN_2l', '0', 'GCN_2l_best.ckpt')
----> 5 model.load_state_dict(torch.load(ckpt_path, map_location=torch.device(device))['state_dict'])

~\anaconda3\lib\site-packages\torch\serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
592 pickle_load_args['encoding'] = 'utf-8'
593
--> 594 with _open_file_like(f, 'rb') as opened_file:
595 if _is_zipfile(opened_file):
596 # The zipfile reader is going to advance the current file position.

~\anaconda3\lib\site-packages\torch\serialization.py in _open_file_like(name_or_buffer, mode)
228 def _open_file_like(name_or_buffer, mode):
229 if _is_path(name_or_buffer):
--> 230 return _open_file(name_or_buffer, mode)
231 else:
232 if 'w' in mode:

~\anaconda3\lib\site-packages\torch\serialization.py in init(self, name, mode)
209 class _open_file(_opener):
210 def init(self, name, mode):
--> 211 super(_open_file, self).init(open(name, mode))
212
213 def exit(self, *args):

FileNotFoundError: [Errno 2] No such file or directory: 'checkpoints\ba_shapes\GCN_2l\0\GCN_2l_best.ckpt'

And because of this error not able to follow further tuttorial.

SubgraphX

Hello,

I would like to ask you for detailed info, especially on how to identify the subgraph with the following questions:

  1. How can we select N0 as initialization which you define the root of search tree?
    N_i means node, so Figure1 in the paper of subgraphX, which node is the N0 as initialization in root?

  2. It is mentioned that h(N_i) denotes the associated subgraph of tree node N_i. I would like to know how you define "associated subgraph" and in Figure1, do you describe h(N_i)?

  3. To obtain concrete understanding notation S, in Figure1 S as the possible coalition set of players is (node1), (node6) respectively, and (node 1 and 6). But in pseudo-code of Algorithms1, S_l is the leaf set so that the Leaves graph (which node list is equal to [ 2, 4] and edge list is equal to [(2, 4)]) belongs to S_l.

Thank you !

edge_mask operation on GNNExplainer

Hi,

When I read the GNNExplainer source code, I have a little confused about how to edge_mask was applied to the forward step.

For instance, considering the demo: https://github.com/divelab/DIG/blob/dig/benchmarks/xgraph/gnnexplainer.ipynb. I do found that we call the model GCN_2l and the gnn_explainer_alg also optimized the loss function including both feature mask and edge mask. It seems that the code did not use the class GCN_2l_mask which directly using of the edge_mask.

Thank you so much for your time.

Best,
Chao

Parameters to reproduce published results of Spherical 3D MPNN and DimeNet ?

Is it please possible to publish the full set of parameters that are required in order to reproduce the published results of Spherical 3D MPNN on QM9 and MD17 ? I tried using the default parameters in https://github.com/divelab/DIG/blob/dig/benchmarks/threedgraph/threedgraph.ipynb and played with some variations (lr, bs, etc) but the results I got were quite far from the published results, for both Spherical Net and DimeNet. Worse, the models reach a plateau quite early in the optimization (e.g. after 20 epochs or so). Would be really great if you could publish the full set of hyperparameters used to obtain the results in your paper. Thank you!

a few questions regarding graph generation tutorial

Hi, this is a nice library. I have two questions regarding the Graph Generation Tutorial.

  1. For the Property Optimization part, the code uses a pretrained model as show below:
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
pretrain_path = 'saved_ckpts/prop_opt/pretrain_plogp.pth'  <------ pretrained model ***
lr = 0.0001
wd = 0
warm_up = 0
max_iters = 200
save_interval = 20
save_dir = 'prop_opt_plogp'
runner.train_prop_opt(lr=lr, wd=wd, max_iters=max_iters, warm_up=warm_up,
    model_conf_dict=model_conf_dict, pretrain_path=pretrain_path,
    save_interval=save_interval, save_dir=save_dir)

Is the pretrained model obtained from the first part, "Random Generation Example"?

  1. How much RAM is needed to load the ZINC250k dataset dataset = ZINC250k(one_shot=False, use_aug=True) in Random Generation Example? My program keeps crashing due to my 16GB RAM filling up.

Subgraphx_overfitting

Hi:
In SubgraphX/models/train_gnn.py, you first trained a graph neural network for later calculation of each subgraph shapley value. While I use MUTAG dataset, run train_gnn.py, I found there is a overfitting problem. In this training process, MUTAG train accuracy is near to 100 percent and the gap between train accuracy and valid accuracy is large. I want to learn whether there occurs overfitting.

Installation of Packages

Hello,
I wish to use sslgraph from DIG.
I downloaded the codes from github. However, I cannot install the packages by doing
pip install git+git://github.com/divelab/DIG.git ,
inside a conda environment.
This immediately gives errors.
ERROR: Command errored out with exit status1: python setup.py egg_info Check the logs for full command output.

Do you have suggestions? Is there a setup.py file available?
Thank you,
Saheli

GraphCL example error

While running the codes in GraphCL at line evaluator.evaluate(learning_model=graphcl, encoder=encoder)
I'm getting error:

Traceback (most recent call last):
  File "train.py", line 18, in <module>
    evaluator.evaluate(learning_model=graphcl, encoder=encoder)
  File "C:\ProgramData\Anaconda3\envs\dig\lib\site-packages\dig\sslgraph\evaluation\eval_graph.py", line 394, in evaluate
    for fold, train_loader, test_loader, val_loader in k_fold(
  File "C:\ProgramData\Anaconda3\envs\dig\lib\site-packages\dig\sslgraph\evaluation\eval_graph.py", line 561, in k_fold
    test_loader = DataLoader(dataset[test_indices[i]], batch_size, shuffle=False)
  File "C:\ProgramData\Anaconda3\envs\dig\lib\site-packages\torch_geometric\data\dataset.py", line 198, in __getitem__
    return self.index_select(idx)
  File "C:\ProgramData\Anaconda3\envs\dig\lib\site-packages\torch_geometric\data\dataset.py", line 224, in index_select
    raise IndexError(
IndexError: Only integers, slices (':'), list, tuples, torch.tensor and np.ndarray of dtype long or bool are valid indices (got 'Tensor')

I run this code:

from dig.sslgraph.utils import Encoder
from dig.sslgraph.evaluation import GraphSemisupervised, GraphUnsupervised
from dig.sslgraph.dataset import get_dataset
from dig.sslgraph.method import GraphCL

dataset, dataset_pretrain = get_dataset('NCI1', task='semisupervised')
feat_dim = dataset[0].x.shape[1]
embed_dim = 128

print(dataset)

encoder = Encoder(feat_dim, embed_dim, n_layers=3, gnn='resgcn')
graphcl = GraphCL(embed_dim, aug_1='subgraph', aug_2='subgraph')

evaluator = GraphSemisupervised(dataset, dataset_pretrain, label_rate=0.01)

evaluator.evaluate(learning_model=graphcl, encoder=encoder)

Torch version: 1.9.0 (cpu)

Any help will be highly appreciated.

device mismtach in GraphAF benchmark

Thanks for providing this amazing work.

When I run the GraphAF benchmark, I encounter this error:

RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cpu

After checking the code, I find that in line 32 of graphflow.py, the model MaskedGraphAF still uses CPU. I can remove the above error by adding:

self.flow_core = self.flow_core.cuda()

Just want to make sure whether this is a bug. Thanks.

Tutorial issue

The error comes out while executing your tutorial.
The code block:

# --- Create data collector and explanation processor ---
from dig.xgraph.evaluation import XCollector, ExplanationProcessor
x_collector = XCollector()

index = -1
node_indices = torch.where(dataset[0].test_mask * dataset[0].y != 0)[0].tolist()
data = dataset[0]

from dig.xgraph.method.subgraphx import PlotUtils
from dig.xgraph.method.subgraphx import find_closest_node_result, k_hop_subgraph_with_default_whole_graph
plotutils = PlotUtils(dataset_name='ba_shapes')

# Visualization
max_nodes = 5
node_idx = node_indices[6]
print(f'explain graph node {node_idx}')
data.to(device)
logits = model(data.x, data.edge_index)
prediction = logits[node_idx].argmax(-1).item()

_, explanation_results, related_preds = \
    explainer(data.x, data.edge_index, node_idx=node_idx, max_nodes=max_nodes)
result = find_closest_node_result(explanation_results[prediction], max_nodes=max_nodes)

plotutils = PlotUtils(dataset_name='ba_shapes')
explainer.visualization(explanation_results,
                        prediction,
                        max_nodes=max_nodes,
                        plot_utils=plotutils,
                        y=data.y)

The error message:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-6-7e5867836373> in <module>
     20 
     21 _, explanation_results, related_preds = \
---> 22     explainer(data.x, data.edge_index, node_idx=node_idx, max_nodes=max_nodes)
     23 result = find_closest_node_result(explanation_results[prediction], max_nodes=max_nodes)
     24 

~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/subgraphx.py in __call__(self, x, edge_index, **kwargs)
    671                 payoff_func = self.get_reward_func(value_func, node_idx=self.mcts_state_map.node_idx)
    672                 self.mcts_state_map.set_score_func(payoff_func)
--> 673                 results = self.mcts_state_map.mcts(verbose=False)
    674 
    675                 tree_node_x = find_closest_node_result(results, max_nodes=max_nodes)

~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/subgraphx.py in mcts(self, verbose)
    465             print(f"The nodes in graph is {self.graph.number_of_nodes()}")
    466         for rollout_idx in range(self.n_rollout):
--> 467             self.mcts_rollout(self.root)
    468             if verbose:
    469                 print(f"At the {rollout_idx} rollout, {len(self.state_map)} states that have been explored.")

~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/subgraphx.py in mcts_rollout(self, tree_node)
    450                     tree_node.children.append(new_node)
    451 
--> 452             scores = compute_scores(self.score_func, tree_node.children)
    453             for child, score in zip(tree_node.children, scores):
    454                 child.P = score

~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/subgraphx.py in compute_scores(score_func, children)
    163     for child in children:
    164         if child.P == 0:
--> 165             score = score_func(child.coalition, child.data)
    166         else:
    167             score = child.P

~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/shapley.py in mc_l_shapley(coalition, data, local_raduis, value_func, subgraph_building_method, sample_num)
    216     include_mask = np.stack(set_include_masks, axis=0)
    217     marginal_contributions = \
--> 218         marginal_contribution(data, exclude_mask, include_mask, value_func, subgraph_build_func)
    219 
    220     mc_l_shapley_value = (marginal_contributions).mean().item()

~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/shapley.py in marginal_contribution(data, exclude_mask, include_mask, value_func, subgraph_build_func)
     73     marginal_contribution_list = []
     74 
---> 75     for exclude_data, include_data in dataloader:
     76         exclude_values = value_func(exclude_data)
     77         include_values = value_func(include_data)

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch/utils/data/dataloader.py in __next__(self)
    515             if self._sampler_iter is None:
    516                 self._reset()
--> 517             data = self._next_data()
    518             self._num_yielded += 1
    519             if self._dataset_kind == _DatasetKind.Iterable and \

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    557         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    558         if self._pin_memory:
--> 559             data = _utils.pin_memory.pin_memory(data)
    560         return data
    561 

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py in pin_memory(data)
     53         return type(data)(*(pin_memory(sample) for sample in data))
     54     elif isinstance(data, container_abcs.Sequence):
---> 55         return [pin_memory(sample) for sample in data]
     56     elif hasattr(data, "pin_memory"):
     57         return data.pin_memory()

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py in <listcomp>(.0)
     53         return type(data)(*(pin_memory(sample) for sample in data))
     54     elif isinstance(data, container_abcs.Sequence):
---> 55         return [pin_memory(sample) for sample in data]
     56     elif hasattr(data, "pin_memory"):
     57         return data.pin_memory()

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py in pin_memory(data)
     55         return [pin_memory(sample) for sample in data]
     56     elif hasattr(data, "pin_memory"):
---> 57         return data.pin_memory()
     58     else:
     59         return data

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch_geometric/data/data.py in pin_memory(self, *keys)
    363         If :obj:`*keys` is not given, the conversion is applied to all present
    364         attributes."""
--> 365         return self.apply(lambda x: x.pin_memory(), *keys)
    366 
    367     def debug(self):

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch_geometric/data/data.py in apply(self, func, *keys)
    324         """
    325         for key, item in self(*keys):
--> 326             self[key] = self.__apply__(item, func)
    327         return self
    328 

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch_geometric/data/data.py in __apply__(self, item, func)
    303     def __apply__(self, item, func):
    304         if torch.is_tensor(item):
--> 305             return func(item)
    306         elif isinstance(item, SparseTensor):
    307             # Not all apply methods are supported for `SparseTensor`, e.g.,

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch_geometric/data/data.py in <lambda>(x)
    363         If :obj:`*keys` is not given, the conversion is applied to all present
    364         attributes."""
--> 365         return self.apply(lambda x: x.pin_memory(), *keys)
    366 
    367     def debug(self):

RuntimeError: cannot pin 'torch.cuda.LongTensor' only dense CPU tensors can be pinned

My installed packages:

# packages in environment at /home/*/anaconda3/envs/dig:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main
_openmp_mutex             4.5                       1_gnu
anyio                     3.1.0            py38h578d9bd_0    conda-forge
argon2-cffi               20.1.0           py38h497a2fe_2    conda-forge
ase                       3.21.1                   pypi_0    pypi
async_generator           1.10                       py_0    conda-forge
attrs                     21.2.0             pyhd8ed1ab_0    conda-forge
babel                     2.9.1              pyh44b312d_0    conda-forge
backcall                  0.2.0              pyh9f0ad1d_0    conda-forge
backports                 1.0                        py_2    conda-forge
backports.functools_lru_cache 1.6.4              pyhd8ed1ab_0    conda-forge
blas                      1.0                         mkl
bleach                    3.3.0              pyh44b312d_0    conda-forge
boost                     1.74.0           py38hc10631b_3    conda-forge
boost-cpp                 1.74.0               hc6e9bd1_3    conda-forge
brotlipy                  0.7.0           py38h497a2fe_1001    conda-forge
bzip2                     1.0.8                h7b6447c_0
ca-certificates           2021.5.30            ha878542_0    conda-forge
cairo                     1.16.0            h6cf1ce9_1008    conda-forge
captum                    0.2.0                    pypi_0    pypi
certifi                   2021.5.30        py38h578d9bd_0    conda-forge
cffi                      1.14.5           py38ha65f79e_0    conda-forge
chardet                   4.0.0            py38h578d9bd_1    conda-forge
cilog                     1.2.0                    pypi_0    pypi
cloudpickle               1.6.0                    pypi_0    pypi
cryptography              3.4.7            py38ha5dfef3_0    conda-forge
cudatoolkit               10.1.243             h6bb024c_0
cycler                    0.10.0                     py_2    conda-forge
decorator                 4.4.2                    pypi_0    pypi
defusedxml                0.7.1              pyhd8ed1ab_0    conda-forge
dive-into-graphs          0.0.4                    pypi_0    pypi
entrypoints               0.3             pyhd8ed1ab_1003    conda-forge
et-xmlfile                1.1.0                    pypi_0    pypi
ffmpeg                    4.3                  hf484d3e_0    pytorch
fontconfig                2.13.1            hba837de_1005    conda-forge
freetype                  2.10.4               h5ab3b9f_0
gettext                   0.19.8.1          h0b5b191_1005    conda-forge
gmp                       6.2.1                h2531618_2
gnutls                    3.6.15               he1e5248_0
googledrivedownloader     0.4                      pypi_0    pypi
greenlet                  1.1.0            py38h709712a_0    conda-forge
h5py                      3.2.1                    pypi_0    pypi
icu                       68.1                 h58526e2_0    conda-forge
idna                      2.10               pyh9f0ad1d_0    conda-forge
importlib-metadata        4.5.0            py38h578d9bd_0    conda-forge
intel-openmp              2021.2.0           h06a4308_610
ipykernel                 5.5.5            py38hd0cf306_0    conda-forge
ipython                   7.24.1           py38hd0cf306_0    conda-forge
ipython_genutils          0.2.0                      py_1    conda-forge
isodate                   0.6.0                    pypi_0    pypi
jedi                      0.18.0           py38h578d9bd_2    conda-forge
jinja2                    3.0.1              pyhd8ed1ab_0    conda-forge
joblib                    1.0.1                    pypi_0    pypi
jpeg                      9b                   h024ee3a_2
json5                     0.9.5              pyh9f0ad1d_0    conda-forge
jsonschema                3.2.0              pyhd8ed1ab_3    conda-forge
jupyter_client            6.1.12             pyhd8ed1ab_0    conda-forge
jupyter_core              4.7.1            py38h578d9bd_0    conda-forge
jupyter_server            1.8.0              pyhd8ed1ab_0    conda-forge
jupyterlab                3.0.16             pyhd8ed1ab_0    conda-forge
jupyterlab_pygments       0.1.2              pyh9f0ad1d_0    conda-forge
jupyterlab_server         2.6.0              pyhd8ed1ab_0    conda-forge
kiwisolver                1.3.1            py38h1fd1430_1    conda-forge
lame                      3.100                h7b6447c_0
lcms2                     2.12                 h3be6417_0
ld_impl_linux-64          2.35.1               h7274673_9
libffi                    3.3                  he6710b0_2
libgcc-ng                 9.3.0               h5101ec6_17
libglib                   2.68.3               h3e27bee_0    conda-forge
libgomp                   9.3.0               h5101ec6_17
libiconv                  1.16                 h516909a_0    conda-forge
libidn2                   2.3.1                h27cfd23_0
libpng                    1.6.37               hbc83047_0
libsodium                 1.0.18               h36c2ea0_1    conda-forge
libstdcxx-ng              9.3.0               hd4cf53a_17
libtasn1                  4.16.0               h27cfd23_0
libtiff                   4.2.0                h85742a9_0
libunistring              0.9.10               h27cfd23_0
libuuid                   2.32.1            h7f98852_1000    conda-forge
libuv                     1.40.0               h7b6447c_0
libwebp-base              1.2.0                h27cfd23_0
libxcb                    1.13              h7f98852_1003    conda-forge
libxml2                   2.9.12               h72842e0_0    conda-forge
llvmlite                  0.36.0                   pypi_0    pypi
lz4-c                     1.9.3                h2531618_0
markupsafe                2.0.1            py38h497a2fe_0    conda-forge
matplotlib-base           3.4.2            py38hcc49a3a_0    conda-forge
matplotlib-inline         0.1.2              pyhd8ed1ab_2    conda-forge
mistune                   0.8.4           py38h497a2fe_1003    conda-forge
mkl                       2021.2.0           h06a4308_296
mkl-service               2.3.0            py38h27cfd23_1
mkl_fft                   1.3.0            py38h42c9631_2
mkl_random                1.2.1            py38ha9443f7_2
mypy-extensions           0.4.3                    pypi_0    pypi
nbclassic                 0.3.1              pyhd8ed1ab_1    conda-forge
nbclient                  0.5.3              pyhd8ed1ab_0    conda-forge
nbconvert                 6.0.7            py38h578d9bd_3    conda-forge
nbformat                  5.1.3              pyhd8ed1ab_0    conda-forge
ncurses                   6.2                  he6710b0_1
nest-asyncio              1.5.1              pyhd8ed1ab_0    conda-forge
nettle                    3.7.3                hbbd107a_1
networkx                  2.5.1                    pypi_0    pypi
ninja                     1.10.2               hff7bd54_1
notebook                  6.4.0              pyha770c72_0    conda-forge
numba                     0.53.1                   pypi_0    pypi
numpy                     1.20.2           py38h2d18471_0
numpy-base                1.20.2           py38hfae3a4d_0
olefile                   0.46                       py_0
openh264                  2.1.0                hd408876_0
openpyxl                  3.0.7                    pypi_0    pypi
openssl                   1.1.1k               h7f98852_0    conda-forge
packaging                 20.9               pyh44b312d_0    conda-forge
pandas                    1.2.4                    pypi_0    pypi
pandoc                    2.14.0.1             h7f98852_0    conda-forge
pandocfilters             1.4.2                      py_1    conda-forge
parso                     0.8.2              pyhd8ed1ab_0    conda-forge
pcre                      8.44                 he1b5a44_0    conda-forge
pexpect                   4.8.0              pyh9f0ad1d_2    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pillow                    8.2.0            py38he98fc37_0
pip                       21.1.2           py38h06a4308_0
pixman                    0.40.0               h36c2ea0_0    conda-forge
prometheus_client         0.11.0             pyhd8ed1ab_0    conda-forge
prompt-toolkit            3.0.18             pyha770c72_0    conda-forge
pthread-stubs             0.4               h36c2ea0_1001    conda-forge
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
pycairo                   1.20.1           py38hf61ee4a_0    conda-forge
pycparser                 2.20               pyh9f0ad1d_2    conda-forge
pygments                  2.9.0              pyhd8ed1ab_0    conda-forge
pyopenssl                 20.0.1             pyhd8ed1ab_0    conda-forge
pyparsing                 2.4.7              pyh9f0ad1d_0    conda-forge
pyrsistent                0.17.3           py38h497a2fe_2    conda-forge
pysocks                   1.7.1            py38h578d9bd_3    conda-forge
python                    3.8.10               h12debd9_8
python-dateutil           2.8.1                      py_0    conda-forge
python-louvain            0.15                     pypi_0    pypi
python_abi                3.8                      1_cp38    conda-forge
pytorch                   1.8.1           py3.8_cuda10.1_cudnn7.6.3_0    pytorch
pytz                      2021.1             pyhd8ed1ab_0    conda-forge
pyzmq                     22.1.0           py38h2035c66_0    conda-forge
rdflib                    5.0.0                    pypi_0    pypi
rdkit                     2021.03.3        py38hf8acc3d_0    conda-forge
readline                  8.1                  h27cfd23_0
reportlab                 3.5.67           py38hadf75a6_0    conda-forge
requests                  2.25.1             pyhd3deb0d_0    conda-forge
scikit-learn              0.24.2                   pypi_0    pypi
scipy                     1.6.3                    pypi_0    pypi
send2trash                1.5.0                      py_0    conda-forge
setuptools                52.0.0           py38h06a4308_0
shap                      0.39.0                   pypi_0    pypi
six                       1.15.0           py38h06a4308_0
slicer                    0.0.7                    pypi_0    pypi
sniffio                   1.2.0            py38h578d9bd_1    conda-forge
sqlalchemy                1.4.18           py38h497a2fe_0    conda-forge
sqlite                    3.35.4               hdfb4753_0
tabulate                  0.8.9                    pypi_0    pypi
terminado                 0.10.1           py38h578d9bd_0    conda-forge
testpath                  0.5.0              pyhd8ed1ab_0    conda-forge
threadpoolctl             2.1.0                    pypi_0    pypi
tk                        8.6.10               hbc83047_0
torch-cluster             1.5.9                    pypi_0    pypi
torch-geometric           1.7.0                    pypi_0    pypi
torch-scatter             2.0.7                    pypi_0    pypi
torch-sparse              0.6.9                    pypi_0    pypi
torch-spline-conv         1.2.1                    pypi_0    pypi
torchaudio                0.8.1                      py38    pytorch
torchvision               0.9.1                py38_cu101    pytorch
tornado                   6.1              py38h497a2fe_1    conda-forge
tqdm                      4.61.0                   pypi_0    pypi
traitlets                 5.0.5                      py_0    conda-forge
typed-argument-parser     1.5.4                    pypi_0    pypi
typing-inspect            0.7.1                    pypi_0    pypi
typing_extensions         3.7.4.3            pyha847dfd_0
tzdata                    2020f                h52ac0ba_0
urllib3                   1.26.5             pyhd8ed1ab_0    conda-forge
wcwidth                   0.2.5              pyh9f0ad1d_2    conda-forge
webencodings              0.5.1                      py_1    conda-forge
websocket-client          0.57.0           py38h578d9bd_4    conda-forge
wheel                     0.36.2             pyhd3eb1b0_0
xorg-kbproto              1.0.7             h7f98852_1002    conda-forge
xorg-libice               1.0.10               h7f98852_0    conda-forge
xorg-libsm                1.2.3             hd9c2040_1000    conda-forge
xorg-libx11               1.7.2                h7f98852_0    conda-forge
xorg-libxau               1.0.9                h7f98852_0    conda-forge
xorg-libxdmcp             1.1.3                h7f98852_0    conda-forge
xorg-libxext              1.3.4                h7f98852_1    conda-forge
xorg-libxrender           0.9.10            h7f98852_1003    conda-forge
xorg-renderproto          0.11.1            h7f98852_1002    conda-forge
xorg-xextproto            7.3.0             h7f98852_1002    conda-forge
xorg-xproto               7.0.31            h7f98852_1007    conda-forge
xz                        5.2.5                h7b6447c_0
zeromq                    4.3.4                h9c3ff4c_0    conda-forge
zipp                      3.4.1              pyhd8ed1ab_0    conda-forge
zlib                      1.2.11               h7b6447c_3
zstd                      1.4.9                haebb681_0

I have installed the latest version of DIG from source.

Reproduction Of XGNN

Dear author, I have two questions about XGNN.

First, why the explainable results of each generation differ greatly? Moreover, the generated molecular graph belonging to mutagenic class often contains no carbon-ring or nitro group. Are there effective hyperparameters available for generating better explanations?

Second, I found the loss value is, by turns, positive and negative. So why does this phenomenon appear?

No module named 'vocab'

Hello DIG team,

In JT-VAE sub-repo (https://github.com/divelab/DIG/tree/main/dig/ggraph/JT-VAE#i-building-the-vocabulary), when I run python mol_tree.py --data_file "moses.csv", it outcomes:

File "mol_tree.py", line 4, in <module>
from chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, set_atommap, enum_assemble, decode_stereo
File "/ssd2/yuning/graphcl/pregnn/chem/dataset/jt_vae/JT-VAE/chemutils.py", line 7, in <module>
from vocab import Vocab
ModuleNotFoundError: No module named 'vocab'

Evaluation of SubgraphX vs rest

Is there a reason some of the methods are evaluated for fidelity by zeroing out the node features, and for subgraphX it seems that the fidelity is being reported as the output of the subgraph itself which is not consistent with the definition of fidelity. Could you please elaborate more on this? Also zeroing out the features should change the data distribution itself as the model might not have seen nodes with 0 features before, so why use zeroing out?

pgexplainer.ipynb/subgraphx.ipynb - UnboundLocalError: local variable 'x' referenced before assignment

from dig.xgraph.method import PGExplainer
explainer = PGExplainer(model, in_channels=900, device=device, explain_graph=False)  
explainer.train_explanation_network(splitted_dataset)  
torch.save(explainer.state_dict(), 'tmp.pt')  
state_dict = torch.load('tmp.pt')  
explainer.load_state_dict(state_dict)

When running the block of code above, it returns the following error:

---------------------------------------------------------------------------
UnboundLocalError                         Traceback (most recent call last)
/tmp/ipykernel_1299148/3789046373.py in <module>
      2 explainer = PGExplainer(model, in_channels=900, device=device, explain_graph=False)
      3 
----> 4 explainer.train_explanation_network(splitted_dataset)
      5 torch.save(explainer.state_dict(), 'tmp.pt')
      6 state_dict = torch.load('tmp.pt')

~/anaconda3/envs/xai_conda_env/lib/python3.8/site-packages/dig/xgraph/method/pgexplainer.py in train_explanation_network(self, dataset)
    594                     x, edge_index, y, subset, _ = \
    595                         self.get_subgraph(node_idx=node_idx, x=data.x, edge_index=data.edge_index, y=data.y)
--> 596                     logits = self.model(data.x, data.edge_index)
    597                     emb = self.model.get_emb(data.x, data.edge_index)
    598 

~/anaconda3/envs/xai_conda_env/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/xai_conda_env/lib/python3.8/site-packages/dig/xgraph/models/models.py in forward(self, *args, **kwargs)
    270         :return:
    271         """
--> 272         x, edge_index, batch = self.arguments_read(*args, **kwargs)
    273 
    274 

~/anaconda3/envs/xai_conda_env/lib/python3.8/site-packages/dig/xgraph/models/models.py in arguments_read(self, *args, **kwargs)
     42             elif len(args) == 2:
     43                 x, edge_index, batch = args[0], args[1], \
---> 44                                        torch.zeros(args[0].shape[0], dtype=torch.int64, device=x.device)
     45             elif len(args) == 3:
     46                 x, edge_index, batch = args[0], args[1], args[2]

UnboundLocalError: local variable 'x' referenced before assignment

A similar error occurs with SubgraphX, but not with GNNExplainer. What is the problem?

Is that a typo in subgraphX ?

in line 683 of subgraphx.py: " related_preds.append({'masked': maskout_score,"
Maybe "related_preds.append({'maskout': maskout_score,"

issue with installation

I am facing the following error on running the installation commands given here.

'''
pip install dive-into-graphs
Collecting dive-into-graphs
Using cached https://files.pythonhosted.org/packages/4b/82/0eaa1cbc15778016ec0b3d40c7d7d28be3cbf0b5bc351cb0dca6b537b1eb/dive_into_graphs-0.1.2-py3-none-any.whl
Collecting captum==0.2.0 (from dive-into-graphs)
Downloading https://files.pythonhosted.org/packages/42/de/c018e206d463d9975444c28b0a4f103c9ca4b2faedf943df727e402a1a1e/captum-0.2.0-py3-none-any.whl (1.4MB)
|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1.4MB 9.6MB/s
Collecting shap (from dive-into-graphs)
Using cached https://files.pythonhosted.org/packages/b9/f4/c5b95cddae15be80f8e58b25edceca105aa83c0b8c86a1edad24a6af80d3/shap-0.39.0.tar.gz
Collecting rdkit-pypi (from dive-into-graphs)
ERROR: Could not find a version that satisfies the requirement rdkit-pypi (from dive-into-graphs) (from versions: none)
ERROR: No matching distribution found for rdkit-pypi (from dive-into-graphs)

'''

It would be great if you could fix the installation code here.
I think the issue has to do with the cross platform nature of rdkit. Based on what I understand, the conda alternative works fine.

GradCAM example on ipynb: graph visualization

Hi DIG,

I have tried to understand how to implement the visualization from the ipynb example on pgexplainer and subgraphx for the visualization on gradcam and gnnexplainer. But, I still don't know how to start with it.

As I try to visualize on gradcam example:

from dig.xgraph.method.base_explainer import ExplainerBase
explainer_base = ExplainerBase(model)

visualize_graph = explainer_base.visualize_graph(node_idx=5, edge_index=data.edge_index, edge_mask=masks[0], nolabel=True)

It show this TypeError:


TypeError Traceback (most recent call last)
in
2 explainer_base = ExplainerBase(model)
3
----> 4 visualize_graph = explainer_base.visualize_graph(node_idx=5, edge_index=data.edge_index, edge_mask=masks[0], nolabel=True)

~.conda\envs\py37\lib\site-packages\dig\xgraph\method\base_explainer.py in visualize_graph(self, node_idx, edge_index, edge_mask, y, threshold, nolabel, **kwargs)
234 connectionstyle="arc3,rad=0.08", # rad control angle
235 ))
--> 236 nx.draw_networkx_nodes(G, pos, node_color=node_colors, **kwargs)
237 # define node labels
238 if self.molecule:

TypeError: draw_networkx_nodes() got an unexpected keyword argument 'with_labels'

grafik

Could you perhaps give me some tips how to visualize on ipynb example given on gradcam and gnnexplainer?

Loss function of XGNN

Hello,

I would like to ask about the loss function of XGNN in gnn_explain.py with a following question:

Specifically at line 100~ 119 in gnn_explain.py
when total_reward < 0 then loss is also negative, but loss should not be below zero.. is there any preprocessing the negative reward? (I hardly find this part of the code)

if total_reward < 0:

Thank you for your help in advance!

GCNConv self.weight missing

in GCNConv, inside the forward function you perform the following operation
x = torch.matmul(x, self.weight)

torch geometric has replaced the weight with a linear layer

Shapley value calculation mismatch with the SubgraphX paper

Hello there,

May I ask a question about the Shapley value calculation, especially the following line of code set_exclude_mask = np.ones(num_nodes)?

set_exclude_mask = np.ones(num_nodes)
set_exclude_mask[local_region] = 0.0
if node_exclude_subset:
set_exclude_mask[list(node_exclude_subset)] = 1.0

From the implementation above, seems like the mask is initialized as including all the nodes in the graph. Then all the selected nodes in the node_exclude_subset are included, and other nodes in the local_region are not. However, doesn't this includes all the nodes outside the local region as well?

Based on this description from the paper,

Screen Shot 2021-11-03 at 9 40 20 PM

I feel like the line 122 above should be changed to set_exclude_mask = np.zeros(num_nodes) ? Or maybe I am missing something here? I appreciate your help.

assertion error when running dimenet

Hi,

Thanks for the nice work. When trying to run threedgraph with dimenet with target='mu', i get the folllowing assertion error:

File "/usr/local/lib/python3.9/site-packages/dig/threedgraph/method/run.py", line 71, in run
valid_mae = self.val(model, valid_loader, energy_and_force, p, evaluation, device)
File "/usr/local/lib/python3.9/site-packages/dig/threedgraph/method/run.py", line 174, in val
return evaluation.eval(input_dict)['mae']
File "/usr/local/lib/python3.9/site-packages/dig/threedgraph/evaluation/eval.py", line 30, in eval
assert(len(y_true.shape) == 1)
AssertionError

Thanks for looking into this!

[QUESTION]: XGNN missing from dig branch

Hey folks,

Quick question regarding XGNN. It seems missing from dig branch (but available in main branch). Any specific reasons behind its (temporary) unavailability?

Cheer,

V

xgraph/PGExplainer edge mask values

I find that the edge mask values in the pgexplainer implementation can be larger than 1 when i run the code on BA_shapes dataset. In the original paper and the official implementation the values are in (0, 1). Would this be a problem?
an example edge mask output:
ori_node_idx = 699
tensor([6.2408, 6.2408, 6.2408, 6.2408, 6.2408, 6.2408, 6.2408, 6.2408, 6.2408,
6.2408, 7.1816, 7.1816, 5.8144, 5.4829, 5.3828, 5.8144, 5.2910, 5.1266,
5.2910, 4.9597, 5.4829, 4.9597, 5.3828, 5.1266], device='cuda:0')

Run DIG/benchmarks/xgraph/gnnexplainer.ipynb

Hi,
When I try to run the gnnexplainer.ipynb, error comes as:
from inspect import isclass
----> 9 import dig.xgraph.models.models as models
10 import torch

AttributeError: module 'dig.xgraph' has no attribute 'models'

Any suggestions? Thanks

Explaining a batch of graphs in SubgraphX

Hi,

I read the paper "On Explainability of Graph Neural Networks via Subgraph Explorations" and find your work quite exciting. I take a look at your codes and find subgraphx can only explain graphs one by one. Is it possible to send a batch of graphs to subgraphx so that the explanation can speed up?

Problems while running the threedgraph.ipynb - AttributeError

I installed DIG and then cloned the repo locally, I tried to run the three3graph.ipynb tutorial as is, but it threw an error on line 4, which reads as follows
model = SphereNet(energy_and_force=False, cutoff=5.0, num_layers=4, hidden_channels=128, out_channels=1, int_emb_size=64, basis_emb_size_dist=8, basis_emb_size_angle=8, basis_emb_size_torsion=8, out_emb_channels=256, num_spherical=3, num_radial=6, envelope_exponent=5, num_before_skip=1, num_after_skip=2, num_output_layers=3, use_node_features=True ) loss_func = torch.nn.L1Loss() evaluation = ThreeDEvaluator()

the error I get is:


AttributeError Traceback (most recent call last)
in
4 num_spherical=3, num_radial=6, envelope_exponent=5,
5 num_before_skip=1, num_after_skip=2, num_output_layers=3,
----> 6 use_node_features=True
7 )
8 loss_func = torch.nn.L1Loss()

~/.local/lib/python3.6/site-packages/dig/threedgraph/method/spherenet/spherenet.py in init(self, energy_and_force, cutoff, num_layers, hidden_channels, out_channels, int_emb_size, basis_emb_size_dist, basis_emb_size_angle, basis_emb_size_torsion, out_emb_channels, num_spherical, num_radial, envelope_exponent, num_before_skip, num_after_skip, num_output_layers, act, output_init, use_node_features)
263 self.init_v = update_v(hidden_channels, out_emb_channels, out_channels, num_output_layers, act, output_init)
264 self.init_u = update_u()
--> 265 self.emb = emb(num_spherical, num_radial, self.cutoff, envelope_exponent)
266
267 self.update_vs = torch.nn.ModuleList([

~/.local/lib/python3.6/site-packages/dig/threedgraph/method/spherenet/spherenet.py in init(self, num_spherical, num_radial, cutoff, envelope_exponent)
22 super(emb, self).init()
23 self.dist_emb = dist_emb(num_radial, cutoff, envelope_exponent)
---> 24 self.angle_emb = angle_emb(num_spherical, num_radial, cutoff, envelope_exponent)
25 self.torsion_emb = torsion_emb(num_spherical, num_radial, cutoff, envelope_exponent)
26 self.reset_parameters()

~/.local/lib/python3.6/site-packages/dig/threedgraph/method/spherenet/features.py in init(self, num_spherical, num_radial, cutoff, envelope_exponent)
196 # self.envelope = Envelope(envelope_exponent)
197
--> 198 bessel_forms = bessel_basis(num_spherical, num_radial)
199 sph_harm_forms = real_sph_harm(num_spherical)
200 self.sph_funcs = []

~/.local/lib/python3.6/site-packages/dig/threedgraph/method/spherenet/features.py in bessel_basis(n, k)
56 normalizer += [normalizer_tmp]
57
---> 58 f = spherical_bessel_formulas(n)
59 x = sym.Symbol('x')
60 bess_basis = []

~/.local/lib/python3.6/site-packages/dig/threedgraph/method/spherenet/features.py in spherical_bessel_formulas(n)
35
36 def spherical_bessel_formulas(n):
---> 37 x = sym.Symbol('x')
38
39 f = [sym.sin(x) / x]

AttributeError: 'NoneType' object has no attribute 'symbols'

Now I did try a solution which was changing the sym.symbols('x') to sym.Symbol('x'), which is the correct usage according to the SymPy page. However this doesn't solve the problem and gives the exact same error on the same line.

I'm running on a linux pod using docker and kubernetes, other notebooks from the DIG tutorials have worked.

Pytorch 1.8.0
Python 3.6
PyG 2.0.1
#91-Ubuntu SMP
5.4.0-81-generic

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.