Comments (6)
Yeah good catch, I guess currently I only route if the initial sequence length is greater than capacity for generation but all following tokens will be included. One thing that I wonder is whether the same capacity should be relevant for inference vs. training, since inference will be cheaper regardless. Nonetheless, I'll try to fix this or at least make it an option. Thanks!
from mixture-of-depths.
Hi, I can see where the confusion is coming from and agree that maybe there is a better or cleaner way of doing this. However, the llama implementation for inference also passes the full sequence to each layer since you still want to compute attention for the new token over the full sequence. However, they maintain a start_pos
that is used for masking and caching which, if I understand correctly, achieves what you might have in mind.
I tried to keep the implementation as close to Llama as possible so I kept the start_pos
argument.
On the other hand, if you mean that routing should minimize the number of tokens passed through the block, then yes every token will pass through the router blocks up until there are more than C (capacity) tokens. I was basing this decision off the paper's strategy to:
" Set a static compute budget that is less than that of an equivalent vanilla transformer by limiting
the number of tokens in a sequence that can participate in a block’s computations (i.e., selfattention and subsequent MLP). For example, while a vanilla transformer might permit all the tokens in a sequence to participate in self-attention, we might limit the number to 50% of the tokens in a sequence."
Since we seek a compute budget, we don't necessarily mind using all the tokens if doing so doesn't exceed our budget. When we exceed the capacity then we need to perform routing to limit the number of tokens flowing through the block.
Hopefully, this answered or cleared up your confusion, otherwise happy to provide more details or clarity. Thanks!
from mixture-of-depths.
I got your point that as long as the we don't exceed the budget, it's fine that all the tokens are routed to attention and mlp. However, the capacity remains unchanged during the auto regressive decoding process, and all the tokens will be routed to attention and mlp. I think the capacity needs to be dynamically updated according to the numbers of routed tokens. Otherwise, we may excceed the budget. Correct me if I made any mistakes, thanks!
from mixture-of-depths.
Not sure I understand, the number of tokens routed is given by the capacity which we set during training as max_seq_len // 8
. During inference, when we perform autoregressive decoding the router won't exclude any tokens until the generation length has reached our capacity, after that the top-k tokens will be processed. Is there a case where you see it exceeding the set capacity?
from mixture-of-depths.
when we perform autoregressive decoding the router won't exclude any tokens until the generation length has reached our capacity
According to your code, we can't tell when the generation length will reach the capacity. Even if we know, before that, all the generated tokens will be routed to attentions and mlp, which is exactly the same with vanilla transformers.
after that the top-k tokens will be processed
What do you mean here, I think after we run out of the capacity, each token will be directly excluded from the router.
Is there a case where you see it exceeding the set capacity?
My point is, the capacity here is static. For example, our capacity is 4, and k will always be 1 since seq_len
== 1, so all the tokens will go through attention and mlp. Let's say the final generation length is 10, then we exceed the capacity by 6.
from mixture-of-depths.
So, I ended up realizing I hadn't been using the router in a causal manner for inference. Now, in #4 I use the router for the initial text/prompt for generation (which may not be necessary) and then use the router in a causal manner for autoregresive decoding. I don't think capacity is relevant for inference since it would be non-causal as currently implemented, which the paper avoids by using its proposed sampling methods.
from mixture-of-depths.
Related Issues (4)
- compute about attention HOT 6
- Normalize topk_weight HOT 3
- Is the implementation wrong? HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mixture-of-depths.