diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 8a8b121e..72bb7fd2 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -322,7 +322,12 @@ void Buffer::destroy() { internode::free(mask_buffer_ptr); internode::free(sync_buffer_ptr); } - internode::finalize(); + // internode::finalize(); + + GlobalState::instance().counter++; + if (GlobalState::instance().counter > 1) { + internode::finalize(); + } } #endif diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index cd5be4a3..8bf9d465 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -25,6 +25,27 @@ #define TORCH_EXTENSION_NAME deep_ep_cpp #endif +#ifndef GLOBALS_H +#define GLOBALS_H + +class GlobalState { +public: + static GlobalState& instance() { + static GlobalState inst; + return inst; + } + + int64_t counter; + +private: + GlobalState() : counter(0) {} + GlobalState(const GlobalState&) = delete; + GlobalState& operator=(const GlobalState&) = delete; +}; + +#endif + + namespace shared_memory { union MemHandleInner { diff --git a/csrc/kernels/runtime.cu b/csrc/kernels/runtime.cu index c4fbb8ed..2a823f8f 100644 --- a/csrc/kernels/runtime.cu +++ b/csrc/kernels/runtime.cu @@ -53,6 +53,9 @@ int init(const std::vector& root_unique_id_val, int rank, int num_ranks nvshmemx_set_attr_uniqueid_args(rank, num_ranks, &root_unique_id, &attr); nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr); + // Initialize before nvshmem_team_split_strided + nvshmem_barrier_all(); + // Create sub-RDMA teams // NOTES: if `num_ranks <= NUM_MAX_NVL_PEERS` then only low-latency kernels are used if (low_latency_mode and num_ranks > NUM_MAX_NVL_PEERS) { @@ -68,7 +71,7 @@ int init(const std::vector& root_unique_id_val, int rank, int num_ranks EP_HOST_ASSERT(cpu_rdma_team != NVSHMEM_TEAM_INVALID); } - nvshmem_barrier_all(); + // nvshmem_barrier_all(); return nvshmem_my_pe(); }