Giter Site home page Giter Site logo

soseki's Introduction

Sōseki

Sōseki is an implementation of an end-to-end question answering (QA) system.

Currently, Sōseki makes use of Binary Passage Retriever (BPR), an efficient passages retrieval model for a large collection of documents. BPR was originally developed to achieve high computational efficiency of the QA system submitted to the Systems under 6GB track in the NeurIPS 2020 EfficientQA competition.

Installation

# Before installation, upgrade pip and setuptools.
$ pip install -U pip setuptools

# Install the PyTorch package.
# You may want to check the install option for your CUDA environment.
# https://pytorch.org/get-started/locally/
$ pip install 'torch==1.11.0'

# Install other dependencies.
$ pip install -r requirements.txt

# Install the soseki package.
$ pip install .
# Or if you want to install it in editable mode:
$ pip install -e .

Note: If you are using a GPU Environment different from CUDA 10.2, you may need to reinstall PyTorch according to the official documentation.

Example Usage

Before you start, you need to download the datasets available on the DPR repository into <DPR_DATASET_DIR>.

We used 4 GPUs with 12GB memory each for the experiments.

1. Build passage database

$ python build_passage_db.py \
    --passage_file <DPR_DATASET_DIR>/wikipedia_split/psgs_w100.tsv \
    --db_file <WORK_DIR>/passages.db \
    --db_map_size 21000000000

2. Train a biencoder

$ python train_biencoder.py \
    --train_file <DPR_DATASET_DIR>/retriever/nq-train.json \
    --val_file <DPR_DATASET_DIR>/retriever/nq-dev.json \
    --output_dir <WORK_DIR>/biencoder \
    --max_question_length 64 \
    --max_passage_length 192 \
    --num_negative_passages 1 \
    --shuffle_hard_negative_passages \
    --shuffle_normal_negative_passages \
    --base_pretrained_model bert-base-uncased \
    --binary \
    --train_batch_size 16 \
    --eval_batch_size 16 \
    --learning_rate 1e-5 \
    --warmup_proportion 0.1 \
    --gradient_clip_val 2.0 \
    --max_epochs 40 \
    --gpus 4 \
    --precision 16 \
    --strategy ddp

3. Build passage embeddings

$ python build_passage_embeddings.py \
    --biencoder_file <WORK_DIR>/biencoder/lightning_logs/version_0/checkpoints/last.ckpt \
    --passage_db_file <WORK_DIR>/passages.db \
    --output_file <WORK_DIR>/passage_embeddings.idx \
    --max_passage_length 192 \
    --batch_size 2048 \
    --device_ids 0 1 2 3

4. Evaluate the retriever and create datasets for reader

$ mkdir <WORK_DIR>/reader_data

$ python evaluate_retriever.py \
    --biencoder_file <WORK_DIR>/biencoder/lightning_logs/version_0/checkpoints/last.ckpt \
    --passage_db_file <WORK_DIR>/passages.db \
    --passage_embeddings_file <WORK_DIR>/passage_embeddings.idx \
    --qa_file <DPR_DATASET_DIR>/retriever/qas/nq-train.csv \
    --output_file <WORK_DIR>/reader_data/nq_train.jsonl \
    --batch_size 64 \
    --max_question_length 64 \
    --top_k 1 2 5 10 20 50 100 \
    --binary \
    --binary_k 2048 \
    --answer_match_type dpr_string \
    --include_title_in_passage \
    --device_ids 0 1 2 3
# The result should be logged as follows:
# Recall at 1: 0.4993 (39532/79168)
# Recall at 2: 0.6175 (48886/79168)
# Recall at 5: 0.7353 (58213/79168)
# Recall at 10: 0.7919 (62690/79168)
# Recall at 20: 0.8288 (65613/79168)
# Recall at 50: 0.8597 (68061/79168)
# Recall at 100: 0.8751 (69281/79168)

$ python evaluate_retriever.py \
    --biencoder_file <WORK_DIR>/biencoder/lightning_logs/version_0/checkpoints/last.ckpt \
    --passage_db_file <WORK_DIR>/passages.db \
    --passage_embeddings_file <WORK_DIR>/passage_embeddings.idx \
    --qa_file <DPR_DATASET_DIR>/retriever/qas/nq-dev.csv \
    --output_file <WORK_DIR>/reader_data/nq_dev.jsonl \
    --batch_size 64 \
    --max_question_length 64 \
    --top_k 1 2 5 10 20 50 100 \
    --binary \
    --binary_k 2048 \
    --answer_match_type dpr_string \
    --include_title_in_passage \
    --device_ids 0 1 2 3
# The result should be logged as follows:
# Recall at 1: 0.4047 (3544/8757)
# Recall at 2: 0.5143 (4504/8757)
# Recall at 5: 0.6398 (5603/8757)
# Recall at 10: 0.7117 (6232/8757)
# Recall at 20: 0.7595 (6651/8757)
# Recall at 50: 0.8134 (7123/8757)
# Recall at 100: 0.8420 (7373/8757)

$ python evaluate_retriever.py \
    --biencoder_file <WORK_DIR>/biencoder/lightning_logs/version_0/checkpoints/last.ckpt \
    --passage_db_file <WORK_DIR>/passages.db \
    --passage_embeddings_file <WORK_DIR>/passage_embeddings.idx \
    --qa_file <DPR_DATASET_DIR>/retriever/qas/nq-test.csv \
    --output_file <WORK_DIR>/reader_data/nq_test.jsonl \
    --batch_size 64 \
    --max_question_length 64 \
    --top_k 1 2 5 10 20 50 100 \
    --binary \
    --binary_k 2048 \
    --answer_match_type dpr_string \
    --include_title_in_passage \
    --device_ids 0 1 2 3
# The result should be logged as follows:
# Recall at 1: 0.4136 (1493/3610)
# Recall at 2: 0.5208 (1880/3610)
# Recall at 5: 0.6452 (2329/3610)
# Recall at 10: 0.7194 (2597/3610)
# Recall at 20: 0.7737 (2793/3610)
# Recall at 50: 0.8283 (2990/3610)
# Recall at 100: 0.8518 (3075/3610)

5. Train a reader

$ python train_reader.py \
    --train_file <WORK_DIR>/reader_data/nq_train.jsonl \
    --val_file <WORK_DIR>/reader_data/nq_dev.jsonl \
    --output_dir <WORK_DIR>/reader \
    --train_num_passages 24 \
    --eval_num_passages 100 \
    --max_input_length 256 \
    --shuffle_positive_passage \
    --shuffle_negative_passage \
    --num_dataloader_workers 1 \
    --base_pretrained_model bert-base-uncased \
    --answer_normalization_type dpr \
    --train_batch_size 1 \
    --eval_batch_size 2 \
    --learning_rate 1e-5 \
    --warmup_proportion 0.1 \
    --accumulate_grad_batches 4 \
    --gradient_clip_val 2.0 \
    --max_epochs 20 \
    --gpus 4 \
    --precision 16 \
    --strategy ddp

6. Evaluate the reader

$ python evaluate_reader.py \
    --reader_file <WORK_DIR>/reader/lightning_logs/version_0/checkpoints/best.ckpt \
    --test_file <WORK_DIR>/reader_data/nq_dev.jsonl \
    --test_num_passages 100 \
    --test_max_load_passages 100 \
    --test_batch_size 4 \
    --gpus 4 \
    --strategy ddp
# The result should be printed as follows:
# ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
# ┃        Test metric        ┃       DataLoader 0        ┃
# ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
# │   test_answer_accuracy    │    0.39294278621673584    │
# │ test_classifier_precision │    0.5889003276824951     │
# └───────────────────────────┴───────────────────────────┘
$ python evaluate_reader.py \
    --reader_file <WORK_DIR>/reader/lightning_logs/version_0/checkpoints/best.ckpt \
    --test_file <WORK_DIR>/reader_data/nq_test.jsonl \
    --test_num_passages 100 \
    --test_max_load_passages 100 \
    --test_batch_size 4 \
    --gpus 4 \
    --strategy ddp
# The result should be printed as follows:
# ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
# ┃        Test metric        ┃       DataLoader 0        ┃
# ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
# │   test_answer_accuracy    │    0.3900277018547058     │
# │ test_classifier_precision │    0.5836564898490906     │
# └───────────────────────────┴───────────────────────────┘

7. (optional) Convert the trained models into ONNX format

$ python convert_models_to_onnx.py \
    --biencoder_ckpt_file <WORK_DIR>/biencoder/lightning_logs/version_0/checkpoints/last.ckpt \
    --reader_ckpt_file <WORK_DIR>/reader/lightning_logs/version_0/checkpoints/best.ckpt \
    --output_dir <WORK_DIR>/onnx

8. Run demo

$ streamlit run demo.py --browser.serverAddress localhost --browser.serverPort 8501 -- \
    --biencoder_ckpt_file <WORK_DIR>/biencoder/lightning_logs/version_0/checkpoints/last.ckpt \
    --reader_ckpt_file <WORK_DIR>/reader/lightning_logs/version_0/checkpoints/best.ckpt \
    --passage_db_file <WORK_DIR>/passages.db \
    --passage_embeddings_file <WORK_DIR>/passage_embeddings.idx \
    --device cuda:0

or if you have exported the models to ONNX format:

$ streamlit run demo.py --browser.serverAddress localhost --browser.serverPort 8501 -- \
    --onnx_model_dir <WORK_DIR>/onnx \
    --passage_db_file <WORK_DIR>/passages.db \
    --passage_embeddings_file <WORK_DIR>/passage_embeddings.idx

Then open http://localhost:8501.

The demo can also be launched with Docker:

$ docker build -t soseki --build-arg TRANSFORMERS_BASE_MODEL_NAME='bert-base-uncased' .
$ docker run --rm -v $(realpath <WORK_DIR>):/app/model -p 8501:8501 -it soseki \
    streamlit run demo.py --browser.serverAddress localhost --browser.serverPort 8501 -- \
        --onnx_model_dir /app/model/onnx \
        --passage_db_file /app/model/passages.db \
        --passage_embeddings_file /app/model/passage_embeddings.idx

License

Creative Commons License
This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.

Citation

If you find this work useful, please cite the following paper:

Efficient Passage Retrieval with Hashing for Open-domain Question Answering

@inproceedings{yamada2021bpr,
  title={Efficient Passage Retrieval with Hashing for Open-domain Question Answering},
  author={Ikuya Yamada and Akari Asai and Hannaneh Hajishirzi},
  booktitle={ACL},
  year={2021}
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.