Comments (8)
We compared attention time (softmax(QK^T)V) vs scan time, without the linear projection.
The dimensions are different, e.g. in a 1.3B model Transformers would typically have Q, K, V of hidden dimensions 2048, while the input to the scan would be of dimension 4096 (due to expand=2). So we measured attention with Q, K, V of dim 2048 and scan with input of dimension 4096 to be fair / favorable to attention.
from mamba.
We decided to leave those linear projections out because they are orthogonal to the main "sequence mixing mechanism" (attention vs scan) that is of interest to benchmark. You're right that the comparisons become slightly harder to control (e.g. what model dimension to use is fair?), but we chose a setting that seemed reasonable to us. No matter what, the timings will only be off by a small constant factor with any other "reasonable" setting of dimensions, which is dwarfed by the linear vs quadratic complexity.
from mamba.
Q, K, V are bf16 for attention. u, delta, B, C, z are bf16, A and D are fp32 for scan.
I write a simple script to compare these two component(scan and flashattn2 with causal), and tested it on A100. As instructed, input dim of scan is 4096 and input dim of flashattn is 2048( 32heads * 64 head dim). however, scan is much slower than flashattention2. (fwd: scan is 0.25ms, and flash2 is 0.14ms, fwd+bwd: scan is 1.25ms, flash2 is 0.59ms) Did I make any settings wrong?
import torch
import time
test_bwd=False
batch, length, dim, d_state =1, 2048, 2048, 16
from mamba_ssm.ops.selective_scan_interface import SelectiveScanFn
u = torch.randn(batch, dim * 2, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
delta = torch.randn(batch, dim * 2, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
A = torch.randn(dim*2, d_state).to("cuda").requires_grad_(True)
B = torch.randn(batch, 1, d_state, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
C = torch.randn(batch, 1, d_state, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
D = torch.randn(dim*2).to("cuda").requires_grad_(True)
z = torch.randn(batch, dim*2, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
delta_bias = torch.randn(dim*2).to("cuda").requires_grad_(True)
doutssm = torch.randn(batch, dim*2, length).to("cuda").to(torch.bfloat16)
ssm = SelectiveScanFn.apply
for i in range(10):
y = ssm(u, delta, A, B, C, D, z, delta_bias, True)
if test_bwd:
y.backward(doutssm)
torch.cuda.synchronize()
start = time.time()
for i in range(1000):
y = ssm(u, delta, A, B, C, D, z, delta_bias, True)
if test_bwd:
y.backward(doutssm)
torch.cuda.synchronize()
print(time.time() - start)
from flash_attn import flash_attn_func
dim_head = 64
n_heads = dim//dim_head
q = torch.randn(batch, length, n_heads, dim_head).to("cuda").to(torch.bfloat16).requires_grad_(True)
k = torch.randn(batch, length, n_heads, dim_head).to("cuda").to(torch.bfloat16).requires_grad_(True)
v = torch.randn(batch, length, n_heads, dim_head).to("cuda").to(torch.bfloat16).requires_grad_(True)
dout = torch.randn(batch, length, n_heads,dim_head).to("cuda").to(torch.bfloat16)
for i in range(10):
y = flash_attn_func(q, k, v, causal=True)
if test_bwd:
y.backward(dout)
torch.cuda.synchronize()
start = time.time()
for i in range(1000):
y = flash_attn_func(q, k, v, causal=True)
if test_bwd:
y.backward(dout)
torch.cuda.synchronize()
print(time.time() - start)
from mamba.
Please format your code with triple backticks followed by "python": ``` python
The appendix of the paper says that the dimension D
is actually 1024, not 2048. We'll have to double check our script to see if anything is different than yours.
from mamba.
We compared attention time (softmax(QK^T)V) vs scan time, without the linear projection. The dimensions are different, e.g. in a 1.3B model Transformers would typically have Q, K, V of hidden dimensions 2048, while the input to the scan would be of dimension 4096 (due to expand=2). So we measured attention with Q, K, V of dim 2048 and scan with input of dimension 4096 to be fair / favorable to attention.
And what datatype did you use? When I try to run scan using fp16, it always raises the error:
Traceback (most recent call last):
File "/home/yuqing/mamba/run.py", line 29, in
y = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, True)
RuntimeError: Expected weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
from mamba.
Q, K, V are bf16 for attention.
u, delta, B, C, z are bf16, A and D are fp32 for scan.
from mamba.
Q, K, V are bf16 for attention. u, delta, B, C, z are bf16, A and D are fp32 for scan.
it works now, thank you!
from mamba.
Please format your code with triple backticks followed by "python":
``` python
The appendix of the paper says that the dimension
D
is actually 1024, not 2048. We'll have to double check our script to see if anything is different than yours.
Sorry for the format issue. I've re-edited the code above. I also tested input with D=1024, for fwd, it's scan 0.13ms vs flash 0.08ms, for fwd+bwd, it's scan 0.71ms vs flash 0.35 ms.
from mamba.
Related Issues (20)
- Building wheel for mamba-ssm (setup.py) ... error HOT 5
- using ssm_state and conv_state during training HOT 6
- nnUNet environment variables are redefined? HOT 1
- Segmentation Fault HOT 1
- Is it running autoregressively? HOT 2
- _layer_norm_fwd_1pass_kernel error HOT 7
- RuntimeError: CUDA error: no kernel image is available for execution on the device on 3xP40 HOT 1
- Question about finetuning
- Code for training Transformer++ HOT 1
- Any suggestions for regularization? HOT 2
- does forward/eval from a trained mamba model require cuda as well? HOT 5
- Question about throughput HOT 4
- Direct Throughput Comparison to RetNet ? HOT 2
- WARNING HOT 4
- Question about cuda speedup HOT 3
- Does Mamba use any matrix multiplication? HOT 2
- Mamba is deeper but narrower than typical models of the same size HOT 2
- selective_scan_cuda HOT 3
- Packaging module missing when installing HOT 2
- Bidirectional model? HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mamba.