Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : add example for tree-based parallel decoding #3137

Closed
ggerganov opened this issue Sep 12, 2023 · 7 comments
Closed

llama : add example for tree-based parallel decoding #3137

ggerganov opened this issue Sep 12, 2023 · 7 comments
Assignees
Labels
performance Speed related topics research 🔬

Comments

@ggerganov
Copy link
Member

Refs:

In simple terms, after implementing batched decoding (a.k.a. parallel decoding) we can extend the inference functionality to support applying a custom attention mask to the batch. This can be used to create a causal tree mask that allows to evaluate a tree of continuations in a single pass, instead of a large batch of independent sequences.

This is useful for implementing advanced speculative strategies such as SpecInfer's token tree verification and Medusa heads

@Azeirah
Copy link
Contributor

Azeirah commented Sep 12, 2023

I read the medusa blog post yesterday and I was very impressed with the results, it sounds like a very practical alternative to draft-based speculative decoding. I'm not sure about how easy it would be to fine-tune and distribute the medusa heads though, I hope it's not another case of thebloke having to redo all of his models again if we want to see widespread support 😅

@ggerganov
Copy link
Member Author

ggerganov commented Sep 16, 2023

I've been thinking about this today and I believe I have a very good plan for supporting it and wanted to write down some thoughts. The interesting thing about the idea is that with custom attention mask, we can automatically support batched decoding using a unified KV cache. The unified KV cache will contain the caches of multiple sequences and will provide information about each cached token - which sequences it is part of and its position. Utilizing the custom attention mask, instead of the existing diagonal mask, we can attend to different tokens for different sequences by simply applying the correct mask. For this to work we need to store non-roped KV data in the cache and rope the entire sequence during inference (edit: this is actually not necessary). We also need to extend the rope operator to accept a vector with custom positions instead of just n_past and extend the mask operator to accept a mask 2D tensor - these will be generated on-the-fly for each eval call based on the evaluated sequences. All other functionality will be reused, so the graph will remain almost the same.

This will allow us to do:

  • batched prompt processing (e.g. during training)
  • parallel decoding with common prefix (the common KV cache will be shared - no copies)
  • tree-based parallel decoding (useful for speculative strategies)
  • KV cache modifications (shift, delete, insert, compress)
  • no context swaps (llama : try to avoid context swap #2060)
  • efficient strided perplexity
  • hot-swap of new sequences (useful for cloud services and serving multiple clients)

As a secondary objective, I think I now see an elegant way to quantize the V cache by utilizing a 2-stage cache:

  • L0: small classical cache with size of 64
  • L1: big quantum cache with size multiple of 32

Eval calls will be split into chunks of 32 or higher powers of 2 directly utilizing the L1 cache. The rest of the batch (mod 32) will be evaluated with the L0 cache plus an extra call at the end to move any full 32 chunks from the L0 into the L1 cache. Unless I'm missing something, this should allow quantization of the KV cache without modifications in the backends and only slightly changing the graph using existing ops.

Will be giving this idea a try soon - it will likely introduce some significant API changes (e.g. no more n_past, etc.) but if it works out, it will unlock some very interesting applications and will bring the roadmap to almost full completion

@slaren
Copy link
Member

slaren commented Sep 16, 2023

For this to work we need to store non-roped KV data in the cache and rope the entire sequence during inference.

I imagine that to be able to do this efficiently, it will require a new op to apply the RoPE on the fly without having to make a copy of the entire K tensor in main memory. Maybe this could possibly be done in a fused attention op similar to ggml_flash_attn.

@dillfrescott
Copy link

Can't wait for this!

@ggerganov
Copy link
Member Author

ggerganov commented Sep 17, 2023

I imagine that to be able to do this efficiently, it will require a new op to apply the RoPE on the fly without having to make a copy of the entire K tensor in main memory

Yes, this is a concern and need to do some testing to see what's the best option. It will be making a copy of the cache only per transformer layer, so not the entire cache all at once. The bigger problem however is that we probably would need to store a non-transposed V in the cache to be able to rope it the same way as K (otherwise we'll need extra rope implementation for transposed data, which I would like to avoid). This makes things simple and symmetrical for implementation, but likely the performance of KQV = V @ KQ will suffer a lot. If there is no way around this, then we might have to fallback to storing roped cache and then re-roping it with the delta of the positions when we do cache operations (i.e. second option in #2060). This should work well with RoPE since I think is an additive operation (i.e. rotate the 2D vectors by certain angle), however not sure what is the case for ALiBi.

We also have models such as StarCoder that don't involve position encoding of the KV data - in such cases we don't have to do anything when we do cache operations.

Edit: looks like ALiBi is not a problem because it's applied just to KQ - i.e. not to K and V individually as RoPE, so should be handle same as StarCoder

@ggerganov
Copy link
Member Author

ggerganov commented Sep 17, 2023

The bigger problem however is that we probably would need to store a non-transposed V in the cache to be able to rope it the same way as K (otherwise we'll need extra rope implementation for transposed data, which I would like to avoid).

Please ignore this comment - I was somehow confused that the V tensor is also roped. Anyway, I'm considering both options of roped and non-roped K cache during the implementation in #3228 and will choose whatever makes more sense.

@ggerganov ggerganov moved this from Todo to In Progress in ggml : roadmap Oct 16, 2023
@ggerganov ggerganov self-assigned this Oct 16, 2023
@ggerganov ggerganov moved this from In Progress to Done in ggml : roadmap Oct 18, 2023
@ggerganov
Copy link
Member Author

This is now supported. Using a single draft model with branching drafts does not seem to get significant improvement over the standard single-sequence draft speculation. Next test is to apply the method directly on Medusa models with multiple heads where no branching strategy is required.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Speed related topics research 🔬
Projects
None yet
Development

No branches or pull requests

4 participants