Giter Site home page Giter Site logo

in-context-ralm's Introduction

In-Context Retrieval-Augmented Language Models

This repo contains the code for reproducing the experiments on WikiText-103 from AI21 Labs' paper In-Context Retrieval-Augmented Language Models (In-Context RALM), to appear in the Transactions of the Association for Computational Linguistics (TACL).

Our code is mainly based on the Transformers and Pyserini libraries.
We test it on Python 3.8.

Table of Contents

Setup

To install the required libraries in our repo, run:

pip install -r requirements.txt

To have a Pytorch version specific to your CUDA, install your version before running the above command.

Retrieval

BM25

Our BM25 preparation script works with Pyserini, so Java 11 is required - see their installation guide.
If you have Java 11 installed, make sure your JAVA_HOME environment variable is set to the correct path. On a Linux system, the correct path might look something like /usr/lib/jvm/java-11.
Then run:

python prepare_retrieval_data.py \
--retrieval_type sparse \
--tokenizer_name $MODEL_NAME \
--max_length 1024 \
--dataset_path wikitext \
--dataset_name wikitext-103-v1 \
--dataset_split [validation, test] \
--index_name wikipedia-dpr \
--forbidden_titles_path ralm/retrievers/wikitext103_forbidden_titles.txt \
--stride 4 \
--output_file $RETRIEVAL_FILE \
--num_tokens_for_query 32 \
--num_docs 16 

Evaluation

List of Language Models

In the paper, we give the results for the following models (replace $MODEL_NAME with one of those).
Note that the larger models may need model parallelism (on a 40GB A100, we used model parallelism for OPT-30B and OPT-66B).
See details below on how to apply this option.

  • GPT-2: gpt2, gpt2-medium, gpt2-large, gpt2-xl
  • GPT-Neo: EleutherAI/gpt-neo-1.3B, EleutherAI/gpt-neo-2.7B, EleutherAI/gpt-j-6B
  • OPT: facebook/opt-125m, facebook/opt-350m, facebook/opt-1.3b, facebook/opt-2.7b, facebook/opt-6.7b, facebook/opt-13b, facebook/opt-30b, facebook/opt-66b

Evaluate models w/o retrieval

To run evaluation on models without retrieval, please use the following command (you can increase stride to 32 for faster evaluation):

python eval_lm.py \
--model_name $MODEL_NAME \
--dataset_path wikitext \
--dataset_name wikitext-103-v1 \
--dataset_split [validation, test] \
--output_dir $OUTPUT_DIR \
--stride 4 \
--max_length 1024 \
[--model_parallelism]

Evaluate models with retrieval:

To run models with retrieval, use the $RETRIEVAL_FILE output from the prepare_retrieval_data.py script:

python eval_lm.py \
--model_name $MODEL_NAME \
--dataset_path wikitext \
--dataset_name wikitext-103-v1 \
--dataset_split [validation, test] \
--output_dir $OUTPUT_DIR \
--stride 4 \
--max_length 1024 \
[--model_parallelism] \
--retrieved_file $RETRIEVAL_FILE

Note: Our main retrieval flow assumes you want to use the top-scored passage from your retrieval file (--ranking_strategy first).

Reranking

Currently, we support logprob (the zero-shot method described in subsection 6.1) and oracle (to understand the potential gains from reranking).

For reranking, first you need to make sure you run the retrieval script with num_docs=16 (or any other number you want to rerank on). If you enable multiple GPUs, data parallelism will automatically be applied (each GPU will get different retrieved documents to condition on). Then run:

python eval_lm.py \
--model_name $MODEL_NAME \
--dataset_path wikitext \
--dataset_name wikitext-103-v1 \
--dataset_split [validation, test] \
--output_dir $OUTPUT_DIR \
--stride 4 \
--max_length 1024 \
[--model_parallelism] \
--retrieved_file $RETRIEVAL_FILE \
--ranking_strategy [logprob, oracle] \
--num_docs_to_rank 16 \
--ranking_logprob_past_tokens 16

Question Answering Experiments

To run our QA experiments on Natural Questions, start by downloading the datasets augmented by DPR results:

wget https://dl.fbaipublicfiles.com/dpr/data/retriever_results/single/nq-test.json.gz
gzip -d ./nq-test.json.gz

To run our QA experiments on TriviaQA, install gsutil and copy the DPR-augmented dataset:

gsutil cp gs://ai21-publishing-public-models/in-context-ralm/trivia-test-dpr-results.json ./trivia-test.json

Then run the evaluation script:

python eval_qa.py \
--model_name $MODEL_NAME \
--dataset_path [nq-test.json,trivia-test.json] \
--output_dir $OUTPUT_DIR \
--num_docs [0,1,2] \
[--model_parallelism]

where num_docs is the number of retrieved documents to include in-context (num_docs=0 is the closed-book setting, num_docs>=1 is open-book setting.)

Citation

If you find our paper or code helpful, please cite our paper:

@article{ram-etal-2023-context,
    title = "In-Context Retrieval-Augmented Language Models",
    author = "Ram, Ori  and
      Levine, Yoav  and
      Dalmedigos, Itay  and
      Muhlgay, Dor  and
      Shashua, Amnon  and
      Leyton-Brown, Kevin  and
      Shoham, Yoav",
    journal = "Transactions of the Association for Computational Linguistics",
    volume = "11",
    year = "2023",
    address = "Cambridge, MA",
    publisher = "MIT Press",
    url = "https://aclanthology.org/2023.tacl-1.75",
    doi = "10.1162/tacl_a_00605",
    pages = "1316--1331",
}

in-context-ralm's People

Contributors

oriram avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.