Comments (6)
Hi, @feifeibear . Thank you for sharing the idea! In our opinion, this is basically a trade-off between memory cost and communication cost.
The current design of 3D Linear layer applies an all-gather on the input matrix A and a reduce-scatter on the output matrix C in the forward pass (all-gather on the gradients of C and reduce-scatter on the gradients of A in the backward pass), so that each activation can be 1/N of the original size.
An alternative design is to use an all-reduce on C in the forward pass as well as on the gradients of A in the backward pass, but the activations are 1/N^(2/3) of the original size.
Considering activation checkpointing, as the forward pass is recomputed, the first design applies 2 * all-gather + reduce-scatter on A and 2 * reduce-scatter + all-gather on C in total, while the second design applies 3 * all-reduce. Since all-reduce of the ring algorithm has similar cost to all-gather + reduce-scatter, the total communication costs of both designs seem to be similar.
However, we indeed concern that small tensors decrease the bandwidth utilization, and it is hard to fuse them up. To find the optimal performance, we are testing as much models and networking environments as possible, and let the results tell.
from colossalai.
I agree 3D parallel can shrink the peak activation footprint in one GPU at cost of more communication. The method definitely works in some special cases. Maybe a simple searching method can be derived to figure out which part of the DNN is suitable for 3D parallelism in the constraint of a limited memory budget.
from colossalai.
I agree 3D parallel can shrink the peak activation footprint in one GPU at cost of more communication. The method definitely works in some special cases. Maybe a simple searching method can be derived to figure out which part of the DNN is suitable for 3D parallelism in the constraint of a limited memory budget.
This can be a good idea. For example, self-attention blocks usually consume more than mlp (ffn) blocks.
from colossalai.
This issue is stale because it has been open for 14 days with no activity.
from colossalai.
@1SAA communication profiling results may support some of my assumption iin discussion.
from colossalai.
We have updated a lot. This issue was closed due to inactivity. Thanks.
from colossalai.
Related Issues (20)
- [FEATURE]: Support qwen2 model
- [BUG]: OOM when saving 70B model HOT 2
- [DOC]: What is the datasetset used to train the Colossal-Llama-2? HOT 1
- [BUG]: Running ColossalAI in H800 with torch 2.0 HOT 28
- [BUG]: pretraing llama2 using "gemini" plugin, can not resume from saved checkpoints HOT 1
- [BUG] [Shardformer]: Error in blip2 testing with half precision HOT 1
- [FEATURE]: support multiple (partial) backward passes for zero
- [BUG]: re-join str type error_msgs using `\n\t` in general_checkpoint_io
- how to wrapped multiple models with booster HOT 3
- [BUG]: ColossalMoE Train: AssertionError: Parameters are expected to have the same dtype `torch.bfloat16`, but got `torch.float32` HOT 1
- [PROPOSAL]: Fix potential github action smells
- Does colossalai support rocm?
- [BUG]: Slack link is invalid HOT 1
- [BUG]: GROK-1 does not support do_sample
- [BUG]: TypeError: _gen_python_code() got an unexpected keyword argument 'verbose' HOT 2
- [BUG]: llama2 hybrid_parallel or 3d giving None loss when using pp_size > 1 HOT 6
- [DOC]: torch-version HOT 1
- [BUG]: fine train llama-2-7b-hf prepare data set error , `bos_token` and `eos_token` should be the same with `conversation_template.seps`. HOT 2
- [BUG]: No module named 'dropout_layer_norm'
- [BUG]: TypeError: LlamaInferenceForwards.llama_causal_lm_forward() got an unexpected keyword argument 'shard_config'
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from colossalai.