-
Notifications
You must be signed in to change notification settings - Fork 10.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : add example for tree-based parallel decoding #3137
Comments
I read the medusa blog post yesterday and I was very impressed with the results, it sounds like a very practical alternative to draft-based speculative decoding. I'm not sure about how easy it would be to fine-tune and distribute the medusa heads though, I hope it's not another case of thebloke having to redo all of his models again if we want to see widespread support 😅 |
I've been thinking about this today and I believe I have a very good plan for supporting it and wanted to write down some thoughts. The interesting thing about the idea is that with custom attention mask, we can automatically support batched decoding using a unified KV cache. The unified KV cache will contain the caches of multiple sequences and will provide information about each cached token - which sequences it is part of and its position. Utilizing the custom attention mask, instead of the existing diagonal mask, we can attend to different tokens for different sequences by simply applying the correct mask. For this to work we need to store non-roped KV data in the cache and rope the entire sequence during inference (edit: this is actually not necessary). We also need to extend the rope operator to accept a vector with custom positions instead of just This will allow us to do:
As a secondary objective, I think I now see an elegant way to quantize the
Eval calls will be split into chunks of 32 or higher powers of 2 directly utilizing the L1 cache. The rest of the batch (mod 32) will be evaluated with the L0 cache plus an extra call at the end to move any full 32 chunks from the L0 into the L1 cache. Unless I'm missing something, this should allow quantization of the KV cache without modifications in the backends and only slightly changing the graph using existing ops. Will be giving this idea a try soon - it will likely introduce some significant API changes (e.g. no more |
I imagine that to be able to do this efficiently, it will require a new op to apply the RoPE on the fly without having to make a copy of the entire |
Can't wait for this! |
Yes, this is a concern and need to do some testing to see what's the best option. It will be making a copy of the cache only per transformer layer, so not the entire cache all at once. The bigger problem however is that we probably would need to store a non-transposed We also have models such as StarCoder that don't involve position encoding of the KV data - in such cases we don't have to do anything when we do cache operations. Edit: looks like ALiBi is not a problem because it's applied just to |
Please ignore this comment - I was somehow confused that the |
This is now supported. Using a single draft model with branching drafts does not seem to get significant improvement over the standard single-sequence draft speculation. Next test is to apply the method directly on Medusa models with multiple heads where no branching strategy is required. |
Refs:
In simple terms, after implementing batched decoding (a.k.a. parallel decoding) we can extend the inference functionality to support applying a custom attention mask to the batch. This can be used to create a causal tree mask that allows to evaluate a tree of continuations in a single pass, instead of a large batch of independent sequences.
This is useful for implementing advanced speculative strategies such as SpecInfer's token tree verification and Medusa heads
The text was updated successfully, but these errors were encountered: