Skip to content

[cub] Replace cub parameter framework with cuda::argument#9074

Open
pciolkosz wants to merge 22 commits into
NVIDIA:mainfrom
pciolkosz:replace_cub_parameter_framework
Open

[cub] Replace cub parameter framework with cuda::argument#9074
pciolkosz wants to merge 22 commits into
NVIDIA:mainfrom
pciolkosz:replace_cub_parameter_framework

Conversation

@pciolkosz
Copy link
Copy Markdown
Contributor

This PR replaces most of the functionality in segmented_params.cuh with cuda::argument wrappers from #8875. This PR contains the other one, since it's not merged yet.

There are two things that were left from the original implementation, the static dispatch over bounded set of values and get_param that either gets item from a sequence at a given index or returns a uniform value depending on the argument. Both of those things were more fitting for a cub-specific functionality, but its not set in stone

@pciolkosz pciolkosz requested review from a team as code owners May 20, 2026 04:42
@pciolkosz pciolkosz requested a review from wmaxey May 20, 2026 04:42
@github-project-automation github-project-automation Bot moved this to Todo in CCCL May 20, 2026
@pciolkosz pciolkosz requested a review from pauleonix May 20, 2026 04:42
@cccl-authenticator-app cccl-authenticator-app Bot moved this from Todo to In Review in CCCL May 20, 2026
@pciolkosz pciolkosz force-pushed the replace_cub_parameter_framework branch from 845daaf to 5dd3c87 Compare May 20, 2026 06:00
@pciolkosz pciolkosz requested a review from a team as a code owner May 20, 2026 06:00
@pciolkosz pciolkosz requested a review from shwina May 20, 2026 06:00
@pciolkosz pciolkosz force-pushed the replace_cub_parameter_framework branch from 5dd3c87 to 8a3b299 Compare May 20, 2026 06:01
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 20, 2026

Review Change Stack

📝 Walkthrough

Summary by CodeRabbit

  • New Features

    • Introduced CUDA argument wrapper types (__constant, __immediate, __deferred_value, __deferred_sequence) for improved compile-time bounds validation and flexible parameter passing in segmented operations.
    • Added argument bounds utilities with static and runtime bounds support for enhanced type safety.
  • Refactor

    • Modernized parameter handling in segmented topk operations to use unified argument wrapper interface.

Walkthrough

This PR introduces a new CUDA argument wrapper system with static and runtime bounds validation, then refactors CUB's batched top-K dispatch, kernel, and agent layers to use unified cuda::argument::__traits-based parameter access instead of nested type-trait patterns. The argument infrastructure covers scalar/sequence values, compile-time/runtime bounds, device-deferred wrappers, and comprehensive trait-based metadata exposure for policy selection and unwrapping.

Changes

CUDA Argument Infrastructure

Layer / File(s) Summary
Bounds types and namespace setup
libcudacxx/include/cuda/__argument/argument_bounds.h, libcudacxx/include/cuda/std/__internal/namespaces.h
__no_bounds sentinel, __static_bounds<_Lowest, _Max> with static_assert validation, __runtime_bounds<_Tp> with constructor assert, and factory functions __bounds() for both forms. Namespace macros _CCCL_BEGIN_NAMESPACE_CUDA_ARGUMENT and _CCCL_END_NAMESPACE_CUDA_ARGUMENT.
Argument wrapper types and unwrap operations
libcudacxx/include/cuda/__argument/argument.h
__constant for NTTP values, __immediate<_Arg, _StaticBounds> for host-accessible arguments with runtime bounds, __deferred_base/__deferred_value/__deferred_sequence for device-resident arguments. All include constructor bounds intersection validation and __unwrap extraction utilities.
Traits and numeric bounds functions
libcudacxx/include/cuda/__argument/argument.h, libcudacxx/include/cuda/argument
__traits<_Tp> exposes element type, value type, and compile-time lowest/max bounds with classification flags. __lowest_(_Tp) and __max_(_Tp) free functions compute effective bounds with bounds intersection validation for each wrapper type.
Bounds and traits test coverage
libcudacxx/test/libcudacxx/cuda/argument/argument_bounds.pass.cpp, libcudacxx/test/libcudacxx/cuda/argument/argument_traits.pass.cpp
Compile-time tests for __static_bounds and __runtime_bounds construction, type preservation under NTTP, bounds intersection detection via __has_bounds_intersection, and __traits classification (single-value, element type, deferred/immediate/constant) across wrapper types.
Argument wrapper tests
libcudacxx/test/libcudacxx/cuda/argument/static_argument.pass.cpp, libcudacxx/test/libcudacxx/cuda/argument/deferred_argument.pass.cpp, libcudacxx/test/libcudacxx/cuda/argument/dynamic_argument.pass.cpp
Tests for __constant with scalar/array NTTPs, __immediate with and without static bounds, __deferred_value from span/pointer, __deferred_sequence from span, runtime bounds validation, and __unwrap lvalue/rvalue semantics for each category.
Bounds conversion edge case test
libcudacxx/test/libcudacxx/cuda/argument/static_bounds_conversion.fail.cpp
Negative test validating static bounds conversion constraints at compile time.
Integration and usage examples
libcudacxx/test/libcudacxx/cuda/argument/usage_example.pass.cpp
Integration test with helper templates select_variant, compute_buffer_size, process_segments that use __traits and __unwrap uniformly across plain values, static constants, immediate dynamic, and deferred uniform arguments to demonstrate compile-time and runtime decisions.
Test macro support
libcudacxx/test/support/test_macros.h
Adds TEST_HAS_CLASS_NTTP feature detection macro to conditionally enable class-type NTTP testing when C++20 support is available and NVCC ≥ 13.1.

CUB Batched Top-K Refactor

Layer / File(s) Summary
Unified parameter access API
cub/cub/detail/segmented_params.cuh
Introduces detail::params::get_param(_Tp&&, size_t) overload set for cuda::argument wrappers (constant, immediate, deferred-value, deferred-sequence) to extract per-segment or uniform values. Removes obsolete parameter-type templates (static_constant_param, uniform_param, per_segment_param) and trait helpers (is_static_param_v, is_uniform_param_v, is_per_segment_param_v, static_max_value_v, static_min_value_v).
Dispatch layer refactoring
cub/cub/device/dispatch/dispatch_batched_topk.cuh
Changes dispatch signature to accept unwrapped SelectDirectionT select_direction instead of wrapped parameter type; internally wraps via wrap_select_direction helpers. Replaces all in-file parameter type aliases (select_direction_*, segment_size_*, k_*, num_segments_*, total_num_items_guarantee). Uses cuda::argument::__traits for policy bound computation, switches all parameter reads to params::get_param API, and requires num_segments to be single-value.
Kernel integration
cub/cub/device/dispatch/kernels/kernel_batched_topk.cuh
Updates device_segmented_topk_kernel pointer types and constraint assertions to use cuda::argument::__traits<NumSegmentsParameterT>::element_type and ::cuda::argument::__traits<SegmentSizeParameterT>::max instead of nested value_type members and params::static_max_value_v helpers.
Agent worker implementation
cub/cub/agent/agent_batched_topk.cuh
Refactors agent_batched_topk_worker_per_segment to derive element types from cuda::argument::__traits<...>::element_type, computes compile-time predicates (only_small_segments, is_full_tile) using __traits methods, and accesses per-segment values via params::get_param(segment_sizes, segment_id) and params::get_param(k_param, segment_id).
CUB tests and benchmarks
cub/benchmarks/bench/segmented_topk/fixed/keys.cu, cub/benchmarks/bench/segmented_topk/variable/keys.cu, cub/test/catch2_test_device_segmented_topk_keys.cu, cub/test/catch2_test_device_segmented_topk_pairs.cu
Fixed and variable-size benchmarks now construct total_num_items, segment_sizes, k, and num_segments via cuda::argument::__immediate and __constant wrappers with explicit __bounds<...> specifications. Device tests similarly replace batched_topk::*_uniform and total_num_items_guarantee selector types with immediate arguments for segment size, k, num_segments, and total item count.

Suggested reviewers

  • pauleonix
  • wmaxey

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (3)
cub/cub/detail/segmented_params.cuh (1)

31-43: 💤 Low value

suggestion: Missing [[nodiscard]] on get_param overloads. Per coding guidelines, most functions with non-void return should have this attribute.

 _CCCL_TEMPLATE(class _Tp)
 _CCCL_REQUIRES((!::cuda::argument::__is_wrapper_v<::cuda::std::remove_cv_t<::cuda::std::remove_reference_t<_Tp>>>) )
-_CCCL_HOST_DEVICE constexpr auto get_param(_Tp&& __arg, [[maybe_unused]] size_t __index) noexcept
+[[nodiscard]] _CCCL_HOST_DEVICE constexpr auto get_param(_Tp&& __arg, [[maybe_unused]] size_t __index) noexcept

Same applies to the other get_param overloads on lines 46-47, 53-54, 67-68, 74-75. As per coding guidelines, most functions with a non-void return type should use [[nodiscard]].

cub/cub/device/dispatch/dispatch_batched_topk.cuh (1)

51-66: 💤 Low value

suggestion: Both wrap_select_direction overloads return non-void and should have [[nodiscard]].

-_CCCL_HOST_DEVICE inline auto wrap_select_direction(detail::topk::select dir)
+[[nodiscard]] _CCCL_HOST_DEVICE inline auto wrap_select_direction(detail::topk::select dir)
-_CCCL_HOST_DEVICE auto wrap_select_direction(IteratorT iter)
+[[nodiscard]] _CCCL_HOST_DEVICE auto wrap_select_direction(IteratorT iter)
libcudacxx/include/cuda/__argument/argument_bounds.h (1)

103-113: ⚡ Quick win

suggestion: Complete Doxygen tags for the documented __bounds overloads. The documented non-void factory functions currently only provide //! @brief; add `//! `@param for each parameter and //! @return`` for both overloads to satisfy header documentation requirements.

As per coding guidelines: "When a function is documented with Doxygen, it must include: //! @brief, `//! `@param`[in/out/in,out]` for every parameter, and `//! `@return for non-void functions."


ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: a53ca9f6-f66d-4f20-a942-2e8bd23c2c84

📥 Commits

Reviewing files that changed from the base of the PR and between 459e81a and 8a3b299.

📒 Files selected for processing (20)
  • cub/benchmarks/bench/segmented_topk/fixed/keys.cu
  • cub/benchmarks/bench/segmented_topk/variable/keys.cu
  • cub/cub/agent/agent_batched_topk.cuh
  • cub/cub/detail/segmented_params.cuh
  • cub/cub/device/dispatch/dispatch_batched_topk.cuh
  • cub/cub/device/dispatch/kernels/kernel_batched_topk.cuh
  • cub/test/catch2_test_device_segmented_topk_keys.cu
  • cub/test/catch2_test_device_segmented_topk_pairs.cu
  • libcudacxx/include/cuda/__argument/argument.h
  • libcudacxx/include/cuda/__argument/argument_bounds.h
  • libcudacxx/include/cuda/argument
  • libcudacxx/include/cuda/std/__internal/namespaces.h
  • libcudacxx/test/libcudacxx/cuda/argument/argument_bounds.pass.cpp
  • libcudacxx/test/libcudacxx/cuda/argument/argument_traits.pass.cpp
  • libcudacxx/test/libcudacxx/cuda/argument/deferred_argument.pass.cpp
  • libcudacxx/test/libcudacxx/cuda/argument/dynamic_argument.pass.cpp
  • libcudacxx/test/libcudacxx/cuda/argument/static_argument.pass.cpp
  • libcudacxx/test/libcudacxx/cuda/argument/static_bounds_conversion.fail.cpp
  • libcudacxx/test/libcudacxx/cuda/argument/usage_example.pass.cpp
  • libcudacxx/test/support/test_macros.h

Comment on lines +274 to +280
template <auto _Lowest, auto _Max>
_CCCL_API constexpr __immediate(_Arg __arg, __static_bounds<_Lowest, _Max>) noexcept
: arg{::cuda::std::move(__arg)}
{
__validate_bounds();
__validate_value();
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

important: These constructors accept a __static_bounds<_Lowest, _Max> argument but validate against the class template parameter _StaticBounds, so explicitly-instantiated types can silently ignore the bounds token passed at construction. Add a compile-time constraint (for example, static_assert(::cuda::std::is_same_v<_StaticBounds, __static_bounds<_Lowest, _Max>>) or a requires clause) so construction fails when the token and _StaticBounds disagree.

Also applies to: 294-302

Comment on lines +340 to +345
template <auto _Lowest, auto _Max>
_CCCL_API constexpr __deferred_base(_Arg __arg, __static_bounds<_Lowest, _Max>) noexcept
: arg{::cuda::std::move(__arg)}
{
__validate_bounds_intersection<__element_type, _StaticBounds>(__runtime_bounds_);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

important: __deferred_base has the same bounds-token mismatch risk: constructor parameters carry __static_bounds<_Lowest, _Max> but all checks use _StaticBounds. This can make user-provided static bounds inert for explicitly-typed wrappers. Enforce _StaticBounds == __static_bounds<_Lowest, _Max> at compile time (or remove the redundant bounds parameter in favor of _StaticBounds-typed overloads).

Also applies to: 357-366

@github-actions
Copy link
Copy Markdown
Contributor

😬 CI Workflow Results

🟥 Finished in 4h 13m: Pass: 94%/341 | Total: 10d 11h | Max: 4h 13m | Hits: 41%/1754003

See results here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: In Review

Development

Successfully merging this pull request may close these issues.

1 participant