Giter Site home page Giter Site logo

Comments (9)

Chillee avatar Chillee commented on August 18, 2024 1

F.scaled_dot_product_attention automatically makes a decision about what backend to dispatch to. For example, it can choose to dispatch to the FlashAttention2 kernel. Or, for example, on platforms where FlashAttention2 is not supported, it can choose to dispatch to the "math" implementation, which is attention implemented using more primitive PyTorch operators.

In the case of decoding, however, the FlashAttention algorithm is not beneficial. In fact, it's actively detrimental. So in this case, it's better to dispatch to more primitive operators, where torch.compile can codegen the kernels from scratch.

from gpt-fast.

Chillee avatar Chillee commented on August 18, 2024 1

The big issue is the work partitioning structure. FlashAttention parallelizes among heads, BS, and output_seq_len (i.e. seq_query). In this case, BS and output_seq_len is 1, so the only parallelism is among heads. An A100 GPU has 108 SMs, so it just can't utilize the entire GPU efficiently enough.

The saving for avoiding doing IO to global memory would scale to ~[2, num_head, bs, input_seq_len, output_seq_len] over the entire output length.

output_seq_len is only 1 in this case. And for the low-latency setting, bs is also 1. So your intermediate matrix is size [2, num_head, 1, input_seq_len]. That's not nothing, but it's not a large enough advantage to dwarf issue 1. I would expect FlashDecoding to perform better than Inductor's generated kernel.

from gpt-fast.

huntzhan avatar huntzhan commented on August 18, 2024 1

Inspired by your project, I've successfully applied the optimization strategy to baichuan 13b:
https://github.com/armed-gpt/gpt-blazing
Could I submit a PR to add a reference case to the README?

from gpt-fast.

huntzhan avatar huntzhan commented on August 18, 2024

Thanks a lot!

from gpt-fast.

ekagra-ranjan avatar ekagra-ranjan commented on August 18, 2024

Hi @Chillee - Why would you say that Flash attn is actively detrimental for decoding?

During prefill stage it helps to avoid the materialization of attention matrix of shape [2, num_head, bs, input_seq_len, input_seq_len] in global memory which saves time by avoiding doing IO to global memory.

During decoding phase it avoid materialization of attention matrix of shape ~[2, num_head, bs, input_seq_len] in global memory for decoding 1 token. The saving for avoiding doing IO to global memory would scale to ~[2, num_head, bs, input_seq_len, output_seq_len] over the entire output length.

These are huge matrices which take a lot of space. Flash attention leads to model taking less memory in GPU along with getting speedup by avoiding global mem IO (even without reducing FLOPS). Usually the prefill stage is FLOP bound and still the reduction in memory IO gives a good speedup. The decoding phase is usually memory bandwidth bound so reduction in memory IO should be useful?

from gpt-fast.

huntzhan avatar huntzhan commented on August 18, 2024

Hi @Chillee , looks like torch.compile cannot handle the torch.backends.cuda.sdp_kernel decorator. The function decode_n_tokens, in which the torch.backends.cuda.sdp_kernel decorator is used, is not compiled. Does that mean the aforementioned behavior is not applied?

  File "/home/wden/.local/lib/python3.8/site-packages/torch/_dynamo/exc.py", line 193, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor _GeneratorContextManager call_function <function sdp_kernel at 0x7fd9c5584a60>

from user code:
   File "/data/xxx", line 451, in model_decode_one_token
    with torch.backends.cuda.sdp_kernel(

btw, I replace torch.backends.cuda.sdp_kernel with the following statements:

            torch.backends.cuda.enable_flash_sdp(False)
            torch.backends.cuda.enable_mem_efficient_sdp(False)
            torch.backends.cuda.enable_math_sdp(True)

from gpt-fast.

Chillee avatar Chillee commented on August 18, 2024

The function decode_n_tokens, in which the torch.backends.cuda.sdp_kernel decorator is used, is not compiled. Does that mean the aforementioned behavior is not applied?

No, decode_n_tokens calls decode_token, which does have the decorator. And the annotation is still respected there.

from gpt-fast.

huntzhan avatar huntzhan commented on August 18, 2024

The function decode_n_tokens, in which the torch.backends.cuda.sdp_kernel decorator is used, is not compiled. Does that mean the aforementioned behavior is not applied?

No, decode_n_tokens calls decode_token, which does have the decorator. And the annotation is still respected there.

I see...
Since the decode_one_token is not called directly, and due to the deferred compilation, the annotation is still respected.

from gpt-fast.

huntzhan avatar huntzhan commented on August 18, 2024

Inspired by your project, I've successfully applied the optimization strategy to baichuan 13b: https://github.com/armed-gpt/gpt-blazing Could I submit a PR to add a reference case to the README?

Hi @Chillee, I've submitted a PR #48, could you make a comment about it? thanks.

from gpt-fast.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.