Skip to content

triton support int8kv kv cache.#1238

Merged
hiworldwzj merged 4 commits intomainfrom
wzj_fix
Mar 19, 2026
Merged

triton support int8kv kv cache.#1238
hiworldwzj merged 4 commits intomainfrom
wzj_fix

Conversation

@hiworldwzj
Copy link
Collaborator

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly advances the support for INT8 KV cache within the system by integrating a novel two-stage flash decoding implementation powered by Triton. The changes streamline the attention mechanism for quantized key-value pairs, aiming to improve computational efficiency and speed. Additionally, the backend selection process has been made more robust, and new autotuning profiles are included to maximize the performance benefits of these optimizations.

Highlights

  • Triton INT8 KV Cache Support: Introduced a new two-stage flash decoding mechanism for INT8 KV cache using Triton kernels, enhancing performance and efficiency for quantized key-value storage.
  • Backend Selection Refinement: Modified the attention backend auto-selection logic to ensure that a backend exists in the map before attempting validation, preventing potential errors.
  • Code Refactoring: Refactored the int8kv attention decoding function, renaming ppl_mha_int8kv_decode_att to normal_decode_att and updating its import path for better clarity and modularity.
  • Autotuning Configurations: Added specific autotuning configurations for the new Triton flash decoding stage 1 kernel, optimized for NVIDIA H200 GPUs, to ensure optimal performance across various batch sizes and sequence lengths.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new 'normal' attention decoding path for int8kv (8-bit quantized key-value) attention, replacing the ppl_mha_int8kv_decode_att with normal_decode_att in int8kv.py and implementing it with a two-stage flash decoding kernel. The change also updates backend selection logic to prevent key errors and removes fa3 and flashinfer backends for int4kv and int8kv. Review comments highlight several improvement opportunities: refactoring duplicated block_num logic in int8kv_flash_decoding_stage1.py to avoid inconsistencies, defining magic numbers like BLOCK_SEQ and the block_num thresholds as named constants for better readability and maintainability, inferring device from input tensors instead of hardcoding 'cuda' for mid_o and mid_o_logexpsum allocations, and clarifying the implications or adding an assertion for gqa_group_size being a power of 2 in int8kv_flash_decoding_stage1.py.

Comment on lines +311 to +316
if batch_size <= 16:
block_num = 128
elif batch_size <= 64:
block_num = 64
else:
block_num = 32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This if/elif/else block for block_num is duplicated from int8kv_flash_decoding.py. Duplicated logic can lead to inconsistencies if changes are made in one place but not the other. It's recommended to refactor this into a shared utility function or a common constant.

Suggested change
if batch_size <= 16:
block_num = 128
elif batch_size <= 64:
block_num = 64
else:
block_num = 32
# Heuristic to balance parallelism and memory consumption based on batch size
if batch_size <= 16:
block_num = 128
elif batch_size <= 64:
block_num = 64
else:
block_num = 32

q_head_num = q.shape[1]
head_dim = q.shape[2]

BLOCK_SEQ = 256
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The value 256 for BLOCK_SEQ appears to be a magic number. It would be beneficial to define this as a named constant (e.g., DEFAULT_BLOCK_SEQ) to improve readability and make its purpose clearer. If this value can vary, consider making it a configurable parameter.

Suggested change
BLOCK_SEQ = 256
DEFAULT_BLOCK_SEQ = 256
BLOCK_SEQ = DEFAULT_BLOCK_SEQ

Comment on lines +32 to +37
if batch_size <= 16:
block_num = 128
elif batch_size <= 64:
block_num = 64
else:
block_num = 32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The if/elif/else block uses several magic numbers (16, 64, 128, 64, 32) to determine block_num. These thresholds and values could be made more explicit by defining them as named constants or by providing comments explaining the rationale behind these specific values. This would improve maintainability and make it easier to understand the performance tuning logic.

Suggested change
if batch_size <= 16:
block_num = 128
elif batch_size <= 64:
block_num = 64
else:
block_num = 32
# Heuristic to balance parallelism and memory consumption based on batch size
if batch_size <= 16:
block_num = 128
elif batch_size <= 64:
block_num = 64
else:
block_num = 32

else:
block_num = 32

mid_o = alloc_tensor_func([batch_size, q_head_num, block_num, head_dim], dtype=q.dtype, device="cuda")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The device="cuda" is hardcoded. While this function is likely intended for CUDA, it's generally better practice to infer the device from input tensors (e.g., q.device) for greater flexibility and to avoid potential issues if the input tensors are on a different device.

Suggested change
mid_o = alloc_tensor_func([batch_size, q_head_num, block_num, head_dim], dtype=q.dtype, device="cuda")
mid_o = alloc_tensor_func([batch_size, q_head_num, block_num, head_dim], dtype=q.dtype, device=q.device)

block_num = 32

mid_o = alloc_tensor_func([batch_size, q_head_num, block_num, head_dim], dtype=q.dtype, device="cuda")
mid_o_logexpsum = alloc_tensor_func([batch_size, q_head_num, block_num], dtype=torch.float32, device="cuda")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The device="cuda" is hardcoded. It's generally better to infer the device from input tensors (e.g., q.device). Additionally, torch.float32 is hardcoded for mid_o_logexpsum. If q.dtype is different (e.g., bfloat16), consider if float32 is strictly necessary here or if q.dtype could be used for consistency, or add a comment explaining the float32 requirement.

Suggested change
mid_o_logexpsum = alloc_tensor_func([batch_size, q_head_num, block_num], dtype=torch.float32, device="cuda")
mid_o_logexpsum = alloc_tensor_func([batch_size, q_head_num, block_num], dtype=torch.float32, device=q.device)

assert triton.next_power_of_2(Lk) == Lk
KV_QUANT_GROUP_SIZE = v.shape[-1] // v_scale.shape[-1]
assert triton.next_power_of_2(KV_QUANT_GROUP_SIZE) == KV_QUANT_GROUP_SIZE
BLOCK_HEAD = triton.next_power_of_2(gqa_group_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The line BLOCK_HEAD = triton.next_power_of_2(gqa_group_size) forces BLOCK_HEAD to be a power of 2. If gqa_group_size is not a power of 2, this could lead to unexpected behavior or inefficient resource allocation. Consider adding an assertion that gqa_group_size itself is a power of 2 if that's a strict requirement, or add a comment explaining the implications of this conversion.

Suggested change
BLOCK_HEAD = triton.next_power_of_2(gqa_group_size)
assert triton.next_power_of_2(KV_QUANT_GROUP_SIZE) == KV_QUANT_GROUP_SIZE
BLOCK_HEAD = triton.next_power_of_2(gqa_group_size)

@hiworldwzj hiworldwzj merged commit a7c925c into main Mar 19, 2026
1 check passed
@hiworldwzj hiworldwzj deleted the wzj_fix branch March 19, 2026 07:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant