Giter Site home page Giter Site logo

alpa-projects / alpa Goto Github PK

View Code? Open in Web Editor NEW
3.0K 3.0K 341.0 7.28 MB

Training and serving large-scale neural networks with auto parallelization.

Home Page: https://alpa.ai

License: Apache License 2.0

Python 95.01% Shell 0.51% Jupyter Notebook 3.86% Dockerfile 0.46% Starlark 0.15%
alpa auto-parallelization compiler deep-learning distributed-computing distributed-training high-performance-computing jax llm machine-learning

alpa's People

Contributors

babychousr avatar blair-johnson avatar comaniac avatar crazyboycjr avatar ddxxdd-code avatar dumpmemory avatar eltociear avatar frankxyy avatar jiahaoyao avatar jiaodong avatar jubilantjerry avatar koyamasohei avatar merrymercy avatar pkuflyingpig avatar reinaw1012 avatar sammeralomair avatar suquark avatar tarzanzhao avatar vatshank avatar vinlnx avatar wgimperial avatar woosukkwon avatar yf225 avatar ymwangg avatar zhanyuanucb avatar zhisbug avatar zhuohan123 avatar zsc avatar zw123han avatar zyhowell avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

alpa's Issues

Send / Recv Order

We want to schedule the order of send/recv when a stage wants to send tensors to multiple stages.

Principle:

  • Sender: Send to earlier stages first
  • Receiver: Receive earlier stages first

Need to think twice

Further optimization for distributed compilation

#162 does not improve the compilation time a lot. For example, to fully profile an 8-layer BERT model on a 1x8 mesh, the time is reduced from 190s -> 130s. This is because the largest part of compilation time comes from the codegen from XLA to CUDA and right now this still happens in the worker and is not parallelized. Currently, the whole profiling process is like the following:

Input: stages to profile
Output: A computation cost measured for each stage.

  1. (Parallelized) Compile each stage to the fully optimized version, but still in HLO (e.g. finish the autosharding pass).
  2. (Single-threaded) Translate HLO to CUDA binary for each stage (the CUDA binary in XLA cannot be easily serialized right now).
  3. (Single-threaded) Profile the CUDA binary on the physical mesh. (This step cannot be parallelized because we might only have one physical mesh).

There are several potential ways to parallelize this:

  1. Make the codegen multi-threaded. (This should be XLA team's job instead of ours)
  2. Find a way to serialize the compiled result, so we can distribute the codegen process of different layers to profile to multiple processes. This is also hard because there does not exist an obvious workaround for this.
  3. Perform a dirty hack to get around with the serilization of compiled result as in the following:
    def profile_all(stages):
        remote_workers = [RemoteWorker() for _ in range(len(stages))]
        compile_jobs = []
        for stage, worker in zip(stages, remote_workers):
            compile_job = worker.compile.remote(stage)
            compile_jobs.append(compile_job)
        ray.get(compile_jobs)
        for remote_worker in remote_workers:
            ray.get(remote_worker.profile.remote())
            
    @ray.remote
    class RemoteWorker:
        def compile(stage):
            self.compiled = stage.compile()
    
        def profile():
            return profile(self.compiled)
    For this method, we need to hack the resource management of Ray to let all the remote workers have access to all GPU devices. In this way, we can parallelize the codegen process and make sure the compilation result is runnable on the mesh. We make sure the profiling process is accurate by keeping remote workers running one by one on the mesh.

Send/recv for tied embedding

Solution

  1. Patch current send/recv design
  2. Use replicated distributed array and all-reduce

Walkaround

  1. Disable tied embedding

Make a more general between-stage communication class/abstraction

In our current design of 3D parallelism, there are possibilities that:

  • stage 0 working with a mesh with 1 node 2 GPUs
  • stage 1 working with a mesh with 2 node 2 GPUs

on each mesh, the model might be sharded differently.

The communication between stage0 and stage1 is not necessarily based on device-device send/recv, because it might involve communications from multiple devices to multiple devices.

We should make some abstraction in the pipeline runtime to handle this communication pattern

Pipeline Layer Discovery Proposal

Conventional graph clustering algorithms (e.g. spectral clustering or k-means) only considers edge weights and doesn't take vertex weights into account. In other words, they are used to find "minimum cuts" for a graph. In our case, we need to take both the vertex weights (FLOPs used by each operator, computation cost) and the edge weights (tensors to be sent between stages, communication cost) into consideration.

To take vertex weights into consideration, one potential algorithmic choice is Balanced Partitioning. In simple words, we need to partition the vertices of a given graph into k almost equal clusters while minimizing the number of edges that are cut by the partition.

One practical algorithm proposed for this problem (blog: https://ai.googleblog.com/2018/03/balanced-partitioning-and-hierarchical.html, paper: https://dl.acm.org/doi/pdf/10.1145/2835776.2835829) on large graphs works as following:

  1. Linearize all graph nodes into a sequence.
  2. Perform local swaps to fine-tune the sequence.
  3. Partition the nodes into k almost equally sized groups. Each group consists of nodes from a consecutive sub-sequence.

balanced-partitioning

To fit this algorithm to our pipeline stage discovery setting, the linearization step (step 1) can be simply skipped: We can just use the sequential Jaxpr order. If we would like to work on XLA in the future, we can use the execution order that we used for the memory constraints. We can consider the swapping step (step 2) later. The main difference between our case and the original algorithm is that our graph is a directed DAG. Therefore when we do swapping, we should make sure that the sequence is always a valid execution sequence.

We can directly follow the partition step (step 3) in the original algorithm, which is implemented by the following dynamic programming algorithm. Let A[l, r, q] denote the smallest cut size, using exactly q partitions, achievable
for the subgraph induced by vertices from position l->r in the sequence.

To begin with, we set A[l, r, 1] = 0 if and only if the total node weights from l->r is smaller than (1 + ฮต) * total_node_weights / k, otherwise we set A[l, r, 1] = +โˆž. Then the recursive formula for computing A[l, r, q] can be presented as:
A[l, r, q] = min { A[l, k, 1] + A[k + 1, r, q - 1] + C(l, k, r) | l <= k < r },
A[l, r, q] = min { A[l, k, q - 1] + A[k + 1, r, 1] + C(l, k, r) | l <= k < r } (with this new formula we can remove replicated input tensors in the computation of C)
where C(l, k, r) is the total edge costs from l->k to k+1->r .

---- Another way to formulate the DP problem ----

Let A[r, q] denote the smallest cut size using q partitions achievable for the subgraph induced by vertices from position 1->r in the sequence. We set another table B[l, r] = 0 if the total node weights from l->r is smaller than (1 + ฮต) * total_node_weights, otherwise we set B[l, r] = 0. We first set A[0, 0] = 0 and A[r, 0] = +โˆž for r > 0. Then the recursive formula for computing A[r, q] can be presented as:
A[r, q] = min { A[k, q-1] + B[k + 1, r] + C(1, k, r) | 0 <= k < r}
where C(1, k, r) is the total edge costs from 1->k to k+1->r. Specially, we set C(1, 0, r) = 0 for all r.

I currently tend to implement this in Jax, mainly because we can only run the process for the forward pass and can automatically get the pipeline stages for the backward pass. If implementing in XLA, the only con right now is on the backward pass. The main pro of implementing in XLA is that we can get a more accurate FLOPs estimation.

The initial implementation plan for pipeline stage discovery:

  • Start working from Jax. Implement a simple FLOPs counter and edge cost analyzer in Jax.
  • Implement the DP algorithm in step 3. Test it with the original Jaxpr order without step 2.
  • See whether step 2 can help the performance.

Support model-parallel inference

This is a TODO item after the paper deadline.
Currently, our framework only works when the program follows a "forward, backward, and apply" pattern. We should also support model parallel inference since otherwise, the trained model will not be able to used in practice

DistributedArray

    Class DistributedArray:
        def __init__(self):
            self.remote_buffers : List[RemoteBuffer]
            self.sharing_specification
        
        def __del__(self):
            delete(self.host_id)
    
    class RemoteBuffer:
        def __init__(self):
            self.host_id
            self.local_buffers  # CUDA pointers
    
    for i in range(n_epoch):
        distributed_array = device_mesh.execute(distributed_array)
        dump(distributed_array)

Split and Merge Hlo

image

Tips

  // insert before the statement "return true;"
  forward_module, backward_module = partition_hlo_module(module);

  PyGILState_STATE gstate = PyGILState_Ensure();
  {
    py::object submodule = py::module_::import("parax.auto_sharding");
    py::object set_forward_backward_module = submodule.attr("set_forward_backward_module");
    set_forward_backward_module(forward_module, backward_module);
  }
  PyGILState_Release(gstate);

in parax/auto_sharding.py

received_forward_module = None
received_backward_module = None
def set_forward_backward_module(forward_module, backward_module):
    global received_forward_module, received_backward_module
    received_forward_module = forward_module
    received_backward_module = backward_module

def split_and_compile(forward_hlo, backward_hlo):
    global received_forward_module, received_backward_module
    merged_hlo = merge(forward_hlo, backward_hlo)
    compile_with_auto_sharidng(merged_hlo)  # this will set received_forward_module and received_backward_module
    forward_binary = compile_without_auto_sharding(received_forward_module)
    backward_binary = compile_without_auto_sharding(received_backward_module)
    return forward_binary, backward_binary

[FEATURE] Distributed Data Loader

Task 1:

Implement a data loader for the resnet benchmark

class Dataloader:
    def __init__(self,
                 file_path: Union[str, tensorflow.dataloader],
                 physical_mesh: parax.device_mesh.PhysicalMesh,
                 sharding_specs: jax.interpreter.pxla.ShardingSpec,
                 pre_process_func: Callable):
       pass

   def __iter__(self) -> Iteratable[Dict[str -> DistributedArray]]:
       yield batch

Task 2:

Implement a data loader for the BERT benchmark
parax's benchmark
https://github.com/parax-project/parax/blob/master/benchmark/parax/benchmark_2d_one_case_bert.py
huggingface's script
https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/run_mlm_flax.py

Test parallelize decorator for imagenet examples

Goal

Test our interface on ResNet50 and imagenet.
The major challenge is how to deal with data loader/prefetch.

Todo

  • Migrate the flax/imagenet example into our examples folder
  • Use our auto-parallel interface to match the original pmap-based data-parallel.

Gradient Accumulation

Current gradient accumulation can introduce extra all-reduce overhead when combined with data parallel.

Project Roadmap

  • June

    • Benchmark baselines
      • FlexFlow (failed)
      • Megatron-LM
      • ZeRO (deepspeed)
    • Match the performance of baselines with SPDM parallelism (auto-sharding) on BERT.
    • Implement profiling and search API for saving and loading the profiling results and parallelization strategy. (TODO: add a doc)
    • End-to-end 3d parallelism benchmark
  • July

    • Pipeline stage partitioning, placement, and scheduling algorithms.
    • Support ZeRO Optimizer
    • Integrate with memory optimizations (swapping, rematerialization)
  • August

    • Large-scale experiments on several models
    • First draft of the paper

Pipeline Stage Composition Algorithm

Hao's algorithm in #59 is composed of two stages, or a small loop in a big loop, where:

  1. The big loop: We first iterate enumerate different sub-mesh shapes. Then for each shape (e.g. a 1x2 mesh), we try to cover the full mesh with the sub-meshes with this specific shape. For each shape, we will find one specific covering solution.
  2. The small loop: For each covering solution, we find a way to assign different pipeline layers to different pipeline stages that minimizes the communication and computation costs.

However, the algorithm in the big loop is slightly misaligned with the goal of the project: For normal GPU clusters, we can always find a greedy way to cover the full mesh: for example, for 8-GPU nodes, we can have 4 1x2 meshes in a node and repeat this pattern for each node. In addition, for most neural networks, the cost of communication in pipeline parallelism is smaller than the cost of communication in data parallelism. So the greedy solution should perform well in most cases.

Also, this algorithm also doesn't consider that we can have multiple different types of sub-mesh shapes. This should be important for networks with non-uniform shapes (e.g. ResNet).


However, finding the optimal way to slice a 2D mesh with minimized pipeline execution latency is a very hard problem and we can't find a polynomial DP algorithm (poly wrt total mesh size) that can directly solve this. Our current proposal is to have some constraints on both the cluster mesh shape and the cluster mesh shapes. Specifically, we have:

  1. For cluster mesh shape, we assume it's of the shape n x 2^m.
  2. The possible sub-mesh shapes are 1 x 1, 1 x 2, 1 x 4, 1 x 8, ... 1 x 2^m, 2 x 2^m, 3 x 2^m, ... n x 2^m.

Then we can utilize a 1D DP algorithm to get the solution. More specifically, we transform the original problem into the following one: Assume we have in total n*2^m devices, find an optimal way to assign layers to device groups, where each device group can have 1, 2, 4, 8, ..., 2^m, 2*2^m, ..., n*2^m devices. This can be solved by defining the DP state as DP[i][k] that represents the optimal cost of putting layers 0 - (i-1) on k devices. Then we can derive

DP[i][k] = min_(j <= i, s < = k, s is a feasible device group size) {DP[i - j][k - s] + Computation cost of putting layers j to i-1 on s devices + communication cost}.

Because of our specific selection of sub-mesh shapes, we can guarantee that we can map the 1D solution back to the 2D mesh.

Cons of this method:

  1. Constraints of the cluster mesh shape. We might be able to loose this constraint by generalizing to n x m meshes and make sure the size of small sub-meshes is a factor of m.
  2. We still assume that only consecutive layers can be put on a sub-mesh. This doesn't cover the case in the updated megatron-lm paper.
  3. We assume the cost of communication in pipeline parallelism is smaller than the cost of communication in data parallelism, which might not be true for all networks.

The issue of the above algorithm is that the total runtime of a pipeline is determined by the following formula:

pipeline time = sum of all stages's time + (#microbatches - 1) * maximal stage time

Some other points:

  • Each stage can receive inputs from multiple previous stages.
  • Communication can be overlapped with computation.
Enumerate all possible maximum stage time:
  for i in range(n_layers):
    for j in range(n_devices):
      for k in range(i):
        for s in possible submeshes:
          if compute cost + communication cost < maximum:
            f[i][j] = min(f[i][j], f[k][j - s] + compute cost from k to i on mesh s + communication cost for layer k on mesh s)
  Cost for this stage time = f[n_layers][n_devices] + (B - 1) * current maximum stage time

How to get communication cost:

  1. Only count the receiver's receiving ability. Assume sender has infinite bandwidth.
  2. Use a greedy solution: directly use the mesh shape that optimizes for each DP subproblem.

Some other issues:

  1. Right now to profile computation cost for a 10 layer bert on a single node with 4 GPUs takes 10 mins without any parallelism.

Parallelize `generate_sharded_xla_computations`

see the profile below.

 - Prepare input: 0.79 s
 - Create train state: 3.67 s
WARNING:parax.pipeline_parallel.apply_grad:Cannot donate lj (shape: ())
/home/ubuntu/project/pycharm/parax/parax/shard_parallel/auto_sharding.py:732: UserWarning: Detect unexpected behaviors in the auto-sharding pass.
  warn("Detect unexpected behaviors in the auto-sharding pass.")
>>>>> profile generate_sharded_xla_computations: 50.908265590667725 
/home/ubuntu/project/pycharm/parax/parax/shard_parallel/auto_sharding.py:732: UserWarning: Detect unexpected behaviors in the auto-sharding pass.
  warn("Detect unexpected behaviors in the auto-sharding pass.")
>>>>> profile generate_sharded_xla_computations: 51.498419761657715 
2021-11-10 19:08:05,074	WARNING worker.py:1239 -- It looks like you're creating a detached actor in an anonymous namespace. In order to access this actor in the future, you will need to explicitly connect to this namespace with ray.init(namespace="e3d4b6c7-22f8-4f1e-b2fc-cf40026d1550", ...)
 - Compile (driver): 135.90 s
 - Compile (worker): 27.27 s

Profile stats

Add some if-conditions in recv_tile to expedite special cross-mesh resharding case

Define the following terms :

to be received = TBR
tensor: the tensor of interest in cross mesh resharding
tile: an arbitrary sublock (regular nd arrays) of that tensor

When any of the following 2 condition meets,
(1) a TBR tile has shape == TBR tensor (holds for many homogenous model cases, e.g. GPT)
(2) the underlying memory of the TBR tile (e.g. pointer-pointed memory) is a continuous subset of the TBR tensor, e.g.., the tile contains just a few columns of the original tensor meanwhile the tensor is col-majored.

In these cases, we do not have to allocate a tmp buffer for nccl receive. Instead, we can in-place receive-write to the corresponding (continuous) memory.

This optimization will have the following effects as shown in the following visual contrast:

Before
image

after:
image

unfortunately, we cannot do better if the above two conditions do not meet.

[API] Remove `parax.grad` and `@layer_construction` markers and replace them with a monkey patch of `jax.grad`

Besides @parallelize decorator, in current parax, we have the following decorators that need to be stated by the user:

  1. Replace jax.grad by parax.grad.
    2.@manual_layer_construction or @automatic_layer_construction.

Code example:

@parallelize
def f(...):
  @automatic_layer_construction
  def loss_func():
    # network body
    ...
  grad = parax.grad(loss_func)
  # apply grad
  ...
  return new_weights

I'm thinking about the removing these two extra decorators. By removing these two decorators, a user only need to add the @parallelize decorator to any jax training function to make it model parallel. This provides a much nicer API:

@parallelize
def f(...):
  # basically now f can be any normal jax training step function
  def loss_func():
    # network body
    ...
  loss = loss_func()
  grad = jax.grad(loss)
  # apply grad
  ...
  return new_weights

The exact modification is as the following:

  1. Use a monkey patch of jax.grad to implement the function of parax.grad. With the monkey patch, jax.grad will behave like parax.grad when it is being traced in a @parallelize decorator. We will set some environment variables in the @parallelize decorator.
    def new_grad(*args):
      if in_parallelize:
        return parax.grad(*args)
      else:
        return jax.grad(*args)
    
    jax.grad = new_grad
  2. For @layer_constuction decorators, we can move them into the new jax.grad:
    def new_grad(f, *args):
      if in_parallelize:
        if auto_layer:
          f = auto_layer_construction(f)
        elif manual_layer:
          f = manual_layer_construction(f)
        return parax.grad(f, *args)
      else:
        return jax.grad(f, *args)
    
    jax.grad = new_grad

Some potential downsides of this modification:

  1. We will not be able to stack @parallelize with other decorators. For example, the following code will not work after the modification:
    @parallelize
    @jax.jit
    def f(...):
      ...
  2. We need to carefully deal with nested jax.grad.
  3. What if the input workload is a pipelined inference workload?

How do you feel about this @merrymercy @zhisbug?

Support convolution layers in auto-sharding

Currently, we only registered auto-sharding strategies for matmul (https://github.com/parax-project/tensorflow-parax/blob/526050a32587b42bb12ed3f4dd62a06f8733a208/tensorflow/compiler/xla/service/gpu/auto_sharding.cc#L1149).

We want to register similar strategies for convolution layers.

If we only do parallelization in the batch, input_channel, and output_channel dimension and ignore width and height dimensions, the convolution layer is almost the same as matmul. We can slightly modify the strategies of matmul and port them to convolution.

Then, we can try some large models like wide-resnet and 3d-unet.

Resolve the ambiguity of stage

Currently, our pipeline has two concepts to represent a part of the computation: layer and stage. A layer is a slice of the computation created by layer-slicing, while a stage is a sequence of consecutive layers.

However, we also have another use of the word 'stage', that is the class PipelineStage and all its subclasses. The class can represent any slice of the computation, in form of Jaxpr or XlaComputation, so both the layer and stage in logic use this class to contain there data. This leads to an ambiguity.

To resolve the problem, we should introduce a new name for the class, or a new name for the sequence of layers in logic. The replacement of word 'stage' is supposed to start after our current PRs are merged.

What's the new name in your mind? @zhuohan123

Frontend API design

  • Should we know the parallelization strategy before weight initialization?
Know Do now know
Weight reorganization overhead Zero High
Programming Style Declarative Imperative

If we know the parallelization strategy before weight initialization, we can put the weights into the right places at the first time. Otherwise, we have to move tensors after we choose the parallelization strategy.
However, if we want to know the parallelization strategy before weight initialization, we have to see the whole training graph before initialization. This is not an intuitive and imperative programming style.

  • Should we know which tensor gradients/weights are?
Know Do now know
Difficulty to implement some specific strategies Easy Hard
Generalization Low High

If we know which tensors gradients are, we can easily implement strategies like data-parallel and insert correct allreduce.
If we do not make the assumption that they are gradients, but just treat them as valina intermediate results, we have to do a more general analysis.

  • Should we see the training loop?
    If we can see the training loop, then we know the weight tensors are updated iteratively, we can make their update in place and preserves their layouts. However, if we want to see the training loop, we have to compile all complicated logic like data loading, preprocessing, model checkpointing, and logging into our IR, which is impossible.

There is one D2D copy to be eliminated

See the contrasting figures below:

# Hao: if the following line cannot print, meaning NCCL hangs...
# logger.debug(
#     ">>> Recv from: rank {}, gpu_idx {}, shape: {}, dtype: {}, sample value: {}."
#     .format(src_rank, src_gpu_idx, to_recv.shape, to_recv.dtype,
#             to_recv[0]))
recv_tensor = to_jax_tensor(to_recv)

# 0-copy version
start_indices = tuple(
    ind_in_dst.start for ind_in_dst in indices_in_dst_tile)
new_buffer = jax_buffer_set(
    xla_buffer_to_jax_buffer(self.buffers[uuid]), recv_tensor,
    start_indices)
self.buffers[uuid] = jax_buffer_to_xla_buffer(new_buffer)

if keeping the above lines:
image

if commenting the above lines:
image

It causes a ~0.005s performance difference, but I need to figure out which line

Donate Intermediate Proposal

A recent PR #119 fixes the memory leak caused by too aggressive invar donation, and as a result, now we only donate two kinds of tensors:

  1. Invars that are allowed to be donated. Typically, if a user sets "auto" as donate_argnums, we exactly donate the optimizer;
  2. Gradient buffers. When #microbatch > 1, gradient accumulation is required, and we create corresponding inputs (accumulated gradients from the last microbatch) and outputs (accumulated gradients after this microbatch).

To explain what we are not donated(but is donatable), we introduce a concept of main copy. The main copy of a tensor locates on the mesh where it is defined. In contrast, the tensor copies sent to other meshes are template copies. Main copy is only for intermediate and output vars, but not for input vars, because the definitions of inputs depend only on dataloader.
This definition is natural from our current runtime: when we need a tensor, we get it by a resharding task from the mesh that defines the tensor.

Hence, typically there are two kinds of donatable buffers:

  1. A template copy of intermediates. It can be donated at any time;
  2. A main copy of intermediate no longer used by other stages, in the same mesh or not.

We do not donate the above types for one consideration: If a buffer is donated, there is a corresponding input-output alias in XlaComputation. The alias influences two aspects:

  1. As we merge forward, backward and apply gradient together in an auto-sharding pass, we cannot set up an alias for an intermediate because it is not an input;

  2. In auto-sharding, if there is an alias between input_i and output_i, they will share the same ShardingSpec. Hence, if we set too many donations, the auto-sharding will be restricted. A rough experiment about how much the auto-sharding's performance is influenced is listed later.

Our experiment reveals two facts:

  1. The time for Jax3DPipeline to do garbage collection is significant.
    Without any change in runnable, but only make the runtime skip the garbage collection process of the two kinds of donatable above, we accelerate at least 10~15%(from 0.69s to 0.60s, other cases observe more improvements).

  2. On the other hand, less-constrained auto-sharding is important for performance:
    Although not design a precise experiment, we have the observation by: switch the donation of input optimizer, but donate it to gradients instead of corresponding output optimizer.
    If we switch off the donation, that is we suffer from our runtime's garbage collection, the cost is 0.69s. But if we switch on the donation, that is to skip our garbage collection but add restrictions to auto-sharding, the time is slow down to 0.85s.

the donation of intermediate is critical in order to reduce the garbage collection's time cost, but it should not limit auto-sharding. In conclusion, we donate the two kinds of buffers discussed above but do this only after the auto-sharding pass. At that point, we only donate those with same shape and same sharding spec. This misses some opportunities of donation, but the overall performance is expected to be better.

There are two problems that remain:

  1. When the user, for some purpose, manually set to donate other inputs e.g. batch, the alias is created and will influence the auto-sharding heavily;
  2. The cost of our runtime to release intermediates is formidabale, why we do not get likewise results in our 2D parallel or other systems?

Reorganize pipeline code

For pipeline parallel, we have several source code files:
pipe.py, pipeline_parallel.py, pipeline_primitive_def.py, pipeline_stage.py, and three_d_parallel.py.

The names and organizations are confusing and not intuitive for me. We should organize them better and remove deprecated code.
We can also introduce a separate folder for all pipeline-related code.

@zhisbug @zhuohan123

Fast path of profiling for DP

For very large model, profiling is costly but simply counting FLOPs and communication is precise. Hence, we should create many fast paths to reduce the profiling cost:

  • Add an option to use cost model instead of profiling after auto-sharding. The cost model should consider both collective communication and computation cost.
  • If intermediates of other microbatches are too large, do not even compile auto-sharding strategy.

improve cross mesh resharding code

I'm taking a note of todos after deadline:

  1. move eager cross-mesh resharding to a separate class, probably contribute that runtime to Ray
  2. Make a new strategy class. Strategy should be able to express pure send/recv, or scatter-allgather, or future possible strategies
  3. There are many workarounds for chasing paper deadlines. Addressing them...

Reuse memory buffers across iterations

To support gradient accumulation, we need to allocate extra buffers where we can accumulate gradients, in addition to buffers that hold gradients generated on-the-fly.

We currently re-allocate these buffers at the start of each training iteration. We are thinking of only allocating them once at the start of the 1st iteration, and reusing these buffers across iterations because re-allocation seems to cause some overheads, see below evidence:

image

Zoom-in version:
image

The way to achieving this is to replace ALLOC instruction using memset.

Some related concerns is to properly synchronize the memset and allocate so to avoid race conditions.

[BUG] All-reduce incorrectly skipped

Currently we skip all-reduce if it is for gradient accumulation and rewritten (call them grad-acc all-reduce). However, after that, such an all-reduce can be merged with all-reduce not for grad-acc. Skip the merged one results in incorrect outputs, we should identify grad-acc all-reduce and only allow them to merge with grad-acc all-reduce.

A reproducible is:

class SkipAllReduceTest(PipelineBasicTest):

    def test_2_layer_bert(self):
        self.run_n_layer_bert(n_layers=2,
                             batch_size=4,
                             seq_len=4,
                             hidden_size=4,
                             num_heads=1,
                             pipeline_stage_mode="manual_gpipe",
                             forward_stage_layer_ids=[[0,], [1]],
                             overwrite_global_config_dict=dict(
                                sub_physical_mesh_shapes=[(1, 2)] * 2,
                                sub_logical_mesh_shapes=[(1, 2), (2, 1)],
                                submesh_autosharding_global_configs=[dict(force_batch_dim_to_mesh_dim=0)] * 2,
                                allow_all_gather=True,
                                use_scatter_gather=False
                             ))

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.