Giter Site home page Giter Site logo

lucidrains / ring-attention-pytorch Goto Github PK

View Code? Open in Web Editor NEW
373.0 8.0 21.0 1.08 MB

Implementation of 💍 Ring Attention, from Liu et al. at Berkeley AI, in Pytorch

License: MIT License

Python 100.00%
attention-mechanism efficient-attention long-context distributed-attention

ring-attention-pytorch's Introduction

Ring Attention - Pytorch

Implementation of Ring Attention, from Liu et al. at Berkeley AI, in Pytorch.

It basically splits the data across the sequence dimension (instead of batch) and applies ring reduce to the processing of the tiles of the attention matrix, flash attention style.

I believe this is being used for the 1-10 million tokens for the latest Gemini. At least some form of it; the other possibility would be unpublished improvements on top of RMT.

In addition, the repository also contains the logic for Striped Attention, a follow up paper that permutes the sequence for better workload balancing for autoregressive transformers.

It also contains support for grouped query attention, popularized by Llama series of attention models. This will further save on communication costs during the ring reduce.

Appreciation

  • A16Z Open Source AI Grant Program for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research

  • Tri Dao for all his tremendous hard work maintaining Flash Attention over the last year or two, from which the CUDA version in this repository depends on

  • Phil Tillet for Triton, without which the forward ring flash attention CUDA kernel would have taken a magnitude of order more work.

Install

$ pip install ring-attention-pytorch

Usage

import torch
from ring_attention_pytorch import RingAttention

attn = RingAttention(
    dim = 512,
    dim_head = 64,
    heads = 8,
    causal = True,
    auto_shard_seq = True,
    ring_attn = True,
    ring_seq_size = 512
)

tokens = torch.randn(1, 1024, 512)
attended = attn(tokens)

assert attended.shape == tokens.shape

Test

First install requirements

$ pip install -r requirements.txt

Then say testing autoregressive striped ring attention on cuda would be

$ python assert.py --use-cuda --causal --striped-ring-attn

Todo

  • make it work with derived causal mask based on rank and chunk sizes

  • modify flash attention to output intermediates and figure out backwards with recompute and ring passes

  • functions for splitting the sequence evenly among ranks, either within attention function, or in the external ring transformer wrapper

  • basic test case with two processes and check for equivalent output and gradients

  • testing

    • make sure key padding mask works
    • make sure causal mask works
    • rotary embeddings, with proper key/value offset depending on ring rank
  • striped attention

    • add the permutating logic before and after transformer
    • add causal masking logic - account for sub bucketing by flash attention
  • fix issue with ring attention when flash buckets > 1

  • move flash attention back to key / value column traversal on outer loop and save on ring communication

    • backwards
    • forwards
  • fix rotary positions for striped ring attention when flash buckets > 1

  • allow for variable ring passes per layer, for local -> global attention in ring transformer as one goes up the layers.

  • when doing ring passes, alternate between designated send and receive buffers

  • instead of max ring passes, able to specify lookback in terms of sequence length, and derive number of flash attention bucket + ring passes from that

  • ability to have ring size < world size, sharding the batch and sequence, and doing ring reduce with the correct set of ranks

  • add flash attention kernel version in the presence of cuda

    • for forwards, use modified Triton flash attention forwards that outputs row sums, maxes, and exponentiated weighted sum
    • for backwards, use Tri's flash attention kernels, accumulate dq, dk, dv across rings
    • refactor to have naive ring+flash attention work with (batch, seq, head, dim)
    • handle key padding mask for forwards by translating mask to bias
    • figure out how Tri handles key padding mask for backwards
    • scale output of flash attention forwards on the last ring pass reduce
    • verify backwards working in a100 runpod
    • dk, dv needs to be float32, while kv needs to be float16. see if both can be cast to int before stacked and ring passed all in one go, then reinterpret back to float32 and float16
    • prevent an unnecessary tl.load on the first ring pass
    • cuda backwards pass must have same dq, dk, dv as naive
  • fix naive flash attention backwards

  • validate cuda causal and striped ring attention works

  • make sure cuda striped attention works for multiple buckets, otherwise flash attention is ineffective

  • for cuda striped attention, for backwards hack, pad the extra token once and index out when passing into Tri's cuda kernel

  • find a machine with 8 GPUs and test with a quarter million tokens first

  • see for cuda version whether softmax_D can be computed once and cached over the ring reduce. go for modified triton backwards if not

  • think about how to craft a special Dataset that shards across sequence length (take into account labels for cross entropy loss) for ring transformer training

  • add ring attention to Tri's flash attention implementation. find some cuda ring reduce impl

  • figure out how to pytest distributed pytorch

  • use sdp context manager to validate when it is possible to use ring_flash_attn_cuda, otherwise assert out

  • improvise a variant where each machine keeps compressed summary tokens, and one only ring pass those summary token for some given distance

Citations

@article{Liu2023RingAW,
    title    = {Ring Attention with Blockwise Transformers for Near-Infinite Context},
    author   = {Hao Liu and Matei Zaharia and Pieter Abbeel},
    journal  = {ArXiv},
    year     = {2023},
    volume   = {abs/2310.01889},
    url      = {https://api.semanticscholar.org/CorpusID:263608461}
}
@article{Brandon2023StripedAF,
    title   = {Striped Attention: Faster Ring Attention for Causal Transformers},
    author  = {William Brandon and Aniruddha Nrusimha and Kevin Qian and Zachary Ankner and Tian Jin and Zhiye Song and Jonathan Ragan-Kelley},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2311.09431},
    url     = {https://api.semanticscholar.org/CorpusID:265220849}
}
@article{Dao2022FlashAttentionFA,
    title   = {FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
    author  = {Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2205.14135}
}
@article{dao2023flashattention2,
    title   = {Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning,
    author  = {Dao, Tri},
    year    = {2023}
}
@article{Tillet2019TritonAI,
    title   = {Triton: an intermediate language and compiler for tiled neural network computations},
    author  = {Philippe Tillet and H. Kung and D. Cox},
    journal = {Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages},
    year    = {2019}
}
@article{Ainslie2023GQATG,
    title   = {GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints},
    author  = {Joshua Ainslie and James Lee-Thorp and Michiel de Jong and Yury Zemlyanskiy and Federico Lebr'on and Sumit K. Sanghai},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2305.13245},
    url     = {https://api.semanticscholar.org/CorpusID:258833177}
}

The Bitter Lesson - Richard Sutton

ring-attention-pytorch's People

Contributors

lucidrains avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

ring-attention-pytorch's Issues

inference for open LLM

Can it be used to extend the structure of the open source LLM from huggingface and use ring attention for inference?

Comment about use of all gather

Hi Phil!

Hope you're doing well. As you saw with Gemini Pro 1.5, which works on 1 million tokens, open-source has some work to do to catch up :D porting Ring Attention to PyTorch is definitely the first step towards that.

@rwightman made an interesting comment on your current approach of implementing Ring Attention, tought it would be useful for you to share that: https://twitter.com/wightmanr/status/1758275957557719308. Basically Ross had something similar he had to implement to make the SigLIP loss function work, leveraging neighbour exchange instead of allgather.

Btw, if your implementation is done, I would like to leverage it to port the LWM model that came out 2 days ago (https://github.com/LargeWorldModel/LWM). I would port the model to the Hugging Face Transformers library, by adding a LWMForCausalLM class. Since the weights are open-sourced I can convert them to the Transformers format.

Btw are you still active on any Discord channel?

Cheers,

Niels

Connection closed by peer

when I set use_cuda = True, and set

  os.environ['MASTER_ADDR'] = '127.0.0.1'
  os.environ['MASTER_PORT'] = '1234'

the error as follow:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/multiprocessing/spawn.py", line 74, in _wrap
    fn(i, *args)
  File "/app/assert.py", line 87, in start
    ring_out = ddp_ring_attention_net(seq)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1509, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1345, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/ring_attention_pytorch/ring_attention.py", line 568, in forward
    x = attn(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/ring_attention_pytorch/ring_attention.py", line 368, in forward
    out = ring_flash_attn_cuda(
  File "<@beartype(ring_attention_pytorch.ring_flash_attention_cuda.ring_flash_attn_cuda) at 0x7fdd828c4ee0>", line 214, in ring_flash_attn_cuda
  File "/app/ring_attention_pytorch/ring_flash_attention_cuda.py", line 752, in ring_flash_attn_cuda
    return ring_flash_attn_cuda_(q, k, v, mask, causal, bucket_size, ring_reduce_col, striped_ring_attn, max_lookback_seq_len, ring_size)
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 551, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/app/ring_attention_pytorch/ring_flash_attention_cuda.py", line 526, in forward
    for (ring_rank, is_last), ((kv, mask), (receive_kv, receive_mask)) in ring_pass_fn(kv, mask, receive_buffers = (receive_kv, receive_mask), max_iters = max_ring_passes, ring_size = ring_size):
  File "/app/ring_attention_pytorch/ring.py", line 127, in all_ring_pass
    new_tensor, new_receive_buffer = one_ring_pass(tensor, receive_buffer, ring_size)
  File "/app/ring_attention_pytorch/ring.py", line 88, in ring_pass
    send_and_receive_(x, receive_buffer, circular_rank_right(ring_size = ring_size), circular_rank_left(ring_size = ring_size))
  File "/app/ring_attention_pytorch/ring.py", line 69, in send_and_receive_
    dist.recv(receive_buffer, receive_from_rank)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/c10d_logger.py", line 72, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py", line 1680, in recv
    pg.recv([tensor], src, tag).wait()
RuntimeError: [/opt/pytorch/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:534] Connection closed by peer [172.24.0.2]:34966

Cross Attention variant?

Hi there,

Sorry if this is a stupid issue but I was wondering if it would be possible to apply Ring Attention to Cross Attention? I was thinking of using RingFlashAttentionCUDAFunction directly but it seems like the transformer block itself has modifications.

Thanks

ValueError: Invalid expression '[ True]', must be integers

pytree_node instead.
  _torch_pytree._register_pytree_node(
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
Traceback (most recent call last):
  File "/Users/defalt/Desktop/Athena/research/Gemini/gemini_block.py", line 18, in <module>
    out = model(x)  # Apply the model to the input tensor
          ^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/defalt/Desktop/Athena/research/Gemini/gemini_torch/model.py", line 101, in forward
    x = self.attn(x)
        ^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/ring_attention_pytorch/ring_attention.py", line 228, in forward
    q, k, v = rearrange('b n (qkv h d) -> qkv b h n d', qkv, qkv = 3, h = self.heads)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/einx/lru_cache.py", line 70, in inner
    graph = construct_graph(*args, backend=backend, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/einx/lru_cache.py", line 20, in inner
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/einx/lru_cache.py", line 45, in construct_graph
    output_tracers = func(*args, **kwargs, backend=einx.backend.tracer)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/einx/op/rearrange.py", line 118, in rearrange
    exprs_in, exprs_out = parse(description, *[einx.param.get_shape(tensor) for tensor in tensors], cse=cse, **parameters)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/einx/lru_cache.py", line 20, in inner
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/einx/op/rearrange.py", line 59, in parse
    + [einx.expr.Equation(k, np.asarray(v)[..., np.newaxis], depth1=None, depth2=None) for k, v in parameters.items()],
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/einx/op/rearrange.py", line 59, in <listcomp>
    + [einx.expr.Equation(k, np.asarray(v)[..., np.newaxis], depth1=None, depth2=None) for k, v in parameters.items()],
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/einx/expr/util.py", line 36, in __init__
    self.expr2 = _input_expr(expr2)
                 ^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/einx/expr/util.py", line 29, in _input_expr
    raise ValueError(f"Invalid expression '{expr}', must be integers")
ValueError: Invalid expression '[ True]', must be integers
``` 

8 A100S

Willing to give you 8 A100s for this, lmk through email

I'm doing an image generation experiment, but my script outputs a json file, how do I train a Transformer model to generate a pixel representation of an image?

I'm doing an experiment with image generation, but my script outputs a json file, how can I train a transformer model??

import cv2
import json
import numpy as np
import os
from PIL import Image

def image_to_text(image_path, text_path):
    # Read the image and convert to grayscale
    image = cv2.imread(image_path)
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    
    # Apply thresholding to get a binary image
    _, binary = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
    
    # Find contours
    contours, _ = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    contours_list = [contour.flatten().tolist() for contour in contours]  # Flatten and convert to list

    # Convert image to list of pixel values
    pixels = image.flatten().tolist()

    # Save mode and size
    mode = image.shape[2] if len(image.shape) == 3 else 1  # Color if 3 channels, grayscale if 2
    size = image.shape[:2]

    # Write pixel data and contour information to text file
    with open(text_path, 'w') as text_file:
        json.dump({'mode': mode, 'size': size, 'pixels': pixels, 'contours': contours_list}, text_file)

def text_to_image(text_path, output_image_path):
    # Read pixel data and contour information from text file
    with open(text_path, 'r') as text_file:
        data = json.load(text_file)
        mode = data['mode']
        size = tuple(data['size'])
        pixels = data['pixels']
        contours_list = data['contours']

    # Reconstruct the image from the pixel information
    image_array = np.array(pixels, dtype=np.uint8)
    if mode == 1:
        image_array = image_array.reshape(size[0], size[1])  # Grayscale
    else:
        image_array = image_array.reshape(size[0], size[1], mode)  # Color

    img = Image.fromarray(image_array)
    img.save(output_image_path)

    # Reconstruct the image contours
    contours = [np.array(contour).reshape(-1, 1, 2) for contour in contours_list]  # Reshape to contour format
    img_contours = cv2.imread(output_image_path)
    cv2.drawContours(img_contours, contours, -1, (0, 255, 0), 2)
    cv2.imwrite(output_image_path, img_contours)


def batch_process(input_folder, output_folder_text, output_folder_images):
    # 确保输出文件夹存在
    if not os.path.exists(output_folder_text):
        os.makedirs(output_folder_text)
    if not os.path.exists(output_folder_images):
        os.makedirs(output_folder_images)
    
    # 遍历文件夹中的所有图像文件
    for filename in os.listdir(input_folder):
        if filename.lower().endswith(('.jpg', '.png', '.jpeg')):  # 处理常见的图像格式
            print(f"Processing {filename}...")
            image_path = os.path.join(input_folder, filename)
            base_filename = os.path.splitext(filename)[0]
            text_path = os.path.join(output_folder_text, base_filename + '.txt')
            output_image_path = os.path.join(output_folder_images, filename)
            
            # 图像到文本
            image_to_text(image_path, text_path)
            
            # 文本到图像
            text_to_image(text_path, output_image_path)

# 使用示例
input_folder = 'D:/llama2.c-master/1/images'
output_folder_text = 'D:/llama2.c-master/1/text'
output_folder_images = 'D:/llama2.c-master/1/imagesout'

batch_process(input_folder, output_folder_text, output_folder_images)

@lucidrains Can you help me?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.