Giter Site home page Giter Site logo

thunlp-mt / document-transformer Goto Github PK

View Code? Open in Web Editor NEW
172.0 6.0 21.0 305 KB

Improving the Transformer translation model with document-level context

License: BSD 3-Clause "New" or "Revised" License

Python 100.00%
neural-machine-translation document-level-translation

document-transformer's Introduction

Improving the Transformer Translation Model with Document-Level Context

Contents

Introduction

This is the implementation of our work, which extends Transformer to integrate document-level context [paper]. The implementation is on top of THUMT

Usage

Note: The usage is not user-friendly. May improve later.

  1. Train a standard Transformer model, please refer to the user manual of THUMT. Suppose that model_baseline/model.ckpt-30000 performs best on validation set.

  2. Generate a dummy improved Transformer model with the following command:

python THUMT/thumt/bin/trainer_ctx.py --inputs [source corpus] [target corpus] \
                                      --context [context corpus] \
                                      --vocabulary [source vocabulary] [target vocabulary] \
                                      --output model_dummy --model contextual_transformer \
                                      --parameters train_steps=1
  1. Generate the initial model by merging the standard Transformer model into the dummy model, then create a checkpoint file:
python THUMT/thumt/scripts/combine_add.py --model model_dummy/model.ckpt-0 \
                                         --part model_baseline/model.ckpt-30000 --output train
printf 'model_checkpoint_path: "new-0"\nall_model_checkpoint_paths: "new-0"' > train/checkpoint
  1. Train the improved Transformer model with the following command:
python THUMT/thumt/bin/trainer_ctx.py --inputs [source corpus] [target corpus] \
                                      --context [context corpus] \
                                      --vocabulary [source vocabulary] [target vocabulary] \
                                      --output train --model contextual_transformer \
                                      --parameters start_steps=30000,num_context_layers=1
  1. Translate with the improved Transformer model:
python THUMT/thumt/bin/translator_ctx.py --inputs [source corpus] --context [context corpus] \
                                         --output [translation result] \
                                         --vocabulary [source vocabulary] [target vocabulary] \
                                         --model contextual_transformer --checkpoints [model path] \
                                         --parameters num_context_layers=1

Citation

Please cite the following paper if you use the code:

@InProceedings{Zhang:18,
  author    = {Zhang, Jiacheng and Luan, Huanbo and Sun, Maosong and Zhai, Feifei and Xu, Jingfang and Zhang, Min and Liu, Yang},
  title     = {Improving the Transformer Translation Model with Document-Level Context},
  booktitle = {Proceedings of EMNLP},
  year      = {2018},
}

FAQ

  1. What is the context corpus?

The context corpus file contains one context sentence each line. Normally, context sentence is the several preceding source sentences within a document. For example, if the origin document-level corpus is:

==== source ====
<document id=XXX>
<seg id=1>source sentence #1</seg>
<seg id=2>source sentence #2</seg>
<seg id=3>source sentence #3</seg>
<seg id=4>source sentence #4</seg>
</document>

==== target ====
<document id=XXX>
<seg id=1>target sentence #1</seg>
<seg id=2>target sentence #2</seg>
<seg id=3>target sentence #3</seg>
<seg id=4>target sentence #4</seg>
</document>

The inputs to our system should be processed as (suppose that 2 preceding source sentences are used as context):

==== train.src ==== (source corpus)
source sentence #1
source sentence #2
source sentence #3
source sentence #4

==== train.ctx ==== (context corpus)
(the first line is empty)
source sentence #1
source sentence #1 source sentence #2 (there is only a space between the two sentence)
source sentence #2 source sentence #3

==== train.trg ==== (target corpus)
target sentence #1
target sentence #2
target sentence #3
target sentence #4

document-transformer's People

Contributors

glaceon31 avatar xc-kiwiberry avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

document-transformer's Issues

数据里的预处理

您好:
我是北京大学的一名学生,正在研究document_nmt这部分,想请问您,如果方便的话,是否可以提供论文中提到的数据集呢?此外,想请问下,对于数据集的预处理部分,可以开放看看么,对于中文该进行哪些预处理呢~
期待您的回复~
祝好~

loss为负数

你好,按照readme的流程,我进行了如下操作:

  1. 利用句子级平行语料及文档级平行语料训练得到了基本的transformer模型;这里我使用的是之前训练好的t2t里面的transformer。
  2. 使用文档级平行语料训练一个虚拟的context_transformer模型;
  3. 将第一步得到的基本transformer模型merge到第二步的虚拟模型中,以初始化context模型;
  4. 训练context模型。

其中2,3,4步我都只使用了文档级平行语料。第1步的词表有经过bpe,第2,3,4步的词表相同,都是使用了bpe。
另外,thumt/data/datasey.py 在 python2 中报了错。

params.mapping["target"][params.unk]

报了一个KeyError的错,我直接改成

default_value=-1

所以我不太确定是因为词表不同,导致 loss 出现负数;还是因为我改了代码导致的。

另外,我想请教一下,default_value 怎么修改比较好。

为了方便作者可以定位我的错误,现我把具体的训练脚本也附上:

python2 thumt/bin/trainer_ctx.py \
--input corpus/train.en corpus/train.zh \
--context corpus/train.ctx.en \
--vocabulary vocab/vocab.en vocab/vocab.zh \
--output models/dummy \
--model contextual_transformer \
--parameters train_steps=1

python2 thumt/scripts/combine_add.py \
--model models/dummy \
--part models/transformer/model.ckpt-300000 \
--output models/train/

python2 thumt/bin/trainer_ctx.py \
--input corpus/train.en corpus/train.zh \
--context corpus/train.ctx.en \
--output models/sentence_doc \
--vocabulary vocab/vocab.en vocab/vocab.zh \
--model contextual_transformer \
--parameters start_steps=95000,num_context_layers=1,batch_size=6250,train_steps=100000,save_checkpoint_steps=5000,keep_checkpoint_max=50,beam_size=5

跪求大神可以看到我的留言,帮忙看一下是否我哪里的操作不当导致训练过程出错,万分感谢!

Dataset

How do I get the dataset, as you mentioned for training the dummy improved transformer model? and then how do I make the context corpus? Will it be a data with multiple documents, converted to sentence1,sentence2....sentence n of each document with a delimiter?

验证集上的 context 参数在哪里设置?

论文中提到模型的训练分两步:

  1. 先用全部训练集(包括句子级和文档级的平行句对)训一个transformer模型。
  2. 固定第一步transformer模型的参数,用文档级的平行句对单独训context input模块的参数。

我在做第2步的时候有两个疑问:

  1. trainer_ctx.py 中的参数 --input 包括source和target语料 --context参数论文中说用target端的语料效果好些。在这套代码里设置为source端的语料也是可以的吧?另外,这个context语料是不是和input参数中的source或target语料是一样的?

  2. 在训练contextual_transformer模型期间,我想看验证集上的效果,验证集的context参数如何指定呢?

使用文档平行语料训练第二步出错

@Glaceon31
你好,按照readme的流程,我进行了如下操作:

  1. 利用句子级平行语料及文档级平行语料训练得到了基本的transformer模型;
  2. 使用文档级平行语料训练一个虚拟的context_transformer模型;
  3. 将第一步得到的基本transformer模型merge到第二步的虚拟模型中,以初始化context模型;
  4. 训练context模型。

其中2,3,4步我都只使用了文档级平行语料,词表和第1步的相同,没有改变。但是我在第四步时,出现了在读取features时找不到context的错误,具体的报错信息如下:

INFO:tensorflow:Restoring parameters from models/en2cn_t2t/sentence_doc/model.ckpt-95001
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:loss = 4.9753056, step = 95001
INFO:tensorflow:Saving checkpoints for 95002 into models/en2cn_t2t/sentence_doc/model.ckpt.
INFO:tensorflow:Saving checkpoints for 95002 into models/en2cn_t2t/sentence_doc/model.ckpt.
INFO:tensorflow:Validating model at step 95002
building context graph
use self attention
building encoder graph
building context graph
use self attention
building encoder graph
building context graph
use self attention
building encoder graph
building context graph
use self attention
building encoder graph
Traceback (most recent call last):
File "thumt/bin/trainer_ctx.py", line 431, in
main(parse_args())
File "thumt/bin/trainer_ctx.py", line 427, in main
sess.run(ops["train_op"])
File "/usr/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 546, in run
run_metadata=run_metadata)
File "/usr/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1022, in run
run_metadata=run_metadata)
File "/usr/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1113, in run
raise six.reraise(*original_exc_info)
File "/usr/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1098, in run
return self._sess.run(*args, **kwargs)
File "/usr/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1178, in run
run_metadata=run_metadata))
File "/disk1/wangfang/transformer/document-level/context-aware/thumt/Document-Transformer/thumt/utils/hooks.py", line 266, in after_run
self._session_config)
File "/disk1/wangfang/transformer/document-level/context-aware/thumt/Document-Transformer/thumt/utils/hooks.py", line 142, in _evaluate
predictions = eval_fn(placeholders)
File "thumt/bin/trainer_ctx.py", line 405, in
[model.get_inference_func()], f, params
File "/disk1/wangfang/transformer/document-level/context-aware/thumt/Document-Transformer/thumt/utils/inference.py", line 289, in create_inference_graph
states.append(model_fn0)
File "/disk1/wangfang/transformer/document-level/context-aware/thumt/Document-Transformer/thumt/models/contextual_transformer.py", line 500, in encoding_fn
context_output, encoder_output = encoding_graph(features, "infer", params)
File "/disk1/wangfang/transformer/document-level/context-aware/thumt/Document-Transformer/thumt/models/contextual_transformer.py", line 260, in encoding_graph
ctx_seq = features["context"]
KeyError: 'context'

为了方便作者可以定位我的错误,现我把具体的训练脚本也附上:

  • 第2步的虚拟训练脚本:

python thumt/bin/trainer_ctx.py
--input corpus/cn_en/trainData/doc_total.sub.en corpus/cn_en/trainData/doc_total.sub.cn
--context corpus/cn_en/trainData/doc_total.sub.en.ctx
--output models/en2cn_t2t/sentence_doc/dummy
--vocabulary corpus/cn_en/trainData/vocab.en corpus/cn_en/trainData/vocab.cn
--model contextual_transformer
--parameters train_steps=1

  • 第3步的combine脚本:

python thumt/scripts/combine_add.py
--model models/en2cn_t2t/sentence_doc/dummy
--part models/en2cn_t2t/sentence_only
--output models/en2cn_t2t/sentence_doc

  • 第4步的真正训练脚本:

python thumt/bin/trainer_ctx.py
--input corpus/cn_en/trainData/doc_total.sub.en corpus/cn_en/trainData/doc_total.sub.cn
--context corpus/cn_en/trainData/doc_total.sub.en.ctx
--output models/en2cn_t2t/sentence_doc
--vocabulary corpus/cn_en/trainData/vocab.en corpus/cn_en/trainData/vocab.cn
--model contextual_transformer
--parameters start_steps=95000,num_context_layers=1,batch_size=6250,device_list=[0,1,2,3],train_steps=1000000,save_checkpoint_steps=5000,keep_checkpoint_max=50,beam_size=5,eval_steps=5000

跪求大神可以看到我的留言,帮忙看一下是否我哪里的操作不当导致训练过程出错,万分感谢!

关于IWSLT Fr-EN的参数设置

您好,
我对你们这篇工作非常感兴趣,最近在复现这篇工作在iwslt, fr-en上的结果,但是遇到了一些问题,SAN-base model, 我跑的是34.89, 单句context, 两句context是35.55,35.57, 和paper report的结果相差有点远。
所以想跟您确认以下参数,这是我的设置:
1. share_target_embedding_and_softmax=true, 是否要开呢?
2. attention_dorpout=0.15,relu_dropuout=0.1, residual_dropout=0.1;还是使用默认的设置?
3. learning_rate=1.0,warmup_step=8000,
此外
layer_norm需要改成 preprocess吗?
第二步中:
1. learning_rate调1.0,还是0.1有什么区别吗? 我看combin_add的脚本好像也读取了adam的参数,应该是属于继续训练的吧,那lr应该是跟着pre-train model的?而且30000 step也远超warmup_step,属于下降区间中,我觉得应该不会有影响才是。 但是我的结果,lr=1 比lr=0.1要好一些。
2. 因为没有eval的结果做参考,模型改怎么挑选呢?
save_checkpint_steps=1000, 这个你们设置是多少呢?感觉很容易错过模型

期待您的回复
谢谢

Is this available in python 3 version?

Since I don't have all the deep learning and other libraries not installed in python 2 I can't it.
I also tried to run it on Google colab using python 2 kernal but there I am not able to add the path of THUMT folder to the python path.
Can you suggest some remedy to it?

Thanks a lot.

about translate and decode

When I use context-level model to decoding and test, is parameter MODEL-PATH a folder that include all models or a model file that is one model? if it is former,when decoding , Do it use averge checkpoint base on all models in the folder to decoding ?

How to set parameters when use 940k ch-en corpus to train?

Comparing with training of using corpus of 2M ch-en, when I use corpus of 940k ch-en to train model, what parameters should I use ?I have tried to use batch_size=25k, learning rate=1.0 and batch_size=25k, learning rate=0.5,but just got bleu=3.6 and bleu=33.7 in MT06,respectively

Optimal training parameters

Hello,
I am training your model on a combination on IWSLT, Europarl and News Commentary French, Russian and German data. Do you have any tips on what should be the optimal training configuration parameters outside of what is already contained in the README?

Thanks!
Barbara

abut train corpus format

hello~
When I use this code to training a model, What format should be processed for the source corpus, the target corpus, the context corpus? are they tokenized and BPE? Could u send me a demo about it?
Thank u very mach.

context corpus

您好,现目前我已将Transformer模型训练好,需要进一步按照readme文件中的指令训练,但是存在的问题是不知道source corpus、target corpus以及context corpus文件里面的具体格式,如果过可以的话,是否可以提供论相关文件呢,期待您的回复,祝好,我的邮箱为[email protected]

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.