Giter Site home page Giter Site logo

btm's Introduction

Branch-Train-Merge: Embarrassingly Parallel Training of Expert Language Models

Below are instructions to access the code and models for the paper "Branch-Train-Merge: Embarrassingly Parallel Training of Expert Language Models".

Code

This code is based on Fairseq, and includes a hard fork of Fairseq in the fairseq folder.

Setup

For basic setup of our code:

git clone https://github.com/hadasah/btm.git
cd btm/fairseq
pip install -e .

Note that this will uninstall any existing Fairseq install in your environment. For additional install options, see the Fairseq repository README.

Data

Most of the experiments are conducted with data from DeMIX. You can easily train on your own data by following the Fairseq instructions for data preprocessing.

Model Training

To train a Transformer-LM baseline with domain data balancing, or to conduct seed LM training:

# TODO @margaretli remove this upper command chunk
FOLDER_NAME=mod_os_test;
NUM_GPUS=16;
ARCHITECTURE=transformer_lm_gpt3_small;
DATA_FOLDER=/private/home/suching/raw_data/demix_scale/data-bin/;
DATA_DOMAIN_NAME=1b_demix_paper;
SAVE_MODEL_FOLDER=/checkpoint/margaretli/;
NUM_UPDATES=80000;
UPDATE_FREQ=32;
LR=5e-4;
SAVE_INTERVAL_UPDATES=2000;
PORT=55555;
WANDB_PROJECT=mod_os_test;
BTM_CODE_PATH=/private/home/margaretli/gitfiles/btm;
RANDOM_SEED=1;
UNIQUE_RUN_ID=unique_name;
conda activate mod_os;
cd $BTM_CODE_PATH;
bash btm_shell_scripts/btm_train.sh $FOLDER_NAME $NUM_GPUS \
$ARCHITECTURE dense $DATA_FOLDER $DATA_DOMAIN_NAME None \
$SAVE_MODEL_FOLDER None None None False \
$NUM_UPDATES $UPDATE_FREQ $LR $SAVE_INTERVAL_UPDATES $PORT \
$WANDB_PROJECT $BTM_CODE_PATH $RANDOM_SEED $UNIQUE_RUN_ID ;

MODEL_FOLDER_NAME=project_name;
NUM_GPUS=16;
ARCHITECTURE=transformer_lm_gpt3_small;
DATA_FOLDER=/path/to/data;
DATA_DOMAIN_NAME=data_domain_name;
SAVE_MODEL_FOLDER=/path/to/new/model/checkpointing;
NUM_UPDATES=80000;
UPDATE_FREQ=32;
LR=5e-4;
SAVE_INTERVAL_UPDATES=2000;
PORT=55555;
WANDB_PROJECT=project_name;
BTM_CODE_PATH=/path/to/this/repo;
RANDOM_SEED=1;
UNIQUE_RUN_ID=unique_run_name;

cd $BTM_CODE_PATH;
bash btm_shell_scripts/btm_train.sh $MODEL_FOLDER_NAME $NUM_GPUS \
$ARCHITECTURE dense $DATA_FOLDER $DATA_DOMAIN_NAME None \
$SAVE_MODEL_FOLDER None None None False \
$NUM_UPDATES $UPDATE_FREQ $LR $SAVE_INTERVAL_UPDATES $PORT \
$WANDB_PROJECT $BTM_CODE_PATH $RANDOM_SEED $UNIQUE_RUN_ID ;

To branch and train from an existing checkpoint:

MODEL_FOLDER_NAME=project_name;
NUM_GPUS=2;
ARCHITECTURE=transformer_lm_gpt3_small;
DATA_FOLDER=/path/to/data;
DATA_DOMAIN_NAME=data_domain;
INIT_CHECKPOINT_FOLDER=/path/to/seed/model/folder;
SAVE_MODEL_FOLDER=/path/to/new/model/checkpointing;
SEED_PHASE_COMPUTE_SHARE=None;
SEED_PHASE_UPDATE_NUM=model_update_number;
NUM_UPDATES=80000;
UPDATE_FREQ=32;
LR=5e-4;
SAVE_INTERVAL_UPDATES=2000;
PORT=55555;
WANDB_PROJECT=project_name;
BTM_CODE_PATH=/path/to/this/repo;
RANDOM_SEED=1;
UNIQUE_RUN_ID=unique_run_name2;

bash btm_shell_scripts/btm_train.sh $MODEL_FOLDER_NAME $NUM_GPUS \
$ARCHITECTURE branch $DATA_FOLDER $DATA_DOMAIN_NAME $INIT_CHECKPOINT_FOLDER \
$SAVE_MODEL_FOLDER . $SEED_PHASE_COMPUTE_SHARE $SEED_PHASE_UPDATE_NUM True \
$NUM_UPDATES $UPDATE_FREQ $LR $SAVE_INTERVAL_UPDATES $PORT \
$WANDB_PROJECT $BTM_CODE_PATH $RANDOM_SEED $UNIQUE_RUN_ID ;

Model Evaluation

To evaluate a single LM:

DATA_FOLDER=/path/to/data;
MODEL_FOLDER=/path/to/new/model/checkpointing/project_name/unique_run_name;
CHECKPOINT_FILE_NAME=checkpoint_last.pt;
DATA_SPLIT=test;
DATA_DOMAIN_NAME=data_domain;

bash btm_shell_scripts/eval_pipeline.sh $DATA_FOLDER $MODEL_FOLDER \
$DATA_SPLIT $DATA_DOMAIN_NAME ;

To evaluate an ensemble of LMs, where the ensemble is weighted by the domain posterior (this requires jq):

NUM_EXPERTS=8;
DATA_FOLDER=/path/to/data;
MODEL_PATHS=/path/to/expert1:/path/to/expert2:etc;
DATA_DOMAIN_NAME=data_domain;
ENSEMBLE_TYPE=cached_prior;
RESULTS_OUTPUT_FOLDER=/path/to/output/folder;

bash btm_shell_scripts/ensemble_eval.sh $NUM_EXPERTS $DATA_FOLDER \
$MODEL_PATHS $DATA_DOMAIN_NAME $ENSEMBLE_TYPE $RESULTS_OUTPUT_FOLDER ;

To parameter average LMs:

RESULTING_MODEL_OUTPUT_FOLDER=/path/to/output/folder;
WEIGHTS=0.1,0.9;
MODEL_PATHS=/path/to/expert1:/path/to/expert2;
python btm_utils/average.py --output-dir RESULTING_MODEL_OUTPUT_FOLDER \
--weights WEIGHTS --model-files MODEL_PATHS ;

To evaluate the parameter-averaged LM, use the single-LM evaluation command above.

Models

All trained ELMs, Transformer-LM and DeMIX baselines across the 125M, 350M, 750M, and 1.3B parameter scales, as well as the 350M parameter ELMs trained on our 64-domain corpus are made available.

To download one of the Transformer_LMs:

# 125M, 350M, 750M or 1.3B
model_scale=125M;
model_architecture=transformer_lm;

mkdir -p btm_models/models/${model_scale}/${model_architecture}/
cd btm_models/models/${model_scale}/${model_architecture}/
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last.pt

To download the DeMIX models:

# 125M, 350M, 750M or 1.3B
model_scale=125M;
model_architecture=demix;

mkdir -p btm_models/models/${model_scale}/${model_architecture}/
cd btm_models/models/${model_scale}/${model_architecture}/
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-0.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-1.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-2.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-3.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-4.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-5.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-6.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-rank-7.pt;
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/checkpoint_last-shared.pt;

To download one of the ELMs:

# 125M, 350M, 750M or 1.3B
model_scale=125M;
model_architecture=elmforest;
# one of the domains specified in btm_utils/constants.py
domain=1b;

mkdir -p btm_models/models/${model_scale}/${model_architecture}/${domain}/
cd btm_models/models/${model_scale}/${model_architecture}/${domain}/
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${model_architecture}/${domain}/checkpoint_last.pt

To download one of the ELMs from the 64 domain experiments:

model_scale=64_domain_curriculum;
# one of the domains specified in btm_utils/constants.py
domain=2021newscrawl;

mkdir -p btm_models/models/${model_scale}/${domain}/
cd btm_models/models/${model_scale}/${domain}/
wget -c https://dl.fbaipublicfiles.com/btm/models/${model_scale}/${domain}/checkpoint_last.pt

btm's People

Contributors

hadasah avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.