Giter Site home page Giter Site logo

peytontolbert / blockwise-parallel-transformer Goto Github PK

View Code? Open in Web Editor NEW

This project forked from lhao499/ringattention

0.0 0.0 0.0 94 KB

Blockwise Parallel Transformer for Long Context Large Models

License: Apache License 2.0

Shell 1.43% Python 98.57%

blockwise-parallel-transformer's Introduction

Blockwise Parallel Transformer for Long Context Large Models

Hao Liu, Pieter Abbeel

Paper: https://arxiv.org/abs/2305.19370

This is the implementation of the Blockwise Parallel Transformer (BPT) model. The model is described in the paper Blockwise Parallel Transformer for Long Context Large Models.

Updates

  • 2023/7/12: Added the implementation of BPT on top of the LLaMA language model, the code is under llamabpt/. It has simpler code and supports BPT, memeff/flashattention, and vanilla computation. See here for more details.

Requirements

Install the requirements with:

conda env create -f gpu_requirements.yml

or set up TPU VM with:

sh tpu_requirements.sh

Code structure

The code is organized as follows:

  • scripts/ contains the requirements and scripts for preparing the data.
  • bpt/ contains original implementation of the BPT model on top of the GPT architecture.
  • llamabpt/ contains the implementation of BPT on top of the LLaMA language model.

We recommend using the implementation in llamabpt/. The implementation in llamabpt/ is simpler and has better optimized sharding annotations for distributed FSDP training. It also supports BPT, memeff/flashattention, and vanilla computation. See here for more details.

For the implementation in bpt/, the code in bpt/blocks/blockwise_parallel.py is optimized for simplicity and readability. A more optimized version in bpt/blocks/blockwise_parallel_v1.py is also provided, which explicitly instruct compiler to fuse certain ops and can serve as a reference for porting the code to low-level kernels. In our tests, the performance of blockwise_parallel_v1 is slightly better than blockwise_parallel, but the difference is not significant. Because with blockwise parallel transformer, the compiler is able to fuse blockwise FFN and blockwise attention automatically.

Usage

An example script for training a 3B Transformers with 65536 context window length on 1 A100 80GB using blockwise parallel transformer is as follows. Firstly generate the training data (e.g OpenWebText) using scripts/prepare_data.py. Then run the following script:

#! /bin/bash

export PYTHONPATH="$PYTHONPATH:$PWD"
export WANDB_API_KEY=''
export project_id='bpt'
export experiment_note=''
export XLA_PYTHON_CLIENT_MEM_FRACTION=.95
export experiment_id='3b,blockwise_parallel,65536,2048,4096'

python3 -m bpt.train \
    --mesh_dim='1,1,-1' \
    --total_steps=1000 \
    --optimizer.type='adamw' \
    --optimizer.accumulate_gradient_steps=1 \
    --log_freq=2 \
    --save_model_freq=-1 \
    --load_gpt_config='3b' \
    --profile_steps=0 \
    --stop_after_profile=True \
    --gpt.scan_layers=True \
    --gpt.param_scan_axis=0 \
    --gpt.q_chunk_size=2048 \
    --gpt.k_chunk_size=4096 \
    --gpt.attn_type='blockwise_parallel' \
    --gpt.gradient_checkpointing='nothing_saveable' \
    --gpt.n_positions=65536 \
    --text_processor.fields='text' \
    --train_dataset.path='./local/owt/openwebtext_train.jsonl' \
    --train_dataset.seq_length=65536 \
    --train_dataset.batch_size=1 \
    --eval_steps=0 \
    --logger.online=False \
    --logger.project_id="$project_id" \
    --logger.experiment_id="$experiment_id" \
    --logger.append_uuid=False \
    --logger.experiment_note="$experiment_note" \
    --logger.output_dir="$HOME/experiment_output/$project_id" \
    --logger.wandb_dir="$HOME/experiment_output/$project_id" \
    --logger.profile_dir="$HOME/experiment_output/$project_id/jax_profile"

Explanation of the arguments:

  • mesh_dim: the mesh dimension, the initial dimension of the mesh is usually for data parallelism, the second dimension is utilized for fully sharded data parallelism (FSDP), and the third dimension is allocated for tensor parallelism.
  • total_steps: the total number of training steps.
  • optimizer.accumulate_gradient_steps: the number of gradient accumulation steps.
  • load_gpt_config: the GPT model configuration, can be one of the configs in GPT_STANDARD_CONFIGS in bpt/model.py.
  • gpt.scan_layers: whether to scan the layers in the GPT model. If True, the compiling time will be significantly reduced.
  • gpt.param_scan_axis: the axis to scan the parameters. If 0, the parameters will be scanned along the first axis, which is the default setting.
  • gpt.q_chunk_size: the chunk size for the query sequence length.
  • gpt.k_chunk_size: the chunk size for the key sequence length. Both q_chunk_size and k_chunk_size are used to control the memory usage and computation speed. We found memory usage is not very sensitive to the values of q_chunk_size and k_chunk_size, and computation speed can be accelerated by using larger q_chunk_size and k_chunk_size. Which values perform best depends on the specific hardware and model size, and we leave it as a hyperparameter to be tuned.
  • gpt.attn_type: the type of Transformer model, can be one of bp, vanilla, and mem_efficient. bp is the blockwise parallel transformer, vanilla is the vanilla transformer, and mem_efficient is the memory efficient transformer.
  • gpt.gradient_checkpointing: the gradient checkpointing strategy, both memory efficient transformer and blockwise parallel transformer require gradient checkpointing to be enabled, otherwise memory saving from reorganizing the compute will be lost due to storing all those intermediate results. See get_gradient_checkpoint_policy for availale options, use nothing_saveable for lowest memory usage.
  • gpt.n_positions: the maximum sequence length.
  • gpt.seq_length: the sequence length. The sequence length should be smaller than or equal to n_positions.
  • gpt.float32_logits: whether to use float32 for logits. If True then compute logits in float32 to avoid numerical issues with bfloat16.

For large scale distributed training on TPU or on GPU cluster with good inter connection, we recommend first using FSDP, and add tensor parallelism when the model is very large or the global batch size is too large. You can control parallelism by using mesh_dim.

This code uses autodiff and rematerialization. In order to achieve optimal performance especially for achieving speed up, the backward pass should be written in low level kernels such as Triton/CUDA etc. This is on our TODO list and PRs are welcome.

Reference

If you find our work relevant to your research, please cite:

@article{liu2023blockwise,
    title={Blockwise Parallel Transformer for Long Context Large Models},
    author={Hao Liu and Pieter Abbeel},
    year={2023},
    journal={arXiv preprint arxiv:2305.19370}
}

blockwise-parallel-transformer's People

Contributors

lhao499 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.