Hao Liu, Pieter Abbeel
Paper: https://arxiv.org/abs/2305.19370
This is the implementation of the Blockwise Parallel Transformer (BPT) model. The model is described in the paper Blockwise Parallel Transformer for Long Context Large Models.
- 2023/7/12: Added the implementation of BPT on top of the LLaMA language model, the code is under
llamabpt/
. It has simpler code and supports BPT, memeff/flashattention, and vanilla computation. See here for more details.
Install the requirements with:
conda env create -f gpu_requirements.yml
or set up TPU VM with:
sh tpu_requirements.sh
The code is organized as follows:
scripts/
contains the requirements and scripts for preparing the data.bpt/
contains original implementation of the BPT model on top of the GPT architecture.llamabpt/
contains the implementation of BPT on top of the LLaMA language model.
We recommend using the implementation in llamabpt/
. The implementation in llamabpt/
is simpler and has better optimized sharding annotations for distributed FSDP training. It also supports BPT, memeff/flashattention, and vanilla computation.
See here for more details.
For the implementation in bpt/
, the code in bpt/blocks/blockwise_parallel.py
is optimized for simplicity and readability. A more optimized version in bpt/blocks/blockwise_parallel_v1.py
is also provided, which explicitly instruct compiler to fuse certain ops and can serve as a reference for porting the code to low-level kernels. In our tests, the performance of blockwise_parallel_v1
is slightly better than blockwise_parallel
, but the difference is not significant. Because with blockwise parallel transformer, the compiler is able to fuse blockwise FFN and blockwise attention automatically.
An example script for training a 3B Transformers with 65536 context window length on 1 A100 80GB using blockwise parallel transformer is as follows. Firstly generate the training data (e.g OpenWebText) using scripts/prepare_data.py
. Then run the following script:
#! /bin/bash
export PYTHONPATH="$PYTHONPATH:$PWD"
export WANDB_API_KEY=''
export project_id='bpt'
export experiment_note=''
export XLA_PYTHON_CLIENT_MEM_FRACTION=.95
export experiment_id='3b,blockwise_parallel,65536,2048,4096'
python3 -m bpt.train \
--mesh_dim='1,1,-1' \
--total_steps=1000 \
--optimizer.type='adamw' \
--optimizer.accumulate_gradient_steps=1 \
--log_freq=2 \
--save_model_freq=-1 \
--load_gpt_config='3b' \
--profile_steps=0 \
--stop_after_profile=True \
--gpt.scan_layers=True \
--gpt.param_scan_axis=0 \
--gpt.q_chunk_size=2048 \
--gpt.k_chunk_size=4096 \
--gpt.attn_type='blockwise_parallel' \
--gpt.gradient_checkpointing='nothing_saveable' \
--gpt.n_positions=65536 \
--text_processor.fields='text' \
--train_dataset.path='./local/owt/openwebtext_train.jsonl' \
--train_dataset.seq_length=65536 \
--train_dataset.batch_size=1 \
--eval_steps=0 \
--logger.online=False \
--logger.project_id="$project_id" \
--logger.experiment_id="$experiment_id" \
--logger.append_uuid=False \
--logger.experiment_note="$experiment_note" \
--logger.output_dir="$HOME/experiment_output/$project_id" \
--logger.wandb_dir="$HOME/experiment_output/$project_id" \
--logger.profile_dir="$HOME/experiment_output/$project_id/jax_profile"
Explanation of the arguments:
mesh_dim
: the mesh dimension, the initial dimension of the mesh is usually for data parallelism, the second dimension is utilized for fully sharded data parallelism (FSDP), and the third dimension is allocated for tensor parallelism.total_steps
: the total number of training steps.optimizer.accumulate_gradient_steps
: the number of gradient accumulation steps.load_gpt_config
: the GPT model configuration, can be one of the configs inGPT_STANDARD_CONFIGS
inbpt/model.py
.gpt.scan_layers
: whether to scan the layers in the GPT model. IfTrue
, the compiling time will be significantly reduced.gpt.param_scan_axis
: the axis to scan the parameters. If0
, the parameters will be scanned along the first axis, which is the default setting.gpt.q_chunk_size
: the chunk size for the query sequence length.gpt.k_chunk_size
: the chunk size for the key sequence length. Bothq_chunk_size
andk_chunk_size
are used to control the memory usage and computation speed. We found memory usage is not very sensitive to the values ofq_chunk_size
andk_chunk_size
, and computation speed can be accelerated by using largerq_chunk_size
andk_chunk_size
. Which values perform best depends on the specific hardware and model size, and we leave it as a hyperparameter to be tuned.gpt.attn_type
: the type of Transformer model, can be one ofbp
,vanilla
, andmem_efficient
.bp
is the blockwise parallel transformer,vanilla
is the vanilla transformer, andmem_efficient
is the memory efficient transformer.gpt.gradient_checkpointing
: the gradient checkpointing strategy, both memory efficient transformer and blockwise parallel transformer require gradient checkpointing to be enabled, otherwise memory saving from reorganizing the compute will be lost due to storing all those intermediate results. Seeget_gradient_checkpoint_policy
for availale options, usenothing_saveable
for lowest memory usage.gpt.n_positions
: the maximum sequence length.gpt.seq_length
: the sequence length. The sequence length should be smaller than or equal ton_positions
.gpt.float32_logits
: whether to use float32 for logits. IfTrue
then compute logits in float32 to avoid numerical issues with bfloat16.
For large scale distributed training on TPU or on GPU cluster with good inter connection, we recommend first using FSDP, and add tensor parallelism when the model is very large or the global batch size is too large. You can control parallelism by using mesh_dim.
This code uses autodiff and rematerialization. In order to achieve optimal performance especially for achieving speed up, the backward pass should be written in low level kernels such as Triton/CUDA etc. This is on our TODO list and PRs are welcome.
If you find our work relevant to your research, please cite:
@article{liu2023blockwise,
title={Blockwise Parallel Transformer for Long Context Large Models},
author={Hao Liu and Pieter Abbeel},
year={2023},
journal={arXiv preprint arxiv:2305.19370}
}