Comments (13)
how much memory of cpu and gpu in your env?
what's batch size?
which model?
which line you find this problem?
from text_classification.
1,@brightmart 显卡GTX1060,6GB,内存16G,后来我换我同学的1080Ti,内存32GB还是报错。
2,跑的是a02_TextCNN,出错是在data_util_zhihu.py里面的load_data_multilabel_new()函数,也就是生产标签y的时候,1999维太大了。
3,还有我用的训练集是你的代码生成的train-zhihu6-title-desc.txt,,词向量是用299万title+desc(字符和词语结合)在word2vec训练的,是这样的吗?
4,这个地方:data_util_zhihu.py里面的creat_vocabulary_label()
我不是很理解,你这个函数产生的是什么,我把它改成了:
def create_voabulary_label():
voabulary_label = './train-zhihu6-title-desc.txt'
zhihu_f_train = codecs.open(voabulary_label, 'r', 'utf8')
lines=zhihu_f_train.readlines()
count=0
vocabulary_word2index_label={}
vocabulary_index2word_label={}
vocabulary_label_count_dict={} #{label:count}
for i,line in enumerate(lines):
if '__label__' in line: #'__label__-2051131023989903826
label=line[line.index('__label__')+len('__label__'):].strip().replace("\n","").split(' ')
for item in label:
if vocabulary_label_count_dict.get(item,None) is not None:
vocabulary_label_count_dict[item]=vocabulary_label_count_dict[item]+1
else:
vocabulary_label_count_dict[item]=1
list_label=sort_by_value(vocabulary_label_count_dict)
print("length of list_label:",len(list_label));#print(";list_label:",list_label)
countt=0
#########################################################################################
for i,label in enumerate(list_label):
if i<50:
count_value=vocabulary_label_count_dict[label]
print("label:",label,"count_value:",count_value)
countt=countt+count_value
indexx = i
vocabulary_word2index_label[label]=indexx
vocabulary_index2word_label[indexx]=label
return vocabulary_word2index_label,vocabulary_index2word_label
改完后生成的打印前10
label: 7476760589625268543 count_value: 62952
label: 4697014490911193675 count_value: 46987
label: -4653836020042332281 count_value: 43165
label: -8175048003539471998 count_value: 40206
label: -8377411942628634656 count_value: 37611
label: -7046289575185911002 count_value: 36330
label: -5932391056759866388 count_value: 33563
label: 2787171473654490487 count_value: 31693
label: -7129272008741138808 count_value: 30344
label: -5513826101327857645 count_value: 29668
也就是vocabulary_index2word_label 中的key为1999中的某一个topic,value为在训练集中出现的次数,这样对吗?
from text_classification.
@brightmart 要是不改 create_voabulary_label() 这个函数的话,把 voabulary_label = './train-zhihu6-title-desc.txt' 扔进去跑,然后到下一个函数 transform_multilabel_as_multihot() 的 result=np.zeros(label_size) 这一行就会报错:list out of range,所以我想搞清楚 create_voabulary_label()到底生成的是什么,另外这个函数接收的数据不是训练集 train-zhihu6-title-desc.txt?
from text_classification.
1,我也参加了知乎看山杯,用的模型也是Yoon Kim 这篇CNN论文 https://arxiv.org/abs/1510.03820
2,但是准确率特别低,就是最后提交的结果。
4,跑我自己的模型也是out of memory(要么显存不足,要么内存不足),所以我跑的时候只用了前面100万的训练数据。
5,做法跟你的类似,也是X先用字典索引表示,然后神经网络第一次为embedding,embedding的输出是三维(sample,sentences_length,embedding_dim), Y用multi-hot表示的(e.g.[1,1,0,1,0,0,........])可能是预测的时候,要预测1999的子集,所以准确率极低。
6,我想问下你用的 multi_label 是我说的那样的吗? 如果是单标签又是怎么样的?
7,如果多标签训练内存出错,单标签的话,我看你的代码有一部分这样处理(问题title+desc,话题desc,1)或者(问题title+desc,话题desc,0)我想问下这个正样本很好处理,但是负样本是怎么来的,然后训练的话假如一个问题包含多个话题,那么怎样预测的?
from text_classification.
about practice problems(memory issue, vocabulary size, cast multilabel task to single label problem):
-
we use RAM 128G, gpu 8G. and did not meet memory issue.
-
you can also cast multi-label task problem into single label problem, and get performance lose to multi-label style 's performance.
you just believe it is single label in training. but when predict, get top-k from logits. -
if you need memory issue at creating one hot style memory, you can change your code. for example, instead of use tf.nn.sparse_softmax_cross_entropy_with_logits, you can use softmax_cross_entropy_with_logits.
-
we use word2vec to get word embdding. but we ignore those rare vocabularies. for example, setting min_count=20. so total vocabulary size is not big,like 200k or 400k.
-
creat_vocabulary_label() is just to put unique label-index pairs into dict. but we sort it. because this is requested when compute loss for multi-label. as long as you can create your label-index pair it will be fine.
as you can see, the input of this method is: 'train-zhihu4-only-title-all.txt'
sample data like:
w32080 w54 w5782 w3595 w54 w6350 w2335 w23 w54 w4016 label-3862493943988412715
w39612 w33 w19811 w1361 w111 __label__8722585993781275470
w20780 w12633 w269 w1482 w10147 w111 __label__6277846260132427404 -
multi-label is same as: [1,1,0,1,0,0,........]
from text_classification.
1,我按你说的自己生成了文件 'train-zhihu4-only-title-all.txt' ,程序没有问题,但是还是跑不了,内存不够。
2,看来要试试单标签。
3,你实在服务器上跑的吗?
4,我把训练数据缩小到100万了,可以加载,但是在启动cuda的时候出错了。
from text_classification.
you can set word frequency at 20("min_count=20") or more when generating word vocabulary at word2vec. so vocabulary size will be less than 100k.
from text_classification.
thank you all the way and I have learned a lot from your code. I'll have try.
from text_classification.
你好可不可以帮我看看这个错误:
InvalidArgumentError (see above for traceback): logits and labels must be same size: logits_size=[512,1999] labels_size=[1,512]
单标签的时候,这个地方报错了:
p7_TextCNN_model.py:
losses = tf.nn.softmax_cross_entropy_with_logits(labels=self.input_y, logits=self.logits);
from text_classification.
for single label,ithink label _size should be [512,1]. you can check method ofsoft max cross entropy with logits to get more detail.
from text_classification.
还是不懂, 汗(⊙﹏⊙)b,要是按你说的话,那这里是不是要定义为[batch_size, ]
self.input_y = tf.placeholder(tf.int32, [None,], name="input_y") # y:[None,num_classes]
self.input_y_multilabel = tf.placeholder(tf.float32,[None,self.num_classes], name="input_y_multilabel")
上面是单标签,下面是多标签,为什么还是报错: logits_size=[512,1999] labels_size=[1,512]
我的输入:
X():
[[ 80009 105626 129221 ..., 0 0 0]
[ 80009 105626 129221 ..., 0 0 0]
[ 40245 106620 103746 ..., 0 0 0]
...,
[ 67771 16202 138143 ..., 0 0 0]
[118126 136109 88785 ..., 0 0 0]
[ 67771 132615 70578 ..., 0 0 0]]
Y(就是单标签):
[744, 36, 1386, 326, 1365, 98, 85, 123, 1725, 202......]
可以帮我看看吗?
from text_classification.
you can check the files under:a07_Transformer
below is for single label classification
a2_transformer_classification.py: model
a2_train_classification.py: for train
a2_predict_classification.py: for prediction
from text_classification.
@ynuwm ,我想问下single label下 ,textCNN报错的问题是如何解决的?
from text_classification.
Related Issues (20)
- p7_Textcnn.model not work
- TextGCN modles
- 您好,能否加一个pytorch版本的呢
- 缺少数据 HOT 5
- intermediate data files HOT 1
- question of TextCNN
- git clone error
- How about SpanBert
- a bytes-like object is required, not 'str'
- Getting the data is a big burden
- data_untils只有测试集和训练集,没有验证集 HOT 1
- seq2seq attention in train with "multi_label_flag=False"
- The BERT model seems trained from the dataset itself?
- Are you sure the latest code of this project is working properly? HOT 2
- HAN的attention里为什么加reduce_sum和reduce_max? HOT 1
- tflearn.data_utils
- 是否能提供一下数据文件生成.h5和.pik文件的代码? HOT 1
- train-data:train-zhihu4-only-title-all.txt/zhihu-word2vec.bin-100 HOT 1
- how to convert the textCNN model to ONNX format?
- Performance issue in the definition of answer_module, a09_DynamicMemoryNet/a8_dynamic_memory_network.py HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from text_classification.