Giter Site home page Giter Site logo

text-classification's Issues

Use pre-trained embedding instead of randome one

Hi there, as I did see in your attn_bi_lstm.py (maybe also other approaches), used a random technique in embedding words. I did give it a try to use pre-trained embedding, however, I do not know how to set it up, also get an error with "must have rank at least 3" (so sorry, I am newer to Tensorflow). Thank you and much appreciate.

Word embedding

    embeddings_var = tf.Variable(tf.random_uniform([self.vocab_size, self.embedding_size], -1.0, 1.0),
                                 trainable=True)
    batch_embedded = tf.nn.embedding_lookup(embeddings_var, self.x)

    rnn_outputs, _ = bi_rnn(BasicLSTMCell(self.hidden_size),
                            BasicLSTMCell(self.hidden_size),
                            inputs=batch_embedded, dtype=tf.float32)

My trial

def generate_embedding(word_index, model_embedding,EMBEDDING_DIM):
  count6 = 0
   countNot6 = 0
   #embedding_matrix = np.zeros((len(word_index) + 1, EMBEDDING_DIM)) 
   embedding_matrix = np.asarray([np.random.uniform(-0.01,0.01,EMBEDDING_DIM) for _ in range((len(word_index) + 1))])
   list_oov = []
   for word, i in word_index.items():
       try:
           embedding_vector = model_embedding[word]
       except:
           list_oov.append(word)
           countNot6 +=1
           continue
       if embedding_vector is not None:
           count6 +=1
           embedding_matrix[i] = embedding_vector
   return embedding_matrix

batch_embedded = generate_embedding(word_index,word_embedding,EMBEDDING_DIM)
rnn_outputs, _ = bi_rnn(BasicLSTMCell(self.hidden_size),
                               BasicLSTMCell(self.hidden_size),
                               inputs=batch_embedded, dtype=tf.float32)

Note that, I got an error at inputs=batch_embedded

Error in attn_bi_lstm.py while feeding data label during training

> python attn_bi_lstm.py

InvalidArgumentError (see above for traceback): Received a label value of 
-2147483648 which is outside the valid range of [0, 15).  Label values: -2147483648
 -2147483648 -2147483648 -2147483648 -2147483648 -2147483648 
-2147483648 -2147483648 -2147483648 -2147483648 -2147483648
 -2147483648 -2147483648 -2147483648 -2147483648 -2147483648 
-2147483648 -2147483648 -2147483648 -2147483648 -2147483648 
-2147483648 -2147483648 -2147483648 -2147483648 -2147483648 
-2147483648 -2147483648 -2147483648 -2147483648 -2147483648 -2147483648
	 [[{{node SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits}}
 = SparseSoftmaxCrossEntropyWithLogits[T=DT_FLOAT, Tlabels=DT_INT32, 
_device="/job:localhost/replica:0/task:0/device:CPU:0"](xw_plus_b, _arg_Placeholder_1_0_1)]]

Printing y_train values:


5164    NaN
3458    NaN
3236    NaN
3118    NaN
1555    NaN
930     NaN
3188    NaN
2899    NaN
2918    NaN
1431    NaN
2373    NaN
1205    NaN
2734    NaN
2560    NaN
1495    NaN
5430    NaN
2912    NaN
2098    NaN
2410    NaN
4482    NaN
1045    1.0
2469    NaN
1703    NaN
250     NaN
5214    NaN
4767    NaN
849     NaN
976     NaN
5489    NaN
5545    NaN
5241    NaN
3128    NaN

Thanks

cannot load the dataset

The tf.contrib module is depreciated so the data can't be loaded.
Also the link given to download the data is not working anymore.
Could you please check this?

attn_bi_lstm.py模型的y_hat那里是不是写错了?

FC_W = tf.Variable(tf.truncated_normal([self.hidden_size, self.max_len], stddev=0.1))

这里马上要求出y的预测值了,而类别只有15个。而不是32个。
这里FC_W 的shape应该是:
[self.hidden_size, self.n_class]
而不是
[self.hidden_size, self.max_len]
吧?下面的FC_b 也是同样的问题。
老哥来看看是不是有问题,还是我理解错了~ @TobiasLee

attn_bi_lstm.py

excuse me, the code in attn_bi_lstm.py, the graph of ABLSTM, I don't see you use "attention"。Maybe you just only use Bi-LSTM?

ValueError: Cannot feed value of shape (32, 15) for Tensor 'Placeholder_1:0', which has shape '(?,)

I am getting above error when I run bilstm attention program on DBpedia dataset

Traceback (most recent call last):
File "attn_bi_lstm.py", line 112, in
return_dict = run_train_step(classifier, sess, (x_batch, y_batch))
File "/home/kbk/Desktop/BudddiHealth/higher models/Text-Classification-master/models/utils/model_helper.py", line 26, in run_train_step
return sess.run(to_return, feed_dict)
File "/home/kbk/miniconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 877, in run
run_metadata_ptr)
File "/home/kbk/miniconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1076, in _run
str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (32, 15) for Tensor 'Placeholder_1:0', which has shape '(?,)'

How can I load data

when I run the statement, it gives me an NameError
"NameError: name 'FLAGS' is not defined"
Could you please tell me what is 'FLAGS' referring to?

data not exist

i ran your model "adversarial_abblstm.py", but the compiler shows file /dbpedia_data/dbpedia_csv/train.csv does not exist, could you tell me where to download the train.csv and text.csv?

Test Accuracy is lower than the Performance in Readme

Epoch 19 start !
Train Epoch time: 108.163 s
validation accuracy: 0.932
Epoch 20 start !
Train Epoch time: 107.236 s
validation accuracy: 0.936
Training finished, time consumed : 2188.0112912654877 s
Start evaluating:

Test accuracy : 93.619048 %

This is the performance in attn_bi_lstm.py。 Why it’s lower than 98.23 % in readme?
Thanks!

validation and testing accuracy=0

First of all thank you for your effort. My problem is that the implementation run correctly when I use dbpedia dataset, but when I try to my dataset I got an accuracy of 0. My dataset is in the Arabic language.

关于 Adversarial Training Methods For Semi-Supervised Text Classification代码中的一个问题

您好 我想问一下

   logits, self.cls_loss = cal_loss_logit(batch_embedded, self.keep_prob, reuse=False)
    embedding_perturbated = self._add_perturbation(batch_embedded, self.cls_loss)
    adv_logits, self.adv_loss = cal_loss_logit(embedding_perturbated, self.keep_prob, reuse=True)

这里为什么在不加扰动的时候,reuse=False,为什么在计算加上扰动求loss的时候reuse=True呢

谢谢~

wrong with cnn.py

Could you tell me what's wrong with "Process finished with exit code 139 (interrupted by signal 11: SIGSEGV)" when I run cnn.py?

Wrong output dimension of the embedding_lookup table

I looked up again and again and the dimension of the output of tf.embedding_lookup should be [?,256,128], but you have written it as [?,256,100](in comments) either I couldn't understand the concept or there's something really wrong there. Please clarify.

怎么存模型呢

请问这个模型应该怎么存,我想写一个单独的预测函数,在训练位置saver.save(sess, '../save_model/atttn_lstm_hierarchical/')
在预测的地方加载模型的时候使用

 with tf.Session()as sess:
        saver = tf.train.import_meta_graph('./save_model/atttn_lstm_hierarchical/-1000.meta')
        saver.restore(sess, tf.train.latest_checkpoint('./save_model/atttn_lstm_hierarchical'))

但是报错如下:
QQ截图20220720094928
请问是什么问题呢

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.