Giter Site home page Giter Site logo

Comments (8)

tridao avatar tridao commented on August 15, 2024 3

We compared attention time (softmax(QK^T)V) vs scan time, without the linear projection.
The dimensions are different, e.g. in a 1.3B model Transformers would typically have Q, K, V of hidden dimensions 2048, while the input to the scan would be of dimension 4096 (due to expand=2). So we measured attention with Q, K, V of dim 2048 and scan with input of dimension 4096 to be fair / favorable to attention.

from mamba.

albertfgu avatar albertfgu commented on August 15, 2024 1

We decided to leave those linear projections out because they are orthogonal to the main "sequence mixing mechanism" (attention vs scan) that is of interest to benchmark. You're right that the comparisons become slightly harder to control (e.g. what model dimension to use is fair?), but we chose a setting that seemed reasonable to us. No matter what, the timings will only be off by a small constant factor with any other "reasonable" setting of dimensions, which is dwarfed by the linear vs quadratic complexity.

from mamba.

xiayuqing0622 avatar xiayuqing0622 commented on August 15, 2024 1

Q, K, V are bf16 for attention. u, delta, B, C, z are bf16, A and D are fp32 for scan.

I write a simple script to compare these two component(scan and flashattn2 with causal), and tested it on A100. As instructed, input dim of scan is 4096 and input dim of flashattn is 2048( 32heads * 64 head dim). however, scan is much slower than flashattention2. (fwd: scan is 0.25ms, and flash2 is 0.14ms, fwd+bwd: scan is 1.25ms, flash2 is 0.59ms) Did I make any settings wrong?

import torch
import time

test_bwd=False
batch, length, dim, d_state =1, 2048, 2048, 16
from mamba_ssm.ops.selective_scan_interface import SelectiveScanFn
u = torch.randn(batch, dim * 2, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
delta = torch.randn(batch, dim * 2, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
A = torch.randn(dim*2, d_state).to("cuda").requires_grad_(True)
B = torch.randn(batch, 1, d_state, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
C = torch.randn(batch, 1, d_state, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
D = torch.randn(dim*2).to("cuda").requires_grad_(True)
z = torch.randn(batch, dim*2, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
delta_bias = torch.randn(dim*2).to("cuda").requires_grad_(True)
doutssm = torch.randn(batch, dim*2, length).to("cuda").to(torch.bfloat16)
ssm = SelectiveScanFn.apply

for i in range(10):
    y = ssm(u, delta, A, B, C, D, z, delta_bias, True)
    if test_bwd:
        y.backward(doutssm)
torch.cuda.synchronize()
start = time.time()
for i in range(1000):
    y = ssm(u, delta, A, B, C, D, z, delta_bias, True)
    if test_bwd:
        y.backward(doutssm)
torch.cuda.synchronize()
print(time.time() - start)

from flash_attn import flash_attn_func

dim_head = 64
n_heads = dim//dim_head
q = torch.randn(batch, length, n_heads, dim_head).to("cuda").to(torch.bfloat16).requires_grad_(True)
k = torch.randn(batch, length, n_heads, dim_head).to("cuda").to(torch.bfloat16).requires_grad_(True)
v = torch.randn(batch, length, n_heads, dim_head).to("cuda").to(torch.bfloat16).requires_grad_(True)
dout = torch.randn(batch, length, n_heads,dim_head).to("cuda").to(torch.bfloat16)

for i in range(10):
    y = flash_attn_func(q, k, v, causal=True)
    if test_bwd:
        y.backward(dout)
torch.cuda.synchronize()
start = time.time()
for i in range(1000):
    y = flash_attn_func(q, k, v, causal=True)
    if test_bwd:
        y.backward(dout)
torch.cuda.synchronize()
print(time.time() - start)

from mamba.

albertfgu avatar albertfgu commented on August 15, 2024 1

Please format your code with triple backticks followed by "python": ``` python

The appendix of the paper says that the dimension D is actually 1024, not 2048. We'll have to double check our script to see if anything is different than yours.

from mamba.

xiayuqing0622 avatar xiayuqing0622 commented on August 15, 2024

We compared attention time (softmax(QK^T)V) vs scan time, without the linear projection. The dimensions are different, e.g. in a 1.3B model Transformers would typically have Q, K, V of hidden dimensions 2048, while the input to the scan would be of dimension 4096 (due to expand=2). So we measured attention with Q, K, V of dim 2048 and scan with input of dimension 4096 to be fair / favorable to attention.

And what datatype did you use? When I try to run scan using fp16, it always raises the error:
Traceback (most recent call last):
File "/home/yuqing/mamba/run.py", line 29, in
y = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, True)
RuntimeError: Expected weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)

from mamba.

tridao avatar tridao commented on August 15, 2024

Q, K, V are bf16 for attention.
u, delta, B, C, z are bf16, A and D are fp32 for scan.

from mamba.

xiayuqing0622 avatar xiayuqing0622 commented on August 15, 2024

Q, K, V are bf16 for attention. u, delta, B, C, z are bf16, A and D are fp32 for scan.

it works now, thank you!

from mamba.

xiayuqing0622 avatar xiayuqing0622 commented on August 15, 2024

Please format your code with triple backticks followed by "python": ``` python

The appendix of the paper says that the dimension D is actually 1024, not 2048. We'll have to double check our script to see if anything is different than yours.

Sorry for the format issue. I've re-edited the code above. I also tested input with D=1024, for fwd, it's scan 0.13ms vs flash 0.08ms, for fwd+bwd, it's scan 0.71ms vs flash 0.35 ms.

from mamba.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.