Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly advances the support for INT8 KV cache within the system by integrating a novel two-stage flash decoding implementation powered by Triton. The changes streamline the attention mechanism for quantized key-value pairs, aiming to improve computational efficiency and speed. Additionally, the backend selection process has been made more robust, and new autotuning profiles are included to maximize the performance benefits of these optimizations. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a new 'normal' attention decoding path for int8kv (8-bit quantized key-value) attention, replacing the ppl_mha_int8kv_decode_att with normal_decode_att in int8kv.py and implementing it with a two-stage flash decoding kernel. The change also updates backend selection logic to prevent key errors and removes fa3 and flashinfer backends for int4kv and int8kv. Review comments highlight several improvement opportunities: refactoring duplicated block_num logic in int8kv_flash_decoding_stage1.py to avoid inconsistencies, defining magic numbers like BLOCK_SEQ and the block_num thresholds as named constants for better readability and maintainability, inferring device from input tensors instead of hardcoding 'cuda' for mid_o and mid_o_logexpsum allocations, and clarifying the implications or adding an assertion for gqa_group_size being a power of 2 in int8kv_flash_decoding_stage1.py.
| if batch_size <= 16: | ||
| block_num = 128 | ||
| elif batch_size <= 64: | ||
| block_num = 64 | ||
| else: | ||
| block_num = 32 |
There was a problem hiding this comment.
This if/elif/else block for block_num is duplicated from int8kv_flash_decoding.py. Duplicated logic can lead to inconsistencies if changes are made in one place but not the other. It's recommended to refactor this into a shared utility function or a common constant.
| if batch_size <= 16: | |
| block_num = 128 | |
| elif batch_size <= 64: | |
| block_num = 64 | |
| else: | |
| block_num = 32 | |
| # Heuristic to balance parallelism and memory consumption based on batch size | |
| if batch_size <= 16: | |
| block_num = 128 | |
| elif batch_size <= 64: | |
| block_num = 64 | |
| else: | |
| block_num = 32 |
| q_head_num = q.shape[1] | ||
| head_dim = q.shape[2] | ||
|
|
||
| BLOCK_SEQ = 256 |
There was a problem hiding this comment.
The value 256 for BLOCK_SEQ appears to be a magic number. It would be beneficial to define this as a named constant (e.g., DEFAULT_BLOCK_SEQ) to improve readability and make its purpose clearer. If this value can vary, consider making it a configurable parameter.
| BLOCK_SEQ = 256 | |
| DEFAULT_BLOCK_SEQ = 256 | |
| BLOCK_SEQ = DEFAULT_BLOCK_SEQ |
| if batch_size <= 16: | ||
| block_num = 128 | ||
| elif batch_size <= 64: | ||
| block_num = 64 | ||
| else: | ||
| block_num = 32 |
There was a problem hiding this comment.
The if/elif/else block uses several magic numbers (16, 64, 128, 64, 32) to determine block_num. These thresholds and values could be made more explicit by defining them as named constants or by providing comments explaining the rationale behind these specific values. This would improve maintainability and make it easier to understand the performance tuning logic.
| if batch_size <= 16: | |
| block_num = 128 | |
| elif batch_size <= 64: | |
| block_num = 64 | |
| else: | |
| block_num = 32 | |
| # Heuristic to balance parallelism and memory consumption based on batch size | |
| if batch_size <= 16: | |
| block_num = 128 | |
| elif batch_size <= 64: | |
| block_num = 64 | |
| else: | |
| block_num = 32 |
| else: | ||
| block_num = 32 | ||
|
|
||
| mid_o = alloc_tensor_func([batch_size, q_head_num, block_num, head_dim], dtype=q.dtype, device="cuda") |
There was a problem hiding this comment.
The device="cuda" is hardcoded. While this function is likely intended for CUDA, it's generally better practice to infer the device from input tensors (e.g., q.device) for greater flexibility and to avoid potential issues if the input tensors are on a different device.
| mid_o = alloc_tensor_func([batch_size, q_head_num, block_num, head_dim], dtype=q.dtype, device="cuda") | |
| mid_o = alloc_tensor_func([batch_size, q_head_num, block_num, head_dim], dtype=q.dtype, device=q.device) |
| block_num = 32 | ||
|
|
||
| mid_o = alloc_tensor_func([batch_size, q_head_num, block_num, head_dim], dtype=q.dtype, device="cuda") | ||
| mid_o_logexpsum = alloc_tensor_func([batch_size, q_head_num, block_num], dtype=torch.float32, device="cuda") |
There was a problem hiding this comment.
The device="cuda" is hardcoded. It's generally better to infer the device from input tensors (e.g., q.device). Additionally, torch.float32 is hardcoded for mid_o_logexpsum. If q.dtype is different (e.g., bfloat16), consider if float32 is strictly necessary here or if q.dtype could be used for consistency, or add a comment explaining the float32 requirement.
| mid_o_logexpsum = alloc_tensor_func([batch_size, q_head_num, block_num], dtype=torch.float32, device="cuda") | |
| mid_o_logexpsum = alloc_tensor_func([batch_size, q_head_num, block_num], dtype=torch.float32, device=q.device) |
| assert triton.next_power_of_2(Lk) == Lk | ||
| KV_QUANT_GROUP_SIZE = v.shape[-1] // v_scale.shape[-1] | ||
| assert triton.next_power_of_2(KV_QUANT_GROUP_SIZE) == KV_QUANT_GROUP_SIZE | ||
| BLOCK_HEAD = triton.next_power_of_2(gqa_group_size) |
There was a problem hiding this comment.
The line BLOCK_HEAD = triton.next_power_of_2(gqa_group_size) forces BLOCK_HEAD to be a power of 2. If gqa_group_size is not a power of 2, this could lead to unexpected behavior or inefficient resource allocation. Consider adding an assertion that gqa_group_size itself is a power of 2 if that's a strict requirement, or add a comment explaining the implications of this conversion.
| BLOCK_HEAD = triton.next_power_of_2(gqa_group_size) | |
| assert triton.next_power_of_2(KV_QUANT_GROUP_SIZE) == KV_QUANT_GROUP_SIZE | |
| BLOCK_HEAD = triton.next_power_of_2(gqa_group_size) |
No description provided.