diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 6faee3f3978..01f1c300f09 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -112,13 +112,10 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept { // will only swap the ptrs NOT the contents. IrContainer::swap(*(a.ir_container()), *(b.ir_container())); - // Fix parent pointers after swapping containers - // After swap, each Fusion owns a different IrContainer, so we must - // update the parent backpointers in those containers to point to their new - // owners + // After swapping container contents, update Statement::ir_container_ + // pointers so each Statement points to the Fusion whose container now + // holds it. if (a.ir_container_) { - // Also update all Statement ir_container_ pointers to point to new owner - a.ir_container()->parent_ = &a; for (auto val : a.vals()) { val->ir_container_ = &a; } @@ -127,8 +124,6 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept { } } if (b.ir_container_) { - // Also update all Statement ir_container_ pointers to point to new owner - b.ir_container()->parent_ = &b; for (auto val : b.vals()) { val->ir_container_ = &b; } @@ -162,7 +157,8 @@ std::unique_ptr Fusion::segment( IrCloner Fusion::copy(const Fusion* from, Fusion* to) { to->clear(); - auto ir_cloner = IrContainer::copy(from->ir_container(), to->ir_container()); + auto ir_cloner = + IrContainer::copy(from->ir_container(), to->ir_container(), to); // Remap cached special val pointers through the cloner if (from->zero_val_) { @@ -255,8 +251,8 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { } // Default constructor -Fusion::Fusion() : ir_container_(std::make_unique()) { - ir_container_->parent_ = this; +Fusion::Fusion() : ir_container_(std::make_shared()) { + ir_container_->addFusion(this); } // Copy constructor @@ -288,6 +284,9 @@ Fusion& Fusion::operator=(Fusion&& other) noexcept { Fusion::~Fusion() { clear(); + if (ir_container_) { + ir_container_->removeFusion(this); + } } void Fusion::clear() noexcept { diff --git a/csrc/fusion.h b/csrc/fusion.h index 4fbf71c988f..6ba7a3788df 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -149,7 +149,6 @@ class NVF_API Fusion : public PolymorphicBase { typedef std::unordered_map> PermutationMap; protected: - // Direct access to underlying container IrContainer* ir_container() { NVF_ERROR( ir_container_.get() != nullptr, @@ -164,6 +163,10 @@ class NVF_API Fusion : public PolymorphicBase { return ir_container_.get(); } + std::shared_ptr ir_container_ptr() const { + return ir_container_; + } + public: // Registration (public API with passkey) virtual void registerStmt(IrBuilderPasskey, Statement* stmt) { @@ -636,7 +639,7 @@ class NVF_API Fusion : public PolymorphicBase { std::unique_ptr> all_tvs_ptr_ = nullptr; inline static const std::string exact_mappings_key = "exact_mappings"; - std::unique_ptr ir_container_; + std::shared_ptr ir_container_; Val* zero_val_ = nullptr; Val* one_val_ = nullptr; diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index 0aea7a0287e..f0537714ce0 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -75,14 +75,15 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { std::swap(a.val_type_name_map_, b.val_type_name_map_); std::swap(a.expr_name_counter_, b.expr_name_counter_); - - std::swap(a.parent_, b.parent_); } -IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { +IrCloner IrContainer::copy( + const IrContainer* from, + IrContainer* to, + Fusion* dest_fusion) { to->clear(); - IrCloner ir_cloner(to->parent()); + IrCloner ir_cloner(dest_fusion); // Copy values in deterministic order for (auto val : from->deterministic_vals()) { @@ -133,7 +134,7 @@ bool IrContainer::inContainer(const Statement* const_stmt) const { } NVF_ERROR( - const_stmt->container() == this->parent(), + sharing_fusions_.count(const_stmt->container()) > 0, "Container claims to own stmt, but stmt disagrees."); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) @@ -152,4 +153,29 @@ bool IrContainer::inContainer(const Statement* const_stmt) const { return true; } +void IrContainer::addFusion(Fusion* fusion) { + sharing_fusions_.insert(fusion); +} + +void IrContainer::removeFusion(Fusion* fusion) { + sharing_fusions_.erase(fusion); +} + +void IrContainer::transferFusion(Fusion* from, Fusion* to) { + sharing_fusions_.erase(from); + sharing_fusions_.insert(to); +} + +size_t IrContainer::sharingCount() const { + return sharing_fusions_.size(); +} + +bool IrContainer::hasMultipleFusions() const { + return sharing_fusions_.size() > 1; +} + +const std::unordered_set& IrContainer::sharingFusions() const { + return sharing_fusions_; +} + } // namespace nvfuser diff --git a/csrc/ir/container.h b/csrc/ir/container.h index a61f78be5f2..e3738be7349 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -86,7 +86,10 @@ class IrContainer { } protected: - static IrCloner copy(const IrContainer* from, IrContainer* to); + static IrCloner copy( + const IrContainer* from, + IrContainer* to, + Fusion* dest_fusion); static void swap(IrContainer& a, IrContainer& b) noexcept; @@ -127,16 +130,15 @@ class IrContainer { StmtNameType expr_name_counter_ = 0; public: - Fusion* parent() const { - NVF_ERROR( - parent_ != nullptr, "Call to IrContainer::parent() holds nullptr.") - return parent_; - } + void addFusion(Fusion* fusion); + void removeFusion(Fusion* fusion); + void transferFusion(Fusion* from, Fusion* to); + size_t sharingCount() const; + bool hasMultipleFusions() const; + const std::unordered_set& sharingFusions() const; private: - // Parent Fusion that owns this container (for pure composition pattern) - // Used by Statement::fusion() to navigate back to owning Fusion - Fusion* parent_ = nullptr; + std::unordered_set sharing_fusions_; }; } // namespace nvfuser diff --git a/csrc/runtime/fusion_kernel_runtime.cpp b/csrc/runtime/fusion_kernel_runtime.cpp index 070fc1b27eb..a76bb19d563 100644 --- a/csrc/runtime/fusion_kernel_runtime.cpp +++ b/csrc/runtime/fusion_kernel_runtime.cpp @@ -26,6 +26,9 @@ namespace nvfuser { +// TODO: Remove when std::shared_mutex is added to IrContainer. +constexpr bool kPhase2DisableParallelCompile = true; + namespace { // Replace CUDA tensor with Meta tensor because storing tensors can cause // out-of-memory issues. Other arguments are returned as-is. @@ -436,7 +439,8 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) { try { for (const auto& [group_to_run, group_runtime_inputs] : zip(runtime_workspace_.group_run_order, all_runtime_inputs)) { - if (num_groups == 1 || isOptionDisabled(DisableOption::ParallelCompile)) { + if (num_groups == 1 || kPhase2DisableParallelCompile || + isOptionDisabled(DisableOption::ParallelCompile)) { compileKernel(group_runtime_inputs, group_to_run); } else { // launch compileKernel thread here @@ -470,7 +474,8 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) { throw; } - if (num_groups != 1 && !isOptionDisabled(DisableOption::ParallelCompile)) { + if (num_groups != 1 && !kPhase2DisableParallelCompile && + !isOptionDisabled(DisableOption::ParallelCompile)) { // Wait until all segments finish compiling getThreadPool()->waitWorkComplete(); NVF_ERROR(