Giter Site home page Giter Site logo

Comments (10)

albertfgu avatar albertfgu commented on July 17, 2024

The amount of information the model can store is proportional to its state size. So in this task, the reason it's able to generalize perfectly is because there isn't much information it has to remember (just one token).

Your question might be touching on a second point of more nuance, although I'm not exactly sure if this is what you mean. But the vocabulary size also implicitly affects memorization ability, because the larger the vocab the more memory is required to represent one token (e.g. information theoretically with a uniform distribution, log(|vocab|) bits). And so a finite state size can also only support a maximum vocabulary in principle.

from mamba.

hrbigelow avatar hrbigelow commented on July 17, 2024

Hi Albert,

Thank you for the explanation and sorry, I do think I was unclear. I understand your point about the informational capacity and vocab size. But it seems the problem comes a bit upstream from that, namely that Delta can't learn to perform a linear (affine) separation of all possible combinations of memory_token vs. all-other-tokens. I do see that in the input-dependent version, the B and C parameters are also learning something, but they are also just linear so it doesn't seem to resolve this quandary. Also, actually I'm not asking even about generalization ability - I just can't see how such a model would be able to achieve 100% training accuracy.

From this section:

image

my understanding from the following section would be that the synthetic dataset is generated like:

EDIT: forgot to add in the special tok a second time

from random import choice, choices

def random_seq(L=256, V=16, P=10):
    """
    Section E.1
    Training consists of randomly generating data every step, with a batch size of 8.
    """
    # prefix is the region where the special token first occurs
    assert P < L - 2
    vocab = list(range(V))
    memory_tok = choice(vocab)
    other_toks = [ t for t in vocab if t != memory_tok ] 
    # Section 3.1 from https://arxiv.org/pdf/2212.14052.pdf the 'special token'
    special_tok = V 
    seq = choices(vocab, k=P) + [special_tok, memory_tok] + choices(other_toks, k=L-P-2) + [special_tok, memory_tok]
    return seq

if __name__ == '__main__':
    L, V, P = 20, 5, 3 
    print(f'seq_length={L}, vocab_size={V}, prefix={P}')
    for b in range(20):
        print(random_seq(L, V, P))

"""
seq_length=20, vocab_size=5, prefix=3
[4, 0, 3, 5, 3, 0, 1, 4, 0, 0, 1, 2, 1, 0, 2, 2, 4, 1, 2, 1, 5, 3]
[4, 3, 1, 5, 0, 4, 4, 2, 3, 4, 4, 1, 3, 4, 4, 2, 2, 2, 4, 3, 5, 0]
[0, 4, 3, 5, 1, 4, 2, 2, 3, 2, 3, 4, 4, 3, 3, 0, 0, 4, 3, 0, 5, 1]
[1, 2, 3, 5, 1, 2, 3, 2, 3, 2, 2, 3, 0, 0, 0, 4, 2, 0, 2, 2, 5, 1]
[2, 3, 2, 5, 2, 1, 4, 1, 4, 3, 1, 1, 4, 4, 3, 4, 4, 4, 0, 0, 5, 2]
[4, 1, 4, 5, 4, 0, 2, 1, 0, 3, 1, 3, 0, 1, 1, 1, 3, 2, 2, 2, 5, 4]
[3, 4, 0, 5, 3, 1, 1, 1, 0, 1, 2, 2, 0, 2, 0, 2, 4, 4, 0, 2, 5, 3]
[3, 3, 0, 5, 2, 0, 4, 3, 1, 1, 1, 1, 3, 0, 3, 4, 3, 4, 4, 1, 5, 2]
[3, 4, 4, 5, 4, 3, 0, 0, 3, 3, 1, 2, 0, 3, 3, 3, 0, 1, 2, 1, 5, 4]
[3, 1, 3, 5, 2, 1, 0, 0, 4, 1, 1, 3, 1, 3, 4, 3, 1, 1, 3, 3, 5, 2]
[1, 1, 4, 5, 2, 0, 4, 1, 0, 4, 1, 3, 4, 4, 1, 0, 1, 1, 0, 3, 5, 2]
[4, 2, 0, 5, 3, 0, 2, 0, 1, 4, 2, 0, 1, 4, 1, 4, 4, 0, 1, 0, 5, 3]
[0, 1, 1, 5, 0, 3, 4, 1, 1, 1, 3, 1, 2, 4, 1, 3, 2, 1, 4, 3, 5, 0]
[2, 2, 2, 5, 1, 0, 2, 4, 0, 4, 3, 4, 4, 0, 0, 3, 4, 4, 3, 3, 5, 1]
[0, 3, 3, 5, 2, 3, 4, 3, 1, 0, 4, 1, 1, 1, 4, 3, 1, 1, 1, 0, 5, 2]
[0, 2, 3, 5, 0, 1, 1, 1, 2, 3, 4, 1, 3, 2, 1, 3, 4, 2, 3, 1, 5, 0]
[0, 4, 3, 5, 0, 1, 1, 1, 2, 1, 1, 2, 2, 3, 1, 2, 4, 2, 4, 4, 5, 0]
[3, 2, 4, 5, 4, 1, 0, 0, 1, 1, 1, 1, 0, 0, 3, 1, 2, 3, 3, 1, 5, 4]
[0, 3, 3, 5, 4, 1, 0, 2, 3, 2, 0, 0, 1, 3, 1, 1, 3, 3, 1, 3, 5, 4]
[3, 0, 3, 5, 4, 2, 3, 3, 3, 1, 3, 3, 2, 1, 0, 2, 1, 1, 1, 0, 5, 4]
"""

So, in the training data, every token in the 16-token vocab will eventually be used as the memory_tok. My simplistic view at the moment then is that the Delta operator must learn to evaluate to a positive value for memory_tok so that \bar{B} evaluates positively, and at the same time must evaluate to near zero for all other tokens != memory_tok. But, Delta is just a perceptron basically, so it's not possible for it to simultaneously learn to perform this round-robin separation of memory tokens from other tokens for each possible memory_tok.

For reference, I've tried to summarize the real-valued version of the algorithm in einsum as:

image

Source colab here

Please do let me know if I made a mistake!

from mamba.

albertfgu avatar albertfgu commented on July 17, 2024

It's a 2 layer model, not 1 layer. I think you might be right that a single (S6) layer can't learn this task. Although, a single Mamba block probably can, because of the local convolution before the main SSM.

from mamba.

hrbigelow avatar hrbigelow commented on July 17, 2024

Thanks for your response. I was aware it was a two-layer model actually - sorry I should have mentioned that.

To be clear, this is the graphical model structure if I'm not mistaken. Hidden state
at layer $l$ time $t$, $h^l_t$ has a Markov blanket of $h^l_{t-1}$, $h^{l-1}_t$ and $x_t$ (due to
residual connections).


layer2   * -> * -> * -> * -> * -> * ->     ...       * -> *
         ^    ^    ^    ^    ^    ^                  ^    ^
         |    |    |    |    |    |                  |    |
layer1   * -> * -> * -> * -> * -> * ->     ...       * -> *
         ^    ^    ^    ^    ^    ^                  ^    ^
         |    |    |    |    |    |                  |    |
input    .    .    S    M    .    .        ...       S    M
pos      0    1    2    3    4    5                T-1    T

Through data processing inequality, $I(h_t ; h_{t-1}) \le I(h_t ; h_{t-2}) \le I(h_t ;h_{t-3}) ...$ (for either layer)
So, hidden states at time 3 must retain the 'M' information across that vast stretch.
And the only way they can do it is to ignore the majority of intervening information
coming in. Seems the only way to ignore this is through the $\Delta(x_t)$ functions acting
on B. But the $\Delta$ functions are just linear separators of the input so I don't see
how they can effectively do that separation.

I'm sure I'm missing something basic though - obviously the model does solve the
task.

from mamba.

albertfgu avatar albertfgu commented on July 17, 2024

After the first layer, the representations have all been mixed. Let me use your diagram:

layer2   * -> * -> * -> * -> * -> * ->     ...       * -> *
         ^    ^    ^    ^    ^    ^                  ^    ^
         |    |    |    |    |    |                  |    |
layer1   * -> * -> * -> o -> * -> o ->     ...       * -> *
         ^    ^    ^    ^    ^    ^                  ^    ^
         |    |    |    |    |    |                  |    |
input    .    .    S    M1   .    M2       ...       S    M
pos      0    1    2    3    4    5                T-1    T

If I'm understanding correctly, your objection is that

  • M1 at the bottom is important while M2 should be ignored
  • but M1 and M2 have the same representation, so how can the model achieve this?

While this is true, the point is that the two o marks (which are outputs of the first layer) can depend on everything before them (because they are outputs of one SSM), and have different representations.

from mamba.

albertfgu avatar albertfgu commented on July 17, 2024

Or perhaps I'm misunderstanding your question still. Reading it again, I don't quite understand this phrase that you've repeated a few times:

Delta can't learn to perform a linear (affine) separation of all possible combinations of memory_token vs. all-other-tokens

I think maybe another point is that the model isn't classifying M (memory-token) vs all-other-tokens. It's classifying S. All it needs to do is know that it's seen S, and from then on ignore everything else.

from mamba.

hrbigelow avatar hrbigelow commented on July 17, 2024

After the first layer, the representations have all been mixed. Let me use your diagram:

layer2   * -> * -> * -> * -> * -> * ->     ...       * -> *
         ^    ^    ^    ^    ^    ^                  ^    ^
         |    |    |    |    |    |                  |    |
layer1   * -> * -> * -> o -> * -> o ->     ...       * -> *
         ^    ^    ^    ^    ^    ^                  ^    ^
         |    |    |    |    |    |                  |    |
input    .    .    S    M1   .    M2       ...       S    M
pos      0    1    2    3    4    5                T-1    T

If I'm understanding correctly, your objection is that

  • M1 at the bottom is important while M2 should be ignored
  • but M1 and M2 have the same representation, so how can the model achieve this?

While this is true, the point is that the two o marks (which are outputs of the first layer) can depend on everything before them (because they are outputs of one SSM), and have different representations.

Ahh yes this does help clarify.

So at a high level, what I was trying to understand is whether Mamba's ability to perfectly solve the induction head task across such long context depends crucially on the fact that $\Delta(u_l)$ is input-dependent. My simplistic picture was that maybe $\Delta(u_l)$ learns to output near zero values for some tokens, so that in the recurrence relation:

$h_l = \exp(\Delta(u_l)A) h_{l-1} + \Delta(u_l) B(u_l) u_l$

when $l$ is in the intervening positions after the first 'M' (the positions we want the model to 'ignore'), I thought that maybe the model achieves this ignoring through the second term $\Delta(u_l) B(u_l) u_l$ being close to zero for most of those timesteps.

I had thought maybe this ability for Mamba to ignore long stretches was due to this idea of input-dependent gating, described here:

image

But that's probably not the case, right? Maybe what's really going on is that those terms aren't near zero during the recurrence over those million or so intervening tokens. Rather, the state $h_l$ just has enough subspaces that the folding in of so much more information still does not disturb the information from the first M.

Thanks again!

By the way, I do want to be mindful not to clutter github issues with theoretical discussions if that is not appropriate. But I am very grateful for your answers - this is a truly exciting development.

from mamba.

albertfgu avatar albertfgu commented on July 17, 2024

I think I still don't understand your question. Letting $\Delta \to 0$ where it wants to (in particular once it's memorized the token, the part of the hidden state that stores the memory token should have $\Delta=0$ for subsequent timesteps) is precisely the motivation for the input-dependent selection.

It would be cool for someone to verify by actually looking at the learned activations.

from mamba.

hrbigelow avatar hrbigelow commented on July 17, 2024

I think I still don't understand your question. Letting $\Delta \rightarrow 0$ where it wants to (in particular once it's memorized the token, the part of the hidden state that stores the memory token should have $\Delta = 0$ for subsequent timesteps) is precisely the motivation for the input-dependent selection.

I'm starting to read your thesis by the way - there is too much background I'm unfamiliar with at the moment.

That would be very interesting if true! I still don't see it though. Because, in the very first layer, where $u_l$ are freshly embedded tokens before any mixing, I don't see how $\Delta(u_l)$ could be close to zero for all of the million or so intervening token inputs - $\Delta(u_l)$ is just defined as $\tau_t(Parameter + Linear_N(u_l))$. So, in your example above, at the very least, M2 will produce the same $\Delta$ value as M1.

In the second layer, $u_l$ are now richer representations including history, so it's harder to reason about them. But $\Delta(u_l)$ in the second layer is still just a linear separator so again it's hard to imagine a pattern where we have $\Delta(u_3) \gt 0$ and $\Delta(u_i) \approx 0$ for all $i \in [4, 10^6]$ (position 3 being the occurrence of M1)

It would be cool for someone to verify by actually looking at the learned activations.

Good idea, if I get a chance I'll try to test that.

from mamba.

albertfgu avatar albertfgu commented on July 17, 2024

Yes, I'm referring to the 2nd layer as previously discussed. In the second layer, you're not working with $u_t$, but $y_t$, the outputs of the first layer. And again, what the model needs to actually operate on is not the memorization tokens but the "induction token" S. It's easy for the first layer to construct representations $y_t$ that encode whether or not S has been previously seen.

I think it would be a great exercise to write down a closed-form mechanism that solves this task, which I believe is fairly simple, and empirically check if a trained model learns a similar mechanism.

from mamba.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.