From 44d0853a9d29f05f2d1a1800b8416178cc7c853d Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Mon, 23 Mar 2026 22:26:28 -0700 Subject: [PATCH] Enable Pathways profiling with jax.profiler.ProfileOptions. This change allows users to configure Pathways profiling by passing a jax.profiler.ProfileOptions object to the start_trace function. The options are translated into the Pathways profile request, enabling control over a subset of parameters. Explicitly, `start_timestamp_ms`, `duration_ms`, `host_tracer_level`, `advanced_configuration`, and `python_tracer_level`. Compatible with JAX 0.9.2 and Pathways images tagged with 0.9.2 and above. PiperOrigin-RevId: 888449980 --- pathwaysutils/profiling.py | 83 ++++++++++++++++++++++++++-- pathwaysutils/test/profiling_test.py | 65 ++++++++++++++++++++-- 2 files changed, 137 insertions(+), 11 deletions(-) diff --git a/pathwaysutils/profiling.py b/pathwaysutils/profiling.py index 34b315e..502309b 100644 --- a/pathwaysutils/profiling.py +++ b/pathwaysutils/profiling.py @@ -57,12 +57,75 @@ def toy_computation() -> None: x.block_until_ready() +def _is_default_profile_options( + profiler_options: jax.profiler.ProfileOptions, +) -> bool: + if jax.version.__version_info__ < (0, 9, 2): + return True + + default_options = jax.profiler.ProfileOptions() + return ( + profiler_options.host_tracer_level == default_options.host_tracer_level + and profiler_options.python_tracer_level + == default_options.python_tracer_level + and profiler_options.duration_ms == default_options.duration_ms + and not getattr(profiler_options, "advanced_configuration", None) + ) + + def _create_profile_request( log_dir: os.PathLike[str] | str, + profiler_options: jax.profiler.ProfileOptions | None = None, ) -> Mapping[str, Any]: """Creates a profile request mapping from the given options.""" - profile_request = {} - profile_request["traceLocation"] = str(log_dir) + profile_request: dict[str, Any] = { + "traceLocation": str(log_dir), + } + + if profiler_options is None or _is_default_profile_options(profiler_options): + return profile_request + + advanced_config = None + if getattr(profiler_options, "advanced_configuration", None): + advanced_config = {} + for k, v in getattr(profiler_options, "advanced_configuration").items(): + # Convert python dict to tensorflow.ProfileOptions.AdvancedConfigValue + # json-compatible dict + if isinstance(v, bool): + advanced_config[k] = {"boolValue": v} + elif isinstance(v, int): + advanced_config[k] = {"intValue": v} + elif isinstance(v, str): + advanced_config[k] = {"stringValue": v} + else: + raise ValueError( + f"Unsupported advanced configuration value type: {type(v)}. " + "Supported types are bool, int, and str." + ) + + xprof_options: dict[str, Any] = { + "traceDirectory": str(log_dir), + } + + if profiler_options.host_tracer_level != 2: + xprof_options["hostTraceLevel"] = profiler_options.host_tracer_level + + pw_trace_opts: dict[str, Any] = {} + if profiler_options.python_tracer_level: + pw_trace_opts["enablePythonTracer"] = bool( + profiler_options.python_tracer_level + ) + + if advanced_config: + pw_trace_opts["advancedConfiguration"] = advanced_config + + if pw_trace_opts: + xprof_options["pwTraceOptions"] = pw_trace_opts + + profile_request["xprofTraceOptions"] = xprof_options + + if profiler_options.duration_ms > 0: + profile_request["maxDurationSecs"] = profiler_options.duration_ms / 1000.0 return profile_request @@ -104,7 +167,7 @@ def start_trace( *, create_perfetto_link: bool = False, create_perfetto_trace: bool = False, - profiler_options: jax.profiler.ProfileOptions | None = None, # pylint: disable=unused-argument + profiler_options: jax.profiler.ProfileOptions | None = None, ) -> None: """Starts a profiler trace. @@ -133,7 +196,6 @@ def start_trace( This feature is experimental for Pathways on Cloud and may not be fully supported. profiler_options: Profiler options to configure the profiler for collection. - Options are not currently supported and ignored. """ if not str(log_dir).startswith("gs://"): raise ValueError(f"log_dir must be a GCS bucket path, got {log_dir}") @@ -144,7 +206,18 @@ def start_trace( "features for Pathways on Cloud and may not be fully supported." ) - _start_pathways_trace_from_profile_request(_create_profile_request(log_dir)) + if jax.version.__version_info__ < (0, 9, 2) and profiler_options is not None: + _logger.warning( + "ProfileOptions are not supported until JAX 0.9.2 and will be omitted. " + "Some options can be specified via command line flags." + ) + profiler_options = None + + profile_request = _create_profile_request(log_dir, profiler_options) + + _logger.debug("Profile request: %s", profile_request) + + _start_pathways_trace_from_profile_request(profile_request) _original_start_trace( log_dir=log_dir, diff --git a/pathwaysutils/test/profiling_test.py b/pathwaysutils/test/profiling_test.py index e2cbe4f..7d524f8 100644 --- a/pathwaysutils/test/profiling_test.py +++ b/pathwaysutils/test/profiling_test.py @@ -14,6 +14,7 @@ import json import logging +import unittest from unittest import mock from absl.testing import absltest @@ -225,9 +226,11 @@ def test_start_trace_success(self): self.mock_toy_computation.assert_called_once() self.mock_plugin_executable_cls.assert_called_once_with( - json.dumps( - {"profileRequest": {"traceLocation": "gs://test_bucket/test_dir"}} - ) + json.dumps({ + "profileRequest": { + "traceLocation": "gs://test_bucket/test_dir", + } + }) ) self.mock_plugin_executable_cls.return_value.call.assert_called_once() self.mock_original_start_trace.assert_called_once_with( @@ -391,10 +394,60 @@ def test_monkey_patched_stop_server(self): mocks["stop_server"].assert_called_once() - def test_create_profile_request_no_options(self): - request = profiling._create_profile_request("gs://bucket/dir") - self.assertEqual(request, {"traceLocation": "gs://bucket/dir"}) + @parameterized.parameters(None, jax.profiler.ProfileOptions()) + def test_create_profile_request_default_options(self, profiler_options): + request = profiling._create_profile_request( + "gs://bucket/dir", profiler_options=profiler_options + ) + self.assertEqual( + request, + { + "traceLocation": "gs://bucket/dir", + }, + ) + + @unittest.skipIf( + jax.version.__version_info__ < (0, 9, 2), + "ProfileOptions requires JAX 0.9.2 or newer", + ) + def test_create_profile_request_with_options(self): + options = jax.profiler.ProfileOptions() + options.host_tracer_level = 2 + options.python_tracer_level = 1 + options.duration_ms = 2000 + options.start_timestamp_ns = 123456789 + options.advanced_configuration = { + "tpu_num_chips_to_profile_per_task": 3, + "tpu_num_sparse_core_tiles_to_trace": 5, + "tpu_trace_mode": "TRACE_COMPUTE", + } + + request = profiling._create_profile_request( + "gs://bucket/dir", profiler_options=options + ) + self.assertEqual( + request, + { + "traceLocation": "gs://bucket/dir", + "maxDurationSecs": 2.0, + "xprofTraceOptions": { + "traceDirectory": "gs://bucket/dir", + "pwTraceOptions": { + "enablePythonTracer": True, + "advancedConfiguration": { + "tpu_num_chips_to_profile_per_task": {"intValue": 3}, + "tpu_num_sparse_core_tiles_to_trace": {"intValue": 5}, + "tpu_trace_mode": {"stringValue": "TRACE_COMPUTE"}, + }, + }, + }, + }, + ) + @unittest.skipIf( + jax.version.__version_info__ < (0, 9, 2), + "ProfileOptions requires JAX 0.9.2 or newer", + ) @parameterized.parameters( ({"traceLocation": "gs://test_bucket/test_dir"},), ({