Giter Site home page Giter Site logo

Comments (4)

Orion-Zheng avatar Orion-Zheng commented on June 11, 2024

Hi๐Ÿ˜ƒThis error is because since the transformers v4.39, the arguments seq_len is removed from LlamaRotaryEmbedding.forward(). But the code for ColossalLlama was written even further back (I guess it was around v4.34). At that time, the Flash Attention technique, which significantly speeds up attention and reduces memory consumption, had just come out and hadn't been integrated into LlamaAttention. That's why we need a flash_attn_patch to enable this feature back then. This patch is based on a function signature from an older version of Transformers.
image
But for now, the Flash Attention has already be integrated to Huggingface Llama Implementation(see classes LlamaFlashAttention2 and LlamaSdpaAttention). So I think you can just set use_flash_attn to False and Llama Model will automatically use the flash attention feature now. I believe later this patch will be removed.

from colossalai.

shawnricecake avatar shawnricecake commented on June 11, 2024

when I change transformer into 4.38.0, it shows

  File "/home/user1/workspace/colossal-ai/ColossalAI/examples/language/llama2/attn.py", line 133, in attention_forward
    cos, sin = self.rotary_emb(v, seq_len=kv_len)
  File "/home/user1/anaconda3/envs/colossalai/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: LlamaRotaryEmbedding.forward() missing 1 required positional argument: 'position_ids'

So, which version of transformer should I use with flash attention?

from colossalai.

Orion-Zheng avatar Orion-Zheng commented on June 11, 2024

Transformers v4.37 is OK. But just as I said, you can use v4.39 and still enjoy the speedup from flash_attn by setting use_flash_attn to False. Because flash attention has been integrated to transformers library without needing our patch.

from colossalai.

shawnricecake avatar shawnricecake commented on June 11, 2024

Transformers v4.37 is OK. But just as I said, you can use v4.39 and still enjoy the speedup from flash_attn by setting use_flash_attn to False. Because flash attention has been integrated to transformers library without needing our patch.

Hi, looks like if I set the use_flash_attn to Flase, the GPU memory will increase.

and here is my env:

Package                   Version
------------------------- -----------
absl-py                   2.1.0
aiohttp                   3.9.3
aiosignal                 1.3.1
annotated-types           0.6.0
async-timeout             4.0.3
attrs                     23.2.0
bcrypt                    4.1.2
beautifulsoup4            4.12.3
cachetools                5.3.3
certifi                   2024.2.2
cffi                      1.16.0
cfgv                      3.4.0
charset-normalizer        3.3.2
click                     8.1.7
cmake                     3.29.0.1
colossalai                0.3.6
contexttimer              0.3.3
cryptography              42.0.5
datasets                  2.18.0
decorator                 5.1.1
Deprecated                1.2.14
dill                      0.3.8
distlib                   0.3.8
dropout-layer-norm        0.1
einops                    0.7.0
fabric                    3.2.2
filelock                  3.13.3
flash-attn                2.2.1
frozenlist                1.4.1
fsspec                    2024.2.0
fused-dense-lib           0.0.0
google                    3.0.0
google-auth               2.29.0
google-auth-oauthlib      1.0.0
grpcio                    1.62.1
huggingface-hub           0.22.2
identify                  2.5.35
idna                      3.6
invoke                    2.2.0
Jinja2                    3.1.3
jsonschema                4.21.1
jsonschema-specifications 2023.12.1
lit                       18.1.2
Markdown                  3.6
markdown-it-py            3.0.0
MarkupSafe                2.1.5
mdurl                     0.1.2
mpmath                    1.3.0
msgpack                   1.0.8
multidict                 6.0.5
multiprocess              0.70.16
networkx                  3.3
ninja                     1.11.1.1
nodeenv                   1.8.0
numpy                     1.26.4
nvidia-cublas-cu11        11.10.3.66
nvidia-cublas-cu12        12.1.3.1
nvidia-cuda-cupti-cu11    11.7.101
nvidia-cuda-cupti-cu12    12.1.105
nvidia-cuda-nvrtc-cu11    11.7.99
nvidia-cuda-nvrtc-cu12    12.1.105
nvidia-cuda-runtime-cu11  11.7.99
nvidia-cuda-runtime-cu12  12.1.105
nvidia-cudnn-cu11         8.5.0.96
nvidia-cudnn-cu12         8.9.2.26
nvidia-cufft-cu11         10.9.0.58
nvidia-cufft-cu12         11.0.2.54
nvidia-curand-cu11        10.2.10.91
nvidia-curand-cu12        10.3.2.106
nvidia-cusolver-cu11      11.4.0.1
nvidia-cusolver-cu12      11.4.5.107
nvidia-cusparse-cu11      11.7.4.91
nvidia-cusparse-cu12      12.1.0.106
nvidia-nccl-cu11          2.14.3
nvidia-nccl-cu12          2.19.3
nvidia-nvjitlink-cu12     12.4.127
nvidia-nvtx-cu11          11.7.91
nvidia-nvtx-cu12          12.1.105
oauthlib                  3.2.2
packaging                 24.0
pandas                    2.2.1
paramiko                  3.4.0
pip                       23.3.1
platformdirs              4.2.0
pre-commit                3.7.0
protobuf                  5.26.1
psutil                    5.9.8
pyarrow                   15.0.2
pyarrow-hotfix            0.6
pyasn1                    0.6.0
pyasn1_modules            0.4.0
pycparser                 2.22
pydantic                  2.6.4
pydantic_core             2.16.3
Pygments                  2.17.2
PyNaCl                    1.5.0
python-dateutil           2.9.0.post0
pytz                      2024.1
PyYAML                    6.0.1
ray                       2.10.0
referencing               0.34.0
regex                     2023.12.25
requests                  2.31.0
requests-oauthlib         2.0.0
rich                      13.7.1
rotary-emb                0.1
rpds-py                   0.18.0
rsa                       4.9
safetensors               0.4.2
sentencepiece             0.1.99
setuptools                68.2.2
six                       1.16.0
soupsieve                 2.5
sympy                     1.12
tensorboard               2.14.0
tensorboard-data-server   0.7.2
tokenizers                0.13.3
torch                     2.0.0
tqdm                      4.66.2
transformers              4.33.3
triton                    2.0.0
typing_extensions         4.11.0
tzdata                    2024.1
urllib3                   2.2.1
virtualenv                20.25.1
Werkzeug                  3.0.2
wheel                     0.41.2
wrapt                     1.16.0
xentropy-cuda-lib         0.1
xxhash                    3.4.1
yarl                      1.9.4

from colossalai.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.