diff --git a/gigl/src/common/custom_launcher.py b/gigl/src/common/custom_launcher.py new file mode 100644 index 000000000..10c3115e9 --- /dev/null +++ b/gigl/src/common/custom_launcher.py @@ -0,0 +1,96 @@ +"""Subprocess dispatch for ``CustomLauncherConfig``-backed launchers. + +Takes ``CustomLauncherConfig.command`` and ``CustomLauncherConfig.args`` +verbatim and shells out via ``subprocess.run(shell_line, shell=True)``. +The shell-style invocation honors leading ``KEY=VALUE`` env-var +assignments in ``command`` so callers can self-document required env +without forcing the dispatcher to parse env separately. + +The receiving subprocess has no special protocol — it is expected to be +a plain CLI that argparses whatever flags the YAML wires up via +``args[]``. The dispatcher performs no template substitution; any +dynamic content (runtime URIs, image refs, etc.) is the caller's +responsibility — typically resolved at YAML-load time before the +proto reaches this module. +""" + +import shlex +import subprocess +from collections.abc import Mapping +from typing import Optional + +from gigl.common import Uri +from gigl.common.logger import Logger +from gigl.src.common.constants.components import GiGLComponents +from snapchat.research.gbml.gigl_resource_config_pb2 import CustomLauncherConfig + +logger = Logger() + +_LAUNCHABLE_COMPONENTS: frozenset[GiGLComponents] = frozenset( + {GiGLComponents.Trainer, GiGLComponents.Inferencer} +) + + +def launch_custom( + custom_launcher_config: CustomLauncherConfig, + applied_task_identifier: str, + task_config_uri: Uri, + resource_config_uri: Uri, + process_command: str, + process_runtime_args: Mapping[str, str], + cpu_docker_uri: Optional[str], + cuda_docker_uri: Optional[str], + component: GiGLComponents, +) -> None: + """Shell out to ``custom_launcher_config.command`` with ``args[]`` appended. + + Composes a shell line as ``command`` followed by each ``args[]`` + element passed through ``shlex.quote``, then invokes + ``subprocess.run(shell_line, shell=True, check=True)``. + + The dispatcher takes ``command`` and ``args[]`` verbatim — no + template substitution of any kind. Any placeholder text in those + fields reaches ``subprocess.run`` literally; consumers that want + substitution should resolve it at YAML-load time before the proto + reaches this module. + + ``applied_task_identifier``, ``task_config_uri``, + ``resource_config_uri``, ``process_command``, + ``process_runtime_args``, ``cpu_docker_uri``, and ``cuda_docker_uri`` + are accepted for API symmetry with the GLT-side Vertex AI launchers + but are intentionally not plumbed into the subprocess — the + receiving CLI is expected to source whatever context it needs from + the resource config it gets handed (or from env vars inherited from + the parent process). + + Args: + custom_launcher_config: Proto whose ``command`` is the shell + snippet to execute and whose ``args`` are positional + arguments appended verbatim. + applied_task_identifier: Accepted for back-compat; ignored. + task_config_uri: Accepted for back-compat; ignored. + resource_config_uri: Accepted for back-compat; ignored. + process_command: Accepted for back-compat; ignored. + process_runtime_args: Accepted for back-compat; ignored. + cpu_docker_uri: Accepted for back-compat; ignored. + cuda_docker_uri: Accepted for back-compat; ignored. + component: Which GiGL component is being launched. Must be in + ``_LAUNCHABLE_COMPONENTS``. + + Raises: + ValueError: If ``component`` is not Trainer or Inferencer, or if + ``custom_launcher_config.command`` is empty. + subprocess.CalledProcessError: If the spawned subprocess exits + non-zero. + """ + if component not in _LAUNCHABLE_COMPONENTS: + raise ValueError(f"Invalid component: {component}") + if not custom_launcher_config.command: + raise ValueError("CustomLauncherConfig.command must be set") + + command: str = custom_launcher_config.command + args: list[str] = list(custom_launcher_config.args) + + shell_line = " ".join([command, *(shlex.quote(a) for a in args)]) + logger.info(f"Launching {component.name} via subprocess: {shell_line!r}") + subprocess.run(shell_line, shell=True, check=True) diff --git a/tests/unit/src/common/custom_launcher_test.py b/tests/unit/src/common/custom_launcher_test.py new file mode 100644 index 000000000..a18971591 --- /dev/null +++ b/tests/unit/src/common/custom_launcher_test.py @@ -0,0 +1,148 @@ +"""Unit tests for ``gigl.src.common.custom_launcher``.""" + +from unittest.mock import MagicMock, patch + +from absl.testing import absltest + +from gigl.common import Uri +from gigl.src.common.constants.components import GiGLComponents +from gigl.src.common.custom_launcher import launch_custom +from snapchat.research.gbml import gigl_resource_config_pb2 +from tests.test_assets.test_case import TestCase + + +class TestLaunchCustom(TestCase): + """Exercises ``launch_custom`` subprocess dispatch and guards. + + The launcher takes ``command`` and ``args[]`` from the proto + verbatim (no template substitution) and shells out via + ``subprocess.run``. Tests patch ``subprocess.run`` to capture the + composed shell line without actually spawning processes. + """ + + def _build_config( + self, + command: str, + args: list[str] | None = None, + ) -> gigl_resource_config_pb2.CustomLauncherConfig: + return gigl_resource_config_pb2.CustomLauncherConfig( + command=command, + args=args or [], + ) + + @patch("gigl.src.common.custom_launcher.subprocess.run") + def test_dispatches_subprocess_with_literal_command_and_args( + self, mock_run: MagicMock + ) -> None: + config = self._build_config( + command="python -m my.cli", + args=["--foo=bar", "--baz=qux"], + ) + launch_custom( + custom_launcher_config=config, + applied_task_identifier="job-42", + task_config_uri=Uri("gs://bucket/task.yaml"), + resource_config_uri=Uri("gs://bucket/resource.yaml"), + process_command="ignored", + process_runtime_args={"ignored": "v"}, + cpu_docker_uri="gcr.io/p/cpu:tag", + cuda_docker_uri="gcr.io/p/cuda:tag", + component=GiGLComponents.Trainer, + ) + + mock_run.assert_called_once() + shell_line = mock_run.call_args.args[0] + self.assertIn("python -m my.cli", shell_line) + self.assertIn("--foo=bar", shell_line) + self.assertIn("--baz=qux", shell_line) + # subprocess invoked with shell=True and check=True. + self.assertTrue(mock_run.call_args.kwargs.get("shell", False)) + self.assertTrue(mock_run.call_args.kwargs.get("check", False)) + + @patch("gigl.src.common.custom_launcher.subprocess.run") + def test_empty_command_raises_value_error(self, mock_run: MagicMock) -> None: + config = self._build_config(command="", args=["ignored"]) + with self.assertRaises(ValueError): + launch_custom( + custom_launcher_config=config, + applied_task_identifier="job", + task_config_uri=Uri("gs://bucket/task.yaml"), + resource_config_uri=Uri("gs://bucket/resource.yaml"), + process_command="", + process_runtime_args={}, + cpu_docker_uri=None, + cuda_docker_uri=None, + component=GiGLComponents.Trainer, + ) + mock_run.assert_not_called() + + @patch("gigl.src.common.custom_launcher.subprocess.run") + def test_invalid_component_raises_value_error(self, mock_run: MagicMock) -> None: + config = self._build_config(command="echo") + with self.assertRaises(ValueError): + launch_custom( + custom_launcher_config=config, + applied_task_identifier="job", + task_config_uri=Uri("gs://bucket/task.yaml"), + resource_config_uri=Uri("gs://bucket/resource.yaml"), + process_command="", + process_runtime_args={}, + cpu_docker_uri=None, + cuda_docker_uri=None, + component=GiGLComponents.DataPreprocessor, + ) + mock_run.assert_not_called() + + @patch("gigl.src.common.custom_launcher.subprocess.run") + def test_args_with_spaces_are_shell_quoted(self, mock_run: MagicMock) -> None: + config = self._build_config( + command="echo", args=["a b c", "--name=with space"] + ) + launch_custom( + custom_launcher_config=config, + applied_task_identifier="job", + task_config_uri=Uri("gs://bucket/task.yaml"), + resource_config_uri=Uri("gs://bucket/resource.yaml"), + process_command="", + process_runtime_args={}, + cpu_docker_uri=None, + cuda_docker_uri=None, + component=GiGLComponents.Trainer, + ) + shell_line = mock_run.call_args.args[0] + # shlex.quote wraps tokens with spaces in single quotes so the + # shell sees one argv element per proto args[] entry. + self.assertIn("'a b c'", shell_line) + self.assertIn("'--name=with space'", shell_line) + + @patch("gigl.src.common.custom_launcher.subprocess.run") + def test_unsubstituted_gigl_placeholder_passes_through_verbatim( + self, mock_run: MagicMock + ) -> None: + # The launcher performs no template substitution: any + # ``${gigl:*}`` placeholder in command/args reaches subprocess + # unchanged. Consumers that want substitution must resolve at + # YAML-load time before the proto reaches this module. + config = self._build_config( + command="python -m my.cli", + args=["--foo=${gigl:bar}"], + ) + launch_custom( + custom_launcher_config=config, + applied_task_identifier="job", + task_config_uri=Uri("gs://bucket/task.yaml"), + resource_config_uri=Uri("gs://bucket/resource.yaml"), + process_command="", + process_runtime_args={}, + cpu_docker_uri=None, + cuda_docker_uri=None, + component=GiGLComponents.Trainer, + ) + shell_line = mock_run.call_args.args[0] + # The placeholder is preserved verbatim inside the shell-quoted + # arg. + self.assertIn("${gigl:bar}", shell_line) + + +if __name__ == "__main__": + absltest.main()