Comments (4)
Hi๐This error is because since the transformers v4.39, the arguments seq_len
is removed from LlamaRotaryEmbedding.forward()
. But the code for ColossalLlama was written even further back (I guess it was around v4.34). At that time, the Flash Attention technique, which significantly speeds up attention and reduces memory consumption, had just come out and hadn't been integrated into LlamaAttention
. That's why we need a flash_attn_patch
to enable this feature back then. This patch is based on a function signature from an older version of Transformers.
But for now, the Flash Attention has already be integrated to Huggingface Llama Implementation(see classes LlamaFlashAttention2
and LlamaSdpaAttention
). So I think you can just set use_flash_attn
to False and Llama Model will automatically use the flash attention feature now. I believe later this patch will be removed.
from colossalai.
when I change transformer into 4.38.0, it shows
File "/home/user1/workspace/colossal-ai/ColossalAI/examples/language/llama2/attn.py", line 133, in attention_forward
cos, sin = self.rotary_emb(v, seq_len=kv_len)
File "/home/user1/anaconda3/envs/colossalai/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
TypeError: LlamaRotaryEmbedding.forward() missing 1 required positional argument: 'position_ids'
So, which version of transformer should I use with flash attention?
from colossalai.
Transformers v4.37 is OK. But just as I said, you can use v4.39 and still enjoy the speedup from flash_attn by setting use_flash_attn
to False
. Because flash attention has been integrated to transformers library without needing our patch.
from colossalai.
Transformers v4.37 is OK. But just as I said, you can use v4.39 and still enjoy the speedup from flash_attn by setting
use_flash_attn
toFalse
. Because flash attention has been integrated to transformers library without needing our patch.
Hi, looks like if I set the use_flash_attn
to Flase
, the GPU memory will increase.
and here is my env
:
Package Version
------------------------- -----------
absl-py 2.1.0
aiohttp 3.9.3
aiosignal 1.3.1
annotated-types 0.6.0
async-timeout 4.0.3
attrs 23.2.0
bcrypt 4.1.2
beautifulsoup4 4.12.3
cachetools 5.3.3
certifi 2024.2.2
cffi 1.16.0
cfgv 3.4.0
charset-normalizer 3.3.2
click 8.1.7
cmake 3.29.0.1
colossalai 0.3.6
contexttimer 0.3.3
cryptography 42.0.5
datasets 2.18.0
decorator 5.1.1
Deprecated 1.2.14
dill 0.3.8
distlib 0.3.8
dropout-layer-norm 0.1
einops 0.7.0
fabric 3.2.2
filelock 3.13.3
flash-attn 2.2.1
frozenlist 1.4.1
fsspec 2024.2.0
fused-dense-lib 0.0.0
google 3.0.0
google-auth 2.29.0
google-auth-oauthlib 1.0.0
grpcio 1.62.1
huggingface-hub 0.22.2
identify 2.5.35
idna 3.6
invoke 2.2.0
Jinja2 3.1.3
jsonschema 4.21.1
jsonschema-specifications 2023.12.1
lit 18.1.2
Markdown 3.6
markdown-it-py 3.0.0
MarkupSafe 2.1.5
mdurl 0.1.2
mpmath 1.3.0
msgpack 1.0.8
multidict 6.0.5
multiprocess 0.70.16
networkx 3.3
ninja 1.11.1.1
nodeenv 1.8.0
numpy 1.26.4
nvidia-cublas-cu11 11.10.3.66
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu11 11.7.101
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu11 11.7.99
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu11 11.7.99
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu11 8.5.0.96
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu11 10.9.0.58
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu11 10.2.10.91
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu11 11.4.0.1
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu11 11.7.4.91
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu11 2.14.3
nvidia-nccl-cu12 2.19.3
nvidia-nvjitlink-cu12 12.4.127
nvidia-nvtx-cu11 11.7.91
nvidia-nvtx-cu12 12.1.105
oauthlib 3.2.2
packaging 24.0
pandas 2.2.1
paramiko 3.4.0
pip 23.3.1
platformdirs 4.2.0
pre-commit 3.7.0
protobuf 5.26.1
psutil 5.9.8
pyarrow 15.0.2
pyarrow-hotfix 0.6
pyasn1 0.6.0
pyasn1_modules 0.4.0
pycparser 2.22
pydantic 2.6.4
pydantic_core 2.16.3
Pygments 2.17.2
PyNaCl 1.5.0
python-dateutil 2.9.0.post0
pytz 2024.1
PyYAML 6.0.1
ray 2.10.0
referencing 0.34.0
regex 2023.12.25
requests 2.31.0
requests-oauthlib 2.0.0
rich 13.7.1
rotary-emb 0.1
rpds-py 0.18.0
rsa 4.9
safetensors 0.4.2
sentencepiece 0.1.99
setuptools 68.2.2
six 1.16.0
soupsieve 2.5
sympy 1.12
tensorboard 2.14.0
tensorboard-data-server 0.7.2
tokenizers 0.13.3
torch 2.0.0
tqdm 4.66.2
transformers 4.33.3
triton 2.0.0
typing_extensions 4.11.0
tzdata 2024.1
urllib3 2.2.1
virtualenv 20.25.1
Werkzeug 3.0.2
wheel 0.41.2
wrapt 1.16.0
xentropy-cuda-lib 0.1
xxhash 3.4.1
yarl 1.9.4
from colossalai.
Related Issues (20)
- [BUG], please delete this item.
- [FEATURE]: cuda 12 support HOT 2
- [BUG]: ValueError: mutable default <class 'colossalai.legacy.tensor.distspec._DistSpec'> for field dist_attr is not allowed: use default_factory HOT 1
- [BUG]: AttributeError: type object 'ColoParameter' has no attribute 'from_torch_tensor' when run hybrid_parallel example HOT 3
- [FEATURE]: Support qwen2 model
- [BUG]: OOM when saving 70B model HOT 2
- [DOC]: What is the datasetset used to train the Colossal-Llama-2? HOT 1
- [BUG]: Running ColossalAI in H800 with torch 2.0 HOT 28
- [BUG]: pretraing llama2 using "gemini" plugin, can not resume from saved checkpoints HOT 1
- [BUG] [Shardformer]: Error in blip2 testing with half precision HOT 1
- [FEATURE]: support multiple (partial) backward passes for zero
- [BUG]: re-join str type error_msgs using `\n\t` in general_checkpoint_io
- how to wrapped multiple models with booster HOT 3
- [BUG]: ColossalMoE Train: AssertionError: Parameters are expected to have the same dtype `torch.bfloat16`, but got `torch.float32` HOT 1
- [PROPOSAL]: Fix potential github action smells
- Does colossalai support rocm? HOT 1
- [BUG]: Slack link is invalid HOT 1
- [BUG]: GROK-1 does not support do_sample
- [BUG]: TypeError: _gen_python_code() got an unexpected keyword argument 'verbose' HOT 2
- [BUG]: llama2 hybrid_parallel or 3d giving None loss when using pp_size > 1 HOT 6
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. ๐๐๐
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google โค๏ธ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from colossalai.