Giter Site home page Giter Site logo

Qs on inference about mixture-of-depths HOT 6 CLOSED

cyLi-Tiger avatar cyLi-Tiger commented on September 17, 2024
Qs on inference

from mixture-of-depths.

Comments (6)

sramshetty avatar sramshetty commented on September 17, 2024 1

Yeah good catch, I guess currently I only route if the initial sequence length is greater than capacity for generation but all following tokens will be included. One thing that I wonder is whether the same capacity should be relevant for inference vs. training, since inference will be cheaper regardless. Nonetheless, I'll try to fix this or at least make it an option. Thanks!

from mixture-of-depths.

sramshetty avatar sramshetty commented on September 17, 2024

Hi, I can see where the confusion is coming from and agree that maybe there is a better or cleaner way of doing this. However, the llama implementation for inference also passes the full sequence to each layer since you still want to compute attention for the new token over the full sequence. However, they maintain a start_pos that is used for masking and caching which, if I understand correctly, achieves what you might have in mind.

I tried to keep the implementation as close to Llama as possible so I kept the start_pos argument.

On the other hand, if you mean that routing should minimize the number of tokens passed through the block, then yes every token will pass through the router blocks up until there are more than C (capacity) tokens. I was basing this decision off the paper's strategy to:

" Set a static compute budget that is less than that of an equivalent vanilla transformer by limiting
the number of tokens in a sequence that can participate in a block’s computations (i.e., selfattention and subsequent MLP). For example, while a vanilla transformer might permit all the tokens in a sequence to participate in self-attention, we might limit the number to 50% of the tokens in a sequence."

Since we seek a compute budget, we don't necessarily mind using all the tokens if doing so doesn't exceed our budget. When we exceed the capacity then we need to perform routing to limit the number of tokens flowing through the block.

Hopefully, this answered or cleared up your confusion, otherwise happy to provide more details or clarity. Thanks!

from mixture-of-depths.

cyLi-Tiger avatar cyLi-Tiger commented on September 17, 2024

I got your point that as long as the we don't exceed the budget, it's fine that all the tokens are routed to attention and mlp. However, the capacity remains unchanged during the auto regressive decoding process, and all the tokens will be routed to attention and mlp. I think the capacity needs to be dynamically updated according to the numbers of routed tokens. Otherwise, we may excceed the budget. Correct me if I made any mistakes, thanks!

from mixture-of-depths.

sramshetty avatar sramshetty commented on September 17, 2024

Not sure I understand, the number of tokens routed is given by the capacity which we set during training as max_seq_len // 8. During inference, when we perform autoregressive decoding the router won't exclude any tokens until the generation length has reached our capacity, after that the top-k tokens will be processed. Is there a case where you see it exceeding the set capacity?

from mixture-of-depths.

cyLi-Tiger avatar cyLi-Tiger commented on September 17, 2024

when we perform autoregressive decoding the router won't exclude any tokens until the generation length has reached our capacity

According to your code, we can't tell when the generation length will reach the capacity. Even if we know, before that, all the generated tokens will be routed to attentions and mlp, which is exactly the same with vanilla transformers.

after that the top-k tokens will be processed

What do you mean here, I think after we run out of the capacity, each token will be directly excluded from the router.

Is there a case where you see it exceeding the set capacity?

My point is, the capacity here is static. For example, our capacity is 4, and k will always be 1 since seq_len == 1, so all the tokens will go through attention and mlp. Let's say the final generation length is 10, then we exceed the capacity by 6.

from mixture-of-depths.

sramshetty avatar sramshetty commented on September 17, 2024

So, I ended up realizing I hadn't been using the router in a causal manner for inference. Now, in #4 I use the router for the initial text/prompt for generation (which may not be necessary) and then use the router in a causal manner for autoregresive decoding. I don't think capacity is relevant for inference since it would be non-causal as currently implemented, which the paper avoids by using its proposed sampling methods.

from mixture-of-depths.

Related Issues (4)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.