Asserting current device and CUB stream matches#9119
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: 📒 Files selected for processing (4)
📝 WalkthroughSummary by CodeRabbit
important: WalkthroughDriver-API no-throw wrappers and cub::detail::validate_stream_device(cudaStream_t) were added; dispatch entrypoints now call it up-front and return early on mismatch; tests for cross-device stream behavior were added. ChangesStream-device validation layer
Assessment against linked issues
Suggested reviewers
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (16)
cub/cub/device/dispatch/dispatch_merge_sort.cuh (1)
406-406: ⚡ Quick winsuggestion: qualify
validate_stream_device(stream)with its global namespace-qualified symbol (matching its declaration namespace) instead of using unqualified lookup.
As per coding guidelines "All calls to free functions must be fully qualified from the global namespace, e.g.::cuda::ceil_div, even when calling functions in the same namespace".Also applies to: 477-477
cub/cub/device/dispatch/dispatch_radix_sort.cuh (1)
1141-1141: ⚡ Quick winsuggestion: use the global namespace-qualified form of
validate_stream_device(stream)at both dispatch entry points to satisfy the free-function qualification rule.
As per coding guidelines "All calls to free functions must be fully qualified from the global namespace, e.g.::cuda::ceil_div, even when calling functions in the same namespace".Also applies to: 1206-1206
cub/cub/device/dispatch/dispatch_reduce.cuh (1)
481-481: ⚡ Quick winsuggestion: qualify
validate_stream_device(stream)from the global namespace in both locations rather than relying on unqualified lookup.
As per coding guidelines "All calls to free functions must be fully qualified from the global namespace, e.g.::cuda::ceil_div, even when calling functions in the same namespace".Also applies to: 754-754
cub/cub/device/dispatch/dispatch_reduce_by_key.cuh (1)
609-609: ⚡ Quick winsuggestion: switch both
validate_stream_device(stream)calls to the fully qualified global-namespace symbol to comply with the free-function call rule.
As per coding guidelines "All calls to free functions must be fully qualified from the global namespace, e.g.::cuda::ceil_div, even when calling functions in the same namespace".Also applies to: 698-698
cub/cub/device/dispatch/dispatch_reduce_deterministic.cuh (1)
342-342: ⚡ Quick winsuggestion: qualify
validate_stream_device(stream)from the global namespace instead of calling it unqualified.
As per coding guidelines "All calls to free functions must be fully qualified from the global namespace, e.g.::cuda::ceil_div, even when calling functions in the same namespace".cub/cub/device/dispatch/dispatch_reduce_nondeterministic.cuh (1)
176-176: ⚡ Quick winsuggestion: call
validate_stream_device(stream)via its global namespace-qualified symbol here.
As per coding guidelines "All calls to free functions must be fully qualified from the global namespace, e.g.::cuda::ceil_div, even when calling functions in the same namespace".cub/cub/device/dispatch/dispatch_rle.cuh (1)
608-608: ⚡ Quick winsuggestion: make both
validate_stream_device(stream)calls fully qualified from the global namespace to align with project call-qualification rules.
As per coding guidelines "All calls to free functions must be fully qualified from the global namespace, e.g.::cuda::ceil_div, even when calling functions in the same namespace".Also applies to: 666-666
cub/cub/device/dispatch/dispatch_scan.cuh (1)
865-865: ⚡ Quick winsuggestion: use the global namespace-qualified form for
validate_stream_device(stream)in both locations rather than unqualified calls.
As per coding guidelines "All calls to free functions must be fully qualified from the global namespace, e.g.::cuda::ceil_div, even when calling functions in the same namespace".Also applies to: 933-933
cub/cub/device/dispatch/dispatch_scan_by_key.cuh (1)
599-599: ⚡ Quick winsuggestion: Qualify
validate_stream_devicefrom the global namespace in both dispatch entrypoints to match the repository call-style rule.As per coding guidelines, "All calls to free functions must be fully qualified from the global namespace, e.g.
::cuda::ceil_div, even when calling functions in the same namespace".Also applies to: 737-737
cub/cub/device/dispatch/dispatch_segmented_radix_sort.cuh (1)
620-620: ⚡ Quick winsuggestion: Use a globally qualified call for
validate_stream_deviceat both insertion points to keep dispatch code aligned with repository qualification rules.As per coding guidelines, "All calls to free functions must be fully qualified from the global namespace, e.g.
::cuda::ceil_div, even when calling functions in the same namespace".Also applies to: 907-907
cub/cub/device/dispatch/dispatch_segmented_reduce.cuh (1)
424-424: ⚡ Quick winsuggestion: Fully qualify
validate_stream_devicefrom global scope in both dispatch paths for consistency with the project’s free-function call rule.As per coding guidelines, "All calls to free functions must be fully qualified from the global namespace, e.g.
::cuda::ceil_div, even when calling functions in the same namespace".Also applies to: 531-531
cub/cub/device/dispatch/dispatch_segmented_scan.cuh (1)
132-132: ⚡ Quick winsuggestion: Qualify
validate_stream_devicefrom the global namespace here to satisfy the repository’s free-function qualification requirement.As per coding guidelines, "All calls to free functions must be fully qualified from the global namespace, e.g.
::cuda::ceil_div, even when calling functions in the same namespace".cub/cub/device/dispatch/dispatch_segmented_sort.cuh (1)
692-692: ⚡ Quick winsuggestion: Switch both
validate_stream_deviceinvocations to globally qualified form to match the enforced free-function qualification convention.As per coding guidelines, "All calls to free functions must be fully qualified from the global namespace, e.g.
::cuda::ceil_div, even when calling functions in the same namespace".Also applies to: 1285-1285
cub/cub/device/dispatch/dispatch_select_if.cuh (1)
846-846: ⚡ Quick winsuggestion: Apply global qualification to
validate_stream_devicein both dispatch entrypoints to comply with the project-wide free-function call convention.As per coding guidelines, "All calls to free functions must be fully qualified from the global namespace, e.g.
::cuda::ceil_div, even when calling functions in the same namespace".Also applies to: 1105-1105
cub/cub/device/dispatch/dispatch_three_way_partition.cuh (1)
367-367: ⚡ Quick winsuggestion: Use globally qualified
validate_stream_devicecalls in both updated dispatch layers to align with the mandatory free-function qualification rule.As per coding guidelines, "All calls to free functions must be fully qualified from the global namespace, e.g.
::cuda::ceil_div, even when calling functions in the same namespace".Also applies to: 438-438
cub/cub/device/dispatch/dispatch_topk.cuh (1)
478-478: ⚡ Quick winsuggestion: Qualify
validate_stream_devicefrom global scope in this dispatch entrypoint to satisfy the repository free-function qualification rule.As per coding guidelines, "All calls to free functions must be fully qualified from the global namespace, e.g.
::cuda::ceil_div, even when calling functions in the same namespace".
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 5d183277-1a01-4666-9f1b-617c330bdabb
📒 Files selected for processing (26)
cub/cub/device/dispatch/dispatch_adjacent_difference.cuhcub/cub/device/dispatch/dispatch_batch_memcpy.cuhcub/cub/device/dispatch/dispatch_batched_topk.cuhcub/cub/device/dispatch/dispatch_find.cuhcub/cub/device/dispatch/dispatch_for.cuhcub/cub/device/dispatch/dispatch_histogram.cuhcub/cub/device/dispatch/dispatch_merge.cuhcub/cub/device/dispatch/dispatch_merge_sort.cuhcub/cub/device/dispatch/dispatch_radix_sort.cuhcub/cub/device/dispatch/dispatch_reduce.cuhcub/cub/device/dispatch/dispatch_reduce_by_key.cuhcub/cub/device/dispatch/dispatch_reduce_deterministic.cuhcub/cub/device/dispatch/dispatch_reduce_nondeterministic.cuhcub/cub/device/dispatch/dispatch_rle.cuhcub/cub/device/dispatch/dispatch_scan.cuhcub/cub/device/dispatch/dispatch_scan_by_key.cuhcub/cub/device/dispatch/dispatch_segmented_radix_sort.cuhcub/cub/device/dispatch/dispatch_segmented_reduce.cuhcub/cub/device/dispatch/dispatch_segmented_scan.cuhcub/cub/device/dispatch/dispatch_segmented_sort.cuhcub/cub/device/dispatch/dispatch_select_if.cuhcub/cub/device/dispatch/dispatch_three_way_partition.cuhcub/cub/device/dispatch/dispatch_topk.cuhcub/cub/device/dispatch/dispatch_transform.cuhcub/cub/device/dispatch/dispatch_unique_by_key.cuhcub/cub/util_device.cuh
There was a problem hiding this comment.
Actionable comments posted: 2
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 2d2b79d9-a47e-40c6-99d9-5d9e96a54b55
📒 Files selected for processing (26)
cub/cub/device/dispatch/dispatch_adjacent_difference.cuhcub/cub/device/dispatch/dispatch_batch_memcpy.cuhcub/cub/device/dispatch/dispatch_batched_topk.cuhcub/cub/device/dispatch/dispatch_find.cuhcub/cub/device/dispatch/dispatch_for.cuhcub/cub/device/dispatch/dispatch_histogram.cuhcub/cub/device/dispatch/dispatch_merge.cuhcub/cub/device/dispatch/dispatch_merge_sort.cuhcub/cub/device/dispatch/dispatch_radix_sort.cuhcub/cub/device/dispatch/dispatch_reduce.cuhcub/cub/device/dispatch/dispatch_reduce_by_key.cuhcub/cub/device/dispatch/dispatch_reduce_deterministic.cuhcub/cub/device/dispatch/dispatch_reduce_nondeterministic.cuhcub/cub/device/dispatch/dispatch_rle.cuhcub/cub/device/dispatch/dispatch_scan.cuhcub/cub/device/dispatch/dispatch_scan_by_key.cuhcub/cub/device/dispatch/dispatch_segmented_radix_sort.cuhcub/cub/device/dispatch/dispatch_segmented_reduce.cuhcub/cub/device/dispatch/dispatch_segmented_scan.cuhcub/cub/device/dispatch/dispatch_segmented_sort.cuhcub/cub/device/dispatch/dispatch_select_if.cuhcub/cub/device/dispatch/dispatch_three_way_partition.cuhcub/cub/device/dispatch/dispatch_topk.cuhcub/cub/device/dispatch/dispatch_transform.cuhcub/cub/device/dispatch/dispatch_unique_by_key.cuhcub/cub/util_device.cuh
✅ Files skipped from review due to trivial changes (1)
- cub/cub/device/dispatch/dispatch_histogram.cuh
bernhardmgruber
left a comment
There was a problem hiding this comment.
Thanks a lot for this contribution! Please add a unit test to at least one algorithm calling it with a stream that does not match the current device. This test must be written in a way that it also works if there is only one GPU/device in the system (just succeeding is fine I think). I can try it briefly on my machine where I have two GPUs.
| error = cudaStreamGetDevice(stream, &streamDevice); | ||
| if (error != cudaSuccess) | ||
| { | ||
| return error; | ||
| } |
There was a problem hiding this comment.
Suggestion: Let's not reuse the error variable:
| error = cudaStreamGetDevice(stream, &streamDevice); | |
| if (error != cudaSuccess) | |
| { | |
| return error; | |
| } | |
| if (const auto error = cudaStreamGetDevice(stream, &streamDevice);) | |
| { | |
| return error; | |
| } |
| error = cudaGetDevice(¤tDevice); | ||
| if (error != cudaSuccess) | ||
| { | ||
| return error; | ||
| } |
There was a problem hiding this comment.
| error = cudaGetDevice(¤tDevice); | |
| if (error != cudaSuccess) | |
| { | |
| return error; | |
| } | |
| if (const auto error = cudaGetDevice(¤tDevice);) | |
| { | |
| return error; | |
| } |
| return cudaErrorInvalidDevice; | ||
| } | ||
| # endif // _CCCL_CTK_AT_LEAST(12,8) | ||
| return error; |
There was a problem hiding this comment.
| return error; | |
| return cudaSuccess; |
| CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t validate_stream_device(cudaStream_t stream) | ||
| { | ||
| cudaError_t error = cudaSuccess; | ||
| # if _CCCL_CTK_AT_LEAST(12, 8) |
There was a problem hiding this comment.
Important: sometimes users violate our API requirements, but their software ran fine for a long time. They would be upset if we suddenly enforce requirements, causing their software to break. Let's add a macro to disable this new feature:
| # if _CCCL_CTK_AT_LEAST(12, 8) | |
| # if _CCCL_CTK_AT_LEAST(12, 8) && !defined(CCCL_DISABLE_STREAM_DEVICE_CHECK) |
If possible, add a unit test that calls a simple algorithm like DeviceFor with a stream and a different current device and define the CCCL_DISABLE_STREAM_DEVICE_CHECK macro, to see whether the escape hatch works.
| # if _CCCL_CTK_AT_LEAST(12, 8) | ||
| int streamDevice; | ||
| error = cudaStreamGetDevice(stream, &streamDevice); |
There was a problem hiding this comment.
We can make this function work even before CTK 12.8 using CUDA Driver API. We already have this implemented for cuda::stream_ref. It should look as:
{
::CUdevice current_device;
if (const auto error = ::cuda::__driver::__ctxGetDeviceNoThrow(current_device); error != cudaSuccess)
{
return error;
}
::CUcontext stream_ctx;
if (const auto error = ::cuda::__driver::__streamGetCtxNoThrow(stream_ctx, stream); error != cudaSuccess)
{
return error;
}
if (const auto error = ::cuda::__driver::__ctxPushNoThrow(stream_ctx); error != cudaSuccess)
{
return error;
}
::CUdevice stream_device;
if (const auto error = ::cuda::__driver::__ctxGetDeviceNoThrow(stream_device); error != cudaSuccess)
{
return error;
}
if (const auto error = ::cuda::__driver::__ctxPopNoThrow(); error != cudaSuccess)
{
return error;
}
_CCCL_ASSERT(current_device == stream_device, "current device must match CUB stream device");
}The only problem is that we need to add __meowNoThrow variants of all context-related driver APIs to <cuda/__driver/driver_api.h>.
If you don't feel comfortable doing this, I will make a follow up PR after this one is merged :)
| cudaError_t error = cudaSuccess; | ||
| # if _CCCL_CTK_AT_LEAST(12, 8) | ||
| int streamDevice; | ||
| error = cudaStreamGetDevice(stream, &streamDevice); | ||
| if (error != cudaSuccess) | ||
| { | ||
| return error; | ||
| } | ||
| int currentDevice; | ||
| error = cudaGetDevice(¤tDevice); | ||
| if (error != cudaSuccess) | ||
| { | ||
| return error; | ||
| } | ||
| _CCCL_ASSERT(currentDevice == streamDevice, "current device must match CUB stream device"); | ||
| if (currentDevice != streamDevice) | ||
| { | ||
| return cudaErrorInvalidDevice; | ||
| } | ||
| # endif // _CCCL_CTK_AT_LEAST(12,8) | ||
| return error; |
There was a problem hiding this comment.
Critical: Since this is an assertion, we need to make sure all of the CUDA Runtime/Driver calls are done only when assertions are enabled, because they won't get optimized out and can introduce some unwanted overhead.
There was a problem hiding this comment.
Actionable comments posted: 6
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: bd797629-5166-4034-acec-48de9bb9ed1e
📒 Files selected for processing (4)
cub/cub/util_device.cuhcub/test/catch2_test_device_for.cucub/test/catch2_test_device_for_api.culibcudacxx/include/cuda/__driver/driver_api.h
| #define CCCL_DISABLE_STREAM_DEVICE_CHECK | ||
|
|
There was a problem hiding this comment.
important: defining CCCL_DISABLE_STREAM_DEVICE_CHECK at file scope disables validation for every test in this translation unit, not just the new escape-hatch case. Move this case to a dedicated test file (or dedicated compile target) so default behavior tests stay meaningful.
… end of tests, and making sure pop always gets executed in validate_stream_device
|
Hi, thanks to you both for the feedbacks:
happy to keep modifying this if needed :) |
Description
closes #7782
Adding an assertion to all the dispatching codes to ensure current device and CUB stream matches. Calling it at the very beginning of each of the dispatch functions, hence the number of modified files
The assertion itself uses
cudaStreamGetDevicewhich was introduced in CTK 12.8 so it's guarded by the macro _CCCL_CTK_AT_LEAST(12,8).I'm new to the project so unsure if there is a better place to call the assertion rather than doing it in every dispatch file, also unsure if the assertion should be put in the
cub/cub/util_device.cuhfile like i did or elsewhere, please tell me if this issue should be addressed differently and i'll try to do it !Checklist