Giter Site home page Giter Site logo

Comments (3)

achew010 avatar achew010 commented on May 22, 2024 1

Thanks for giving some context @fxmarty, from your explanation i have a better understanding of what caused the regression.

To me the more likely culprit is that PyTorch now picks for your FA2 instead of mem-efficient attention and somehow FA2 is slower for you. Could you try playing with the decorator https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel and report here?

Indeed like what you said, by playing around with this context manager and setting them to use the same backend (only mem-efficient_attention=True), i was able to match their speeds. Seems that Pytorch is choosing what backend to use based on the presence of a custom mask afterall. If no custom mask is passed in, it will choose to use the faster FA2 backend. (See Below)

Causal Mask Attn Mask is passed to Torch SDPA Causal Mask handled internally in Torch SDPA

It is still unclear to me why you see this regression, given that transformers==4.39 used to always use attn_mask argument (= never FA2, see here and here in 4.39)

The introduction of a sliding window here will influence what backend the SDPA kernel will use. In my setup, having a max_context_length=4096 same as the sliding_window=4096 will cause this check to set ignore_causal_mask=False. This will then skip to the code to produce a custom attention mask here to be passed to the SDPA kernel. By running my example script and setting the context length to be smaller than the sliding window value e.g. <=4095 i will avoid the generation of a custom mask and subsequently let Pytorch SDPA use the faster FA2 backend, by doing so i am able to replicate the throughputs i saw in 4.39.

This clarifies everything, thanks alot for the help!

from transformers.

amyeroberts avatar amyeroberts commented on May 22, 2024

cc @fxmarty

from transformers.

fxmarty avatar fxmarty commented on May 22, 2024

Hi @achew010, thank you for the report. Two PRs may be at play here, #30127 and #30070. Long story short,

  • SDPA requires attn_mask=None to be able to dispatch on its FA2 backend.
  • the implementation of attention with SDPA when using sliding window used to be incorrect, not using sliding window at all due to the mask being dropped (for the above reason). #30127 ensures the correctness of the sliding mask and does not rely on SDPA's is_causal argument.

As you can see here, a check is done on key_value_length < sliding_window, that still allows to ignore the mask for some sequence lengths.

I also dont see the computation savings at sequence length=8192 from the introduction of sliding window attention compared to if there wasnt a windowed causal mask at all (calculating attention across all 8192 tokens).

This was already the case for transformers<=4.39 (and also in Mistral public code release). Unfortunately, apart from the original HazyResearch/flash-attn implementation (attn_implementation="flast_attention_2" in Transformers, see the doc), the is no efficient implementation for eager & SDPA currently. I know Driss from PyTorch was working on this.

It is still unclear to me why you see this regression, given that transformers==4.39 used to always use attn_mask argument (= never FA2, see here and here in 4.39). To me the more likely culprit is that PyTorch now picks for your FA2 instead of mem-efficient attention and somehow FA2 is slower for you. Could you try playing with the decorator https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel and report here?

from transformers.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.