Giter Site home page Giter Site logo

btm's Introduction

Branch-Train-Merge: Embarrassingly Parallel Training of Expert Language Models

Below are instructions to access the code and models for the paper "Branch-Train-Merge: Embarrassingly Parallel Training of Expert Language Models".

Code

This code is based on Fairseq, and includes a hard fork of Fairseq in the fairseq folder.

Setup

For basic setup of our code:

git clone https://github.com/hadasah/btm.git
cd btm/fairseq
pip install -e .

Note that this will uninstall any existing Fairseq install in your environment. For additional install options, see the Fairseq repository README.

Data

Most of the experiments are conducted with data from DeMIX. You can easily train on your own data by following the Fairseq instructions for data preprocessing.

Model Training

To train a Transformer-LM baseline with domain data balancing, or to conduct seed LM training:

# TODO @margaretli remove this upper command chunk
FOLDER_NAME=mod_os_test;
NUM_GPUS=16;
ARCHITECTURE=transformer_lm_gpt3_small;
DATA_FOLDER=/private/home/suching/raw_data/demix_scale/data-bin/;
DATA_DOMAIN_NAME=1b_demix_paper;
SAVE_MODEL_FOLDER=/checkpoint/margaretli/;
NUM_UPDATES=80000;
UPDATE_FREQ=32;
LR=5e-4;
SAVE_INTERVAL_UPDATES=2000;
PORT=55555;
WANDB_PROJECT=mod_os_test;
BTM_CODE_PATH=/private/home/margaretli/gitfiles/btm;
RANDOM_SEED=1;
UNIQUE_RUN_ID=unique_name;
conda activate mod_os;
cd $BTM_CODE_PATH;
bash btm_shell_scripts/btm_train.sh $FOLDER_NAME $NUM_GPUS \
$ARCHITECTURE dense $DATA_FOLDER $DATA_DOMAIN_NAME None \
$SAVE_MODEL_FOLDER None None None False \
$NUM_UPDATES $UPDATE_FREQ $LR $SAVE_INTERVAL_UPDATES $PORT \
$WANDB_PROJECT $BTM_CODE_PATH $RANDOM_SEED $UNIQUE_RUN_ID ;

MODEL_FOLDER_NAME=project_name;
NUM_GPUS=16;
ARCHITECTURE=transformer_lm_gpt3_small;
DATA_FOLDER=/path/to/data;
DATA_DOMAIN_NAME=data_domain_name;
SAVE_MODEL_FOLDER=/path/to/new/model/checkpointing;
NUM_UPDATES=80000;
UPDATE_FREQ=32;
LR=5e-4;
SAVE_INTERVAL_UPDATES=2000;
PORT=55555;
WANDB_PROJECT=project_name;
BTM_CODE_PATH=/path/to/this/repo;
RANDOM_SEED=1;
UNIQUE_RUN_ID=unique_run_name;

cd $BTM_CODE_PATH;
bash btm_shell_scripts/btm_train.sh $MODEL_FOLDER_NAME $NUM_GPUS \
$ARCHITECTURE dense $DATA_FOLDER $DATA_DOMAIN_NAME None \
$SAVE_MODEL_FOLDER None None None False \
$NUM_UPDATES $UPDATE_FREQ $LR $SAVE_INTERVAL_UPDATES $PORT \
$WANDB_PROJECT $BTM_CODE_PATH $RANDOM_SEED $UNIQUE_RUN_ID ;

To branch and train from an existing checkpoint:

MODEL_FOLDER_NAME=project_name;
NUM_GPUS=2;
ARCHITECTURE=transformer_lm_gpt3_small;
DATA_FOLDER=/path/to/data;
DATA_DOMAIN_NAME=data_domain;
INIT_CHECKPOINT_FOLDER=/path/to/seed/model/folder;
SAVE_MODEL_FOLDER=/path/to/new/model/checkpointing;
SEED_PHASE_COMPUTE_SHARE=None;
SEED_PHASE_UPDATE_NUM=model_update_number;
NUM_UPDATES=80000;
UPDATE_FREQ=32;
LR=5e-4;
SAVE_INTERVAL_UPDATES=2000;
PORT=55555;
WANDB_PROJECT=project_name;
BTM_CODE_PATH=/path/to/this/repo;
RANDOM_SEED=1;
UNIQUE_RUN_ID=unique_run_name2;

bash btm_shell_scripts/btm_train.sh $MODEL_FOLDER_NAME $NUM_GPUS \
$ARCHITECTURE branch $DATA_FOLDER $DATA_DOMAIN_NAME $INIT_CHECKPOINT_FOLDER \
$SAVE_MODEL_FOLDER . $SEED_PHASE_COMPUTE_SHARE $SEED_PHASE_UPDATE_NUM True \
$NUM_UPDATES $UPDATE_FREQ $LR $SAVE_INTERVAL_UPDATES $PORT \
$WANDB_PROJECT $BTM_CODE_PATH $RANDOM_SEED $UNIQUE_RUN_ID ;

Model Evaluation

To evaluate a single LM:

DATA_FOLDER=/path/to/data;
MODEL_FOLDER=/path/to/new/model/checkpointing/project_name/unique_run_name;
CHECKPOINT_FILE_NAME=checkpoint_last.pt;
DATA_SPLIT=test;
DATA_DOMAIN_NAME=data_domain;

bash btm_shell_scripts/eval_pipeline.sh $DATA_FOLDER $MODEL_FOLDER \
$DATA_SPLIT $DATA_DOMAIN_NAME ;

To evaluate an ensemble of LMs, where the ensemble is weighted by the domain posterior (this requires jq):

NUM_EXPERTS=8;
DATA_FOLDER=/path/to/data;
MODEL_PATHS=/path/to/expert1:/path/to/expert2:etc;
DATA_DOMAIN_NAME=data_domain;
ENSEMBLE_TYPE=cached_prior;
RESULTS_OUTPUT_FOLDER=/path/to/output/folder;

bash btm_shell_scripts/ensemble_eval.sh $NUM_EXPERTS $DATA_FOLDER \
$MODEL_PATHS $DATA_DOMAIN_NAME $ENSEMBLE_TYPE $RESULTS_OUTPUT_FOLDER ;

To parameter average LMs:

RESULTING_MODEL_OUTPUT_FOLDER=/path/to/output/folder;
WEIGHTS=0.1,0.9;
MODEL_PATHS=/path/to/expert1:/path/to/expert2;
python btm_utils/average.py --output-dir RESULTING_MODEL_OUTPUT_FOLDER \
--weights WEIGHTS --model-files MODEL_PATHS ;

To evaluate the parameter-averaged LM, use the single-LM evaluation command above.

Models

All trained ELMs, Transformer-LM and DeMIX baselines across the 125M, 350M, 750M, and 1.3B parameter scales, as well as the 350M parameter ELMs trained on our 64-domain corpus are made available.

To download one of the Transformer_LMs:

# 125M, 350M, 750M or 1.3B
model_scale=125M;
model_architecture=transformer_lm;

mkdir -p btm_models/models/${model_scale}/${model_architecture}/
cd btm_models/models/${model_scale}/${model_architecture}/
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last.pt

To download the DeMIX models:

# 125M, 350M, 750M or 1.3B
model_scale=125M;
model_architecture=demix;

mkdir -p btm_models/models/${model_scale}/${model_architecture}/
cd btm_models/models/${model_scale}/${model_architecture}/
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-0.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-1.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-2.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-3.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-4.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-5.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-6.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-7.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-shared.pt;

To download one of the ELMs:

# 125M, 350M, 750M or 1.3B
model_scale=125M;
model_architecture=elmforest;
# one of the domains specified in btm_utils/constants.py
domain=1b;

mkdir -p btm_models/models/${model_scale}/${model_architecture}/${domain}/
cd btm_models/models/${model_scale}/${model_architecture}/${domain}/
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/${domain}/checkpoint_last.pt

To download one of the ELMs from the 64 domain experiments:

model_scale=64_domain_curriculum;
# one of the domains specified in btm_utils/constants.py
domain=2021newscrawl;

mkdir -p btm_models/models/${model_scale}/${domain}/
cd btm_models/models/${model_scale}/${domain}/
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${domain}/checkpoint_last.pt

btm's People

Contributors

hadasah avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

btm's Issues

Unable to find the contents of the fairseq directory

Hello.

I am very interested in the contributions your paper has shown.
However, I cannot try to reproduce the experiment because the fairseq directory is empty.
Could you please add some information about the fairseq directory?

`expert_probs` in btm/ensemble_eval_lm.py: 104

Hi,

Thanks for open sourcing this great work! I have some questions on how to calculate the posterior probability for experts. From this line, it seems that the expert probabilities are calculated inside the sequence_scorer.py. I didn't find any expert_probs in this file but fount it in another repo. It seems that the expert probs are the probabilities of the last position in the sequence weights[:, :, -1], why is that the case? I imagine that the last position's probability does not have any meanings. It would be very helpful if you could clarify how to calculate the posterior probability in more detail :) Thanks!

750M ELM model download links lead to 403 forbidden

Hello,

Recently I've trying to reproduce the experiments and also tried the pretrained ELMs, but it came out that some of the links are not working well. Below are the models that cannot be downloaded successfully.

  • 1b
  • anonymized_openwebtext
  • anonymized_realnews
  • anonymized_reviews

All of them are with model_scale=750M, model_architecture=elmforest;

Here's how I download those files.

model_scales=("125M" "350M" "750M" "1.3B")
model_architecture=elmforest;
domains=("1b" "anonymized_openwebtext" "anonymized_realnews" "anonymized_reviews" "cs" "legal" "med" "reddit")

for model_scale in "${model_scales[@]}"
do
    for domain in "${domains[@]}"
    do
        mkdir -p btm_models/models/${model_scale}/${model_architecture}/${domain}/
        cd btm_models/models/${model_scale}/${model_architecture}/${domain}/
        wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/${domain}/checkpoint_last.pt
        cd -
    done
done

Thank you for open-sourcing this work!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.