diff --git a/nemo_run/run/torchx_backend/schedulers/slurm.py b/nemo_run/run/torchx_backend/schedulers/slurm.py index 66fabd5d..912bca43 100644 --- a/nemo_run/run/torchx_backend/schedulers/slurm.py +++ b/nemo_run/run/torchx_backend/schedulers/slurm.py @@ -22,6 +22,7 @@ import json import logging import os +import threading import time from dataclasses import asdict from datetime import datetime @@ -78,6 +79,8 @@ def __init__( super().__init__(session_name) self.experiment = experiment self._consecutive_sacct_failures: dict[str, int] = {} + self._start_time_threads: dict[str, threading.Thread] = {} + self._start_time_stop_events: dict[str, threading.Event] = {} # TODO: Move this into the SlurmExecutor def _initialize_tunnel(self, tunnel: SSHTunnel | LocalTunnel): @@ -190,6 +193,41 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t return AppDryRunInfo(req, repr) + def _poll_job_start_time( + self, job_id: str, tunnel: Tunnel, stop_event: threading.Event + ) -> None: + attempt = 0 + while not stop_event.is_set(): + try: + result = tunnel.run( + f"squeue --start --noheader -j {job_id} -o '%i|%S|%T'", + warn=True, + hide=True, + ) + output = (result.stdout or "").strip() + if output and result.return_code == 0: + # Array jobs produce one line per task — print only the first + line = output.splitlines()[0] + parts = line.strip().split("|") + if len(parts) >= 3: + _, start_time, state = parts[0].strip(), parts[1].strip(), parts[2].strip() + now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + print( + f"[SLURM] Job {job_id} - State: {state}, Estimated start: {start_time}, Current time: {now}", + flush=True, + ) + if state.upper() not in ("PENDING", "CF", "CONFIGURING"): + return + else: + print(f"[SLURM] Job {job_id} is no longer pending.", flush=True) + return + except Exception as e: + log.debug(f"Failed to poll start time for job {job_id}: {e}") + + delay = min(30 * (2**attempt), 900) + attempt += 1 + stop_event.wait(delay) + def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest | SlurmRayRequest]) -> str: # type: ignore # Setup req = dryrun_info.request @@ -218,6 +256,23 @@ def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest | SlurmRayReques # Save metadata _save_job_dir(job_id, job_dir, tunnel, slurm_executor.job_details.ls_term) + + # Stop any existing polling thread for this job_id (retry scenario) + if job_id in self._start_time_stop_events: + self._start_time_stop_events.pop(job_id).set() + self._start_time_threads.pop(job_id, None) + + stop_event = threading.Event() + self._start_time_stop_events[job_id] = stop_event + thread = threading.Thread( + target=self._poll_job_start_time, + args=(job_id, self.tunnel, stop_event), + daemon=True, + name=f"slurm-start-time-{job_id}", + ) + self._start_time_threads[job_id] = thread + thread.start() + return job_id def _cancel_existing(self, app_id: str) -> None: @@ -231,6 +286,10 @@ def _cancel_existing(self, app_id: str) -> None: assert self.tunnel, "Tunnel is None." self.tunnel.run(f"scancel {app_id}", hide=False) + if app_id in self._start_time_stop_events: + self._start_time_stop_events.pop(app_id).set() + self._start_time_threads.pop(app_id, None) + def describe(self, app_id: str) -> Optional[DescribeAppResponse]: try: job_dirs = _get_job_dirs() @@ -366,7 +425,11 @@ def log_iter( else: return [f"Failed getting logs for {app_id}"] - def close(self) -> None: ... + def close(self) -> None: + for stop_event in self._start_time_stop_events.values(): + stop_event.set() + self._start_time_threads.clear() + self._start_time_stop_events.clear() class TunnelLogIterator(LogIterator): diff --git a/test/run/torchx_backend/schedulers/test_slurm.py b/test/run/torchx_backend/schedulers/test_slurm.py index 9fc08ba0..3e99a93f 100644 --- a/test/run/torchx_backend/schedulers/test_slurm.py +++ b/test/run/torchx_backend/schedulers/test_slurm.py @@ -18,6 +18,7 @@ import logging import os import tempfile +import threading from unittest import mock import pytest @@ -143,6 +144,7 @@ def test_schedule(slurm_scheduler, slurm_executor): with ( mock.patch.object(SlurmTunnelScheduler, "_initialize_tunnel"), mock.patch("nemo_run.run.torchx_backend.schedulers.slurm._save_job_dir"), + mock.patch.object(SlurmTunnelScheduler, "_poll_job_start_time"), ): # Create a fresh mock tunnel for each test to avoid interference mock_tunnel = mock.MagicMock() @@ -473,6 +475,7 @@ def test_schedule_with_dependencies(slurm_scheduler, slurm_executor): mock.patch.object(SlurmTunnelScheduler, "_initialize_tunnel"), mock.patch.object(SlurmExecutor, "parse_deps", return_value=["54321"]), mock.patch("nemo_run.run.torchx_backend.schedulers.slurm._save_job_dir"), + mock.patch.object(SlurmTunnelScheduler, "_poll_job_start_time"), ): # Create a fresh mock tunnel for testing mock_tunnel = mock.MagicMock() @@ -726,3 +729,216 @@ def test_non_heterogeneous_ray_cluster(slurm_scheduler, temp_dir): # Verify run_as_group was NOT set assert not hasattr(executor, "run_as_group") or not executor.run_as_group assert isinstance(dryrun_info.request, SlurmRayRequest) + + +# --------------------------------------------------------------------------- +# Tests for start-time polling feature +# --------------------------------------------------------------------------- + + +def test_poll_job_start_time_prints_while_pending(slurm_scheduler, mocker): + job_id = "12345" + stop_event = threading.Event() + mock_tunnel = mock.MagicMock() + mock_tunnel.run.return_value.stdout = f"{job_id}|2026-03-14T15:30:00|PENDING\n" + mock_tunnel.run.return_value.return_code = 0 + + mock_print = mocker.patch("builtins.print") + + # Stop after first iteration by setting the event inside wait + def wait_once(timeout=None): + stop_event.set() + return True + + stop_event.wait = wait_once + + slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event) + mock_print.assert_called_once() + printed = mock_print.call_args[0][0] + assert job_id in printed + assert "PENDING" in printed + assert "2026-03-14T15:30:00" in printed + assert "Current time:" in printed + + +def test_poll_job_start_time_stops_when_job_starts(slurm_scheduler, mocker): + job_id = "12345" + stop_event = threading.Event() + mock_tunnel = mock.MagicMock() + mock_tunnel.run.return_value.stdout = f"{job_id}|2026-03-14T15:30:00|RUNNING\n" + mock_tunnel.run.return_value.return_code = 0 + + mocker.patch("builtins.print") + wait_called = [] + original_wait = stop_event.wait + stop_event.wait = lambda t=None: wait_called.append(t) or original_wait(0) + + slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event) + assert len(wait_called) == 0 # returned immediately, no wait + + +def test_poll_job_start_time_stops_when_queue_empty(slurm_scheduler, mocker): + job_id = "12345" + stop_event = threading.Event() + mock_tunnel = mock.MagicMock() + mock_tunnel.run.return_value.stdout = "" + mock_tunnel.run.return_value.return_code = 0 + + mock_print = mocker.patch("builtins.print") + slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event) + + mock_print.assert_called_once() + assert "no longer pending" in mock_print.call_args[0][0] + + +def test_poll_job_start_time_continues_on_exception(slurm_scheduler, mocker): + job_id = "12345" + stop_event = threading.Event() + mock_tunnel = mock.MagicMock() + # First call raises, second call returns empty to stop the loop + second_result = mock.MagicMock() + second_result.stdout = "" + mock_tunnel.run.side_effect = [ + Exception("squeue failed"), + second_result, + ] + + mocker.patch("builtins.print") + # Patch wait so the inter-poll sleep doesn't block the test (edge case #1) + stop_event.wait = mock.MagicMock(return_value=False) + + slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event) + assert mock_tunnel.run.call_count == 2 + stop_event.wait.assert_called_once_with(30) + + +def test_poll_job_start_time_handles_none_stdout(slurm_scheduler, mocker): + job_id = "12345" + stop_event = threading.Event() + mock_tunnel = mock.MagicMock() + mock_tunnel.run.return_value.stdout = None + + mock_print = mocker.patch("builtins.print") + slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event) + + mock_print.assert_called_once() + assert "no longer pending" in mock_print.call_args[0][0] + + +def test_poll_job_start_time_skips_nonzero_return_code(slurm_scheduler, mocker): + job_id = "12345" + stop_event = threading.Event() + mock_tunnel = mock.MagicMock() + mock_tunnel.run.return_value.stdout = "slurm_load_jobs error: Invalid job id specified" + mock_tunnel.run.return_value.return_code = 1 + + mock_print = mocker.patch("builtins.print") + slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event) + + mock_print.assert_called_once() + assert "no longer pending" in mock_print.call_args[0][0] + + +def test_poll_job_start_time_deduplicates_array_job_lines(slurm_scheduler, mocker): + job_id = "12345" + stop_event = threading.Event() + mock_tunnel = mock.MagicMock() + mock_tunnel.run.return_value.stdout = ( + f"{job_id}_1|2026-03-14T15:30:00|PENDING\n{job_id}_2|2026-03-14T15:30:00|PENDING\n" + ) + mock_tunnel.run.return_value.return_code = 0 + + mock_print = mocker.patch("builtins.print") + + def wait_once(timeout=None): + stop_event.set() + return True + + stop_event.wait = wait_once + + slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event) + assert mock_print.call_count == 1 + + +def test_schedule_starts_start_time_polling_thread(slurm_scheduler, mocker): + job_id = "99999" + dryrun_info = mock.MagicMock() + + mock_tunnel = mock.MagicMock() + mock_tunnel.run.return_value.stdout = job_id + slurm_scheduler.tunnel = mock_tunnel + + mocker.patch.object(SlurmTunnelScheduler, "_initialize_tunnel") + mocker.patch("nemo_run.run.torchx_backend.schedulers.slurm._save_job_dir") + + # Block the polling thread so is_alive() is True when we check + started = threading.Event() + + def blocking_poll(poll_job_id, poll_tunnel, stop_event): + started.set() + stop_event.wait() + + mocker.patch.object(SlurmTunnelScheduler, "_poll_job_start_time", side_effect=blocking_poll) + + slurm_scheduler.schedule(dryrun_info) + + started.wait(timeout=2) + assert job_id in slurm_scheduler._start_time_threads + thread = slurm_scheduler._start_time_threads[job_id] + assert thread.daemon + assert thread.is_alive() + assert job_id in slurm_scheduler._start_time_stop_events + + # Cleanup + slurm_scheduler._start_time_stop_events[job_id].set() + + +def test_schedule_stops_existing_thread_on_duplicate_job_id(slurm_scheduler, mocker): + job_id = "99999" + old_ev = threading.Event() + slurm_scheduler._start_time_stop_events[job_id] = old_ev + slurm_scheduler._start_time_threads[job_id] = mock.MagicMock() + + dryrun_info = mock.MagicMock() + mock_tunnel = mock.MagicMock() + mock_tunnel.run.return_value.stdout = job_id + slurm_scheduler.tunnel = mock_tunnel + + mocker.patch.object(SlurmTunnelScheduler, "_initialize_tunnel") + mocker.patch("nemo_run.run.torchx_backend.schedulers.slurm._save_job_dir") + mocker.patch.object(SlurmTunnelScheduler, "_poll_job_start_time") + + slurm_scheduler.schedule(dryrun_info) + + assert old_ev.is_set() + assert slurm_scheduler._start_time_stop_events[job_id] is not old_ev # new event + + # Cleanup + slurm_scheduler._start_time_stop_events[job_id].set() + + +def test_close_stops_all_polling_threads(slurm_scheduler): + ev1, ev2 = threading.Event(), threading.Event() + slurm_scheduler._start_time_stop_events = {"1": ev1, "2": ev2} + slurm_scheduler.close() + assert ev1.is_set() + assert ev2.is_set() + assert slurm_scheduler._start_time_threads == {} + assert slurm_scheduler._start_time_stop_events == {} + + +def test_cancel_stops_polling_thread_for_job(slurm_scheduler, mocker): + job_id = "12345" + ev = threading.Event() + slurm_scheduler._start_time_stop_events[job_id] = ev + slurm_scheduler._start_time_threads[job_id] = mock.MagicMock() + mocker.patch( + "nemo_run.run.torchx_backend.schedulers.slurm._get_job_dirs", + return_value={job_id: ("dir", mock.MagicMock(), "")}, + ) + slurm_scheduler.tunnel = mock.MagicMock() + + slurm_scheduler._cancel_existing(job_id) + + assert ev.is_set() + assert job_id not in slurm_scheduler._start_time_stop_events