From 3a04799af57f2696fe0721d6c06988d1507ffab4 Mon Sep 17 00:00:00 2001 From: Sannidhya Chauhan Date: Mon, 13 Apr 2026 12:25:00 -0700 Subject: [PATCH] Expose ProfilerSessionOrchestrator in JAX profiler bindings. PiperOrigin-RevId: 899127579 --- tsl/profiler/lib/BUILD | 42 ++++- tsl/profiler/lib/profiler_collection.cc | 23 +++ tsl/profiler/lib/profiler_collection.h | 4 +- tsl/profiler/lib/profiler_controller.cc | 9 ++ tsl/profiler/lib/profiler_controller.h | 6 + tsl/profiler/lib/profiler_interface.h | 13 +- tsl/profiler/lib/profiler_orchestrator.cc | 120 +++++++++++++++ tsl/profiler/lib/profiler_orchestrator.h | 66 ++++++++ .../lib/profiler_orchestrator_test.cc | 143 ++++++++++++++++++ tsl/profiler/lib/profiler_session.cc | 20 +++ tsl/profiler/lib/profiler_session.h | 10 +- 11 files changed, 447 insertions(+), 9 deletions(-) create mode 100644 tsl/profiler/lib/profiler_orchestrator.cc create mode 100644 tsl/profiler/lib/profiler_orchestrator.h create mode 100644 tsl/profiler/lib/profiler_orchestrator_test.cc diff --git a/tsl/profiler/lib/BUILD b/tsl/profiler/lib/BUILD index 1071f08e2..54c1ee3b2 100644 --- a/tsl/profiler/lib/BUILD +++ b/tsl/profiler/lib/BUILD @@ -140,7 +140,7 @@ cc_library( ]), deps = [ "//tsl/profiler/protobuf:xplane_proto_cc", - "@xla//xla/tsl/platform:status", + "@com_google_absl//absl/status", ], ) @@ -378,7 +378,6 @@ cc_library( ":profiler_interface", "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/status", - "@xla//xla/tsl/platform:status", ], ) @@ -391,3 +390,42 @@ cc_library( "@com_google_absl//absl/strings:string_view", ], ) + +cc_library( + name = "profiler_orchestrator", + srcs = ["profiler_orchestrator.cc"], + hdrs = ["profiler_orchestrator.h"], + visibility = internal_visibility([ + "@xla//xla/python:__pkg__", + ]), + deps = [ + ":profiler_session", + "//tsl/platform:platform_port", + "//tsl/profiler/protobuf:profiler_options_proto_cc", + "//tsl/profiler/protobuf:xplane_proto_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/profiler/convert:post_process_single_host_xplane", + "@xla//xla/tsl/profiler/utils:xplane_builder", + "@xla//xla/tsl/profiler/utils:xplane_schema", + "@xla//xla/tsl/profiler/utils:xplane_utils", + ], +) + +tsl_cc_test( + name = "profiler_orchestrator_test", + srcs = ["profiler_orchestrator_test.cc"], + deps = [ + ":profiler_factory", + ":profiler_interface", + ":profiler_orchestrator", + ":profiler_session", + "//tsl/profiler/protobuf:profiler_options_proto_cc", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@xla//xla/tsl/platform:test", + ], +) diff --git a/tsl/profiler/lib/profiler_collection.cc b/tsl/profiler/lib/profiler_collection.cc index f3ffec62b..253870533 100644 --- a/tsl/profiler/lib/profiler_collection.cc +++ b/tsl/profiler/lib/profiler_collection.cc @@ -55,5 +55,28 @@ absl::Status ProfilerCollection::CollectData( return status; } +absl::Status ProfilerCollection::Consume(void* ptr) { + absl::Status status; + for (auto& profiler : profilers_) { + absl::Status s = profiler->Consume(ptr); + if (!s.ok() && !absl::IsUnimplemented(s)) { + status.Update(s); + } + } + return status; +} + +absl::Status ProfilerCollection::Serialize( + void* ptr, tensorflow::profiler::XSpace* output_space) { + absl::Status status; + for (auto& profiler : profilers_) { + absl::Status s = profiler->Serialize(ptr, output_space); + if (!s.ok() && !absl::IsUnimplemented(s)) { + status.Update(s); + } + } + return status; +} + } // namespace profiler } // namespace tsl diff --git a/tsl/profiler/lib/profiler_collection.h b/tsl/profiler/lib/profiler_collection.h index e2b9fd3ef..b206b449b 100644 --- a/tsl/profiler/lib/profiler_collection.h +++ b/tsl/profiler/lib/profiler_collection.h @@ -19,7 +19,6 @@ limitations under the License. #include #include "absl/status/status.h" -#include "xla/tsl/platform/status.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/xplane.pb.h" @@ -38,6 +37,9 @@ class ProfilerCollection : public ProfilerInterface { absl::Status Stop() override; absl::Status CollectData(tensorflow::profiler::XSpace* space) override; + absl::Status Consume(void* ptr) override; + absl::Status Serialize(void* ptr, + tensorflow::profiler::XSpace* output_space) override; private: std::vector> profilers_; diff --git a/tsl/profiler/lib/profiler_controller.cc b/tsl/profiler/lib/profiler_controller.cc index d9c58717c..4ba378b9f 100644 --- a/tsl/profiler/lib/profiler_controller.cc +++ b/tsl/profiler/lib/profiler_controller.cc @@ -85,5 +85,14 @@ absl::Status ProfilerController::CollectData( return status; } +absl::Status ProfilerController::Consume(void* ptr) { + return profiler_->Consume(ptr); +} + +absl::Status ProfilerController::Serialize( + void* ptr, tensorflow::profiler::XSpace* output_space) { + return profiler_->Serialize(ptr, output_space); +} + } // namespace profiler } // namespace tsl diff --git a/tsl/profiler/lib/profiler_controller.h b/tsl/profiler/lib/profiler_controller.h index cc0334e9d..1afdb593c 100644 --- a/tsl/profiler/lib/profiler_controller.h +++ b/tsl/profiler/lib/profiler_controller.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PROFILER_LIB_PROFILER_CONTROLLER_H_ #define TENSORFLOW_TSL_PROFILER_LIB_PROFILER_CONTROLLER_H_ +#include #include #include "absl/status/status.h" @@ -45,6 +46,11 @@ class ProfilerController : public ProfilerInterface { absl::Status CollectData(tensorflow::profiler::XSpace* space) override; + absl::Status Consume(void* ptr) override; + + absl::Status Serialize(void* ptr, + tensorflow::profiler::XSpace* output_space) override; + private: enum class ProfilerState { kInit = 0, diff --git a/tsl/profiler/lib/profiler_interface.h b/tsl/profiler/lib/profiler_interface.h index 2b0b71242..a5b7306a1 100644 --- a/tsl/profiler/lib/profiler_interface.h +++ b/tsl/profiler/lib/profiler_interface.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PROFILER_LIB_PROFILER_INTERFACE_H_ #define TENSORFLOW_TSL_PROFILER_LIB_PROFILER_INTERFACE_H_ -#include "xla/tsl/platform/status.h" +#include "absl/status/status.h" #include "tsl/profiler/protobuf/xplane.pb.h" namespace tsl { @@ -41,6 +41,17 @@ class ProfilerInterface { // Saves collected profile data into XSpace. virtual absl::Status CollectData(tensorflow::profiler::XSpace* space) = 0; + + // Pulls collected profile data into arbitrary raw memory. + virtual absl::Status Consume(void* ptr) { + return absl::UnimplementedError("Consume not implemented"); + } + + // Serializes collected profile data into XSpace. + virtual absl::Status Serialize(void* ptr, + tensorflow::profiler::XSpace* output_space) { + return absl::UnimplementedError("Serialize not implemented"); + } }; } // namespace profiler diff --git a/tsl/profiler/lib/profiler_orchestrator.cc b/tsl/profiler/lib/profiler_orchestrator.cc new file mode 100644 index 000000000..e586fd69f --- /dev/null +++ b/tsl/profiler/lib/profiler_orchestrator.cc @@ -0,0 +1,120 @@ +/* Copyright 2026 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tsl/profiler/lib/profiler_orchestrator.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/profiler/convert/post_process_single_host_xplane.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" +#include "tsl/platform/host_info.h" +#include "tsl/profiler/lib/profiler_session.h" +#include "tsl/profiler/protobuf/profiler_options.pb.h" + +namespace tsl { +namespace profiler { + +ProfilerSessionOrchestrator::ProfilerSessionOrchestrator( + const tensorflow::ProfileOptions& options) + : options_(options) {} + +ProfilerSessionOrchestrator::~ProfilerSessionOrchestrator() { + Stop().IgnoreError(); +} + +absl::Status ProfilerSessionOrchestrator::Start() { + if (session_ != nullptr) { + return absl::FailedPreconditionError("Session already started."); + } + global_start_time_ns_ = tsl::Env::Default()->NowNanos(); + session_ = tsl::ProfilerSession::Create(options_); + if (session_ == nullptr) { + return absl::InternalError("Failed to create ProfilerSession."); + } + return session_->Status(); +} + +absl::Status ProfilerSessionOrchestrator::Stop() { + if (session_ == nullptr) { + return absl::OkStatus(); // Already stopped or not started. + } + session_.reset(); + return absl::OkStatus(); +} + +absl::StatusOr ProfilerSessionOrchestrator::Consume() { + if (session_ == nullptr) { + return absl::FailedPreconditionError("Session not started."); + } + + consume_buffers_.emplace_back(sizeof(std::vector)); + auto& buffer = consume_buffers_.back(); + TF_RETURN_IF_ERROR(session_->Consume(buffer.data())); + + uint64_t now = tsl::Env::Default()->NowNanos(); + consume_stop_times_.push_back(now); + + return consume_buffers_.size() - 1; +} + +absl::Status ProfilerSessionOrchestrator::Serialize(int buffer_index) { + if (session_ == nullptr) { + return absl::FailedPreconditionError("Session not started."); + } + + if (buffer_index < 0 || buffer_index >= consume_buffers_.size()) { + return absl::InvalidArgumentError("Invalid buffer index."); + } + serialize_space_.Clear(); + auto& buffer = consume_buffers_[buffer_index]; + TF_RETURN_IF_ERROR(session_->Serialize(buffer.data(), &serialize_space_)); + + serialize_space_.add_hostnames(tsl::port::Hostname()); + profiler::SetXSpacePidIfNotSet(serialize_space_, + tsl::Env::Default()->GetProcessId()); + + uint64_t stop_time = 0; + if (buffer_index < consume_stop_times_.size()) { + stop_time = consume_stop_times_[buffer_index]; + } + profiler::PostProcessSingleHostXSpace(&serialize_space_, + global_start_time_ns_, stop_time); + + { + profiler::XPlaneBuilder xplane(profiler::FindOrAddMutablePlaneWithName( + &serialize_space_, tsl::profiler::kTaskEnvPlaneName)); + xplane.AddStatValue( + *xplane.GetOrCreateStatMetadata(tsl::profiler::GetTaskEnvStatTypeStr( + tsl::profiler::kEnvProfileOptions)), + options_); + } + + return absl::OkStatus(); +} + +void ProfilerSessionOrchestrator::ClearConsumeBuffers() { + std::vector>().swap(consume_buffers_); + consume_stop_times_.clear(); +} + +} // namespace profiler +} // namespace tsl diff --git a/tsl/profiler/lib/profiler_orchestrator.h b/tsl/profiler/lib/profiler_orchestrator.h new file mode 100644 index 000000000..c9f83dc1c --- /dev/null +++ b/tsl/profiler/lib/profiler_orchestrator.h @@ -0,0 +1,66 @@ +/* Copyright 2026 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_TSL_PROFILER_LIB_PROFILER_ORCHESTRATOR_H_ +#define TENSORFLOW_TSL_PROFILER_LIB_PROFILER_ORCHESTRATOR_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tsl/profiler/lib/profiler_session.h" +#include "tsl/profiler/protobuf/profiler_options.pb.h" +#include "tsl/profiler/protobuf/xplane.pb.h" + +namespace tsl { +namespace profiler { + +class ProfilerSessionOrchestrator { + public: + explicit ProfilerSessionOrchestrator( + const tensorflow::ProfileOptions& options); + ~ProfilerSessionOrchestrator(); + + absl::Status Start(); + + absl::StatusOr Consume(); + + absl::Status Serialize(int buffer_index); + + absl::Status Stop(); + + void ClearConsumeBuffers(); + + const std::vector& GetConsumeBuffer(int index) const { + return consume_buffers_[index]; + } + const tensorflow::profiler::XSpace& GetSerializeSpace() const { + return serialize_space_; + } + + private: + tensorflow::ProfileOptions options_; + std::unique_ptr session_; + std::vector> consume_buffers_; + tensorflow::profiler::XSpace serialize_space_; + uint64_t global_start_time_ns_; + std::vector consume_stop_times_; +}; + +} // namespace profiler +} // namespace tsl + +#endif // TENSORFLOW_TSL_PROFILER_LIB_PROFILER_ORCHESTRATOR_H_ diff --git a/tsl/profiler/lib/profiler_orchestrator_test.cc b/tsl/profiler/lib/profiler_orchestrator_test.cc new file mode 100644 index 000000000..45996ec05 --- /dev/null +++ b/tsl/profiler/lib/profiler_orchestrator_test.cc @@ -0,0 +1,143 @@ +/* Copyright 2026 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#include "tsl/profiler/lib/profiler_orchestrator.h" + +#include + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "xla/tsl/platform/test.h" +#include "tsl/profiler/lib/profiler_factory.h" +#include "tsl/profiler/lib/profiler_interface.h" +#include "tsl/profiler/lib/profiler_session.h" +#include "tsl/profiler/protobuf/profiler_options.pb.h" + +namespace tsl { +namespace profiler { +namespace { + +class MockProfiler : public ProfilerInterface { + public: + absl::Status Start() override { return absl::OkStatus(); } + absl::Status Stop() override { return absl::OkStatus(); } + absl::Status CollectData(tensorflow::profiler::XSpace*) override { + return absl::OkStatus(); + } + absl::Status Consume(void* ptr) override { + consume_called_ = true; + return absl::OkStatus(); + } + absl::Status Serialize(void* ptr, + tensorflow::profiler::XSpace* output_space) override { + serialize_called_ = true; + return absl::OkStatus(); + } + + bool consume_called() const { return consume_called_; } + bool serialize_called() const { return serialize_called_; } + + private: + bool consume_called_ = false; + bool serialize_called_ = false; +}; + +TEST(ProfilerSessionOrchestratorTest, SimpleLifecycle) { + ClearRegisteredProfilersForTest(); + + static MockProfiler* active_mock = nullptr; + + RegisterProfilerFactory([](const tensorflow::ProfileOptions& options) { + auto mock = absl::make_unique(); + active_mock = mock.get(); + return mock; + }); + + tensorflow::ProfileOptions options = ProfilerSession::DefaultOptions(); + ProfilerSessionOrchestrator orchestrator(options); + + ASSERT_OK(orchestrator.Start()); + auto index_or = orchestrator.Consume(); + ASSERT_OK(index_or.status()); + ASSERT_OK(orchestrator.Serialize(index_or.value())); + + EXPECT_TRUE(active_mock != nullptr); + if (active_mock) { + EXPECT_TRUE(active_mock->consume_called()); + EXPECT_TRUE(active_mock->serialize_called()); + } + + ASSERT_OK(orchestrator.Stop()); +} + +TEST(ProfilerSessionOrchestratorTest, MultipleConsumeAndSelectiveSerialize) { + ClearRegisteredProfilersForTest(); + + RegisterProfilerFactory([](const tensorflow::ProfileOptions& options) { + return absl::make_unique(); + }); + + tensorflow::ProfileOptions options = ProfilerSession::DefaultOptions(); + // Using default sizes here + ProfilerSessionOrchestrator orchestrator(options); + + ASSERT_OK(orchestrator.Start()); + + auto index1_or = orchestrator.Consume(); + ASSERT_OK(index1_or.status()); + + auto index2_or = orchestrator.Consume(); + ASSERT_OK(index2_or.status()); + + EXPECT_NE(index1_or.value(), index2_or.value()); + + ASSERT_OK(orchestrator.Serialize(index1_or.value())); + ASSERT_OK(orchestrator.Serialize(index2_or.value())); + + ASSERT_OK(orchestrator.Stop()); +} + +TEST(ProfilerSessionOrchestratorTest, ClearConsumeBuffers) { + ClearRegisteredProfilersForTest(); + + RegisterProfilerFactory([](const tensorflow::ProfileOptions& options) { + return absl::make_unique(); + }); + + tensorflow::ProfileOptions options = ProfilerSession::DefaultOptions(); + ProfilerSessionOrchestrator orchestrator(options); + + ASSERT_OK(orchestrator.Start()); + + auto index1_or = orchestrator.Consume(); + ASSERT_OK(index1_or.status()); + EXPECT_EQ(index1_or.value(), 0); + + orchestrator.ClearConsumeBuffers(); + + auto index2_or = orchestrator.Consume(); + ASSERT_OK(index2_or.status()); + EXPECT_EQ(index2_or.value(), 0); // Should be 0 again after clear! + + ASSERT_OK(orchestrator.Stop()); +} + +} // namespace +} // namespace profiler +} // namespace tsl + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tsl/profiler/lib/profiler_session.cc b/tsl/profiler/lib/profiler_session.cc index ce2333f87..3ee8ecb28 100644 --- a/tsl/profiler/lib/profiler_session.cc +++ b/tsl/profiler/lib/profiler_session.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tsl/profiler/lib/profiler_session.h" +#include #include #include @@ -90,6 +91,25 @@ absl::Status ProfilerSession::CollectDataInternal(XSpace* space) { profiler_lock_.ReleaseIfActive(); return absl::OkStatus(); } + +absl::Status ProfilerSession::Consume(void* ptr) { + absl::MutexLock l(mutex_); + TF_RETURN_IF_ERROR(status_); + if (profilers_ == nullptr) { + return absl::FailedPreconditionError("No active profilers in session."); + } + return profilers_->Consume(ptr); +} + +absl::Status ProfilerSession::Serialize( + void* ptr, tensorflow::profiler::XSpace* output_space) { + absl::MutexLock l(mutex_); + TF_RETURN_IF_ERROR(status_); + if (profilers_ == nullptr) { + return absl::FailedPreconditionError("No active profilers in session."); + } + return profilers_->Serialize(ptr, output_space); +} #endif absl::Status ProfilerSession::CollectData(XSpace* space) { diff --git a/tsl/profiler/lib/profiler_session.h b/tsl/profiler/lib/profiler_session.h index 6644017f5..8df9af186 100644 --- a/tsl/profiler/lib/profiler_session.h +++ b/tsl/profiler/lib/profiler_session.h @@ -15,15 +15,11 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PROFILER_LIB_PROFILER_SESSION_H_ #define TENSORFLOW_TSL_PROFILER_LIB_PROFILER_SESSION_H_ -#include +#include #include -#include #include "absl/status/status.h" #include "absl/synchronization/mutex.h" -#include "xla/tsl/platform/status.h" -#include "xla/tsl/platform/types.h" -#include "tsl/platform/platform.h" #include "tsl/platform/thread_annotations.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" @@ -67,6 +63,10 @@ class ProfilerSession { absl::Status CollectData(tensorflow::profiler::XSpace* space) TF_LOCKS_EXCLUDED(mutex_); + absl::Status Consume(void* ptr) TF_LOCKS_EXCLUDED(mutex_); + absl::Status Serialize(void* ptr, tensorflow::profiler::XSpace* output_space) + TF_LOCKS_EXCLUDED(mutex_); + private: // Constructs an instance of the class and starts profiling explicit ProfilerSession(const tensorflow::ProfileOptions& options);