Hi there, just found this work thanks to @yk's recent video. Nice job! There are similarities with work I've been doing for a few months, and while I'm a little bummed you beat me to publish I wasn't going to be able to do a good job of evaluating the architectures anyway (this is a side-project that is currently thrashing my laptop and I'm not sure I could justify the cloud costs to train even a moderately sized model just out of curiosity), and I'm glad the idea is being investigated and released with a permissive license.
I'm not sure if you're looking for suggestions or collaborations, but thought I'd put my ideas out there and see what happens. I'm happy to provide more details/collaborate on a future work if there's interest, or feel free to point me towards someone else who might be interested or run with it yourself.
TL;DR
From my understanding of the paper/code (and I apologise if I've got any of this wrong), computing retention values is still O(T^2)
in sequence length T
and prone to underflow (hence the nan replacement). Neither of these is necessary. The computation you're performing is just an exponential moving average which can be computed in O(T)
with a scan
using an associative operator, meaning associative_scan
implementations can do it very efficiently in parallel.
Details
Unfortunately we're still waiting on pytorch's associative_scan implementation, so I'll be using jax below, for which a primitive exists. Note I've got a pytorch version working which wraps the jax implementations with jax2torch, though I can't make it work nicely with torch's compile
and I'm more comfortable with jax anyway.
The below is an implementation that takes an arbitrary decay factor at each step. To get the same performance as in your paper, I think you can just set it to factors = gamma * ones_like(values)
, but
import typing as tp
import jax
import jax.numpy as jnp
Pair = tp.Tuple[jnp.ndarray, jnp.ndarray]
def _cumulative_ema_op(a: Pair, b: Pair) -> Pair:
xa, fa = a
xb, fb = b
return xa * fb + xb, fa * fb
def cumulative_ema(
values: jnp.ndarray, factors: jnp.ndarray, reverse: bool = False, axis: int = 0
) -> jnp.ndarray:
"""
Compute cumulative exponential moving average.
If `reverse == False` and axis == 0,
output[i+1] = output[i] * factors[i+1] + output[i+1]
If `reverse == True`, then the result is the reverse of the non-reversed call on
arguments reversed on the given axis.
Args:
values: N-D float values
factors: same shape/dtype as values
axis: the axis to compute exponential moving average along.
reverse: if True, perform accumulation in reverse.
Returns:
cumulative ema values, same shape as values/factors.
"""
if axis < 0:
axis += len(values.shape)
assert values.shape == factors.shape, (values.shape, factors.shape)
f, t = jax.lax.associative_scan(
_cumulative_ema_op, (values, factors), reverse=reverse, axis=axis
)
del t
return f
Thus computing retention values from Q, K and V values would be:
def retention(Q, K, gamma, V, reverse=False):
"""
Notation:
T: time dimension
A: attention dimension
C: number of output channels
Args:
Q: [T, A] query
K: [T, A] key
gamma: [] decay constant
V: [T, C] values
Returns:
[T, C]
"""
rhs = jnp.einsum('ta,tc->tac', K, V)
rhs = cumulative_ema(rhs, jnp.full_like(rhs, gamma), axis=0, reverse=reverse)
return jnp.einsum('ta,tac->tc', Q, rhs)
I've left out the batch dimension for simplicity, but I'm sure you could make the appropriate modifications (or if you decide to use jax, just vmap it). I'll spare you the full theoretical derivation for why this computes (Q K.T * D) @ V
, but the short version is we use property 1 from here (see last slide) and note that DX = cumulative_ema(X, jnp.full_like(X, gamma), axis=0)
. This is O(TAC) ins space/time rather than O(T^2(A + C) in time and O(T(T + C)) in space.
Creating a bidirectional encoder is thus trivial by combining two - one with reverse=False
and the other with reverse=True
.
Now with that implementation you might be tempted to play around with the architecture a little - I've played with creating only two transformed matrices, factors
(sigmoid-activated to ensure decay) and values
of the same shape (rather than Q, K, V) and using them in the cumulative_ema
directly which reduces the O(TAC)
memory/time requirement to O(TC)
. Conceptually this just means that each token embedding at each layer just decides how much of the past to forget, and what to add based on the previous layer's embedding. I don't see any barriers to implementing a complex version to allow for periodic behaviour, but haven't attempted that.
My implementation is keras_core
-based (so you can use pytorch backend so long as you don't try and compile). It needs a lot of cleaning up before I'm prepared to make it public but happy to share privately. Very small-scale experiments where I've just replaced Bert's self-attention mechanism with the bidirectional O(TC) implementation discussed above and remove positional embeddings entirely have proved promising (faster training, better performance than bert). I have no way of validating if performance scales with model size - I was planning on looking for collaborators/sponsors for that, so if you're interested in that let me know :).