diff --git a/graph/actions.cpp b/graph/actions.cpp index 39338d8..f7b4659 100644 --- a/graph/actions.cpp +++ b/graph/actions.cpp @@ -379,6 +379,15 @@ std::string CalcAllGradNormAction::to_string() const { return oss.str(); } +std::string CalcAllGradNormAction::get_dot_string() const { + std::ostringstream oss; + oss << "Action_" << action_id << " -> Tensor_" << res->get_id() << ";" << std::endl; + for (const auto& grad : grads) { + oss << "Tensor_" << grad->get_id() << " -> Action_" << action_id << ";" << std::endl; + } + return oss.str(); +} + void ClipGradAction::execute() { assert(lhs != nullptr); // grad assert(rhs != nullptr); // norm @@ -391,6 +400,14 @@ std::string ClipGradAction::to_string() const { return oss.str(); } +std::string ClipGradAction::get_dot_string() const { + std::ostringstream oss; + oss << "Action_" << action_id << " -> Tensor_" << lhs->get_id() << ";" << std::endl; + oss << "Tensor_" << rhs->get_id() << "-> Action_" << action_id << ";" << std::endl; + return oss.str(); + +} + void AdamStepAction::execute() { param->inc_t(); int t = param->get_t(); @@ -1015,7 +1032,7 @@ void printDotGraph() { // build edge auto parent = tensor_view->get_parent(); if (parent != nullptr) { - out << "Tensor_" << parent->get_id() << " -> Tensor_" << tensor_view->get_id() << ";" << std::endl; + out << "Tensor_" << parent->get_id() << " -> Tensor_" << tensor_view->get_id() << "[style=dashed, dir=none];" << std::endl; } } diff --git a/graph/actions.h b/graph/actions.h index 7940b05..75f786e 100644 --- a/graph/actions.h +++ b/graph/actions.h @@ -214,6 +214,7 @@ class CalcAllGradNormAction : public Action { return "CalcAllGradNormAction"; } std::string to_string() const override; + std::string get_dot_string() const override; private: std::vector grads; }; @@ -228,6 +229,7 @@ class ClipGradAction : public Action { return "ClipGradAction"; } std::string to_string() const override; + std::string get_dot_string() const override; private: float grad_clip_val; }; diff --git a/handwritten_recognition_topo.png b/handwritten_recognition_topo.png index 2e92dc9..ba667b6 100644 Binary files a/handwritten_recognition_topo.png and b/handwritten_recognition_topo.png differ