Giter Site home page Giter Site logo

mt-dnn's Introduction

Multi-Task Deep Neural Networks for Natural Language Understanding

This PyTorch package implements the Multi-Task Deep Neural Networks (MT-DNN) for Natural Language Understanding, as described in:

Xiaodong Liu*, Pengcheng He*, Weizhu Chen and Jianfeng Gao
Multi-Task Deep Neural Networks for Natural Language Understanding
arXiv version
*: Equal contribution

Xiaodong Liu, Pengcheng He, Weizhu Chen and Jianfeng Gao
Improving Multi-Task Deep Neural Networks via Knowledge Distillation for Natural Language Understanding
arXiv version

Quickstart

Setup Environment

Install via pip:

  1. python3.6
    Reference to download and install : https://www.python.org/downloads/release/python-360/

  2. install requirements
    > pip install -r requirements.txt

Use docker:

  1. Pull docker
    > docker pull allenlao/pytorch-mt-dnn:v0.1

  2. Run docker
    > docker run -it --rm --runtime nvidia allenlao/pytorch-mt-dnn:v0.1 bash
    Please refer to the following link if you first use docker: https://docs.docker.com/

Train a toy MT-DNN model

  1. Download data
    > sh download.sh
    Please refer to download GLUE dataset: https://gluebenchmark.com/

  2. Preprocess data
    > python prepro.py

  3. Training
    > python train.py

Note that we ran experiments on 4 V100 GPUs for base MT-DNN models. You may need to reduce batch size for other GPUs.

GLUE Result reproduce

  1. MTL refinement: refine MT-DNN (shared layers), initialized with the pre-trained BERT model, via MTL using all GLUE tasks excluding WNLI to learn a new shared representation.
    Note that we ran this experiment on 8 V100 GPUs (32G) with a batch size of 32.

    • Preprocess GLUE data via the aforementioned script
    • Training:
      >scripts\run_mt_dnn.sh
  2. Finetuning: finetune MT-DNN to each of the GLUE tasks to get task-specific models.
    Here, we provide two examples, STS-B and RTE. You can use similar scripts to finetune all the GLUE tasks.

    • Finetune on the STS-B task
      > scripts\run_stsb.sh
      You should get about 90.5/90.4 on STS-B dev in terms of Pearson/Spearman correlation.
    • Finetune on the RTE task
      > scripts\run_rte.sh
      You should get about 83.8 on RTE dev in terms of accuracy.

SciTail & SNIL Result reproduce (Domain Adaptation)

  1. Domain Adaptation on SciTail
    >scripts\scitail_domain_adaptation_bash.sh

  2. Domain Adaptation on SNLI
    >scripts\snli_domain_adaptation_bash.sh

TODO

[ ] Release codes/models MT-DNN with Knowledge Distillation.
[ ] Publish pretrained Tensorflow checkpoints.

FAQ

Did you share the pretrained mt-dnn models?

Yes, we released the pretrained shared embedings via MTL which are aligned to BERT base/large models: mt_dnn_base.pt and mt_dnn_large.pt.
To obtain the similar models:

  1. run the >sh scripts\run_mt_dnn.sh, and then pick the best checkpoint based on the average dev preformance of MNLI/RTE.
  2. strip the task-specific layers via scritps\strip_model.py.

Why SciTail/SNLI do not enable SAN?

For SciTail/SNLI tasks, the purpose is to test generalization of the learned embedding and how easy it is adapted to a new domain instead of complicated model structures for a direct comparison with BERT. Thus, we use a linear projection on the all domain adaptation settings.

What is the difference between V1 and V2

The difference is in the QNLI dataset. Please refere to the GLUE official homepage for more details.

Did you fine-tune single task for your GLUE leaderboard submission?

We can use the multi-task refinement model to run the prediction and produce a reasonable result. But to achieve a better result, it requires a fine-tuneing on each task. It is worthing noting the paper in arxiv is a littled out-dated and on the old GLUE dataset. We will update the paper as we mentioned below.

Notes and Acknowledgments

BERT pytorch is from: https://github.com/huggingface/pytorch-pretrained-BERT
BERT: https://github.com/google-research/bert
We also used some code from: https://github.com/kevinduh/san_mrc

How do I cite MT-DNN?

For now, please cite arXiv version:

@article{liu2019mt-dnn,
  title={Multi-Task Deep Neural Networks for Natural Language Understanding},
  author={Liu, Xiaodong and He, Pengcheng and Chen, Weizhu and Gao, Jianfeng},
  journal={arXiv preprint arXiv:1901.11504},
  year={2019}
}

and a new version of the paper will be shared later.

@article{liu2019mt-dnn-kd,
  title={Improving Multi-Task Deep Neural Networks via Knowledge Distillation for Natural Language Understanding},
  author={Liu, Xiaodong and He, Pengcheng and Chen, Weizhu and Gao, Jianfeng},
  journal={arXiv preprint arXiv:1904.09482},
  year={2019}
}

Typo: there is no activation fuction in Equation 2.

Contact Information

For help or issues using MT-DNN, please submit a GitHub issue.

For personal communication related to MT-DNN, please contact Xiaodong Liu ([email protected]), Pengcheng He ([email protected]), Weizhu Chen ([email protected]) or Jianfeng Gao ([email protected]).

mt-dnn's People

Contributors

namisan avatar pidugusundeep avatar brittonwinterrose avatar chenweizhu avatar monomagentaeggroll avatar

Watchers

Jockey Yan avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.