[kernel][moe] better splitK for fused moe#1603
Conversation
Signed-off-by: AlpinDale <alpindale@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a split-K implementation for the fused MoE kernel to improve performance. The changes involve modifying the Triton kernel to handle split-K logic, including a 2D launch grid, atomic adds for reduction, and a pre-run hook to zero the output buffer. The host-side code is also updated to support launching the split-K kernel.
My review found a critical issue in the kernel launch logic that effectively disables the split-K optimization. I've provided a specific comment and suggestion to fix this. Once addressed, this PR should correctly enable the performance benefits of split-K.
| config["SPLIT_K"] = 1 | ||
| BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") | ||
| if block_shape is not None: | ||
| BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1])) | ||
| if not do_split_k: | ||
| config["SPLIT_K"] = 1 |
There was a problem hiding this comment.
The line config["SPLIT_K"] = 1 unconditionally sets SPLIT_K to 1. This overrides any tuned value for SPLIT_K from the configuration and effectively disables the split-K optimization, as the kernel will always be launched with a grid dimension of 1 for the K-split axis. The subsequent check if not do_split_k: is then redundant when do_split_k is True.
To fix this, the unconditional assignment should be removed, and SPLIT_K should only be set to 1 if do_split_k is False.
| config["SPLIT_K"] = 1 | |
| BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") | |
| if block_shape is not None: | |
| BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1])) | |
| if not do_split_k: | |
| config["SPLIT_K"] = 1 | |
| BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") | |
| if block_shape is not None: | |
| BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1])) | |
| if not do_split_k: | |
| config["SPLIT_K"] = 1 |
Minor improvement, in the range of ~0.4%
2x RTX 6000 Ada, Qwen3-30B-A3B-FP8, 512 tokens I/O
Main:
PR: