Giter Site home page Giter Site logo

devjwsong / recosa-dialogue-generation-pytorch Goto Github PK

View Code? Open in Web Editor NEW
21.0 2.0 2.0 457 KB

The PyTorch implementation of ReCoSa(the Relevant Contexts with Self-attention) for dialogue generation using the multi-head attention and GRU.

Home Page: https://songstudio.info/tech/tech-34

License: MIT License

Python 98.01% Shell 1.99%
pytorch transformer natural-language-processing natural-language-generation nlp nlg multiturn dialogue-generation

recosa-dialogue-generation-pytorch's Introduction

recosa-dialogue-generation-pytorch

This is a multi-turn chatbot project using the ReCoSa structure introduced in ReCoSa: Detecting the Relevant Contexts with Self-Attention for Multi-turn Dialogue Generation[1].

The model detects the relevant dialogue histories with the self-attention mechanism, which uses the history-level transformer encoder, not the word-level.

The details of structure are as follows.

The description of the ReCoSa structure.



Arguments

Arguements for training

Argument Type Description Default
seed int The random seed number for training. 0
data_dir str The name of the parent directory where the whole data files are stored. "data"
task str The name of the specific task(dataset) name. ("daily_dialog", "empathetic_dialogues", "persona_chat", "blended_skill_talk") YOU MUST SPECIFY
pad_token str The pad token. "<pad>"
bos_token str The bos token. "<bos>"
eos_token str The eos token. "<eos>"
sp1_token str The speaker1 token. "<sp1>"
sp2_token str The speaker2 token. "<sp2>"
learning_rate float The initial learning rate. 5e-4
warmup_ratio float The warmup step ratio. 0.0
max_grad_norm float The max value for gradient clipping. 1.0
train_batch_size int The batch size for training. 32
eval_batch_size int The batch size for evaluation. 8
num_workers int The number of workers for data loading. 0
num_epochs int The number of training epochs. 10
src_max_len int The max length of each input utterance. 128
max_turns int The max number of utterances to be included. 10
trg_max_len int The max length of a target response. 128
num_heads int The number of heads for multi-head attention. 8
num_encoder_layers int The number of layers in the utterance-level encoder. 6
num_gru_layers int The number of layers in the word-level encoder. 2
gru_dropout float The dropout rate of the word-level encoder. 0.1
num_decoder_layers int The number of layers in the decoder. 2
d_model int The hidden size inside of the transformer module. 768
d_pos int The hidden size of the positional embedding. 256
d_ff int The intermediate hidden size of each feed-forward layer. 2048
dropout int The dropout rate of the transformer modules. 0.1
gpus str The indices of GPUs to use. (This should be a string which contains index values separated with commas. ex: "0, 1, 2, 3") "0"
num_nodes int The number of machine. 1

Arguments for inference

Argument Type Description Default
pad_token str The pad token. "<pad>"
bos_token str The bos token. "<bos>"
eos_token str The eos token. "<eos>"
sp1_token str The speaker1 token. "<sp1>"
sp2_token str The speaker2 token. "<sp2>"
src_max_len int The max length of each input utterance. 128
max_turns int The max number of utterances to be included. 10
trg_max_len int The max length of a target response. 128
gpus str The indices of GPUs to use. (When inferencing, only a single GPU is used. If you try to set mutiple GPUs, the assertion error will be raised.) "0"
top_p float The top-p value for nucleus sampling decoding. 0.9
end_command str The command to stop the conversation when inferencing. "Abort!"
log_idx int The index of a lightning log directory which contains the checkpoints to use. YOU MUST SPECIFY
ckpt_file str The full name of the trained checkpoint for inference. YOU MUST SPECIFY


Datasets

By default, I propose the codes for downloading the datasets and preprocessing.

There are 4 types of the default datasets as follows.


  • DailyDialog[2]
  • EmpatheticDialogues[3]
  • Persona-Chat[4]
  • BlendedSkillTalk[5]

For this project, we use the ParlAI[6] platform made by Facebook, to download the datasets we need.

This repository also provides a useful parsing script for each downloaded data.

The detailed instruction for using ParlAI can be found in the official document[7].



How to run

  1. Install all required packages.

    pip install -r requirements.txt

  2. Clone the official ParlAI repository in your project directory.

    git clone https://github.com/facebookresearch/ParlAI.git && cd ParlAI

  3. Setup ParlAI and download the data.

    python setup.py develop
    parlai display_data --task dailydialog
    parlai display_data --task empathetic_dialogues
    parlai display_data --task personachat
    parlai display_data --task blended_skill_talk
    cd ..

    ParlAI has a lot of useful dialogue corpus beside 4 datasets mentioned above.

    You can check the list of the tasks it supports in the document.


  4. Parse each data and save them info *.pickle and *.json files. (After parsing, you can delete ParlAI repo.)

    python src/parse_data.py --data_dir=DATA_DIR
    • --data_dir: The name of the parent directory where the whole data files are stored. (default: "data")

  5. Run the following command to train the model.

    sh exec_train.sh

  6. Run below command to conduct an inference with the trained model.

    sh exec_infer.sh


References

[1] Zhang, H., Lan, Y., Pang, L., Guo, J., & Cheng, X. (2019). Recosa: Detecting the relevant contexts with self-attention for multi-turn dialogue generation. arXiv preprint arXiv:1907.05339. (https://arxiv.org/abs/1907.05339)

[2] Li, Y., Su, H., Shen, X., Li, W., Cao, Z., & Niu, S. (2017). Dailydialog: A manually labelled multi-turn dialogue dataset. arXiv preprint arXiv:1710.03957. (https://arxiv.org/abs/1710.03957)

[3] Rashkin, H., Smith, E. M., Li, M., & Boureau, Y. L. (2018). Towards empathetic open-domain conversation models: A new benchmark and dataset. arXiv preprint arXiv:1811.00207. (https://arxiv.org/abs/1811.00207)

[4] Zhang, S., Dinan, E., Urbanek, J., Szlam, A., Kiela, D., & Weston, J. (2018). Personalizing dialogue agents: I have a dog, do you have pets too?. arXiv preprint arXiv:1801.07243. (https://arxiv.org/abs/1801.07243)

[5] Smith, E. M., Williamson, M., Shuster, K., Weston, J., & Boureau, Y. L. (2020). Can You Put it All Together: Evaluating Conversational Agents' Ability to Blend Skills. arXiv preprint arXiv:2004.08449. (https://arxiv.org/abs/2004.08449)

[6] Miller, A. H., Feng, W., Fisch, A., Lu, J., Batra, D., Bordes, A., ... & Weston, J. (2017). Parlai: A dialog research software platform. arXiv preprint arXiv:1705.06476. (https://arxiv.org/abs/1705.06476)

[7] https://parl.ai/docs/index.html

recosa-dialogue-generation-pytorch's People

Contributors

devjwsong avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

recosa-dialogue-generation-pytorch's Issues

Database not available

file in path src/data_process.py not exist to run command src/data_process.py --config_path. how can i use default dataset

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.