Giter Site home page Giter Site logo

jasonwu0731 / trade-dst Goto Github PK

View Code? Open in Web Editor NEW
393.0 22.0 115.0 850 KB

Source code for transferable dialogue state generator (TRADE, Wu et al., 2019). https://arxiv.org/abs/1905.08743

Python 95.81% Perl 4.19%
machine-learning natural-language-processing dialogue seq2seq multi-domain dialogue-state-tracking

trade-dst's Introduction

TRADE Multi-Domain and Unseen-Domain Dialogue State Tracking

This is the PyTorch implementation of the paper: Transferable Multi-Domain State Generator for Task-Oriented Dialogue Systems. Chien-Sheng Wu, Andrea Madotto, Ehsan Hosseini-Asl, Caiming Xiong, Richard Socher and Pascale Fung. ACL 2019. [PDF]

This code has been written using PyTorch >= 1.0. If you use any source codes or datasets included in this toolkit in your work, please cite the following paper. The bibtex is listed below:

@InProceedings{WuTradeDST2019,
  	author = "Wu, Chien-Sheng and Madotto, Andrea and Hosseini-Asl, Ehsan and Xiong, Caiming and Socher, Richard and Fung, Pascale",
  	title = 	"Transferable Multi-Domain State Generator for Task-Oriented Dialogue Systems",
  	booktitle = 	"Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
  	year = 	"2019",
  	publisher = "Association for Computational Linguistics"
}

Abstract

Over-dependence on domain ontology and lack of knowledge sharing across domains are two practical and yet less studied problems of dialogue state tracking. Existing approaches generally fall short in tracking unknown slot values during inference and often have difficulties in adapting to new domains. In this paper, we propose a Transferable Dialogue State Generator (TRADE) that generates dialogue states from utterances using a copy mechanism, facilitating knowledge transfer when predicting (domain, slot, value) triplets not encountered during training. Our model is composed of an utterance encoder, a slot gate, and a state generator, which are shared across domains. Empirical results demonstrate that TRADE achieves state-of-the-art joint goal accuracy of 48.62% for the five domains of MultiWOZ, a human-human dialogue dataset. In addition, we show its transferring ability by simulating zero-shot and few-shot dialogue state tracking for unseen domains. TRADE achieves 60.58% joint goal accuracy in one of the zero-shot domains, and is able to adapt to few-shot cases without forgetting already trained domains.

Model Architecture

The architecture of the proposed TRADE model, which includes (a) an utterance encoder, (b) a state generator, and (c) a slot gate, all of which are shared among domains. The state generator will decode J times independently for all the possible (domain, slot) pairs. At the first decoding step, state generator will take the j-th (domain, slot) embeddings as input to generate its corresponding slot values and slot gate. The slot gate predicts whether the j-th (domain, slot) pair is triggered by the dialogue.

Data

Download the MultiWOZ dataset and the processed dst version.

❱❱❱ python3 create_data.py

An example of multi-domain dialogue state tracking in a conversation. The solid arrows on the left are the single-turn mapping, and the dot arrows on the right are multi-turn mapping. The state tracker needs to track slot values mentioned by the user for all the slots in all the domains.

Dependency

Check the packages needed or simply run the command

❱❱❱ pip install -r requirements.txt

If you run into an error related to Cython, try to upgrade it first.

❱❱❱ pip install --upgrade cython

Multi-Domain DST

Training

❱❱❱ python3 myTrain.py -dec=TRADE -bsz=32 -dr=0.2 -lr=0.001 -le=1

Testing

❱❱❱ python3 myTest.py -path=${save_path}
  • -bsz: batch size
  • -dr: drop out ratio
  • -lr: learning rate
  • -le: loading pretrained embeddings
  • -path: model saved path

[2019.08 Update] Now the decoder can generate all the (domain, slot) pairs in one batch at the same time to speedup decoding process. You can set flag "--parallel_decode=1" to decode all (domain, slot) pairs in one batch.

Unseen Domain DST

Zero-Shot DST

Training

❱❱❱ python3 myTrain.py -dec=TRADE -bsz=32 -dr=0.2 -lr=0.001 -le=1 -exceptd=${domain}

Testing

❱❱❱ python3 myTest.py -path=${save_path} -exceptd=${domain}
  • -exceptd: except domain selection, choose one from {hotel, train, attraction, restaurant, taxi}.

Few-Shot DST with CL

Training Naive

❱❱❱ python3 fine_tune.py -bsz=8 -dr=0.2 -lr=0.001 -path=${save_path_except_domain} -exceptd=${except_domain}

EWC

❱❱❱ python3 EWC_train.py -bsz=8 -dr=0.2 -lr=0.001 -path=${save_path_except_domain} -exceptd=${except_domain} -fisher_sample=10000 -l_ewc=${lambda}

GEM

❱❱❱ python3 GEM_train.py -bsz=8 -dr=0.2 -lr=0.001 -path={save_path_except_domain} -exceptd=${except_domain}
  • -l_ewc: lambda value in EWC training

Other Notes

  • We found that there might be some variances in different runs, especially for the few-shot setting. For our own experiments, we only use one random seed (seed=10) to do the experiments reported in the paper. Please check the results for average three runs in our ACL presentation.

Bug Report

Feel free to create an issue or send email to [email protected]

License

copyright 2019-present https://jasonwu0731.github.io/

Permission is hereby granted, free of charge, to any person obtaining a copy 
of this software and associated documentation files (the "Software"), to deal 
in the Software without restriction, including without limitation the rights 
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 
copies of the Software, and to permit persons to whom the Software is 
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all 
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 
SOFTWARE.

trade-dst's People

Contributors

jasonwu0731 avatar naveross avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

trade-dst's Issues

Killed training during evaluation

Dear Sir,
I am reproducing results from TRADE, the training process is killed during the evaluation. One more issue that I am facing is that during the testing phase the error for "list index out of range" occurs. May you please guide in this regard. Thanking you.

Missing turn in dialogue acts file

When working with the dialogue_acts.json file (data/multi-woz/dialogue_acts.json), I found out that the first turn in each dialogue (turn_id=0) is nowhere to be found!

Modifications required for running experiments on MultiWOZ2.1

Hi,

I am wondering what all files need to be changed for evaluating on MultiWOZ2.1 dataset.

Based on preliminary analysis I found the following files:

  1. create_data.py
  2. utils/utils_multiWOZ_DST.py and
  3. evaluation part in models/TRADE.py.
    Please let me know if there are any other files that need to be modified.

Also, incase the changes already exist, please point me to them.

Thanks.

CUDA out of memory

When I run this command python3 myTrain.py -dec=TRADE -bsz=32 -dr=0.2 -lr=0.001 -le=1, I keep getting

RuntimeError: CUDA out of memory. Tried to allocate 1.06 GiB (GPU 0; 15.75 GiB total capacity; 13.09 GiB already allocated; 546.88 MiB free; 1.08 GiB cached)

It only uses a single GPU card although the machine has 4 cards.
How much GPU memory do I need to train the model? Is there a way to use multi-GPU cards?

Thanks,
Hao

How to generate words beyond the vocabulary ?

As the paper says in the end of Section 2.2 State Generator

Note that due to Eq (2), our model is able to generate words even if they are not pre-defined in the vocabulary.

but as the code shows, p_context_ptr = torch.zeros(p_vocab.size()), does that mean it restrict the words must be in pre-defined vocabulary ?

For example, the vocabulary size is 100, but we have a word id in story is 103.

Thanks a lot

Unexpected results in the multi-domain DST experiment

When I run the multi-domain DST experiment , the result is as follows:
image

As you can see, the joint and turn acc is quite low although the loss is reasonable (i.g. 0.5). Is it a normal state? Could you help me with the unexpected results? Besides, when I run the unseen-domain DST experiment, it works well.

捕获

Few short training error

Hi @andreamad8 @jasonwu0731 ,

I am facing a path error, dont know if the command format is correct for "Few short DST"
here is the error log..

image

here is my command

python3 fine_tune.py -bsz=8 -dr=0.2 -lr=0.001 -path=${root/trade-dst/save/TRADE-multiwozdst/HDD400BSZ32DR0.2ACC-0.4455} -exceptd=${taxi}

Python3 Incompatability

The README indicates a python3 interpreter should be used to run the repo but it is not actually python3 compatible.

What is MD example save path?

The README.md states:

Training

❱❱❱ python3 myTrain.py -dec=TRADE -bsz=32 -dr=0.2 -lr=0.001 -le=1

Testing

❱❱❱ python3 myTest.py -path=${save_path}

Based on the training command, what should save_path be set to?

I ran the myTrain.py command successfully. But the testing command fails if I don't supply a path:

$ python myTest.py
{'dataset': 'multiwoz', 'task': 'dst', 'path': None, 'sample': None, 'patience': 6, 'earlyStop': 'BLEU', 'all_vocab': 1, 'imbalance_sampler': 0, 'data_ratio': 100, 'unk_mask': 1, 'batch': None, 'run_dev_testing': 0, 'vizualization': 0, 'genSample': 0, 'evalp': 1, 'addName': '', 'eval_batch': 0, 'use_gate': 1, 'load_embedding': 0, 'fix_embedding': 0, 'parallel_decode': 0, 'decoder': None, 'hidden': 400, 'learn': None, 'drop': None, 'limit': -10000, 'clip': 10, 'teacher_forcing_ratio': 0.5, 'lambda_ewc': 0.01, 'fisher_sample': 0, 'all_model': False, 'domain_as_task': False, 'run_except_4d': 1, 'strict_domain': False, 'except_domain': '', 'only_domain': ''}
Traceback (most recent call last):
  File "myTest.py", line 8, in <module>
    directory = args['path'].split("/")
AttributeError: 'NoneType' object has no attribute 'split'

The target value vocabulary

It seems that the vocabulary of the slot values are not added to the vocabulary, E.g., dontcare. Only the system and user utterance tokens are added. How does the model generate the tokens which are not from the dialog history context?

Problem of the train_dial.json

Hi,

Thanks for sharing the codes. In the train_dials.json, for each dialogue, the last utterance from the agent is directly dropped, why?
Looking for your reply.

Best,
Mian

BadZipFile Error when runing myTrain.py, using python3.7

$ python3 myTrain.py -dec=TRADE -bsz=32 -dr=0.2 -lr=0.001 -le=1

[Warning] Using hidden size = 400 for pretrained word embedding (300 + 100)...
{'dataset': 'multiwoz', 'task': 'dst', 'path': None, 'sample': None, 'patience': 6, 'earlyStop': 'BLEU', 'all_vocab': 1, 'imbalance_sampler': 0, 'data_ratio': 100, 'unk_mask': 1, 'batch': 32, 'run_dev_testing': 0, 'vizualization': 0, 'genSample': 0, 'evalp': 1, 'addName': '', 'eval_batch': 0, 'use_gate': 1, 'load_embedding': 1, 'fix_embedding': 0, 'parallel_decode': 0, 'decoder': 'TRADE', 'hidden': 400, 'learn': 0.001, 'drop': 0.2, 'limit': -10000, 'clip': 10, 'teacher_forcing_ratio': 0.5, 'lambda_ewc': 0.01, 'fisher_sample': 0, 'all_model': False, 'domain_as_task': False, 'run_except_4d': 1, 'strict_domain': False, 'except_domain': '', 'only_domain': ''}
folder_name save/TRADE-multiwozdst/
Reading from data/train_dials.json
domain_counter {'hotel': 3381, 'train': 3103, 'attraction': 2717, 'restaurant': 3813, 'taxi': 1654}
Reading from data/dev_dials.json
domain_counter {'hotel': 416, 'train': 484, 'attraction': 401, 'restaurant': 438, 'taxi': 207}
Reading from data/test_dials.json
domain_counter {'taxi': 195, 'attraction': 395, 'restaurant': 437, 'train': 494, 'hotel': 394}
[Info] Loading saved lang files...
Dumping pretrained embeddings...
Traceback (most recent call last):
File "myTrain.py", line 22, in
train, dev, test, test_special, lang, SLOTS_LIST, gating_dict, max_word = prepare_data_seq(True, args['task'], False, batch_size=int(args['batch']))
File "/home//trade-dst/utils/utils_multiWOZ_DST.py", line 426, in prepare_data_seq
dump_pretrained_emb(lang.word2index, lang.index2word, emb_dump_path)
File "/home/
/trade-dst/utils/utils_multiWOZ_DST.py", line 362, in dump_pretrained_emb
embeddings = [GloveEmbedding(), KazumaCharEmbedding()]
File "/home//miniconda3/lib/python3.7/site-packages/embeddings/glove.py", line 48, in init
self.load_word2emb(show_progress=show_progress)
File "/home/
/miniconda3/lib/python3.7/site-packages/embeddings/glove.py", line 65, in load_word2emb
with zipfile.ZipFile(fin_name) as fin:
File "/home//miniconda3/lib/python3.7/zipfile.py", line 1258, in init
self._RealGetContents()
File "/home/
/miniconda3/lib/python3.7/zipfile.py", line 1325, in _RealGetContents
raise BadZipFile("File is not a zip file")
zipfile.BadZipFile: File is not a zip file

why not consider the request accuracy?

Hi, @jasonwu0731 :

I have a question about your nice work: why not eval the request accu in the experiment? In the original dataset, it consist not only inform slots (as states in your paper) but also request slots(domain + slots). It seems to be achieved by only using the slot gates without the state generator. So why not report the request accuracy?

Lamda Value for Training EWC model

hi @andreamad8 @jasonwu0731,

I am trying to train EWC model.
python3 EWC_train.py -bsz=8 -dr=0.2 -lr=0.001 -path=${save_path_except_domain} -exceptd=${except_domain} -fisher_sample=10000 -l_ewc=${lambda}

What should be the value of -l_ewc=${lambda}

In the paper it is mentioned "For EWC, we set different
values of lamda for all the domains, and the optimal
value is selected using the validation set" that doesnt make sense to me, may be I am missing something.

Can you please give me the lamda value that you used..?

About the choice of the hyperparameters

Hi Jason,

When i training for the unseen domain zero-shot, the results on the four domains are not desirable sometimes, e.g. except the attraction domain. Is the model not very stable sometimes and need to tune the hyperparameters? Could you give me some advice for reproducing?

More specifically, i train for 13 epochs except the attraction domain, the joint and turn acc are 9.2 and 87.5. Besides, for the BM excpet hotel, the fine-tuned results on 4 domains and new domain are (18.79 89.13) and (26.8 77.95).

Thanks for your code release!

License info

Hi,

I've found MIT license notice at README.
Can you let me get your copyright notice?

Thanks!

Typo in Table 1

In Table 1 of your paper, "leave by" in the Taxi domain should be "leave at." ("leave by" cannot be found in other domains)

the emb18311.json file

What is the file emb18311.json under the data file?can you help me,thank you in advance

'EncoderRNN' object is not callable

TRADE.py line 135 - you trying to call an object. Must be you mean not
encoded_outputs, encoded_hidden = self.encoder(story.transpose(0, 1), data['context_len'])
but
encoded_outputs, encoded_hidden = self.encoder.forward(story.transpose(0, 1), data['context_len'])
?
I wonder how this error exists in all versions of file TRADE.py and how it's even worked at all

Evaluation

Thanks a lot @jasonwu0731 for sharing your code with the community. I have a quick question. I get much lower performance following the same steps you provided in the read me. Is there any thing I need to do?

Again thanks so much!

Issue with Equation 2 in paper

How did you get R^|V| vector with P_{jk}^history? From what I understood, (1-P_jk^gen )is a scalar and P_{jk}^history is a vector of size |X_t|. How does that add to P_{jk}^vocab?

Error in requirements.txt

Hi @andreamad8 @jasonwu0731 ,

In requirements file, mkl-fft==1.0.12 package is mentioned and this version is not available now. you can see my error promt.
Could not find a version that satisfies the requirement mkl-fft==1.0.12 (from -r requirements.txt (line 11)) (from versions: 1.0.0.17, 1.0.2, 1.0.6) No matching distribution found for mkl-fft==1.0.12 (from -r requirements.txt (line 11))
Actually, I build a docker image to reproduce this system in remote server, and it is impossible for me to edit this file. instead, you could just mention package name like "mkl-fft" or an explicit command
pip3 install mkl-fft.

Can you please update the requirement file accordingly so I can build an image on that..

Feature Request: Multi-GPU version

Considering the training time on a single GPU for the demo, it would be useful to have a version of trade-dst which detects if multiple GPUs are available via torch.cuda.device_count and makes use of them using the torch.nn.DataParallel method.

About License

May I know about the license?
We intend to use your code for one of our projects. Kindly mention the license so that we can understand if its allowed the usage.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.