Comments (10)
The amount of information the model can store is proportional to its state size. So in this task, the reason it's able to generalize perfectly is because there isn't much information it has to remember (just one token).
Your question might be touching on a second point of more nuance, although I'm not exactly sure if this is what you mean. But the vocabulary size also implicitly affects memorization ability, because the larger the vocab the more memory is required to represent one token (e.g. information theoretically with a uniform distribution, log(|vocab|) bits). And so a finite state size can also only support a maximum vocabulary in principle.
from mamba.
Hi Albert,
Thank you for the explanation and sorry, I do think I was unclear. I understand your point about the informational capacity and vocab size. But it seems the problem comes a bit upstream from that, namely that Delta can't learn to perform a linear (affine) separation of all possible combinations of memory_token
vs. all-other-tokens. I do see that in the input-dependent version, the B and C parameters are also learning something, but they are also just linear so it doesn't seem to resolve this quandary. Also, actually I'm not asking even about generalization ability - I just can't see how such a model would be able to achieve 100% training accuracy.
From this section:
my understanding from the following section would be that the synthetic dataset is generated like:
EDIT: forgot to add in the special tok a second time
from random import choice, choices
def random_seq(L=256, V=16, P=10):
"""
Section E.1
Training consists of randomly generating data every step, with a batch size of 8.
"""
# prefix is the region where the special token first occurs
assert P < L - 2
vocab = list(range(V))
memory_tok = choice(vocab)
other_toks = [ t for t in vocab if t != memory_tok ]
# Section 3.1 from https://arxiv.org/pdf/2212.14052.pdf the 'special token'
special_tok = V
seq = choices(vocab, k=P) + [special_tok, memory_tok] + choices(other_toks, k=L-P-2) + [special_tok, memory_tok]
return seq
if __name__ == '__main__':
L, V, P = 20, 5, 3
print(f'seq_length={L}, vocab_size={V}, prefix={P}')
for b in range(20):
print(random_seq(L, V, P))
"""
seq_length=20, vocab_size=5, prefix=3
[4, 0, 3, 5, 3, 0, 1, 4, 0, 0, 1, 2, 1, 0, 2, 2, 4, 1, 2, 1, 5, 3]
[4, 3, 1, 5, 0, 4, 4, 2, 3, 4, 4, 1, 3, 4, 4, 2, 2, 2, 4, 3, 5, 0]
[0, 4, 3, 5, 1, 4, 2, 2, 3, 2, 3, 4, 4, 3, 3, 0, 0, 4, 3, 0, 5, 1]
[1, 2, 3, 5, 1, 2, 3, 2, 3, 2, 2, 3, 0, 0, 0, 4, 2, 0, 2, 2, 5, 1]
[2, 3, 2, 5, 2, 1, 4, 1, 4, 3, 1, 1, 4, 4, 3, 4, 4, 4, 0, 0, 5, 2]
[4, 1, 4, 5, 4, 0, 2, 1, 0, 3, 1, 3, 0, 1, 1, 1, 3, 2, 2, 2, 5, 4]
[3, 4, 0, 5, 3, 1, 1, 1, 0, 1, 2, 2, 0, 2, 0, 2, 4, 4, 0, 2, 5, 3]
[3, 3, 0, 5, 2, 0, 4, 3, 1, 1, 1, 1, 3, 0, 3, 4, 3, 4, 4, 1, 5, 2]
[3, 4, 4, 5, 4, 3, 0, 0, 3, 3, 1, 2, 0, 3, 3, 3, 0, 1, 2, 1, 5, 4]
[3, 1, 3, 5, 2, 1, 0, 0, 4, 1, 1, 3, 1, 3, 4, 3, 1, 1, 3, 3, 5, 2]
[1, 1, 4, 5, 2, 0, 4, 1, 0, 4, 1, 3, 4, 4, 1, 0, 1, 1, 0, 3, 5, 2]
[4, 2, 0, 5, 3, 0, 2, 0, 1, 4, 2, 0, 1, 4, 1, 4, 4, 0, 1, 0, 5, 3]
[0, 1, 1, 5, 0, 3, 4, 1, 1, 1, 3, 1, 2, 4, 1, 3, 2, 1, 4, 3, 5, 0]
[2, 2, 2, 5, 1, 0, 2, 4, 0, 4, 3, 4, 4, 0, 0, 3, 4, 4, 3, 3, 5, 1]
[0, 3, 3, 5, 2, 3, 4, 3, 1, 0, 4, 1, 1, 1, 4, 3, 1, 1, 1, 0, 5, 2]
[0, 2, 3, 5, 0, 1, 1, 1, 2, 3, 4, 1, 3, 2, 1, 3, 4, 2, 3, 1, 5, 0]
[0, 4, 3, 5, 0, 1, 1, 1, 2, 1, 1, 2, 2, 3, 1, 2, 4, 2, 4, 4, 5, 0]
[3, 2, 4, 5, 4, 1, 0, 0, 1, 1, 1, 1, 0, 0, 3, 1, 2, 3, 3, 1, 5, 4]
[0, 3, 3, 5, 4, 1, 0, 2, 3, 2, 0, 0, 1, 3, 1, 1, 3, 3, 1, 3, 5, 4]
[3, 0, 3, 5, 4, 2, 3, 3, 3, 1, 3, 3, 2, 1, 0, 2, 1, 1, 1, 0, 5, 4]
"""
So, in the training data, every token in the 16-token vocab will eventually be used as the memory_tok
. My simplistic view at the moment then is that the Delta operator must learn to evaluate to a positive value for memory_tok
so that \bar{B} evaluates positively, and at the same time must evaluate to near zero for all other tokens != memory_tok. But, Delta is just a perceptron basically, so it's not possible for it to simultaneously learn to perform this round-robin separation of memory tokens from other tokens for each possible memory_tok.
For reference, I've tried to summarize the real-valued version of the algorithm in einsum as:
Source colab here
Please do let me know if I made a mistake!
from mamba.
It's a 2 layer model, not 1 layer. I think you might be right that a single (S6) layer can't learn this task. Although, a single Mamba block probably can, because of the local convolution before the main SSM.
from mamba.
Thanks for your response. I was aware it was a two-layer model actually - sorry I should have mentioned that.
To be clear, this is the graphical model structure if I'm not mistaken. Hidden state
at layer
residual connections).
layer2 * -> * -> * -> * -> * -> * -> ... * -> *
^ ^ ^ ^ ^ ^ ^ ^
| | | | | | | |
layer1 * -> * -> * -> * -> * -> * -> ... * -> *
^ ^ ^ ^ ^ ^ ^ ^
| | | | | | | |
input . . S M . . ... S M
pos 0 1 2 3 4 5 T-1 T
Through data processing inequality,
So, hidden states at time 3 must retain the 'M' information across that vast stretch.
And the only way they can do it is to ignore the majority of intervening information
coming in. Seems the only way to ignore this is through the
on B. But the
how they can effectively do that separation.
I'm sure I'm missing something basic though - obviously the model does solve the
task.
from mamba.
After the first layer, the representations have all been mixed. Let me use your diagram:
layer2 * -> * -> * -> * -> * -> * -> ... * -> *
^ ^ ^ ^ ^ ^ ^ ^
| | | | | | | |
layer1 * -> * -> * -> o -> * -> o -> ... * -> *
^ ^ ^ ^ ^ ^ ^ ^
| | | | | | | |
input . . S M1 . M2 ... S M
pos 0 1 2 3 4 5 T-1 T
If I'm understanding correctly, your objection is that
- M1 at the bottom is important while M2 should be ignored
- but M1 and M2 have the same representation, so how can the model achieve this?
While this is true, the point is that the two o
marks (which are outputs of the first layer) can depend on everything before them (because they are outputs of one SSM), and have different representations.
from mamba.
Or perhaps I'm misunderstanding your question still. Reading it again, I don't quite understand this phrase that you've repeated a few times:
Delta can't learn to perform a linear (affine) separation of all possible combinations of memory_token vs. all-other-tokens
I think maybe another point is that the model isn't classifying M (memory-token) vs all-other-tokens. It's classifying S. All it needs to do is know that it's seen S, and from then on ignore everything else.
from mamba.
After the first layer, the representations have all been mixed. Let me use your diagram:
layer2 * -> * -> * -> * -> * -> * -> ... * -> * ^ ^ ^ ^ ^ ^ ^ ^ | | | | | | | | layer1 * -> * -> * -> o -> * -> o -> ... * -> * ^ ^ ^ ^ ^ ^ ^ ^ | | | | | | | | input . . S M1 . M2 ... S M pos 0 1 2 3 4 5 T-1 T
If I'm understanding correctly, your objection is that
- M1 at the bottom is important while M2 should be ignored
- but M1 and M2 have the same representation, so how can the model achieve this?
While this is true, the point is that the two
o
marks (which are outputs of the first layer) can depend on everything before them (because they are outputs of one SSM), and have different representations.
Ahh yes this does help clarify.
So at a high level, what I was trying to understand is whether Mamba's ability to perfectly solve the induction head task across such long context depends crucially on the fact that
when
I had thought maybe this ability for Mamba to ignore long stretches was due to this idea of input-dependent gating, described here:
But that's probably not the case, right? Maybe what's really going on is that those terms aren't near zero during the recurrence over those million or so intervening tokens. Rather, the state
Thanks again!
By the way, I do want to be mindful not to clutter github issues with theoretical discussions if that is not appropriate. But I am very grateful for your answers - this is a truly exciting development.
from mamba.
I think I still don't understand your question. Letting
It would be cool for someone to verify by actually looking at the learned activations.
from mamba.
I think I still don't understand your question. Letting
$\Delta \rightarrow 0$ where it wants to (in particular once it's memorized the token, the part of the hidden state that stores the memory token should have$\Delta = 0$ for subsequent timesteps) is precisely the motivation for the input-dependent selection.
I'm starting to read your thesis by the way - there is too much background I'm unfamiliar with at the moment.
That would be very interesting if true! I still don't see it though. Because, in the very first layer, where
In the second layer,
It would be cool for someone to verify by actually looking at the learned activations.
Good idea, if I get a chance I'll try to test that.
from mamba.
Yes, I'm referring to the 2nd layer as previously discussed. In the second layer, you're not working with
I think it would be a great exercise to write down a closed-form mechanism that solves this task, which I believe is fairly simple, and empirically check if a trained model learns a similar mechanism.
from mamba.
Related Issues (20)
- Questions about Chunk_size using Triton optimization in SSD kernel HOT 2
- When I run mamba2 : ImportError: libcudart.so.11.0: cannot open shared object file: No such file or directory
- Possible bug when running evaluation with self.use_mem_eff_path=False
- Typo of dconv at Line 231 of modules/mamba2.py HOT 1
- How to load mamba1's weight to mamba2 ? HOT 1
- Small datasets HOT 4
- Help with _chunk_state_fwd. HOT 1
- Assertion error in ssd_minimal HOT 5
- Questions regarding pretrained Mamba2-Attention Hybrid Model HOT 2
- (about the paper) In the Section5.1, I have a question: Why M matrix, whose element is also matrix, can finally be (T, T) size? HOT 2
- A mamba scaling problem given the perplexity score curves shown in the TTT paper HOT 2
- Passing an initial_conv_state in mamba_split_conv1d_scan_combined? HOT 2
- Self-distillation technique
- Question for 'self.use_mem_eff_path and inference_params'
- triton.runtime.autotuner.OutOfResources: out of resource: shared memory, Required: 254208, Hardware limit: 101376. HOT 2
- I want to ask does anyone know how to solve this problem
- /anaconda3/lib/python3.11/site-packages/causal_conv1d_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c107WarningC1ENS_7variantIJNS0_11UserWarningENS0_18DeprecationWarningEEEERKNS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEEb HOT 1
- Mamba-2 Error: `'NoneType' object has no attribute 'causal_conv1d_fwd'` HOT 2
- Used selective_scan_cuda and causal_conv1d_cuda, but still very slow to train
- mamba / self-attention hybrid generation
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mamba.