alpa-projects / alpa Goto Github PK
View Code? Open in Web Editor NEWTraining and serving large-scale neural networks with auto parallelization.
Home Page: https://alpa.ai
License: Apache License 2.0
Training and serving large-scale neural networks with auto parallelization.
Home Page: https://alpa.ai
License: Apache License 2.0
We want to schedule the order of send/recv when a stage wants to send tensors to multiple stages.
Principle:
Need to think twice
#162 does not improve the compilation time a lot. For example, to fully profile an 8-layer BERT model on a 1x8 mesh, the time is reduced from 190s -> 130s. This is because the largest part of compilation time comes from the codegen from XLA to CUDA and right now this still happens in the worker and is not parallelized. Currently, the whole profiling process is like the following:
Input: stages to profile
Output: A computation cost measured for each stage.
There are several potential ways to parallelize this:
def profile_all(stages):
remote_workers = [RemoteWorker() for _ in range(len(stages))]
compile_jobs = []
for stage, worker in zip(stages, remote_workers):
compile_job = worker.compile.remote(stage)
compile_jobs.append(compile_job)
ray.get(compile_jobs)
for remote_worker in remote_workers:
ray.get(remote_worker.profile.remote())
@ray.remote
class RemoteWorker:
def compile(stage):
self.compiled = stage.compile()
def profile():
return profile(self.compiled)
Solution
Walkaround
In our current design of 3D parallelism, there are possibilities that:
on each mesh, the model might be sharded differently.
The communication between stage0 and stage1 is not necessarily based on device-device send/recv, because it might involve communications from multiple devices to multiple devices.
We should make some abstraction in the pipeline runtime to handle this communication pattern
Unfortunately, this array-indexing operation has an ~3ms overhead. See below:
And it will accumulate in each resharding task.
We have at least two ways to optimize it
1 can be easily done. 2 needs some experiments.
Solution:
@flax.remat
(checkpoint at transformer boundary) https://github.com/parax-project/parax/blob/d37ac0464ddd5a9159f1542d52d3274b024f2997/parax/model/bert_model.py#L315@manual
(checkpoint at marker boundary)
Conventional graph clustering algorithms (e.g. spectral clustering or k-means) only considers edge weights and doesn't take vertex weights into account. In other words, they are used to find "minimum cuts" for a graph. In our case, we need to take both the vertex weights (FLOPs used by each operator, computation cost) and the edge weights (tensors to be sent between stages, communication cost) into consideration.
To take vertex weights into consideration, one potential algorithmic choice is Balanced Partitioning. In simple words, we need to partition the vertices of a given graph into k
almost equal clusters while minimizing the number of edges that are cut by the partition.
One practical algorithm proposed for this problem (blog: https://ai.googleblog.com/2018/03/balanced-partitioning-and-hierarchical.html, paper: https://dl.acm.org/doi/pdf/10.1145/2835776.2835829) on large graphs works as following:
k
almost equally sized groups. Each group consists of nodes from a consecutive sub-sequence.To fit this algorithm to our pipeline stage discovery setting, the linearization step (step 1) can be simply skipped: We can just use the sequential Jaxpr order. If we would like to work on XLA in the future, we can use the execution order that we used for the memory constraints. We can consider the swapping step (step 2) later. The main difference between our case and the original algorithm is that our graph is a directed DAG. Therefore when we do swapping, we should make sure that the sequence is always a valid execution sequence.
We can directly follow the partition step (step 3) in the original algorithm, which is implemented by the following dynamic programming algorithm. Let A[l, r, q]
denote the smallest cut size, using exactly q
partitions, achievable
for the subgraph induced by vertices from position l->r
in the sequence.
To begin with, we set A[l, r, 1] = 0
if and only if the total node weights from l->r
is smaller than (1 + ฮต) * total_node_weights / k
, otherwise we set A[l, r, 1] = +โ
. Then the recursive formula for computing A[l, r, q]
can be presented as:
,A[l, r, q] = min { A[l, k, 1] + A[k + 1, r, q - 1] + C(l, k, r) | l <= k < r }
A[l, r, q] = min { A[l, k, q - 1] + A[k + 1, r, 1] + C(l, k, r) | l <= k < r }
(with this new formula we can remove replicated input tensors in the computation of C
)
where C(l, k, r)
is the total edge costs from l->k
to k+1->r
.
---- Another way to formulate the DP problem ----
Let A[r, q]
denote the smallest cut size using q
partitions achievable for the subgraph induced by vertices from position 1->r
in the sequence. We set another table B[l, r] = 0
if the total node weights from l->r
is smaller than (1 + ฮต) * total_node_weights
, otherwise we set B[l, r] = 0
. We first set A[0, 0] = 0
and A[r, 0] = +โ
for r > 0
. Then the recursive formula for computing A[r, q]
can be presented as:
A[r, q] = min { A[k, q-1] + B[k + 1, r] + C(1, k, r) | 0 <= k < r}
where C(1, k, r)
is the total edge costs from 1->k
to k+1->r
. Specially, we set C(1, 0, r) = 0
for all r
.
I currently tend to implement this in Jax, mainly because we can only run the process for the forward pass and can automatically get the pipeline stages for the backward pass. If implementing in XLA, the only con right now is on the backward pass. The main pro of implementing in XLA is that we can get a more accurate FLOPs estimation.
The initial implementation plan for pipeline stage discovery:
This is a TODO item after the paper deadline.
Currently, our framework only works when the program follows a "forward, backward, and apply" pattern. We should also support model parallel inference since otherwise, the trained model will not be able to used in practice
pmap
on multi nodes.google/jax#2731
google/jax#5667
google/jax#3004
tensorflow/tensorflow#48210
https://github.com/tensorflow/tensorflow/blob/eee783f657a2678768bf4272212615311f8be1b0/tensorflow/compiler/xla/python/xla_client.py#L69
tensorflow/tensorflow@8a72c44
Class DistributedArray:
def __init__(self):
self.remote_buffers : List[RemoteBuffer]
self.sharing_specification
def __del__(self):
delete(self.host_id)
class RemoteBuffer:
def __init__(self):
self.host_id
self.local_buffers # CUDA pointers
for i in range(n_epoch):
distributed_array = device_mesh.execute(distributed_array)
dump(distributed_array)
Call python code in c++
Example: https://github.com/parax-project/tensorflow-parax/blob/b13493fd4429e9f6f4373c5d2e8503a3ad6d020f/tensorflow/compiler/xla/service/gpu/auto_sharding.cc#L1382-L1417
Pybind11 doc: https://pybind11.readthedocs.io/en/stable/advanced/embedding.html#
Where to get the hlo module? Right after the auto-sharding pass
https://github.com/parax-project/tensorflow-parax/blob/b13493fd4429e9f6f4373c5d2e8503a3ad6d020f/tensorflow/compiler/xla/service/gpu/auto_sharding.cc#L1646-L1648
Export c++ funtion to python
Add your function here https://github.com/parax-project/tensorflow-parax/blob/b13493fd4429e9f6f4373c5d2e8503a3ad6d020f/tensorflow/compiler/xla/python/xla_compiler.cc#L178
A possible implementation
In auto_sharding.cc
(the last line)
// insert before the statement "return true;"
forward_module, backward_module = partition_hlo_module(module);
PyGILState_STATE gstate = PyGILState_Ensure();
{
py::object submodule = py::module_::import("parax.auto_sharding");
py::object set_forward_backward_module = submodule.attr("set_forward_backward_module");
set_forward_backward_module(forward_module, backward_module);
}
PyGILState_Release(gstate);
in parax/auto_sharding.py
received_forward_module = None
received_backward_module = None
def set_forward_backward_module(forward_module, backward_module):
global received_forward_module, received_backward_module
received_forward_module = forward_module
received_backward_module = backward_module
def split_and_compile(forward_hlo, backward_hlo):
global received_forward_module, received_backward_module
merged_hlo = merge(forward_hlo, backward_hlo)
compile_with_auto_sharidng(merged_hlo) # this will set received_forward_module and received_backward_module
forward_binary = compile_without_auto_sharding(received_forward_module)
backward_binary = compile_without_auto_sharding(received_backward_module)
return forward_binary, backward_binary
Move this folder into parax to reduce our dependence on Ray. Otherwise, we have to use nightly built ray and our experiments might be blocked by bugs in upstrean ray.
Implement a data loader for the resnet benchmark
class Dataloader:
def __init__(self,
file_path: Union[str, tensorflow.dataloader],
physical_mesh: parax.device_mesh.PhysicalMesh,
sharding_specs: jax.interpreter.pxla.ShardingSpec,
pre_process_func: Callable):
pass
def __iter__(self) -> Iteratable[Dict[str -> DistributedArray]]:
yield batch
Implement a data loader for the BERT benchmark
parax's benchmark
https://github.com/parax-project/parax/blob/master/benchmark/parax/benchmark_2d_one_case_bert.py
huggingface's script
https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/run_mlm_flax.py
Test our interface on ResNet50 and imagenet.
The major challenge is how to deal with data loader/prefetch.
Current gradient accumulation can introduce extra all-reduce overhead when combined with data parallel.
Currently, we use Set
in several places, which makes the output jaxpr of pipeline slicing non-deterministic.
One approach to fix this is replacing all Set
with OrderedSet
https://github.com/parax-project/parax/blob/8d7d9dd6436aea38780b9e84f3396c656fb96a0b/parax/util.py#L134
This is a TODO after the deadline.
June
July
August
Hao's algorithm in #59 is composed of two stages, or a small loop in a big loop, where:
However, the algorithm in the big loop is slightly misaligned with the goal of the project: For normal GPU clusters, we can always find a greedy way to cover the full mesh: for example, for 8-GPU nodes, we can have 4 1x2 meshes in a node and repeat this pattern for each node. In addition, for most neural networks, the cost of communication in pipeline parallelism is smaller than the cost of communication in data parallelism. So the greedy solution should perform well in most cases.
Also, this algorithm also doesn't consider that we can have multiple different types of sub-mesh shapes. This should be important for networks with non-uniform shapes (e.g. ResNet).
However, finding the optimal way to slice a 2D mesh with minimized pipeline execution latency is a very hard problem and we can't find a polynomial DP algorithm (poly wrt total mesh size) that can directly solve this. Our current proposal is to have some constraints on both the cluster mesh shape and the cluster mesh shapes. Specifically, we have:
n x 2^m
.1 x 1
, 1 x 2
, 1 x 4
, 1 x 8
, ... 1 x 2^m
, 2 x 2^m
, 3 x 2^m
, ... n x 2^m
.Then we can utilize a 1D DP algorithm to get the solution. More specifically, we transform the original problem into the following one: Assume we have in total n*2^m
devices, find an optimal way to assign layers to device groups, where each device group can have 1
, 2
, 4
, 8
, ..., 2^m
, 2*2^m
, ..., n*2^m
devices. This can be solved by defining the DP state as DP[i][k]
that represents the optimal cost of putting layers 0 - (i-1)
on k
devices. Then we can derive
DP[i][k] = min_(j <= i, s < = k, s is a feasible device group size) {DP[i - j][k - s] + Computation cost of putting layers j to i-1 on s devices + communication cost}
.
Because of our specific selection of sub-mesh shapes, we can guarantee that we can map the 1D solution back to the 2D mesh.
Cons of this method:
n x m
meshes and make sure the size of small sub-meshes is a factor of m
.The issue of the above algorithm is that the total runtime of a pipeline is determined by the following formula:
pipeline time = sum of all stages's time + (#microbatches - 1) * maximal stage time
Some other points:
Enumerate all possible maximum stage time:
for i in range(n_layers):
for j in range(n_devices):
for k in range(i):
for s in possible submeshes:
if compute cost + communication cost < maximum:
f[i][j] = min(f[i][j], f[k][j - s] + compute cost from k to i on mesh s + communication cost for layer k on mesh s)
Cost for this stage time = f[n_layers][n_devices] + (B - 1) * current maximum stage time
How to get communication cost:
Some other issues:
see the profile below.
- Prepare input: 0.79 s
- Create train state: 3.67 s
WARNING:parax.pipeline_parallel.apply_grad:Cannot donate lj (shape: ())
/home/ubuntu/project/pycharm/parax/parax/shard_parallel/auto_sharding.py:732: UserWarning: Detect unexpected behaviors in the auto-sharding pass.
warn("Detect unexpected behaviors in the auto-sharding pass.")
>>>>> profile generate_sharded_xla_computations: 50.908265590667725
/home/ubuntu/project/pycharm/parax/parax/shard_parallel/auto_sharding.py:732: UserWarning: Detect unexpected behaviors in the auto-sharding pass.
warn("Detect unexpected behaviors in the auto-sharding pass.")
>>>>> profile generate_sharded_xla_computations: 51.498419761657715
2021-11-10 19:08:05,074 WARNING worker.py:1239 -- It looks like you're creating a detached actor in an anonymous namespace. In order to access this actor in the future, you will need to explicitly connect to this namespace with ray.init(namespace="e3d4b6c7-22f8-4f1e-b2fc-cf40026d1550", ...)
- Compile (driver): 135.90 s
- Compile (worker): 27.27 s
It seems that it should be an RDA.
If we send it every iter, it will block the default stream
Define the following terms :
to be received = TBR
tensor: the tensor of interest in cross mesh resharding
tile: an arbitrary sublock (regular nd arrays) of that tensor
When any of the following 2 condition meets,
(1) a TBR tile has shape == TBR tensor (holds for many homogenous model cases, e.g. GPT)
(2) the underlying memory of the TBR tile (e.g. pointer-pointed memory) is a continuous subset of the TBR tensor, e.g.., the tile contains just a few columns of the original tensor meanwhile the tensor is col-majored.
In these cases, we do not have to allocate a tmp buffer for nccl receive. Instead, we can in-place receive-write to the corresponding (continuous) memory.
This optimization will have the following effects as shown in the following visual contrast:
unfortunately, we cannot do better if the above two conditions do not meet.
Besides @parallelize
decorator, in current parax, we have the following decorators that need to be stated by the user:
jax.grad
by parax.grad
.@manual_layer_construction
or @automatic_layer_construction
.Code example:
@parallelize
def f(...):
@automatic_layer_construction
def loss_func():
# network body
...
grad = parax.grad(loss_func)
# apply grad
...
return new_weights
I'm thinking about the removing these two extra decorators. By removing these two decorators, a user only need to add the @parallelize
decorator to any jax training function to make it model parallel. This provides a much nicer API:
@parallelize
def f(...):
# basically now f can be any normal jax training step function
def loss_func():
# network body
...
loss = loss_func()
grad = jax.grad(loss)
# apply grad
...
return new_weights
The exact modification is as the following:
jax.grad
to implement the function of parax.grad
. With the monkey patch, jax.grad
will behave like parax.grad
when it is being traced in a @parallelize
decorator. We will set some environment variables in the @parallelize
decorator.
def new_grad(*args):
if in_parallelize:
return parax.grad(*args)
else:
return jax.grad(*args)
jax.grad = new_grad
@layer_constuction
decorators, we can move them into the new jax.grad
:
def new_grad(f, *args):
if in_parallelize:
if auto_layer:
f = auto_layer_construction(f)
elif manual_layer:
f = manual_layer_construction(f)
return parax.grad(f, *args)
else:
return jax.grad(f, *args)
jax.grad = new_grad
Some potential downsides of this modification:
@parallelize
with other decorators. For example, the following code will not work after the modification:
@parallelize
@jax.jit
def f(...):
...
jax.grad
.How do you feel about this @merrymercy @zhisbug?
Currently, we only registered auto-sharding strategies for matmul (https://github.com/parax-project/tensorflow-parax/blob/526050a32587b42bb12ed3f4dd62a06f8733a208/tensorflow/compiler/xla/service/gpu/auto_sharding.cc#L1149).
We want to register similar strategies for convolution layers.
If we only do parallelization in the batch, input_channel, and output_channel dimension and ignore width and height dimensions, the convolution layer is almost the same as matmul. We can slightly modify the strategies of matmul and port them to convolution.
Then, we can try some large models like wide-resnet and 3d-unet.
Currently, our pipeline has two concepts to represent a part of the computation: layer and stage. A layer is a slice of the computation created by layer-slicing, while a stage is a sequence of consecutive layers.
However, we also have another use of the word 'stage', that is the class PipelineStage
and all its subclasses. The class can represent any slice of the computation, in form of Jaxpr
or XlaComputation
, so both the layer and stage in logic use this class to contain there data. This leads to an ambiguity.
To resolve the problem, we should introduce a new name for the class, or a new name for the sequence of layers in logic. The replacement of word 'stage' is supposed to start after our current PRs are merged.
What's the new name in your mind? @zhuohan123
The current gradient accumulation only works for jnp.mean loss because we always use mean reduction.
For other losses or auxiliary states, we should support other reduction types such as sum reduction and concatenation reduction.
Implement a static 1F1B pipeline schedule
Know | Do now know | |
---|---|---|
Weight reorganization overhead | Zero | High |
Programming Style | Declarative | Imperative |
If we know the parallelization strategy before weight initialization, we can put the weights into the right places at the first time. Otherwise, we have to move tensors after we choose the parallelization strategy.
However, if we want to know the parallelization strategy before weight initialization, we have to see the whole training graph before initialization. This is not an intuitive and imperative programming style.
Know | Do now know | |
---|---|---|
Difficulty to implement some specific strategies | Easy | Hard |
Generalization | Low | High |
If we know which tensors gradients are, we can easily implement strategies like data-parallel and insert correct allreduce.
If we do not make the assumption that they are gradients, but just treat them as valina intermediate results, we have to do a more general analysis.
See the contrasting figures below:
# Hao: if the following line cannot print, meaning NCCL hangs...
# logger.debug(
# ">>> Recv from: rank {}, gpu_idx {}, shape: {}, dtype: {}, sample value: {}."
# .format(src_rank, src_gpu_idx, to_recv.shape, to_recv.dtype,
# to_recv[0]))
recv_tensor = to_jax_tensor(to_recv)
# 0-copy version
start_indices = tuple(
ind_in_dst.start for ind_in_dst in indices_in_dst_tile)
new_buffer = jax_buffer_set(
xla_buffer_to_jax_buffer(self.buffers[uuid]), recv_tensor,
start_indices)
self.buffers[uuid] = jax_buffer_to_xla_buffer(new_buffer)
if commenting the above lines:
It causes a ~0.005s performance difference, but I need to figure out which line
A recent PR #119 fixes the memory leak caused by too aggressive invar donation, and as a result, now we only donate two kinds of tensors:
donate_argnums
, we exactly donate the optimizer;To explain what we are not donated(but is donatable), we introduce a concept of main copy
. The main copy of a tensor locates on the mesh where it is defined. In contrast, the tensor copies sent to other meshes are template copies
. Main copy is only for intermediate and output vars, but not for input vars, because the definitions of inputs depend only on dataloader.
This definition is natural from our current runtime: when we need a tensor, we get it by a resharding task from the mesh that defines the tensor.
Hence, typically there are two kinds of donatable buffers:
We do not donate the above types for one consideration: If a buffer is donated, there is a corresponding input-output alias in XlaComputation
. The alias influences two aspects:
As we merge forward, backward and apply gradient together in an auto-sharding pass, we cannot set up an alias for an intermediate because it is not an input;
In auto-sharding, if there is an alias between input_i and output_i, they will share the same ShardingSpec. Hence, if we set too many donations, the auto-sharding will be restricted. A rough experiment about how much the auto-sharding's performance is influenced is listed later.
Our experiment reveals two facts:
The time for Jax3DPipeline
to do garbage collection is significant.
Without any change in runnable, but only make the runtime skip the garbage collection process of the two kinds of donatable above, we accelerate at least 10~15%(from 0.69s to 0.60s, other cases observe more improvements).
On the other hand, less-constrained auto-sharding is important for performance:
Although not design a precise experiment, we have the observation by: switch the donation of input optimizer, but donate it to gradients instead of corresponding output optimizer.
If we switch off the donation, that is we suffer from our runtime's garbage collection, the cost is 0.69s. But if we switch on the donation, that is to skip our garbage collection but add restrictions to auto-sharding, the time is slow down to 0.85s.
the donation of intermediate is critical in order to reduce the garbage collection's time cost, but it should not limit auto-sharding. In conclusion, we donate the two kinds of buffers discussed above but do this only after the auto-sharding pass. At that point, we only donate those with same shape and same sharding spec. This misses some opportunities of donation, but the overall performance is expected to be better.
There are two problems that remain:
For pipeline parallel, we have several source code files:
pipe.py
, pipeline_parallel.py
, pipeline_primitive_def.py
, pipeline_stage.py
, and three_d_parallel.py
.
The names and organizations are confusing and not intuitive for me. We should organize them better and remove deprecated code.
We can also introduce a separate folder for all pipeline-related code.
For very large model, profiling is costly but simply counting FLOPs and communication is precise. Hence, we should create many fast paths to reduce the profiling cost:
I'm taking a note of todos after deadline:
strategy
class. Strategy should be able to express pure send/recv, or scatter-allgather, or future possible strategiesTo support gradient accumulation, we need to allocate extra buffers where we can accumulate gradients, in addition to buffers that hold gradients generated on-the-fly.
We currently re-allocate these buffers at the start of each training iteration. We are thinking of only allocating them once at the start of the 1st iteration, and reusing these buffers across iterations because re-allocation seems to cause some overheads, see below evidence:
The way to achieving this is to replace ALLOC instruction using memset.
Some related concerns is to properly synchronize the memset and allocate so to avoid race conditions.
Currently we skip all-reduce if it is for gradient accumulation and rewritten (call them grad-acc all-reduce). However, after that, such an all-reduce can be merged with all-reduce not for grad-acc. Skip the merged one results in incorrect outputs, we should identify grad-acc all-reduce and only allow them to merge with grad-acc all-reduce.
A reproducible is:
class SkipAllReduceTest(PipelineBasicTest):
def test_2_layer_bert(self):
self.run_n_layer_bert(n_layers=2,
batch_size=4,
seq_len=4,
hidden_size=4,
num_heads=1,
pipeline_stage_mode="manual_gpipe",
forward_stage_layer_ids=[[0,], [1]],
overwrite_global_config_dict=dict(
sub_physical_mesh_shapes=[(1, 2)] * 2,
sub_logical_mesh_shapes=[(1, 2), (2, 1)],
submesh_autosharding_global_configs=[dict(force_batch_dim_to_mesh_dim=0)] * 2,
allow_all_gather=True,
use_scatter_gather=False
))
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.