Giter Site home page Giter Site logo

flash-attention-minimal's Introduction

flash-attention-minimal

A minimal re-implementation of Flash Attention with CUDA and PyTorch. The official implementation can be quite daunting for a CUDA beginner (like myself), so this repo tries to be small and educational.

  • The entire forward pass is written in ~100 lines in flash.cu.
  • The variable names follow the notations from the original paper.

Usage

Prerequisite

  • PyTorch (with CUDA)
  • Ninja for loading in C++

Benchmark

Compare the wall-clock time between manual attention and minimal flash attention:

python bench.py

Sample output on a T4:

=== profiling manual attention ===
...
Self CPU time total: 52.389ms
Self CUDA time total: 52.545ms

=== profiling minimal flash attention === 
...  
Self CPU time total: 11.452ms
Self CUDA time total: 3.908ms

Speed-up achieved!

I don't have a GPU

Try out this online colab demo.

Caveats

  • No backward pass! To be honest, I found it a lot more complex than the forward pass, which was enough to show the use of shared memory to avoid large N^2 read/writes.
  • In the inner loop, I assign each thread to a row of the output matrix. This differs from the original implementation.
  • This thread-per-row simplification makes the matrix multiplications very slow. This is probably why for longer sequences and larger block sizes, this gets slower than the manual implementation.
  • Q,K,Vs are in float32, unlike the original implementation which uses float16.
  • The block size is fixed at compile time to 32.

Todos

  • Add backward pass
  • Speed up matmults
  • Dynamically set block size

flash-attention-minimal's People

Contributors

tspeterkim avatar

Stargazers

Hongwei Chen avatar Sean avatar Yifeng Yu avatar Bill Xu avatar Muhammad Anas Raza avatar  avatar Will avatar preferenceforxlh avatar  avatar feng_shuai avatar Dayou Du avatar TerryHu avatar Dick Tang avatar  avatar Hexu Zhao avatar Zhaoyue Cheng avatar Nathan Azrak avatar Robert Flynn avatar  avatar Jason Kuan avatar weishengying avatar longer_is_better avatar Jason Huang avatar Wendyyy Liu avatar Louis Ulmer avatar  avatar Jinda Jia avatar Osayamen Aimuyo avatar  avatar Cyclotomic Fields avatar Ykeon avatar  avatar Sunghyun Jun avatar Guangyao Li avatar A.J avatar Kevin Ko avatar Sang Park avatar kingfly avatar  avatar  avatar dytc avatar david avatar miao avatar Hongtao-Xu avatar zxsong avatar LiYu Lu avatar Maulvi Alfansuri avatar Ziyi Tan avatar Songrun Xie avatar Carl Guo avatar Jianyang Gao avatar Jintao Zhang avatar lizi avatar Nadav Timor avatar Luchang Li avatar  avatar Kim Jae-Jin (김재진) avatar Alexander Nanda avatar Lujun Gui avatar  avatar Huy Huu Nguyen avatar  avatar Garrett Allen avatar  avatar Vinh Tran avatar xzy avatar Leonard Gleyzer avatar Gunther Xing avatar  avatar Lee Hyun Joon  avatar  avatar Renyang Guan avatar Liu-xiandong avatar 刘鹤 avatar Repeerc avatar Martin Asenov avatar Dakai avatar Baizhou Zhang avatar  avatar Niranjan Ravichandra avatar Ankit avatar Insop avatar Thanki Dhruv Ashwinkumar avatar Andrei Nigmatulin avatar Siddharth Singh avatar LeeHX avatar Bo Li avatar Yucheng Lu avatar Sinjin Jeong avatar Yang Su avatar Vineeth avatar  avatar Narain  avatar dhcode95 avatar panshaohua avatar Ken Sonoda avatar  avatar Yi Liu avatar Lau Van Kiet avatar Im Sunghoon avatar

Watchers

 avatar Iron-Bound avatar Shida Wang avatar  avatar

flash-attention-minimal's Issues

slow in for loop test

slow if i test it in for loop:

REPEAT = 10
manual_result = manual_attn(q, k, v) # warmup
st = time.time()
for _ in range(REPEAT):
    manual_result = manual_attn(q, k, v)
    torch.cuda.synchronize()
print(f"manual attention mean time(ms): {((time.time() - st) * 1000) / REPEAT}")

minimal_result = minimal_attn.forward(q, k, v)  # warmup
st = time.time()
for _ in range(REPEAT):
    minimal_result = minimal_attn.forward(q, k, v)
    torch.cuda.synchronize()
print(f"minimal attention mean time(ms): {((time.time() - st) * 1000) / REPEAT}")

Correctness parameters

Hi Peter,

I just found your post on HN. Congratulations on the post!

I am one of the developers behind Faial which is a tool that can analyze CUDA kernels and find data-races.

I ran our tool against the kernel flash.cu and found that it is data-race free as long as the following conditions are met:

  • N > 0
  • Bc == blockDim.x
  • Br == blockDim.x
  • Tr <= blockDim.x
  • N >= blockDim.x * blockDim.x

Faial is a research project, so I am wondering if having access to these correctness conditions is valuable to you as a developer.

Please let me know if you'd like me to try out any combinations of parameters to see if the kernel is still data-race free.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.