Comments (3)
Thanks for giving some context @fxmarty, from your explanation i have a better understanding of what caused the regression.
To me the more likely culprit is that PyTorch now picks for your FA2 instead of mem-efficient attention and somehow FA2 is slower for you. Could you try playing with the decorator https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel and report here?
Indeed like what you said, by playing around with this context manager and setting them to use the same backend (only mem-efficient_attention=True
), i was able to match their speeds. Seems that Pytorch is choosing what backend to use based on the presence of a custom mask afterall. If no custom mask is passed in, it will choose to use the faster FA2 backend. (See Below)
Causal Mask Attn Mask is passed to Torch SDPA | Causal Mask handled internally in Torch SDPA |
---|---|
It is still unclear to me why you see this regression, given that transformers==4.39 used to always use attn_mask argument (= never FA2, see here and here in 4.39)
The introduction of a sliding window here will influence what backend the SDPA kernel will use. In my setup, having a max_context_length=4096
same as the sliding_window=4096
will cause this check to set ignore_causal_mask=False
. This will then skip to the code to produce a custom attention mask here to be passed to the SDPA kernel. By running my example script and setting the context length to be smaller than the sliding window value e.g. <=4095
i will avoid the generation of a custom mask and subsequently let Pytorch SDPA use the faster FA2 backend, by doing so i am able to replicate the throughputs i saw in 4.39.
This clarifies everything, thanks alot for the help!
from transformers.
cc @fxmarty
from transformers.
Hi @achew010, thank you for the report. Two PRs may be at play here, #30127 and #30070. Long story short,
- SDPA requires
attn_mask=None
to be able to dispatch on its FA2 backend. - the implementation of attention with SDPA when using sliding window used to be incorrect, not using sliding window at all due to the mask being dropped (for the above reason). #30127 ensures the correctness of the sliding mask and does not rely on SDPA's
is_causal
argument.
As you can see here, a check is done on key_value_length < sliding_window
, that still allows to ignore the mask for some sequence lengths.
I also dont see the computation savings at sequence length=8192 from the introduction of sliding window attention compared to if there wasnt a windowed causal mask at all (calculating attention across all 8192 tokens).
This was already the case for transformers<=4.39 (and also in Mistral public code release). Unfortunately, apart from the original HazyResearch/flash-attn implementation (attn_implementation="flast_attention_2"
in Transformers, see the doc), the is no efficient implementation for eager & SDPA currently. I know Driss from PyTorch was working on this.
It is still unclear to me why you see this regression, given that transformers==4.39 used to always use attn_mask
argument (= never FA2, see here and here in 4.39). To me the more likely culprit is that PyTorch now picks for your FA2 instead of mem-efficient attention and somehow FA2 is slower for you. Could you try playing with the decorator https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel and report here?
from transformers.
Related Issues (20)
- Mixtral past_key_values and output_router_logits incompatible HOT 1
- Disable Progress Bar? HOT 1
- Meet problems when I use the file src/transformers/models/llama/convert_llama_weights_to_hf.py to transfer LlaMa-7B HOT 2
- [DOCS] - Model outputs of RecurrentGemmaCausalLM doesn't align with the documentation HOT 1
- [Batched Whisper] ValueError on input mel features HOT 3
- use_reentrant=False can't be set properly HOT 6
- Bug: InformerModel, decoder_input torch.cat size of tensor mismatch error otherwise HOT 7
- BitsNBytes 4 bit quantization error message typo and logical errors in error message handling HOT 3
- train_new_from_iterator does not properly modify the tokenizer's postprocessor's ids when using a Sequence postprocessor
- recent version of Transformers seems to mess with forward/__call__. Breaks patching loss function HOT 3
- TypeError: 'list' object is not callable || Resume from checkpoint HOT 3
- Failed to import transformers.models.vit.feature_extraction_vit because of the following error (look up to see its traceback): No module named 'ml_dtypes._custom_floats' HOT 1
- TokenClassificationPipeline support is_split_into_words tokeniser parameter HOT 2
- Implement kv cache sparsity like H2O with attention score HOT 1
- BART generate with min_new_tokens exceeds maximum length HOT 4
- Convert Helsinki-NLP model to huggingface
- Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained HOT 3
- Grounding DINO missing custom kernels HOT 2
- For multiple GPUs: torch.cuda.empty_cache() stuck forever
- Issues occuring during parallel evaluation (using Trainer.evaluate)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from transformers.