Giter Site home page Giter Site logo

aaai18-code's Introduction

Learning Structured Representation for Text Classification via Reinforcement Learning

Tianyang Zhang*, Minlie Huang, Li Zhao

Representation learning is a fundamental problem in natural language processing. This paper studies how to learn a structured representation for text classification. Unlike most existing representation models that either use no structure or rely on pre-specified structures, we propose a reinforcement learning (RL) method to learn sentence representation by discovering optimized structures automatically. We demonstrate two attempts to build structured representation: Information Distilled LSTM (ID-LSTM) and Hierarchically Structured LSTM (HS-LSTM). ID-LSTM selects only important, task-relevant words, and HS-LSTM discovers phrase structures in a sentence. Structure discovery in the two representation models is formulated as a sequential decision problem: current decision of structure discovery affects following decisions, which can be addressed by policy gradient RL. Results show that our method can learn task-friendly representations by identifying important words or task-relevant structures without explicit structure annotations, and thus yields competitive performance.

@inproceedings{zhang2018learning,

title={Learning Structured Representation for Text Classification via Reinforcement Learning},

author={Zhang, Tianyang and Huang, Minlie and Zhao, Li},

booktitle={AAAI},

year={2018}

}

AGnews dataset used in the experiment: https://drive.google.com/open?id=1becf7pzfuLL7qgWqv4q-TyDYjSzodWfR

aaai18-code's People

Contributors

keavil avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

aaai18-code's Issues

ID-LSTM

作者您好,
看了您的代码,我怎么感觉ID-LSTM代码中actor网络的更新忘记乘以reward这一项了?也许是我的理解有问题,希望您能够解惑。

Is Actor-critic used here?

I am confused by your code.

In the paper, it is mentioned that a policy gradient method [1] is used. But more specifically, I think that is implemented by Actor-Critic.

If I am wrong, plz tell me.

[1] Sutton, R. S.; McAllester, D. A.; Singh, S. P.; and Mansour, Y. 2000. Policy gradient methods for reinforcement learning with function approximation. In NIPS, 1057–1063.

结果跑出来准确率并不是很高啊,什么原因

2018-11-30 10 16 42
training : total 19 nodes.
LSTM_only 0 ----test: 0.1 | dev: 0.1
training : total 19 nodes.
LSTM_only 1 ----test: 0.3 | dev: 0.3
LSTM pretrain OK
epsilon 0.05
training : total 19 nodes.
RL pretrain 0 ----test: 0.3 | dev: 0.3
training : total 19 nodes.
RL pretrain 1 ----test: 0.3 | dev: 0.3
training : total 19 nodes.
RL pretrain 2 ----test: 0.3 | dev: 0.3
training : total 19 nodes.
RL pretrain 3 ----test: 0.3 | dev: 0.3
training : total 19 nodes.
RL pretrain 4 ----test: 0.3 | dev: 0.3
training : total 19 nodes.
RL pretrain 5 ----test: 0.3 | dev: 0.3
training : total 19 nodes.
RL pretrain 6 ----test: 0.3 | dev: 0.3
training : total 19 nodes.
RL pretrain 7 ----test: 0.3 | dev: 0.3
training : total 19 nodes.
RL pretrain 8 ----test: 0.3 | dev: 0.3
training : total 19 nodes.
RL pretrain 9 ----test: 0.3 | dev: 0.3
training : total 19 nodes.
RL pretrain 10 ----test: 0.3 | dev: 0.3
training : total 19 nodes.
RL pretrain 11 ----test: 0.3 | dev: 0.3
training : total 19 nodes.
RL pretrain 12 ----test: 0.3 | dev: 0.3
training : total 19 nodes.
RL pretrain 13 ----test: 0.3 | dev: 0.3
training : total 19 nodes.
RL pretrain 14 ----test: 0.3 | dev: 0.3
training : total 19 nodes.
RL pretrain 15 ----test: 0.3 | dev: 0.3
training : total 19 nodes.
RL pretrain 16 ----test: 0.3 | dev: 0.3
training : total 19 nodes.
RL pretrain 17 ----test: 0.3 | dev: 0.3
training : total 19 nodes.
RL pretrain 18 ----test: 0.3 | dev: 0.3
training : total 19 nodes.
RL pretrain 19 ----test: 0.3 | dev: 0.3
RL pretrain OK
training : total 19 nodes.
epoch 0 ----test: 0.3 | dev: 0.3
training : total 19 nodes.
epoch 1 ----test: 0.3 | dev: 0.3
training : total 19 nodes.
epoch 2 ----test: 0.3 | dev: 0.3
training : total 19 nodes.
epoch 3 ----test: 0.3 | dev: 0.3
training : total 19 nodes.
epoch 4 ----test: 0.0 | dev: 0.0

Process finished with exit code 0

PNet and CNet Implementations

Hi

I was going through the paper and the implementation briefly. The implementation contains folders for the ID_LSTM and HS_LSTM parts of the solution. Are the PNet and CNet implementations a part of these folders or are they to be implemented separately? Also in Page 5, it is mentioned "For HS-LSTM, we split a sentence into phrases shorter than the square root of sentence length, and also use some very simple heuristics". Can you please elaborate on what the heuristics are?

Thank You

Dataset .res files

I downloaded the MR dataset from this link : [http://www.cs.cornell.edu/people/pabo/movie-review-data/]. But the dataset does not contain the train, dev and test in *.res file formats as expected by your code. Can you please point to how I can create these file format from the available data?

明白了,我自己根据您的论文也写了一个版本,ID-LSTM在AG语料test集合上三阶段分别是0.9113->0.9130->0.9154,效果并没有您论文提升的那么明显,所有有些疑问:

1、我发现您的代码里虽然定义了batch size,但其实无论是pretrain还是all train,都是单样本更新,如果三个阶段都换成batch样本的更新(我自己写的是batch的更新),对reinforcement learning阶段的更新以及最后阶段全部参数一起更新是否会有影响?
2、我想问下三阶段的学习率分别是多少?哪个阶段对提升的影响比较大?
3、有哪些因素对结果的影响比较大?

问题困扰已久,希望得到您的解答,非常感谢。

Originally posted by @extremin in #5 (comment)

train.res file

Hi,
I downloaded the train/dev/test.res files of AG News corpus as indicated in your repo. However, I cannot understand the structure of the file. It looks like a simple parse tree, but does not resemble the standard parse trees that we can obtain from the parsers available. The AG News' main website also does not contain these .res files.

Could you please tell me how the preprocessing was done, i.e., how to convert the text into this format of parse trees? Any help would be appreciated.

Regards

关于异常

您好,用您的代码运行的时候出现以下异常,请问应该怎么处理呢?
Traceback (most recent call last):
File "/home/chenjiafeng/Desktop/5.2/AAAI18-code-master/ID_LSTM/main.py", line 21, in
train_data, dev_data, test_data = dataManager.getdata(args.grained, args.maxlenth)
File "/home/chenjiafeng/Desktop/5.2/AAAI18-code-master/ID_LSTM/datamanager.py", line 72, in getdata
solution = one_hot_vector(int(sent['rating']))
File "/home/chenjiafeng/Desktop/5.2/AAAI18-code-master/ID_LSTM/datamanager.py", line 52, in one_hot_vector
s[r] += 1.0
IndexError: index 3 is out of bounds for axis 0 with size 2

main.py sampling_RL

您好,
在ID_LSTM/main.py的sampling_RL方法中看到了critic,可是方法的参数列表里并没有传入critic,而且在使用critic的代码前也没看到critic的定义。还望赐教。
代码片段:
def sampling_RL(sess, actor, inputs, vec, lenth, epsilon=0., Random=True):
#print epsilon
current_lower_state = np.zeros((1, 2*args.dim), dtype=np.float32)
actions = []
states = []
#sampling actions

for pos in range(lenth):
    predicted = actor.predict_target(current_lower_state, [vec[0][pos]])
    
    states.append([current_lower_state, [vec[0][pos]]])
    if Random:
        if random.random() > epsilon:
            action = (0 if random.random() < predicted[0] else 1)
        else:
            action = (1 if random.random() < predicted[0] else 0)
    else:
        action = np.argmax(predicted)
    actions.append(action)
    if action == 1:
        out_d, current_lower_state = **critic**.lower_LSTM_target(current_lower_state, [[inputs[pos]]])

Rinput = []
for (i, a) in enumerate(actions):
    if a == 1:
        Rinput.append(inputs[i])
Rlenth = len(Rinput)
if Rlenth == 0:
    actions[lenth-2] = 1
    Rinput.append(inputs[lenth-2])
    Rlenth = 1
Rinput += [0] * (args.maxlenth - Rlenth)
return actions, states, Rinput, Rlenth

Dataset

During my research I came across your valuable paper entitled as "Learning Structured Representation for Text Classification via Reinforcement Learning". To continue my research, would you please provide MR, SST, Subj and AG dataset with Word Vector on the site?
During running your code, I meet this error: "FileNotFoundError: [Errno 2] No such file or directory: '../WordVector/vector.300dim' "
can you help me to solve this error?

datamanager.py

我想问一下在这个函数中,get_wordvector()里面的 ‘n, dim = map(int, fr.readline().split())’。这段代码是什么意思?
map的返回值只有一个吧,而且fr中包含浮点数和英文字符,不能强制转int吧?
希望能得到解答,谢谢!

代码的reward定义是否与论文不符?

作者您好:
看了您论文的ID-LSTM reward的定义是r=logP(cg|x)+r*L'/L,这里的L'指的是删除的词的个数,但是看了您的代码,L'指的好像是保留的词的个数,以及代码里L'/L还取了平方?
如有理解错误,烦请指正,谢谢。

What is the input_l and input_d in the actor_network

What is the input_l and input_d in the actor_network?

In the paper, the policy net accepts states as input and outputs actions.
However, there are two inputs in the actor-network, input_l, input_d. What are them respectively?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.