Giter Site home page Giter Site logo

retnet's Introduction

RetNet

An implementation of Retentive Network: A Successor to Transformer for Large Language Models in PyTorch.

About this repository

This is a minimal, pure pytorch implementation of RetNet. RetNet paper: Retentive Network: A Successor to Transformer for Large Language Models.

The contributors(s) to this repository are not authors of the original paper. All credit for the idea and formulation of RetNet goes to the original authors.

The purpose of this repository is to aid scientific and technological understanding and advancement. The code prioritizes correctness and readability over optimization.

Features implemented

  • Single-scale and MultiScale retention:
    • parallel paradigm
    • recurrent paradigm
    • chunkwise paradigm
  • Multi-layer retentive network with FFN and LayerNorm
    • parallel paradigm
    • recurrent paradigm
    • chunkwise paradigm
  • Causal language model (CLM) built on top of the the retentive network

Usage and Examples:

  • See scripts prefixed with test_ for examples of basic usage

Positional Encodings

The main implementation in src/ uses Microsoft's xPos for positional encoding.

The implementation in src/complex uses complex values to encode position, which requires parameter and data throughput types to be torch.ComplexFloat (64-bit). This has some limitations due to there not yet being torch support for half-precision complex types. It also requires twice the amount of memory as real-valued data at 32-bit precision.

Contributions

All contributions are welcome. Please see issues for an idea of what needs doing.

If you would like to contribute to this project, please fork it and submit a pull request for review.

References

@misc{sun2023retentive,
      title={Retentive Network: A Successor to Transformer for Large Language Models}, 
      author={Yutao Sun and Li Dong and Shaohan Huang and Shuming Ma and Yuqing Xia and Jilong Xue and Jianyong Wang and Furu Wei},
      year={2023},
      eprint={2307.08621},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

retnet's People

Contributors

jamie-stirling avatar leffff avatar michaelfu1998-create avatar regenhardt avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

retnet's Issues

Some Questions about Attention Mask

Hello, I have reviewed some of the code and did not use an attention mask. It's retnet. Don't you need to cover up the pad ID? Or does the pad ID have no impact on the previous sequence?

Real-valued implementation using xPos

The current implementation uses complex arithmetic to implement the original paper, which has known issues with stability and precision. It's been suggested that xPos is a more stable and efficient way to achieve the same things by representing rotations using Euler's identity.

It would be nice if all constructors had an additional option to do arithmetic in real algebra using xPos (rotary positional embeddings), as described in this paper:
https://arxiv.org/abs/2212.10554
and implemented here:
https://github.com/microsoft/torchscale/blob/main/torchscale/component/xpos_relative_position.py

This may solve current issues with memory stability.

The complex theta should cancel out

Maybe I am missing something but do we need the Theta? Since its magnitude is 1, multiplying with its conjugate should cancel out in the parallel version.

Proposed improvement/collaboration: removing the O(T^2) training cost

Hi there, just found this work thanks to @yk's recent video. Nice job! There are similarities with work I've been doing for a few months, and while I'm a little bummed you beat me to publish I wasn't going to be able to do a good job of evaluating the architectures anyway (this is a side-project that is currently thrashing my laptop and I'm not sure I could justify the cloud costs to train even a moderately sized model just out of curiosity), and I'm glad the idea is being investigated and released with a permissive license.

I'm not sure if you're looking for suggestions or collaborations, but thought I'd put my ideas out there and see what happens. I'm happy to provide more details/collaborate on a future work if there's interest, or feel free to point me towards someone else who might be interested or run with it yourself.

TL;DR

From my understanding of the paper/code (and I apologise if I've got any of this wrong), computing retention values is still O(T^2) in sequence length T and prone to underflow (hence the nan replacement). Neither of these is necessary. The computation you're performing is just an exponential moving average which can be computed in O(T) with a scan using an associative operator, meaning associative_scan implementations can do it very efficiently in parallel.

Details

Unfortunately we're still waiting on pytorch's associative_scan implementation, so I'll be using jax below, for which a primitive exists. Note I've got a pytorch version working which wraps the jax implementations with jax2torch, though I can't make it work nicely with torch's compile and I'm more comfortable with jax anyway.

The below is an implementation that takes an arbitrary decay factor at each step. To get the same performance as in your paper, I think you can just set it to factors = gamma * ones_like(values), but

import typing as tp
import jax
import jax.numpy as jnp

Pair = tp.Tuple[jnp.ndarray, jnp.ndarray]


def _cumulative_ema_op(a: Pair, b: Pair) -> Pair:
    xa, fa = a
    xb, fb = b
    return xa * fb + xb, fa * fb


def cumulative_ema(
    values: jnp.ndarray, factors: jnp.ndarray, reverse: bool = False, axis: int = 0
) -> jnp.ndarray:
    """
    Compute cumulative exponential moving average.

    If `reverse == False` and axis == 0,
        output[i+1] = output[i] * factors[i+1] + output[i+1]
        
    If `reverse == True`, then the result is the reverse of the non-reversed call on
    arguments reversed on the given axis.

    Args:
        values: N-D float values
        factors: same shape/dtype as values
        axis: the axis to compute exponential moving average along.
        reverse: if True, perform accumulation in reverse.

    Returns:
        cumulative ema values, same shape as values/factors.
    """
    if axis < 0:
        axis += len(values.shape)
    assert values.shape == factors.shape, (values.shape, factors.shape)
    f, t = jax.lax.associative_scan(
        _cumulative_ema_op, (values, factors), reverse=reverse, axis=axis
    )
    del t
    return f

Thus computing retention values from Q, K and V values would be:

def retention(Q, K, gamma, V, reverse=False):
    """
    Notation:
      T: time dimension
      A: attention dimension
      C: number of output channels
    
    Args:
        Q: [T, A] query
        K: [T, A] key
        gamma: [] decay constant
        V: [T, C] values

    Returns:
        [T, C]
    """
    rhs = jnp.einsum('ta,tc->tac', K, V)
    rhs = cumulative_ema(rhs, jnp.full_like(rhs, gamma), axis=0, reverse=reverse)
    return jnp.einsum('ta,tac->tc', Q, rhs)

I've left out the batch dimension for simplicity, but I'm sure you could make the appropriate modifications (or if you decide to use jax, just vmap it). I'll spare you the full theoretical derivation for why this computes (Q K.T * D) @ V, but the short version is we use property 1 from here (see last slide) and note that DX = cumulative_ema(X, jnp.full_like(X, gamma), axis=0). This is O(TAC) ins space/time rather than O(T^2(A + C) in time and O(T(T + C)) in space.

Creating a bidirectional encoder is thus trivial by combining two - one with reverse=False and the other with reverse=True.

Now with that implementation you might be tempted to play around with the architecture a little - I've played with creating only two transformed matrices, factors (sigmoid-activated to ensure decay) and values of the same shape (rather than Q, K, V) and using them in the cumulative_ema directly which reduces the O(TAC) memory/time requirement to O(TC). Conceptually this just means that each token embedding at each layer just decides how much of the past to forget, and what to add based on the previous layer's embedding. I don't see any barriers to implementing a complex version to allow for periodic behaviour, but haven't attempted that.

My implementation is keras_core-based (so you can use pytorch backend so long as you don't try and compile). It needs a lot of cleaning up before I'm prepared to make it public but happy to share privately. Very small-scale experiments where I've just replaced Bert's self-attention mechanism with the bidirectional O(TC) implementation discussed above and remove positional embeddings entirely have proved promising (faster training, better performance than bert). I have no way of validating if performance scales with model size - I was planning on looking for collaborators/sponsors for that, so if you're interested in that let me know :).

Chunkwise retention giving different output

The implementation of chunkwise retention paradigm on the chunkwise-real branch gives different outputs to the other two paradigms.

It appears there may be a mistake in the paper on which the implementation was based, in equation (7). A pull request fixing this and obtaining outputs consistent with the other two paradigms would be greatly appreciated.

This can be reproduced by running `python src/tests.py', with stdout:

FFF
======================================================================
FAIL: test_retnet (__main__.TestRetNet)
verify that the three implementations of RetNet are identical
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/jamie/Repos/RetNet/src/tests.py", line 137, in test_retnet
    self.assertTrue(torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5)) # fails
AssertionError: False is not true

======================================================================
FAIL: test_multiscale (__main__.TestRetention)
verify that the three implementations of MultiScaleRetention are identical
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/jamie/Repos/RetNet/src/tests.py", line 86, in test_multiscale
    self.assertTrue(torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5)) # fails
AssertionError: False is not true

======================================================================
FAIL: test_simple (__main__.TestRetention)
verify that the three implementations of SimpleRetention are identical
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/jamie/Repos/RetNet/src/tests.py", line 45, in test_simple
    assert torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5) # fails
AssertionError

----------------------------------------------------------------------
Ran 3 tests in 0.098s

FAILED (failures=3)

About the complex

Sorry for bothering you and this may be a dumb question:
The Complex type in here is for what?

I'm not very good at math and if you guys can explain why we need to use complex it will be good.

what about cross-attention

Can this model achieve cross-attention similar to how transformer handles different modal embedding matrices?

Dimensions of forward_recurrent

In MultiScaleRetention class, it is mentioned that 's_n_1s' has dimensions (batch_size, heads, head_size, head_size), while in SimpleRetention, 's_n_1' is defined as 's_n_1s[i]'. However, you mentioned that 's_n_1' has dimensions (batch_size, hidden_size, v_dim). Can you clarify this?

_get_D function very slow for long sequence

First, many thanks for your implementation!

It seems that the _get_D function

def _get_D(self, sequence_length):
    # D[n,m] = gamma ** (n - m) if n >= m else 0
    D = torch.zeros((sequence_length, sequence_length), requires_grad=False)
    for n in range(sequence_length):
        for m in range(sequence_length):
            if n >= m:
                D[n, m] = self.gamma ** (n - m)
    return D

gets really slow for long sequence lengths, resulting in very low GPU utility.

by changing to the style below it gets better. Not sure if it's perfectly correct but for gamma < 1 it seems all good.

def _get_D(self, sequence_length):
    n = torch.arange(sequence_length).unsqueeze(1)
    m = torch.arange(sequence_length).unsqueeze(0)

    # Broadcast self.gamma ** (n - m) with appropriate masking to set values where n < m to 0
    D = (self.gamma ** (n - m)) * (n >= m).float()  #this results in some NaN when n is much larger than m
    # fill the NaN with 0
    D[D != D] = 0

    return D

Training is slow and some errors (perhaps)

Thank you for reproducing retnet!

However, when I actually run the code, I find that the training is slow, 5-6 times slower for the same task compared to transformer (transformer uses half-precision, retnet does not). The memory usage is also very unstable, is it due to the loops in the code or the retnet itself? Is there any plan or way to optimize this?

For the first time, there seems to be a problem in the code, and it seems that we need to change here to

return (self.swish(X @ self.W_G.to(self.complex_type)) + Y) @ self.W_O.to(self.complex_type)

Confusion about "the chunkwise recurrent representation of retention"

I have a question regarding "Chunkwise Recurrent Representation of Retention." The original expression in the paper is as follows:
image

In your implementation, the code looks like this:

 r_i = (K.transpose(-1, -2) @ (V * D[-1].view(1, chunk_size, 1))) + (self.gamma ** chunk_size) * r_i_1

The first part of this equation calculates the KV matrix for the current chunk, and then multiplies it by a scaling factor. My understanding is that, assuming we ignore batch size, the shapes of K and V for the current chunk are both (2,3). In other words, the current chunk contains 2 tokens, so the KV matrix should have a shape of (3,3). Then, based on your code, you multiply this KV matrix by the last row of the D matrix (shape is (2,2)), for example, if the D matrix is [[1, 0], [0.9, 1]], then V * D[-1].view(1, chunk_size, 1) becomes [[0.9], [1]], and these values are multiplied with the first and second rows of the V matrix to implement decay. However, when we take the inner product of the Q matrix for the chunk and the first half of R_i, it seems like both q tokens within the Q matrix are using the same decay factor, is that correct? In other words, for the same chunk, if we want to perform attention, the second q token should intuitively be multiplied by a decay factor (0.9) when attending to the first v token, but when the first q token operates on the first v token, it doesn't need this decay factor.

Additionally, for the second half of R_i, it seems that the entire chunk is considered as a whole, and R_i_1 is directly subjected to decay as a whole, and the decay occurs as many times as the length of the chunk.

There's another question I have regarding the cross-chunk calculations.

        #e[i,j] = gamma ** (i+1)
        e = torch.zeros(batch, chunk_size, 1)
        
        for _i in range(chunk_size):
            e[:, _i, :] = self.gamma ** (_i + 1)
        
        cross_chunk = (Q @ r_i_1) * e

In the code, the variable 'e' appears to play a role in decay as well. However, based on the code, the final result after calculating (Q @ r_i_1) might be something like [o1, o2, o3]^T, where each 'oi' is a row vector with D dimensions. What I'd like to point out is that, according to your code, 'o1' actually has the least decay, and 'o3' has the most decay. But intuitively, for the current Q, shouldn't the vector corresponding to 'o1' be the farthest from the q tokens within the current chunk? In other words, shouldn't the decay of 'o1' be the greatest? So, should the code be like this:

        #e[i,j] = gamma ** (i+1)
        e = torch.zeros(batch, chunk_size, 1)
        
        for _i in range(chunk_size):
            # e[:, _i, :] = self.gamma ** (_i + 1)
            e[:, _i, :] = self.gamma ** (chunk_size - _i)
      
        cross_chunk = (Q @ r_i_1) * e

This is very confusing to me. Is there a more detailed derivation or a clearer explanation of how equation (7) in the original article is obtained? Especially the exponential part of the decay factor, is the result of this calculation consistent with the result of completely parallel computation? Can someone help me with this?

Is Retnet equivalent to ordinary GPT when the decay is set to 1 ?

I'm a little confused of what retnet does in practice. Because in the formula Rentention(X) = (Q @ K.T * D) @ V, if the decay is 1, the mathematical derivation of proving the equivalence between RNN and the Retnet's transformer still works. As when decay is equal to 1, D will be the normal attention mask used by almost all existing GPT models. Does that mean all existing GPT models can be modified into Retnet by simply modifying the inference function without any further training? Am I correct or do I miss something?

Q, k and D device difference

ret = (Q @ K.permute(0, 2, 1)) * D.unsqueeze(0)

Q and K are put onto any device because they are model parameters, while D is created in SimpleRetention._get_D and is not put to any device. Therefore if you train on CUDA, Q and K are on cuda and D is on CPU. Error arises

/src/retnet.py GPU

2511 if has_torch_function_variadic(input, weight, bias):
2512 return handle_torch_function(
2513 layer_norm, (input, weight, bias), input, normalized_shape, weight=weight, bias=bias, eps=eps
2514 )
-> 2515 return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument weight in method wrapper_CUDA__native_layer_norm)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.