Implement block TopK sieve and ranking API for handling multi-key workloads #9066
Implement block TopK sieve and ranking API for handling multi-key workloads #9066pauleonix wants to merge 8 commits into
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
📝 WalkthroughSummary by CodeRabbitRelease Notes
suggestion: WalkthroughAdds per-item top-k state, a multi-pass radix sieve for iterative refinement, and an atomic rank finalizer; refactors block_topk_air to compose these components and updates TempStorage binding. New host/test utilities and parameterized Catch2 device tests validate correctness across modes and FP edge cases. ChangesBlock-level multi-key top-k via iterative refinement
Assessment against linked issues
Suggested labels
Suggested reviewers
Warning Review ran into problems🔥 ProblemsStopped waiting for pipeline failures after 30000ms. One of your pipelines takes longer than our 30000ms fetch window to run, so review may not consider pipeline-failure results for inline comments if any failures occurred after the fetch window. Increase the timeout if you want to wait longer or run a Comment |
This comment has been minimized.
This comment has been minimized.
As suggested by CodeRabbit nitpick.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
🥳 CI Workflow Results🟩 Finished in 1h 11m: Pass: 100%/283 | Total: 2d 20h | Max: 45m 41s | Hits: 88%/219289See results here. |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
cub/test/catch2_test_block_topk_rank.cu (2)
75-84:⚠️ Potential issue | 🔴 Critical | ⚡ Quick wincritical: device kernel contains lambda expression
Guideline explicitly forbids lambda expressions in device-only or host-device code. Refactor using if-constexpr directly:
- auto states = [&] { - if constexpr (SelectMax) - { - return sieve.template select_max<IsFullTile, BlockedInput>(keys, k, static_cast<int>(g_in.size())); - } - else - { - return sieve.template select_min<IsFullTile, BlockedInput>(keys, k, static_cast<int>(g_in.size())); - } - }(); + decltype(sieve.template select_max<IsFullTile, BlockedInput>(keys, k, static_cast<int>(g_in.size()))) states; + if constexpr (SelectMax) + { + states = sieve.template select_max<IsFullTile, BlockedInput>(keys, k, static_cast<int>(g_in.size())); + } + else + { + states = sieve.template select_min<IsFullTile, BlockedInput>(keys, k, static_cast<int>(g_in.size())); + }As per coding guidelines: "Never allow lambda expressions in device-only or host-device code".
162-173:⚠️ Potential issue | 🔴 Critical | ⚡ Quick wincritical: device kernel contains lambda expression
Same violation as single_key_kernel. Apply identical refactoring:
- auto states = [&] { - if constexpr (SelectMax) - { - return primary_sieve.template select_max<IsFullTile, blocked_input>( - primary_keys, k, static_cast<int>(g_primary.size())); - } - else - { - return primary_sieve.template select_min<IsFullTile, blocked_input>( - primary_keys, k, static_cast<int>(g_primary.size())); - } - }(); + decltype(primary_sieve.template select_max<IsFullTile, blocked_input>(primary_keys, k, static_cast<int>(g_primary.size()))) states; + if constexpr (SelectMax) + { + states = primary_sieve.template select_max<IsFullTile, blocked_input>(primary_keys, k, static_cast<int>(g_primary.size())); + } + else + { + states = primary_sieve.template select_min<IsFullTile, blocked_input>(primary_keys, k, static_cast<int>(g_primary.size())); + }As per coding guidelines: "Never allow lambda expressions in device-only or host-device code".
🧹 Nitpick comments (3)
cub/test/catch2_test_block_topk_rank.cu (3)
224-224: 💤 Low valuesuggestion: use static_cast for consistency
The
int{sizeof(KeyT) * 8}cast is valid butstatic_cast<int>(sizeof(KeyT) * 8)is more common in the codebase and equally clear:- static_assert(int{sizeof(KeyT) * 8} % WindowBits == 0, "test currently requires window-aligned key width"); + static_assert(static_cast<int>(sizeof(KeyT) * 8) % WindowBits == 0, "test currently requires window-aligned key width");
241-241: 💤 Low valuesuggestion: use static_cast for consistency
Same as line 224, prefer
static_cast<int>(sizeof(KeyT) * 8)for consistency:- int hi = int{sizeof(KeyT) * 8}; + int hi = static_cast<int>(sizeof(KeyT) * 8);
309-309: ⚖️ Poor tradeoffsuggestion: consider testing k=0 edge case
The REQUIRE explicitly excludes k=0. If the API should support k=0 (returning empty output), add a test case verifying no writes occur. If k=0 is invalid input, document or assert that constraint in the implementation.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 3c086ff4-c3df-4a37-8169-2645a8f0c345
📒 Files selected for processing (2)
cub/cub/block/specializations/block_topk_air.cuhcub/test/catch2_test_block_topk_rank.cu
Description
During testing I found and fixed a bug in the previous block TopK AIR implementation regarding bit-twiddling-inversion and 0.0/-0.0 handling.
TODO:
Add simple optimization of integer scan+adjacent difference(No perf improvement with current segmented topk tuning)closes #8368
Checklist