Giter Site home page Giter Site logo

foundation-model-stack / fms-fsdp Goto Github PK

View Code? Open in Web Editor NEW
59.0 59.0 10.0 446 KB

Demonstrate throughput of PyTorch FSDP

Home Page: https://pytorch.org/docs/stable/fsdp.html

License: Apache License 2.0

Python 98.92% Shell 1.08%
distributed-training llm pytorch

fms-fsdp's Introduction

Foundation Model Stack

Foundation Model Stack is a collection of components for development, inference, training, and tuning of foundation models leveraging PyTorch native components. For inference optimizations we aim to support PyTorch compile, accelerated transformers, and tensor parallelism. At training time we aim to support FSDP, accelerated transformers, and PyTorch compile. To enable these optimizations, we will provide reimplementations of several popular model architectures starting with Llama and GPT-BigCode.

Models Supported

Model family Inference Tuning and Training
LLaMA ✔️ ✔️
GPT-BigCode ✔️
RoBERTa ✔️

Installation

We recommend running this on Python 3.11 and CUDA 12.1 for best performance, as the CPU overheads of the models are reduced significantly.

Pypi

pip install ibm-fms

Local

Requires PyTorch >= 2.1.

pip install -e .

or

python setup.py install

Inference

Approach

Our approach for inference optimization is to use PyTorch compile, accelerated transformers, and tensor parallelism. PyTorch compile compiles the code into optimized kernels, accelerated transformers leverages scaled_dot_product_attention (SDPA) for accelerating attention computation while saving memory, and tensor parallelism is necessary for larger models.

To enable the Llama models to compile, we had to reimplement RoPE encodings without complex numbers. With this change, Llama model inference is able to leverage model compilation for latency reduction.

Inference latency

We measured inference latencies with 1024 token prompt and generation of 256 tokens on AWS P4de instance nodes with 8 80G A100 GPUs and report the median latency in the below table.

Model # GPUs Median latency (ms)
7B 1 14ms
13B 1 22ms
70B 8 30ms

If you would like to reproduce the latencies, you can run the scripts/benchmark_inference.py and the details are described in inference.

For more information on reproducing the benchmarks and running some examples, see here

HF Model Support

The support for HF models is provided by our HF model adapter. One can obtain similar latencies as tabulated above with HF models using our HF model adapter:

from fms.models import get_model
from fms.models.hf import to_hf_api
import torch
from transformers import pipeline
# fms model
llama = get_model("llama", "13b")

# huggingface model backed by fms internals
llama_hf = to_hf_api(llama)

# compile the model -- in HF, the decoder only
llama_hf.decoder = torch.compile(llama_hf.decoder)

# generate some text -- the first time will be slow since the model needs to be compiled, but subsequent generations should be faster.
llama_generator = pipeline(task="text-generation", model=llama_hf, tokenizer=tokenizer)
llama_generator("""q: how are you? a: I am good. How about you? q: What is the weather like today? a:""")

A detailed example is provided here.

Tuning

To fine-tune LLaMA, use the scripts/train_causal.py training script. Here's an example of that command.

torchrun --nproc_per_node=2 \
        scripts/train_causal.py \
        --architecture=llama \
        --variant=7b \
        --tokenizer=~/models/tokenizer.model \
        --model_path=~/models/7B/ \
        --report_steps=10 \
        --checkpoint_format=meta \
        --distributed=fsdp

See options in the script for other ways to train and tune.

Structure and contents of this Repository

  • fms/models/ - Pure pytorch implementations of popular model architectures, without requiring any specific common interface beyond nn.Module. Each model configuration is registered with fms.models.register_model() so that instances can be obtained through fms.models.get_model('architecture', 'variant', '/path/to/data'). Each model can also register sources/formats/versions of data to load (e.g. checkpoints provided by meta, HF, or trained from this repo). Users of the repo (e.g. fms-extras) can register their own model architectures as well.
  • fms/models/hf/ - Adapters that compose our native PyTorch FMS model architecture implementations in HF-compatible wrapper interfaces. Each FMS model implements an adapter, and adapted instances are obtained via fms.models.hf.to_hf_api(model)
  • fms/datasets/ - Code for loading data for pre-training and fine-tuning. Individual datasets are retrieved by fms.datasets.get_dataset('name', tokenizer, 'optional path or other data reference'). The expected tokenizer conforms to an fms.utils.tokenizers.BaseTokenizer interface.
  • fms/modules/ - Components extending nn.Module used in our model architecture implementations. Each Module has a corresponding TPModule so that modules can be sharded using a tensor-parallel distribution strategy. FMS modules should all support torch.compile without graph breaks.
  • fms/training/ - Pre-training and fine-tuning code.
  • fms/utils/ - Other operators useful in working with LLMs. These include a generate() function, Tensor subclasses, code for dealing with LLM checkpoints that might be saved/sharded in a variety of formats, tokenization code, and various other useful helper functions.
  • scripts/ - Various scripts for inference, benchmarking, and evaluation, as well as an entry-point for tuning/training.

Extensions and Use Cases

This library is used by three dependent projects at IBM.

  • fms-fsdp - This repo shares training code that has been used to pretrain an fms implementation of LLaMA on IBM internal data.
  • fms-extras - This repo shares code for additional fms-based models trained by IBM. This repo will also be a home for other extensions, and may also include research or in-developent work intended for eventual upstreaming to fms.
  • TGIS - This inference server includes support for serving fms models.

Open Issues

  • pytorch/pytorch#107824 prevents training/finetuning from working with torch.compile.
  • In addition, there are several open issues we are tracking to improve stability and memory footprint of inference

References

fms-fsdp's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

fms-fsdp's Issues

revert "raise Dynamo accumulated cache size limit"

We recently added a commit to raise Dynamo accumulated cache size limit to make compile work with large models like 70b whose num_layer is greater than default limit (64): #45 (comment).

Now this number has been officially raised in PyTorch (related PR) from 64 to 256. So we can revert this commit as the new default is enough.

I am holding the revert change as the torch PR just got merged in today's nightly, and most env does not have this "fix" yet.

This is to be revisited later once that PR is picked by most env we have, and then we can revert it.

add wandb

we should add wandb logging for better external log sharing.

add FLOP counter

add a flop counter to the code with a bool flag.

it is already available in the flop_counter branch but will require some extra work to prettify it and integrate it with a flag.

low priority. to be visited in the future.

add Rank0-only profiler

With the current profiler, each GPU will write its own trace. This can be sometimes unnecessary/unwanted, as one might want to avoid writing 1024 traces (each hundreds of MBs) to the same shared location at the same time.

We should provide a new flag on whether to write profiler trace from rank0 gpu only. This should serve most of the cases.

Unable to Replicate MFU for 7B on 80gb A100

Issue

I'm not able to replicate the reported MFU results for the 7B model on a single 8GPU 80gb A100 node. Rather, I'm getting ~15% MFU. Despite the lower MFU, I'm getting approximately the same reported throughput of 3.7k.

Details

Here are the config settings I'm using. They're close to the settings in train.sh except I switch low_cpu_fsdp = True and use_dummy_dataset = True.

        "use_dummy_dataset": True,
        "ckpt_load_path": "/lustre/pretrain/ckpt",
        "ckpt_save_path": "/lustre/pretrain/ckpt",
        "data_path": "/lustre/bluepile-processing/rel0_7/tokens/llama2/high_quality_rerun_fuzzy_deduped",
        "fsdp_activation_checkpointing": False,
        "selective_checkpointing": 1,
        "sharding_strategy": "hsdp",
        "low_cpu_fsdp": True,
        "batch_size": 2,
        "report_interval": 5,
        "checkpoint_interval": 20000,
        "use_torch_compile": False,
        "use_profiler": False,

MFU Measurement

I'm measuring the MFU using the class below, which I wrote based on the MFU calculation in nanoGPT. I initialize this class as shown below for the 7b model. To track MFU, I add a print statement to the training loop inside the train function as given below.

MFU Initialization

    mfu = ModelFlopsUtilization(
        n=sum([p.numel() for p in model.parameters()]),
        t=4096,
        l=32,
        h=32,
        q=128,
        gpu_type="A100_bf16",
        num_gpus_per_replica=1,
    )

MFU Print Statement

        t_start = time.time()
        input = input.to(local_rank)
        label = label.to(local_rank)

        optimizer.zero_grad()
        output = model(input)
        ce_loss = torch.nn.CrossEntropyLoss()
        loss = ce_loss(output.view(-1, output.size(-1)), label.view(-1).long())

        loss.backward()
        ddp_stats[1] += model.clip_grad_norm_(cfg.grad_clip_thresh).item()
        optimizer.step()
        scheduler.step()

        t_end = time.time()
        print("mfu:", mfu(cfg.batch_size, 1, t_end - t_start))

MFU Class

class ModelFlopsUtilization(object):

    GPU_TYPE_TO_MAX_FLOPS_PER_SEC = {
        # H100 GPU with bfloat16, this is the SXM version
        "H100_SXM_bf16": 1979e12,
        # H100 GPU with bfloat16, this is the PICe version
        "H100_PCIe_bf16": 1513e12,
        # V100 GPU with FP 32
        "V100_fp32": 14e12,
        # A100 GPU with FP 32
        "A100_fp32": 19.5e12,
        # A100 GPU with tensor float 32 peak flops is 156 TFLOPS
        "A100_tf32": 156e12,
        # A100 GPU bfloat16 peak flops is 312 TFLOPS
        "A100_bf16": 312e12,
        # 0 result for CPU
        "cpu": float("-inf"),
    }

    def __init__(self, n, t, l, h, q, gpu_type, num_gpus_per_replica) -> None:
        self.model_flops_per_token = 6 * n + 12 * l * h * q * t
        self.model_flops_per_input = t * self.model_flops_per_token
        self.gpu_max_flops_per_sec = self.GPU_TYPE_TO_MAX_FLOPS_PER_SEC[gpu_type] * num_gpus_per_replica

    def __call__(self, batch_size, steps_per_iter, iter_time):
        """
        Returns the fraction of theoretical flops-per-sec utilized
        """
        model_flops_per_step = self.model_flops_per_input * batch_size
        model_flops_per_iter = model_flops_per_step * steps_per_iter
        model_flops_per_sec = model_flops_per_iter / iter_time
        mfu = model_flops_per_sec / self.gpu_max_flops_per_sec
        return mfu

Hardware Info

In case it's relevant, I've included the output of nvidia-smi on my machine below.

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-80GB          On  | 00000001:00:00.0 Off |                    0 |
| N/A   32C    P0              62W / 400W |      4MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM4-80GB          On  | 00000002:00:00.0 Off |                    0 |
| N/A   31C    P0              63W / 400W |      4MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA A100-SXM4-80GB          On  | 00000003:00:00.0 Off |                    0 |
| N/A   31C    P0              66W / 400W |      4MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA A100-SXM4-80GB          On  | 00000004:00:00.0 Off |                    0 |
| N/A   30C    P0              61W / 400W |      4MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA A100-SXM4-80GB          On  | 0000000B:00:00.0 Off |                    0 |
| N/A   31C    P0              62W / 400W |      4MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA A100-SXM4-80GB          On  | 0000000C:00:00.0 Off |                    0 |
| N/A   31C    P0              63W / 400W |      4MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA A100-SXM4-80GB          On  | 0000000D:00:00.0 Off |                    0 |
| N/A   31C    P0              60W / 400W |      4MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA A100-SXM4-80GB          On  | 0000000E:00:00.0 Off |                    0 |
| N/A   32C    P0              62W / 400W |      4MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

A revisit on improving the performance of Data Loader

We have been noticing a slowdown on training that was introduced by our dataloader. Upon further checking, we identified the issue coming from the fact that our dataset class is maintaining a bunch of very large lists.

Background

Each logical shard maintains a list of (dataset_id, shard_id, doc_id) in order to track the document. e.g. ("c4", 3, 110) refers to the 110th document inside the file dataset_root_folder/dataset=c4/xxx.part3.arrrow. When we distribute billions of documents over the thousands of logical shard workers, each logical shard worker gets such a list of millions of (dataset_id, shard_id, doc_id) tuples. So in total we are maintaining hundreds of GBs worth of lists internally.

And why we did this at first place? datasets are assumed not shuffled and thus we need to shuffle our billions of (dataset_id, shard_id, doc_id), so each logical shards maintains a shuffled list that contains millions of such tuples. Such kind of list has to be materialized at certain point (even we do lazy init or something similar) in order to have our dataloader stateful - we need to know and checkpoint exactly which documents are visited and which are to be visited and in what order, so that we can recover a training flawlessly in a deterministic fashion.

Solution

If we peel the onion here completely, the question actually boils down to:
how can we maintain a list that: is truly stateful, provides random reading, and provides easy checkpointing and recovery.
This leads us to leverage LCG (Linear congruential generator) and utilize the "stateful-ness" of LCG to achieve the stateful-ness of the list.

A quick overview of the LCG we built for an arbitrary sized list:

# ZX81, cc65, Knuth and H. W. Lewis
LCG_PARAMS = [(2 ** 16 + 1, 75, 74), (2 ** 23, 65793, 4282663), (2 ** 32, 1664525, 1013904223)]

class LCG:
    def __init__(self, size, seed=42):
        self.size = size
        self.state = seed
        for params in LCG_PARAMS:
            if size <= params[0]:
                self.m, self.a, self.c = params
                break

    def _next(self):
        self.state = (self.a * self.state + self.c) % self.m
        return self.state

    def next(self):
        while True:
            res = self._next()
            if res < self.size:
                return res

and validation:

selector = LCG(1000000)
res = [selector.next() for _ in range(1000000)]
expected = list(range(1000000))
assert sorted(res) == expected

RuntimeError: CUDA driver error: an illegal memory access was encountered

@daviswer It seems calling model.reset_parameters() after FSDP call will raise the following error.

Can you take a look?

[rank8]: Traceback (most recent call last):
[rank8]:   File "/lustre/lchu/fms-fsdp/main_training.py", line 168, in <module>
[rank8]:     fire.Fire(main)
[rank8]:   File "/home/lchu/.conda/envs/latest/lib/python3.9/site-packages/fire/core.py", line 141, in Fire
[rank8]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank8]:   File "/home/lchu/.conda/envs/latest/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire
[rank8]:     component, remaining_args = _CallAndUpdateTrace(
[rank8]:   File "/home/lchu/.conda/envs/latest/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
[rank8]:     component = fn(*varargs, **kwargs)
[rank8]:   File "/lustre/lchu/fms-fsdp/main_training.py", line 130, in main
[rank8]:     model.reset_parameters()
[rank8]:   File "/lustre/t5/public/foundation-model-stack/fms/models/llama.py", line 237, in reset_parameters
[rank8]:     m.reset_parameters()
[rank8]:   File "/lustre/t5/public/foundation-model-stack/fms/modules/embedding.py", line 95, in reset_parameters
[rank8]:     nn.init.trunc_normal_(getattr(self, layer).weight, mean=0.0, std=0.02)
[rank8]:   File "/home/lchu/.conda/envs/latest/lib/python3.9/site-packages/torch/nn/init.py", line 205, in trunc_normal_
[rank8]:     return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator)
[rank8]:   File "/home/lchu/.conda/envs/latest/lib/python3.9/site-packages/torch/nn/init.py", line 47, in _no_grad_trunc_normal_
[rank8]:     tensor.erfinv_()
[rank8]: RuntimeError: CUDA driver error: an illegal memory access was encountered

Moving it before the FSDP won't trigger this error.

change package name

As discussed offline, we want to change the package name from pretraining to fms_fsdp

Faulty type handling for 'weight' kwarg

The current scheme for overriding the defaults in the training config file is to pass keyword arguments into the training script, and then updating the config with those new values. This scheme breaks down for weights due to lack of type handling: the weights field takes a string of comma-separated numbers. Passing --weights="1,2,3" to the training script, however, results in the tuple (1,2,3) rather than the required string, as python interprets this argument as a tuple, causing dataloader construction failures. We should either enforce typing on the training script args, or update the config and dataloader code expectations to match python's default arg handling.

add 1.4B config

As discussed offline, we want to add a "comparatively small" variant (1.4B) in the configs.

make fms-to-hf support for "compiled" model

Compiled FSDP model uses use_orig_params=True and work on orig parameters. Thus, the state_dict in the saved ckpt is inconsistent as the non-compiled ckpt.

Despite there are some works has been done to automatic some of these inconsistency, yet those does not work in our fms-to-hf conversion script as we hardcode the key mapping.

We should add a flag on if the ckpt is compiled checkpoint, and use load-and-off-load way to massage the state dict to make it work with our script.

Revert low_cpu_fsdp implementation

we had two implementations of low_cpu_fsdp:

  1. load full model on rank0, and meta device on all other ranks. then we use sync_module_states to broadcast the weights.
  2. use meta device on all ranks. and random init it in FSDP call. and post-load the state dict.

we switched from 1 to 2 when we switched from FSDP to 2d (TP + FSDP), as TP's parallelize_module does not like implementation 1 due to variables on different devices during tp-parallelize.

However, as we cut TP and 2d in this open source version of training, we want to revert this and use the first implementation which will make post_init easier.

A write-up on Meta Device Init x Pretraining

Scope

This write-up only applies to "initial model init". For cases that require loading a checkpoint (continue-pretraining, fine-tuning and inference), this is not needed as any init would be overwritten by ckpt. Therefore, this mainly target "pretraining from scratch".

Background on meta device init

There are two ways to leverage meta device to save cpu memory during model init:

  1. create a full model copy on rank0 while put model on meta device on all other ranks.
  2. put model on meta device on all ranks, including rank0.

The first method init a full model on rank0 and utilize sync_module_states=True during FSDP call to broadcast model from rank0 to all other ranks. This saves cpu memory from world_size total copies to only 1 copy.
The second method puts model on meta device on all ranks (including rank0), and utilize proper param_init_fn and/or post-FSDP init.
Comparing to the first method, the second one not only saves cpu memory (0 copy), but also greatly saves model init time, as this avoids initialing a full model on cpu (for large models like 70b, this could take 20 mins)

Method 2 is both more efficient and better cpu-mem saving, however, it can be very tricky to properly set up for pretraining and it might cause silent problems. Unlike continue-pretraining/fine-tuning/inference where model init isn’t important as it will be overwritten by loaded ckpt, pretraining requires proper model init which is very crucial. And model init for method 2 can be tricky no matter which stage you want to apply init:

pre-FSDP init

This isn't possible with method 2 as all ranks are using meta device before FSDP call. And this is also the reason that method 1 is much safer: you do all you want before the FSDP call as the model was still a full copy sitting on cpu. you can perform any init you need and it will be properly broadcast to other ranks during FSDP call. But again, we want method 2 and we don't want any cpu copy, so we will pass on this.

during-FSDP init

This is achieved by leveraging param_init_fn, which will be performed on "to be materialized modules". Since we need to materialize and put on device first (as full model is on meta device), such param_init_fn is typically something like:

def param_init_fn(module):
    module.to_empty(device)
    module.init  # e.g. module.reset_parameters()

here comes the tricky part where we might get silent problems. param_init_fn will be performed on all to-be-materialized-modules, which pop/deque in a top-down/parent-children fashion (reference). Although this is already a great improvement from old times when we started the work (this has a very great detailed explanations on some old issues which we also observed and had to conquer), yet current design still requires a hidden-user-agreed-contract that "param_init_fn should only initialize module's own parameters/buffers but not any of the sub-modules". Another implicit requirement is we need to have such "init" defined on all possible modules. So what would happen here if we don't follow strictly to the rules here, like what we have now in FMS?
sub-modules would be re-init multiple times. Our reset_parameters() is designed in a way that calling model.reset_parameters() would init the full model with true/desired init. Similarly, Llama_Block.reset_parameters() would init the full block. This is desired as typically we want this single line model wise init. And this works well for method 1. But imagine what would happen here if we use it as param_init_fn: recall the "to be materialized modules" will be something like [LLaMABlock, MultiHeadAttention, Linear, Linear, Linear, etc.], so children modules like "Linear" will be re-init multiple times and this can be problematic:

  1. issues discussed in the reference I shared above.
  2. more importantly: silent problems if we don't provide init all FULL coverage. Again, recall the fact that we defined our "init" on model level (llama.reset_parameters()) and "key module" levels (attn_block, embed, mlp_block, layer_norm) as that was typically sufficient, but these will be "silently" overwritten by lower level modules (e.g. Linear) because basic modules like Linear has their own implementation of reset_parameters(). so during this "re-init" on these "leaf nodes", wrong init will overwrite our true init, and this is silent!

post-FSDP init

This can be more tricky. This is less preferred than using param_init_fn so I am not going into too much details. But trying to do post-FSDP init involves manipulating model params outside forward/backward which you will run into issues like "illegal memory access" as the model is already sharded. And you could technically leverage FSDP.summon_full_params() with some "writebacks" to achieve some, but that is less-efficient and less-user-friendly than leveraging param_init_fn. So this is also not wanted.

what to do with FMS

so it seems "during-FSDP init with param_init_fn" is the way to go, but we would have to meet the contract:

  1. rewrite ALL init (reset_parameters) to be non-recursive.
  2. provide FULL coverage for init.

Is there a way to avoid doing so? and potentially re-use our existing recursive version? Well, the answer is yes, and the trick here turns out to be simple: we just need to add a "filter" to make sure param_init_fn is recursively applied to modules that are mutually exclusive but cover 100% of the params. This way, no re-init would ever happen.

    def param_init_fn(module):
        
        if (
                # provide the modules that are mutually exclusive but also cover 100% of the model params
                isinstance(module, MultiHeadAttention)
                or isinstance(module, WordEmbedding)
                or isinstance(module, GatedLinearUnit)
                or isinstance(module, LayerNormParameterized)
        ):
            module.to_empty(device=torch.cuda.current_device())
            with torch.no_grad():
                module.reset_parameters()

optimize profiler trace generation

By default, trace will be generated per each gpu. In a multi-gpu setup, this means we are adding many threads concurrently to the gpus during the steps that traces are generated. This might cause those steps being slower than expected, thus making those step numbers un-realible.

In most cases this is fine but in some scenarios the captured steps will be significantly slower than normal due to the overhead-ed threads that being added to the steps.

We should revisit the trace generation part in the code and make it generate trace only on rank0 to avoid this.

maximize mistral throughput

Instructlab backend currently focuses on mistral fine tuning and I'm trying to maximize throughput for that. If anyone notices anything obvious or has any suggestions I'd truly appreciate it. @raghukiran1224 mentioned that posting an issue here would potentially help.

I'm currently seeing a throughput of around 90 samples per second at max context length of 2600 tokens (but on average is only around 500 tokens) on 80 GPUs in prod vela. On a single node I get a throughput of around 11.2 samples per second and the best way is to do shard_op (zero stage 2) and no gradient checkpointing.

The main bottleneck is the networking, so having the largest possible batch size maximizes throughput since the networking communication bottlenecks almost at the same rate regardless of the bs. For such reason I ended up using HYBRID_SHARD_ZERO2 and enabling checkpointing to get a bs of 20 samples per gpu at 2600 max length.

These are the main parts to look at:

Model setup

Currently using HYBRID_SHARD_ZERO2 but have experimented with all the possibilities. Couldn't get torch.compile to work. And had to enable gradient checkpointing to maximize batch size.

def setup_model(model_name, tokenizer):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
    )
    if len(tokenizer) > model.config.vocab_size:
        print(
            f"WARNING: tokenizer has {len(tokenizer)} tokens but model has {model.config.vocab_size} vocab size"
        )
        model.resize_token_embeddings(
            int(8 * math.ceil(len(tokenizer) / 8.0))
        )  # make the vocab size multiple of 8 for sharding the embedding layer.

    assert model.__class__.__name__ in [
        "MistralForCausalLM"
    ], f"Model class name: {model.__class__.__name__} is not supported."

    model = FSDP(
        model,
        auto_wrap_policy=partial(
            transformer_auto_wrap_policy,
            transformer_layer_cls={
                MistralDecoderLayer,
            },
        ),
        # use_orig_params=True,
        limit_all_gathers=True,
        mixed_precision=MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.bfloat16,
            buffer_dtype=torch.bfloat16,
        ),
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        sharding_strategy=ShardingStrategy._HYBRID_SHARD_ZERO2,
        device_id=torch.cuda.current_device(),
    )
    model.gradient_checkpointing_enable()
    # model = torch.compile(model)
    return model

training loop

importantly the use_cache=False, even though it is commented out gets set to True because only the gradient checkpointing works.

        for batch in train_loader:
            start = time.time()
            for k in batch:
                batch[k] = batch[k].to(local_rank)

            output = model(
                **batch,
                # use_cache=False,
            )

            loss = output["loss"]
            loss.backward()

            if global_step % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

make selective ac more flexible.

The current design of selective ac is to checkpoint one block every k blocks:
K=1 (i.e. full ac): ac, ac, ac, ac, ...
K=2: ac, no-ac, ac, no-ac, ac, no-ac, ...
K=3: ac, no-ac, no-ac, ac, no-ac, no-ac, ...
...
K=100: this is essentially no-ac for the whole model if num_layer < 100

The granularity is finer and finer towards no-ac side, thus this favor smaller models as we can carefully tune K to achieve the best perf while pushing memory to the limit.

However, what's obviously missing here is the granularity on the other end - as we move towards full ac, the granular is pretty bad: there are enough spaces to further tune between K=1 and K=2, K=2 and K=3, etc. e.g. what if we want to do things like:
K=2/3: checkpoint 2 blocks every 3 blocks, i.e. ac, ac, no-ac, ac, ac, no-ac, ... .

such "fractional k" can be beneficial for larger models as larger models almost always need near-full-ac. so we need full-ac side finer granularity.

To achieve this, we can change K from integer to a tuple (m, n), and instead of doing:

if block_idx % k == 0: ac
else: no-ac

we do

if block_idx % n in range(0, m): ac
else: no-ac

This way, we can achieve arbitrary selective ac.

Will push a PR soon.

Clean up training configs

a list of minors around training configs that should be fixed later.

  1. this should really be > num_step rather than == num_step, as our batch_idx starts from 1 rather than 0. The current implementation would skip the required last step causing last checkpoint won't be written.
  2. we should remove cfg. here
  3. remove sharding_group_size in training configs, as we dropped the support for both SSDP and TP in the open source version.
  4. re-order training configs to a more meaning grouping.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.