Skip to content

[Kernel] feat: add NVFP4 blockwise MoE kernels for sm_120#1528

Open
AlpinDale wants to merge 1 commit into
mainfrom
sm_120_nvfp4_moe
Open

[Kernel] feat: add NVFP4 blockwise MoE kernels for sm_120#1528
AlpinDale wants to merge 1 commit into
mainfrom
sm_120_nvfp4_moe

Conversation

@AlpinDale
Copy link
Copy Markdown
Collaborator

Not fully optimized, as a lot of the sm_100 codepath is still used for this.

Tested with alpindale/Ling-mini-2.0-NVFP4, it gets about 91 tok/s decode (slower than the 140 tok/s with AWQ Marlin).

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @AlpinDale, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request lays the groundwork for supporting NVFP4 blockwise Mixture-of-Experts (MoE) operations on NVIDIA's latest SM120 (Blackwell) GPU architecture. It integrates a new reference kernel, updates the underlying CUTLASS dependency, and adjusts the build system and Python dispatch mechanisms to enable this new hardware capability. While functional, the current implementation is noted as not fully optimized, indicating it's an initial step towards full performance parity.

Highlights

  • SM120 NVFP4 MoE Kernel Support: Introduced initial support for NVFP4 blockwise Mixture-of-Experts (MoE) kernels specifically for NVIDIA's SM120 architecture (e.g., Blackwell GPUs). This includes a new CUDA kernel file and associated dispatch logic.
  • CUTLASS Version Update: The CUTLASS library revision has been updated from v4.0.0 to v4.2.0 to incorporate improved support for SM120 devices.
  • Conditional Kernel Dispatch: Implemented logic in _custom_ops.py to detect SM120 compute capability and conditionally dispatch to the new cutlass_fp4_group_mm_sm120 kernel if available, falling back to existing kernels otherwise.
  • Kernel Registration and Compilation Flags: New kernel files have been added, and CMake build configurations have been updated to include SM120 architecture flags, ensuring the new kernels are compiled and registered with the PyTorch bindings.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for NVFP4 blockwise Mixture-of-Experts (MoE) kernels specifically for the sm_120 architecture, likely targeting upcoming Blackwell GPUs. The changes span across CMake build configurations, Python dispatch logic, and new CUDA kernel implementations. While the overall approach is sound, I've identified several critical issues in the new sm_120 kernel that could lead to incorrect computations, particularly in the dequantization logic and loop bounds. Additionally, there are opportunities to improve code maintainability and performance. My review provides specific suggestions to address these points.


// E4M3 scale factor dequantization
__device__ __forceinline__ float dequantize_e4m3_scale(uint8_t e4m3_val) {
if (e4m3_val == 0) return 1.0f;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The special handling for e4m3_val == 0 is incorrect. An E4M3 value of 0x00 represents +0.0f, but the function currently returns 1.0f. The subsequent logic in the function already correctly handles denormalized numbers (when exp == 0), so this special case is not only incorrect but also unnecessary. Removing this line will allow the function to correctly return 0.0f for an input of 0x00.

// Compute dot product in 16-element blocks
float sum = 0.0f;
int k_packed = K / 2;
int k_blocks = K / 16;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The calculation of k_blocks uses integer division, which will truncate the result if K is not a multiple of 16. This will cause the kernel to skip processing the tail end of the K dimension, leading to incorrect matrix multiplication results. You should use ceiling division to ensure all elements are processed.

    int k_blocks = (K + 15) / 16;

Comment thread aphrodite/_custom_ops.py
major, minor = torch.cuda.get_device_capability(device)

# Use SM120 kernel for compute capability 12.0 and above
if major == 12 and minor == 0:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check for the SM120 architecture is too specific. By checking for major == 12 and minor == 0, you are limiting this path to compute capability 12.0 exactly. To ensure forward compatibility with future minor revisions of the same architecture (e.g., 12.1), it's better to check only the major version number.

Suggested change
if major == 12 and minor == 0:
if major == 12:

Comment on lines +417 to +426
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
{
int32_t version_num = get_sm_version_num();
if (version_num >= 120) {
return cutlass_fp4_group_mm_sm120(output, a, b, a_blockscale,
b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets);
}
}
#endif
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code, which dispatches to the sm120 kernel, is a duplicate of the logic at lines 368-377. This redundancy makes the code harder to maintain. The initial check at the beginning of the function is sufficient to handle the dispatch, so this second block can be safely removed.

Comment on lines +57 to +64
for (int e = 0; e < num_experts; e++) {
int start = expert_offsets[e];
int end = (e == num_experts - 1) ? M : expert_offsets[e + 1];
if (tid_y >= start && tid_y < end) {
expert_id = e;
break;
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This loop performs a linear scan to find the expert_id. This can be inefficient if num_experts is large. Since expert_offsets is sorted, you can achieve better performance by using a binary search to locate the expert.

    // Find expert using binary search
    int low = 0, high = num_experts;
    while (low < high) {
        int mid = low + (high - low) / 2;
        if (tid_y < expert_offsets[mid]) {
            high = mid;
        } else {
            low = mid + 1;
        }
    }
    int expert_id = low - 1;

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant