Giter Site home page Giter Site logo

r-gsn's People

Contributors

xjtuwxliang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

zbn123

r-gsn's Issues

CUDA out of memory

Hi,Thank you for releasing your code. When I run R-GSN get error "RuntimeError: CUDA out of memory. Tried to allocate 562.00 MiB (GPU 1; 10.76 GiB total capacity; 8.98 GiB already allocated; 470.56 MiB free; 9.19 GiB reserved in total by PyTorch)".

I try to reduce the batch size. batch_size has been reduced to 64 and test_batch_size has been reduced to 4, I still get the same error. I used GeForce RTX 2080, can u tell me why and how to fix it, thanks a lot!

Environment

numpy==1.18.5
scipy==1.6.2
ogb==1.3.1
texttable==1.6.3
torch==1.7.0+cu110
torchvision==0.8.0
torch-cluster==1.5.9
torch-geometric==1.7.0
torch-scatter==2.0.7
torch-sparse==0.6.9
torch-spline-conv==1.2.1

full error information

Using backend: pytorch
+-----------------+-------+
| Parameter       | Value |
+-----------------+-------+
| device          | 1     |
+-----------------+-------+
| num_layers      | 2     |
+-----------------+-------+
| hidden_channels | 64    |
+-----------------+-------+
| dropout         | 0.500 |
+-----------------+-------+
| lr              | 0.004 |
+-----------------+-------+
| epochs          | 3     |
+-----------------+-------+
| runs            | 10    |
+-----------------+-------+
| batch_size      | 64    |
+-----------------+-------+
| test_batch_size | 4     |
+-----------------+-------+
| opt             | adamw |
+-----------------+-------+
| early_stop      | 1     |
+-----------------+-------+
| feat_dir        | feat  |
+-----------------+-------+
| conv_name       | rgsn  |
+-----------------+-------+
| Norm4           | 1     |
+-----------------+-------+
| FDFT            | 1     |
+-----------------+-------+
| use_attack      | 1     |
+-----------------+-------+
Data(
  edge_index_dict={
    ('author', 'affiliated_with', 'institution')=[2, 1043998],
    ('author', 'writes', 'paper')=[2, 7145660],
    ('paper', 'cites', 'paper')=[2, 5416271],
    ('paper', 'has_topic', 'field_of_study')=[2, 7505078]
  },
  edge_reltype={
    ('author', 'affiliated_with', 'institution')=[1043998, 1],
    ('author', 'writes', 'paper')=[7145660, 1],
    ('paper', 'cites', 'paper')=[5416271, 1],
    ('paper', 'has_topic', 'field_of_study')=[7505078, 1]
  },
  node_year={
    paper=[736389, 1]
  },
  num_nodes_dict={
    author=1134649,
    field_of_study=59965,
    institution=8740,
    paper=736389
  },
  x_dict={
    paper=[736389, 128]
  },
  y_dict={
    paper=[736389, 1]
  }
)
preprocess finished
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/container.py:550: UserWarning: Setting attributes on ParameterDict is not supported.
  warnings.warn("Setting attributes on ParameterDict is not supported.")
Model #Params: 154373028
Attack Epoch 01: 100%|███████████████| 629571/629571 [1:01:40<00:00, 170.13it/s]
* infer valid_test exact :  86%|█████▏| 629655/736389 [01:41<2:44:33, 10.81it/s]Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
    return func(*args, **kwargs)
  File "/data-input/houl/R-GSN/rgsn.py", line 284, in infer
    out = model(n_id, x_dict, adjs, edge_type, node_type, local_node_idx)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data-input/houl/R-GSN/models.py", line 266, in forward
    x = conv((x, x_target), edge_index, edge_type[e_id], node_type, src_node_type)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data-input/houl/R-GSN/models.py", line 124, in forward
    msg_from_i = F.normalize(self.propagate(ei, x=x, edge_type=i, src_node_type = src_node_type, a=a))
  File "/opt/conda/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py", line 237, in propagate
    out = self.message(**msg_kwargs)
  File "/data-input/houl/R-GSN/models.py", line 163, in message
    res = a.unsqueeze(-1) * self.rel_lins[edge_type](x_j)  ######## Message Transform
RuntimeError: CUDA out of memory. Tried to allocate 310.00 MiB (GPU 1; 10.76 GiB total capacity; 9.36 GiB already allocated; 54.56 MiB free; 9.59 GiB reserved in total by PyTorch)
python-BaseException

Question about Message Passing Part in model.py

Dear R-GSN authors,

I have a question about how to calculate attention score in this part.
def message(self, edge_index_i, x_i, x_j, src_node_type_j, edge_type: int , a=None): if a == None: res = x_j else: if x_i.size(0) == 0: return self.rel_lins[edge_type](x_j) a = softmax(a, edge_index_i) res = a.unsqueeze(-1) * self.rel_lins[edge_type](x_j) ######## Message Transform return res
In this part, there is this line
a = softmax(a, edge_index_i)
I think edge_index_i is the index of target type node. Shouldn't it be edge_index_j(the index of source type node)? Because I checked your paper, the attention is calculated by:
截屏2022-01-23 12 53 13
I'm a little bit confused. Could you kindly explain this part please?
Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.