Skip to content
10 changes: 5 additions & 5 deletions c/experimental/stf/src/stf.cu
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ void stf_host_launch_submit(stf_host_launch_handle h, stf_host_callback_fn callb
_CCCL_ASSERT(callback != nullptr, "callback must not be null");

auto* scope_ptr = reinterpret_cast<host_launch_type*>(h);
(*scope_ptr)->*[callback](reserved::host_launch_deps& deps) {
(*scope_ptr)->*[callback](host_launch_deps& deps) {
callback(reinterpret_cast<stf_host_launch_deps_handle>(&deps));
};
}
Expand All @@ -476,31 +476,31 @@ void* stf_host_launch_deps_get(stf_host_launch_deps_handle deps, size_t index)
{
_CCCL_ASSERT(deps != nullptr, "deps handle must not be null");

auto* d = reinterpret_cast<reserved::host_launch_deps*>(deps);
auto* d = reinterpret_cast<host_launch_deps*>(deps);
return d->get<slice<char>>(index).data_handle();
}

size_t stf_host_launch_deps_get_size(stf_host_launch_deps_handle deps, size_t index)
{
_CCCL_ASSERT(deps != nullptr, "deps handle must not be null");

auto* d = reinterpret_cast<reserved::host_launch_deps*>(deps);
auto* d = reinterpret_cast<host_launch_deps*>(deps);
return d->get<slice<char>>(index).extent(0);
}

size_t stf_host_launch_deps_size(stf_host_launch_deps_handle deps)
{
_CCCL_ASSERT(deps != nullptr, "deps handle must not be null");

auto* d = reinterpret_cast<reserved::host_launch_deps*>(deps);
auto* d = reinterpret_cast<host_launch_deps*>(deps);
return d->size();
}

void* stf_host_launch_deps_get_user_data(stf_host_launch_deps_handle deps)
{
_CCCL_ASSERT(deps != nullptr, "deps handle must not be null");

auto* d = reinterpret_cast<reserved::host_launch_deps*>(deps);
auto* d = reinterpret_cast<host_launch_deps*>(deps);
return d->user_data();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class stream_ctx;

namespace reserved
{
template <typename Ctx, bool called_from_launch, typename... Deps>
class host_launch_scope;
} // namespace reserved

/**
* @brief Opaque handle passed to untyped host_launch callbacks.
*
Expand All @@ -60,7 +64,6 @@ public:
host_launch_deps(host_launch_deps&&) = default;
host_launch_deps& operator=(host_launch_deps&&) = default;

// If user data was attached with a custom destructor, invoke it before freeing the buffer
~host_launch_deps()
{
if (dtor_ && !user_data_buf_.empty())
Expand Down Expand Up @@ -100,14 +103,16 @@ public:

private:
template <typename, bool, typename...>
friend class host_launch_scope;
friend class reserved::host_launch_scope;

::std::vector<logical_data_untyped> lds_;
::std::vector<instance_id_t> ids_;
::std::vector<char> user_data_buf_; // byte-copied snapshot of the user data attached to the scope
void (*dtor_)(void*) = nullptr; // optional destructor for user_data_buf_ contents
::std::vector<char> user_data_buf_;
void (*dtor_)(void*) = nullptr;
};

namespace reserved
{
//! \brief Resource wrapper for managing host callback arguments
//!
//! This manages the memory allocated for host callback arguments using the
Expand Down
28 changes: 14 additions & 14 deletions cudax/test/stf/interface/host_launch_deps.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void test_stream_basic()

auto scope = ctx.host_launch();
scope.add_deps(task_dep_untyped(lX, access_mode::read));
scope->*[](reserved::host_launch_deps& deps) {
scope->*[](host_launch_deps& deps) {
EXPECT(deps.size() == 1);
auto sX = deps.get<slice<double>>(0);
for (size_t i = 0; i < 64; i++)
Expand Down Expand Up @@ -63,7 +63,7 @@ void test_stream_multiple_deps()
auto scope = ctx.host_launch();
scope.add_deps(task_dep_untyped(lX, access_mode::read));
scope.add_deps(task_dep_untyped(lY, access_mode::read));
scope->*[](reserved::host_launch_deps& deps) {
scope->*[](host_launch_deps& deps) {
EXPECT(deps.size() == 2);
auto sX = deps.get<slice<double>>(0);
auto sY = deps.get<slice<double>>(1);
Expand Down Expand Up @@ -91,7 +91,7 @@ void test_stream_user_data()

auto scope = ctx.host_launch();
scope.set_user_data(&uctx, sizeof(uctx));
scope->*[](reserved::host_launch_deps& deps) {
scope->*[](host_launch_deps& deps) {
EXPECT(deps.size() == 0);
EXPECT(deps.user_data() != nullptr);
EXPECT(deps.user_data_size() == sizeof(my_ctx));
Expand All @@ -117,7 +117,7 @@ void test_stream_write_back()

auto scope = ctx.host_launch();
scope.add_deps(task_dep_untyped(lX, access_mode::rw));
scope->*[](reserved::host_launch_deps& deps) {
scope->*[](host_launch_deps& deps) {
auto sX = deps.get<slice<double>>(0);
for (size_t i = 0; i < 64; i++)
{
Expand All @@ -140,7 +140,7 @@ void test_stream_no_user_data()
stream_ctx ctx;

auto scope = ctx.host_launch();
scope->*[](reserved::host_launch_deps& deps) {
scope->*[](host_launch_deps& deps) {
EXPECT(deps.size() == 0);
EXPECT(deps.user_data() == nullptr);
EXPECT(deps.user_data_size() == 0);
Expand All @@ -163,7 +163,7 @@ void test_stream_chained()

auto s1 = ctx.host_launch();
s1.add_deps(task_dep_untyped(lX, access_mode::rw));
s1->*[](reserved::host_launch_deps& deps) {
s1->*[](host_launch_deps& deps) {
auto sX = deps.get<slice<double>>(0);
for (size_t i = 0; i < 64; i++)
{
Expand All @@ -173,7 +173,7 @@ void test_stream_chained()

auto s2 = ctx.host_launch();
s2.add_deps(task_dep_untyped(lX, access_mode::rw));
s2->*[](reserved::host_launch_deps& deps) {
s2->*[](host_launch_deps& deps) {
auto sX = deps.get<slice<double>>(0);
for (size_t i = 0; i < 64; i++)
{
Expand Down Expand Up @@ -209,7 +209,7 @@ void test_graph_basic()

auto scope = ctx.host_launch();
scope.add_deps(task_dep_untyped(lX, access_mode::read));
scope->*[](reserved::host_launch_deps& deps) {
scope->*[](host_launch_deps& deps) {
EXPECT(deps.size() == 1);
auto sX = deps.get<slice<double>>(0);
for (size_t i = 0; i < 64; i++)
Expand Down Expand Up @@ -238,7 +238,7 @@ void test_graph_multiple_deps()
auto scope = ctx.host_launch();
scope.add_deps(task_dep_untyped(lX, access_mode::read));
scope.add_deps(task_dep_untyped(lY, access_mode::read));
scope->*[](reserved::host_launch_deps& deps) {
scope->*[](host_launch_deps& deps) {
EXPECT(deps.size() == 2);
auto sX = deps.get<slice<double>>(0);
auto sY = deps.get<slice<double>>(1);
Expand Down Expand Up @@ -266,7 +266,7 @@ void test_graph_user_data()

auto scope = ctx.host_launch();
scope.set_user_data(&uctx, sizeof(uctx));
scope->*[](reserved::host_launch_deps& deps) {
scope->*[](host_launch_deps& deps) {
EXPECT(deps.size() == 0);
EXPECT(deps.user_data() != nullptr);
EXPECT(deps.user_data_size() == sizeof(my_ctx));
Expand All @@ -292,7 +292,7 @@ void test_graph_write_back()

auto scope = ctx.host_launch();
scope.add_deps(task_dep_untyped(lX, access_mode::rw));
scope->*[](reserved::host_launch_deps& deps) {
scope->*[](host_launch_deps& deps) {
auto sX = deps.get<slice<double>>(0);
for (size_t i = 0; i < 64; i++)
{
Expand All @@ -315,7 +315,7 @@ void test_graph_no_user_data()
graph_ctx ctx;

auto scope = ctx.host_launch();
scope->*[](reserved::host_launch_deps& deps) {
scope->*[](host_launch_deps& deps) {
EXPECT(deps.size() == 0);
EXPECT(deps.user_data() == nullptr);
EXPECT(deps.user_data_size() == 0);
Expand All @@ -338,7 +338,7 @@ void test_graph_chained()

auto s1 = ctx.host_launch();
s1.add_deps(task_dep_untyped(lX, access_mode::rw));
s1->*[](reserved::host_launch_deps& deps) {
s1->*[](host_launch_deps& deps) {
auto sX = deps.get<slice<double>>(0);
for (size_t i = 0; i < 64; i++)
{
Expand All @@ -348,7 +348,7 @@ void test_graph_chained()

auto s2 = ctx.host_launch();
s2.add_deps(task_dep_untyped(lX, access_mode::rw));
s2->*[](reserved::host_launch_deps& deps) {
s2->*[](host_launch_deps& deps) {
auto sX = deps.get<slice<double>>(0);
for (size_t i = 0; i < 64; i++)
{
Expand Down