kangxiatao / kgat-pytorch-master Goto Github PK
View Code? Open in Web Editor NEWKnowledge Graph Attention Network
Knowledge Graph Attention Network
您好,
我训练的是douban250的数据,但是它好像没有npz预训练,所以我用其他三个数据集的预训练npz,请问训练别的数据集也是可以用用一个mf.npz吗,我这边报错了
2023-11-09 20:30:34,516 - root - INFO - Namespace(K=20, aggregation_type='bi-interaction', cf_batch_size=1024, cf_l2loss_lambda=1e-05, cf_print_every=1, conv_dim_list='[64, 32, 16]', data_dir='datasets/', data_name='douban250', entity_dim=64, evaluate_every=1, kg_batch_size=1024, kg_l2loss_lambda=1e-05, kg_print_every=1, local_rank=0, lr=0.0001, mess_dropout='[0.1, 0.1, 0.1]', n_epoch=1000, pretrain_embedding_dir='datasets/pretrain/', pretrain_model_path='trained_model/KGAT/amazon-book/entitydim64_relationdim64_bi-interaction_64-32-16_lr0.0001_pretrain1/model_epoch1.pth', relation_dim=64, save_dir='trained_model/KGAT/douban250/entitydim64_relationdim64_bi-interaction_64-32-16_lr0.0001_pretrain1/', seed=2020, stopping_steps=10, test_batch_size=1024, use_graph=1, use_pretrain=1)
All logs will be saved to trained_model/KGAT/douban250/entitydim64_relationdim64_bi-interaction_64-32-16_lr0.0001_pretrain1/log1.log
device: cpu n_gpu: 0
load data ...
-- datasets/douban250 --
2023-11-09 20:30:34 -- cf data finish --
2023-11-09 20:30:34 -- kg data load --
2023-11-09 20:30:36 -- kg data finish --
70679
Traceback (most recent call last):
File "D:/KGAT-pytorch-master-main/main_kgat.py", line 276, in
train(args)
File "D:/KGAT-pytorch-master-main/main_kgat.py", line 86, in train
data = DataLoaderKGAT(args, logging)
File "D:\KGAT-pytorch-master-main\utility\loader_kgat.py", line 90, in init
self.load_pretrained_data()
File "D:\KGAT-pytorch-master-main\utility\loader_kgat.py", line 352, in load_pretrained_data
assert self.user_pre_embed.shape[0] == self.n_users
AssertionError
然后请问kg_final.txt要怎么生成?
您好,有一个问题想请教一下。
当我运行到这一步的时候:
kg_batch_loss = model('calc_kg_loss', kg_batch_head, kg_batch_relation, kg_batch_pos_tail, kg_batch_neg_tail)
出现了报错。
然后我尝试在calc_kg_loss方法中,print出pos_score,定位到了问题出在计算pos_score这一步。
部分报错信息如下:
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:658: indexSelectLargeIndex: block: [13,0,0], thread: [64,0,0] Assertion srcIndex < srcSelectDimSize
failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:658: indexSelectLargeIndex: block: [13,0,0], thread: [65,0,0] Assertion srcIndex < srcSelectDimSize
failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:658: indexSelectLargeIndex: block: [13,0,0], thread: [66,0,0] Assertion srcIndex < srcSelectDimSize
failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:658: indexSelectLargeIndex: block: [13,0,0], thread: [67,0,0] Assertion srcIndex < srcSelectDimSize
failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:658: indexSelectLargeIndex: block: [13,0,0], thread: [68,0,0] Assertion srcIndex < srcSelectDimSize
failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:658: indexSelectLargeIndex: block: [13,0,0], thread: [69,0,0] Assertion srcIndex < srcSelectDimSize
failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:658: indexSelectLargeIndex: block: [13,0,0], thread: [70,0,0] Assertion srcIndex < srcSelectDimSize
failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:658: indexSelectLargeIndex: block: [13,0,0], thread: [71,0,0] Assertion srcIndex < srcSelectDimSize
failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:658: indexSelectLargeIndex: block: [13,0,0], thread: [72,0,0] Assertion srcIndex < srcSelectDimSize
failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:658: indexSelectLargeIndex: block: [13,0,0], thread: [73,0,0] Assertion srcIndex < srcSelectDimSize
failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:658: indexSelectLargeIndex: block: [13,0,0], thread: [74,0,0] Assertion srcIndex < srcSelectDimSize
failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:658: indexSelectLargeIndex: block: [13,0,0], thread: [75,0,0] Assertion srcIndex < srcSelectDimSize
failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:658: indexSelectLargeIndex: block: [13,0,0], thread: [76,0,0] Assertion srcIndex < srcSelectDimSize
failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:658: indexSelectLargeIndex: block: [13,0,0], thread: [77,0,0] Assertion srcIndex < srcSelectDimSize
failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:658: indexSelectLargeIndex: block: [13,0,0], thread: [78,0,0] Assertion srcIndex < srcSelectDimSize
failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:658: indexSelectLargeIndex: block: [13,0,0], thread: [79,0,0] Assertion srcIndex < srcSelectDimSize
failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:658: indexSelectLargeIndex: block: [13,0,0], thread: [80,0,0] Assertion srcIndex < srcSelectDimSize
failed.
。。。。。
File "train.py", line 137, in
kg_batch_neg_tail)
File , line 727, in _call_impl
result = self.forward(*input, **kwargs)
File , line 190, in forward
return self.calc_kg_loss(*input)
File , line 168, in calc_kg_loss
前面的r_mul_h,r_embed,r_mul_pos_t,打印出来size都是(256,64)。
但是这一步计算torch.sum(torch.pow(r_mul_h + r_embed - r_mul_pos_t, 2), dim=1) 无法正常执行了,能否帮忙看一下问题出现在哪里呢?非常感谢!
在main_kgat.py中,我注意到您是先求attention,然后训练kg_Loss,最后训练cf_Loss.
这么做的原因是因为什么呢?
而且我发现当不使用--use_pretrain之后,效果衰减的很厉害。几乎可以称得上训练不了。
您好,我对高阶传播这部分代码有疑问,能不能麻烦您解释一下
for idx, layer in enumerate(self.aggregator_layers):
ego_embed = layer(ego_embed, self.A_in)
norm_embed = F.normalize(ego_embed, p=2, dim=1)
all_embed.append(norm_embed)
请问论文中的公式10就是ego_embed = layer(ego_embed, self.A_in)吗,那么论文中的邻居结点是如何体现高阶传播呢,怎么好像都是一阶节点在学习embedding呢。
期待您的回复。
请问您自己制作的数据集中的知识图谱信息是怎么做的呢?尤其是kg_final.txt。谢谢!
请问ndcg20指标为什么跟原论文相比差了很多?谢谢~
各位有issue和私信的学者,实在不好意思,后面基本上没回复。我这个课题有点久远了,基本都忘了,现在在做其他方向,有点无力。KGAT的整体代码还是蛮清晰的,里面的注释点和论文都有对应,可以debug慢慢读一遍。祝各位多发paper,少掉头发、基金职称两开花!
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.