Giter Site home page Giter Site logo

bart-closed-book-qa's Introduction

BART version of closed-book QA

This is a BART version of sequence-to-sequence model for open-domain QA in a closed-book setup, based on PyTorch and Huggingface's Transformers.

The model is a sequence-to-sequence model that takes a question as an input and outputs the answer, without reading any external resource (e.g. passages). Please refer to Roberts et al., 2020, How Much Knowledge Can You Pack Into the Parameters of a Language Model? to learn more about closed-book QA setup and the original model based on T5. Their code and model checkpoints are available here.

The model is based on BART-large. Please refer to Lewis et al., ACL 2020, BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension to learn more about BART.

We experiment with Natural Questions open-domain data (NQ-open), but the code should work on any QA data with question-answer pairs.

Requirement

This code is tested on Python 3.6.9.

Install PyTorch and Transformers:

pip3 install torch==1.10.0+cu113 torchvision==0.11.1+cu113 torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
pip install transformers==4.3.3

Download NQ-open data:

chmod +x download_data.sh; ./download_data.sh

Training

python cli.py --do_train --output_dir out/nq-bart-closed-qa \
        --train_file data/nqopen-train.json \
        --predict_file data/nqopen-dev.json \
        --train_batch_size ${train_bs} \
        --predict_batch_size ${test_bs} \
        --append_another_bos

The script will save the log and the best checkpoint inside out/nq-bart-closed-qa.

Other useful commands (please refer to cli.py for the full list):

  • eval_period: interval to evaluate on the dev data
  • verbose: print a progress bar
  • debug: train and evaluate on a subset of the dev data for debugging purposes

You can use train_batch_size and predict_batch_size depending on the gpu availability. With one 16GB gpu, you can use train_batch_size=64, predict_batch_size=64. Our model that we reports the result below was trained with train_batch_size=1024, predict_batch_size 256 using eight 32GB gpus. Training took roughly 34 hours.

Note:

  • This script saves the pre-tokenized data in data/ once question-answer pairs are tokenized for the first time.
  • The model gives the best result when prepending extra BOS token (--append_another_bos).
  • Inference on multi-gpus is not working for now; we will update the code once it is fixed.

Inference

python cli.py --do_predict --output_dir out/nq-bart-closed-qa \
        --predict_file data/nqopen-dev.json \
        --predict_batch_size ${test_bs} \
        --append_another_bos --prefix dev_
python cli.py --do_predict --output_dir out/nq-bart-closed-qa \
        --predict_file data/nqopen-test.json \
        --predict_batch_size ${test_bs} \
        --append_another_bos --prefix test_

It will save the prediction file as out/nq-bart-closed-qa/{dev|test}_predictions.json.

Result

The final Exact Match score we get is 25.05 on the dev data and 24.10 on the test data.

We made the best model checkpoint and the predictions on the dev/test data available.

Note that T5-based model gets 27.0, 29.8, 32.1 and 34.5 on the test set with Base, Large, 3B and 11B, respectively, based on the original paper. Several factors could lead to the performance gaps: (i) T5 has a larger number of parameters and trained on a larger set of data and (ii) the original paper includes the dev data for training, whereas our codebase only trains the model on the train data and uses the dev data for choosing the best checkpoint. We also did not perform any hyperparamter tuning, as our goal is to provide the basic codebase rather than to achieve the best possible performance; we leave it for the future work.

Note: that the original paper includes ablations that exclude supervised data for T5 pretraining, and reports comparable (or better) numbers: see Appendix C of the original paper for the details!

Contact

Please email Sewon Min or write a Github issue for any question.

bart-closed-book-qa's People

Contributors

hepengfe avatar idocal avatar liangtaiwan avatar shmsw25 avatar

Stargazers

 avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.