Giter Site home page Giter Site logo

Comments (13)

brightmart avatar brightmart commented on May 12, 2024

how much memory of cpu and gpu in your env?
what's batch size?
which model?
which line you find this problem?

from text_classification.

ynuwm avatar ynuwm commented on May 12, 2024

1,@brightmart 显卡GTX1060,6GB,内存16G,后来我换我同学的1080Ti,内存32GB还是报错。
2,跑的是a02_TextCNN,出错是在data_util_zhihu.py里面的load_data_multilabel_new()函数,也就是生产标签y的时候,1999维太大了。
3,还有我用的训练集是你的代码生成的train-zhihu6-title-desc.txt,,词向量是用299万title+desc(字符和词语结合)在word2vec训练的,是这样的吗?
4,这个地方:data_util_zhihu.py里面的creat_vocabulary_label()
我不是很理解,你这个函数产生的是什么,我把它改成了:

def create_voabulary_label():
        voabulary_label = './train-zhihu6-title-desc.txt'
        zhihu_f_train = codecs.open(voabulary_label, 'r', 'utf8')
        lines=zhihu_f_train.readlines()
        count=0
        vocabulary_word2index_label={}
        vocabulary_index2word_label={}
        vocabulary_label_count_dict={} #{label:count}
        for i,line in enumerate(lines):
            if '__label__' in line:  #'__label__-2051131023989903826
                label=line[line.index('__label__')+len('__label__'):].strip().replace("\n","").split(' ')
                for item in label:                    
                    if vocabulary_label_count_dict.get(item,None) is not None:
                        vocabulary_label_count_dict[item]=vocabulary_label_count_dict[item]+1
                    else:
                        vocabulary_label_count_dict[item]=1



        list_label=sort_by_value(vocabulary_label_count_dict)

        print("length of list_label:",len(list_label));#print(";list_label:",list_label)
        countt=0
        
        #########################################################################################
        for i,label in enumerate(list_label):
            if i<50:
                count_value=vocabulary_label_count_dict[label]
                print("label:",label,"count_value:",count_value)
                countt=countt+count_value
            indexx = i 
            vocabulary_word2index_label[label]=indexx
            vocabulary_index2word_label[indexx]=label
            return vocabulary_word2index_label,vocabulary_index2word_label   

改完后生成的打印前10
label: 7476760589625268543 count_value: 62952
label: 4697014490911193675 count_value: 46987
label: -4653836020042332281 count_value: 43165
label: -8175048003539471998 count_value: 40206
label: -8377411942628634656 count_value: 37611
label: -7046289575185911002 count_value: 36330
label: -5932391056759866388 count_value: 33563
label: 2787171473654490487 count_value: 31693
label: -7129272008741138808 count_value: 30344
label: -5513826101327857645 count_value: 29668

也就是vocabulary_index2word_label 中的key为1999中的某一个topic,value为在训练集中出现的次数,这样对吗?

from text_classification.

ynuwm avatar ynuwm commented on May 12, 2024

@brightmart 要是不改 create_voabulary_label() 这个函数的话,把 voabulary_label = './train-zhihu6-title-desc.txt' 扔进去跑,然后到下一个函数 transform_multilabel_as_multihot() 的 result=np.zeros(label_size) 这一行就会报错:list out of range,所以我想搞清楚 create_voabulary_label()到底生成的是什么,另外这个函数接收的数据不是训练集 train-zhihu6-title-desc.txt?

from text_classification.

ynuwm avatar ynuwm commented on May 12, 2024

1,我也参加了知乎看山杯,用的模型也是Yoon Kim 这篇CNN论文 https://arxiv.org/abs/1510.03820
2,但是准确率特别低,就是最后提交的结果。
4,跑我自己的模型也是out of memory(要么显存不足,要么内存不足),所以我跑的时候只用了前面100万的训练数据。
5,做法跟你的类似,也是X先用字典索引表示,然后神经网络第一次为embedding,embedding的输出是三维(sample,sentences_length,embedding_dim), Y用multi-hot表示的(e.g.[1,1,0,1,0,0,........])可能是预测的时候,要预测1999的子集,所以准确率极低。
6,我想问下你用的 multi_label 是我说的那样的吗? 如果是单标签又是怎么样的?
7,如果多标签训练内存出错,单标签的话,我看你的代码有一部分这样处理(问题title+desc,话题desc,1)或者(问题title+desc,话题desc,0)我想问下这个正样本很好处理,但是负样本是怎么来的,然后训练的话假如一个问题包含多个话题,那么怎样预测的?

from text_classification.

brightmart avatar brightmart commented on May 12, 2024

about practice problems(memory issue, vocabulary size, cast multilabel task to single label problem):

  1. we use RAM 128G, gpu 8G. and did not meet memory issue.

  2. you can also cast multi-label task problem into single label problem, and get performance lose to multi-label style 's performance.
    you just believe it is single label in training. but when predict, get top-k from logits.

  3. if you need memory issue at creating one hot style memory, you can change your code. for example, instead of use tf.nn.sparse_softmax_cross_entropy_with_logits, you can use softmax_cross_entropy_with_logits.

  4. we use word2vec to get word embdding. but we ignore those rare vocabularies. for example, setting min_count=20. so total vocabulary size is not big,like 200k or 400k.

  5. creat_vocabulary_label() is just to put unique label-index pairs into dict. but we sort it. because this is requested when compute loss for multi-label. as long as you can create your label-index pair it will be fine.
    as you can see, the input of this method is: 'train-zhihu4-only-title-all.txt'
    sample data like:
    w32080 w54 w5782 w3595 w54 w6350 w2335 w23 w54 w4016 label-3862493943988412715
    w39612 w33 w19811 w1361 w111 __label__8722585993781275470
    w20780 w12633 w269 w1482 w10147 w111 __label__6277846260132427404

  6. multi-label is same as: [1,1,0,1,0,0,........]

from text_classification.

ynuwm avatar ynuwm commented on May 12, 2024

1,我按你说的自己生成了文件 'train-zhihu4-only-title-all.txt' ,程序没有问题,但是还是跑不了,内存不够。
2,看来要试试单标签。
3,你实在服务器上跑的吗?
4,我把训练数据缩小到100万了,可以加载,但是在启动cuda的时候出错了。

from text_classification.

brightmart avatar brightmart commented on May 12, 2024

you can set word frequency at 20("min_count=20") or more when generating word vocabulary at word2vec. so vocabulary size will be less than 100k.

from text_classification.

ynuwm avatar ynuwm commented on May 12, 2024

thank you all the way and I have learned a lot from your code. I'll have try.

from text_classification.

ynuwm avatar ynuwm commented on May 12, 2024

你好可不可以帮我看看这个错误:
InvalidArgumentError (see above for traceback): logits and labels must be same size: logits_size=[512,1999] labels_size=[1,512]

单标签的时候,这个地方报错了:
p7_TextCNN_model.py:
losses = tf.nn.softmax_cross_entropy_with_logits(labels=self.input_y, logits=self.logits);

from text_classification.

brightmart avatar brightmart commented on May 12, 2024

for single label,ithink label _size should be [512,1]. you can check method ofsoft max cross entropy with logits to get more detail.

from text_classification.

ynuwm avatar ynuwm commented on May 12, 2024

还是不懂, 汗(⊙﹏⊙)b,要是按你说的话,那这里是不是要定义为[batch_size, ]

self.input_y = tf.placeholder(tf.int32, [None,], name="input_y")   # y:[None,num_classes]
self.input_y_multilabel = tf.placeholder(tf.float32,[None,self.num_classes], name="input_y_multilabel") 

上面是单标签,下面是多标签,为什么还是报错: logits_size=[512,1999] labels_size=[1,512]
我的输入:
X():
[[ 80009 105626 129221 ..., 0 0 0]
[ 80009 105626 129221 ..., 0 0 0]
[ 40245 106620 103746 ..., 0 0 0]
...,
[ 67771 16202 138143 ..., 0 0 0]
[118126 136109 88785 ..., 0 0 0]
[ 67771 132615 70578 ..., 0 0 0]]

Y(就是单标签):
[744, 36, 1386, 326, 1365, 98, 85, 123, 1725, 202......]

可以帮我看看吗?

from text_classification.

brightmart avatar brightmart commented on May 12, 2024

you can check the files under:a07_Transformer
below is for single label classification
a2_transformer_classification.py: model
a2_train_classification.py: for train
a2_predict_classification.py: for prediction

from text_classification.

kevinsay avatar kevinsay commented on May 12, 2024

@ynuwm ,我想问下single label下 ,textCNN报错的问题是如何解决的?

from text_classification.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.