Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion csrc/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,12 @@ void Buffer::destroy() {
internode::free(mask_buffer_ptr);
internode::free(sync_buffer_ptr);
}
internode::finalize();
// internode::finalize();

GlobalState::instance().counter++;
if (GlobalState::instance().counter > 1) {
internode::finalize();
}
}
#endif

Expand Down
21 changes: 21 additions & 0 deletions csrc/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,27 @@
#define TORCH_EXTENSION_NAME deep_ep_cpp
#endif

#ifndef GLOBALS_H
#define GLOBALS_H

class GlobalState {
public:
static GlobalState& instance() {
static GlobalState inst;
return inst;
}

int64_t counter;

private:
GlobalState() : counter(0) {}
GlobalState(const GlobalState&) = delete;
GlobalState& operator=(const GlobalState&) = delete;
};

#endif


namespace shared_memory {

union MemHandleInner {
Expand Down
5 changes: 4 additions & 1 deletion csrc/kernels/runtime.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ int init(const std::vector<uint8_t>& root_unique_id_val, int rank, int num_ranks
nvshmemx_set_attr_uniqueid_args(rank, num_ranks, &root_unique_id, &attr);
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);

// Initialize before nvshmem_team_split_strided
nvshmem_barrier_all();

// Create sub-RDMA teams
// NOTES: if `num_ranks <= NUM_MAX_NVL_PEERS` then only low-latency kernels are used
if (low_latency_mode and num_ranks > NUM_MAX_NVL_PEERS) {
Expand All @@ -68,7 +71,7 @@ int init(const std::vector<uint8_t>& root_unique_id_val, int rank, int num_ranks
EP_HOST_ASSERT(cpu_rdma_team != NVSHMEM_TEAM_INVALID);
}

nvshmem_barrier_all();
// nvshmem_barrier_all();
return nvshmem_my_pe();
}

Expand Down