Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion nemo_run/run/torchx_backend/schedulers/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import json
import logging
import os
import threading
import time
from dataclasses import asdict
from datetime import datetime
Expand Down Expand Up @@ -78,6 +79,8 @@ def __init__(
super().__init__(session_name)
self.experiment = experiment
self._consecutive_sacct_failures: dict[str, int] = {}
self._start_time_threads: dict[str, threading.Thread] = {}
self._start_time_stop_events: dict[str, threading.Event] = {}

# TODO: Move this into the SlurmExecutor
def _initialize_tunnel(self, tunnel: SSHTunnel | LocalTunnel):
Expand Down Expand Up @@ -190,6 +193,41 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t

return AppDryRunInfo(req, repr)

def _poll_job_start_time(
self, job_id: str, tunnel: Tunnel, stop_event: threading.Event
) -> None:
attempt = 0
while not stop_event.is_set():
try:
result = tunnel.run(
f"squeue --start --noheader -j {job_id} -o '%i|%S|%T'",
warn=True,
hide=True,
)
output = (result.stdout or "").strip()
if output and result.return_code == 0:
# Array jobs produce one line per task — print only the first
line = output.splitlines()[0]
parts = line.strip().split("|")
if len(parts) >= 3:
_, start_time, state = parts[0].strip(), parts[1].strip(), parts[2].strip()
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(
f"[SLURM] Job {job_id} - State: {state}, Estimated start: {start_time}, Current time: {now}",
flush=True,
)
if state.upper() not in ("PENDING", "CF", "CONFIGURING"):
return
else:
print(f"[SLURM] Job {job_id} is no longer pending.", flush=True)
return
except Exception as e:
log.debug(f"Failed to poll start time for job {job_id}: {e}")

delay = min(30 * (2**attempt), 900)
attempt += 1
stop_event.wait(delay)

def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest | SlurmRayRequest]) -> str: # type: ignore
# Setup
req = dryrun_info.request
Expand Down Expand Up @@ -218,6 +256,23 @@ def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest | SlurmRayReques

# Save metadata
_save_job_dir(job_id, job_dir, tunnel, slurm_executor.job_details.ls_term)

# Stop any existing polling thread for this job_id (retry scenario)
if job_id in self._start_time_stop_events:
self._start_time_stop_events.pop(job_id).set()
self._start_time_threads.pop(job_id, None)

stop_event = threading.Event()
self._start_time_stop_events[job_id] = stop_event
thread = threading.Thread(
target=self._poll_job_start_time,
args=(job_id, self.tunnel, stop_event),
daemon=True,
name=f"slurm-start-time-{job_id}",
)
self._start_time_threads[job_id] = thread
thread.start()

return job_id

def _cancel_existing(self, app_id: str) -> None:
Expand All @@ -231,6 +286,10 @@ def _cancel_existing(self, app_id: str) -> None:
assert self.tunnel, "Tunnel is None."
self.tunnel.run(f"scancel {app_id}", hide=False)

if app_id in self._start_time_stop_events:
self._start_time_stop_events.pop(app_id).set()
self._start_time_threads.pop(app_id, None)

def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
try:
job_dirs = _get_job_dirs()
Expand Down Expand Up @@ -366,7 +425,11 @@ def log_iter(
else:
return [f"Failed getting logs for {app_id}"]

def close(self) -> None: ...
def close(self) -> None:
for stop_event in self._start_time_stop_events.values():
stop_event.set()
self._start_time_threads.clear()
self._start_time_stop_events.clear()


class TunnelLogIterator(LogIterator):
Expand Down
216 changes: 216 additions & 0 deletions test/run/torchx_backend/schedulers/test_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
import os
import tempfile
import threading
from unittest import mock

import pytest
Expand Down Expand Up @@ -143,6 +144,7 @@ def test_schedule(slurm_scheduler, slurm_executor):
with (
mock.patch.object(SlurmTunnelScheduler, "_initialize_tunnel"),
mock.patch("nemo_run.run.torchx_backend.schedulers.slurm._save_job_dir"),
mock.patch.object(SlurmTunnelScheduler, "_poll_job_start_time"),
):
# Create a fresh mock tunnel for each test to avoid interference
mock_tunnel = mock.MagicMock()
Expand Down Expand Up @@ -473,6 +475,7 @@ def test_schedule_with_dependencies(slurm_scheduler, slurm_executor):
mock.patch.object(SlurmTunnelScheduler, "_initialize_tunnel"),
mock.patch.object(SlurmExecutor, "parse_deps", return_value=["54321"]),
mock.patch("nemo_run.run.torchx_backend.schedulers.slurm._save_job_dir"),
mock.patch.object(SlurmTunnelScheduler, "_poll_job_start_time"),
):
# Create a fresh mock tunnel for testing
mock_tunnel = mock.MagicMock()
Expand Down Expand Up @@ -726,3 +729,216 @@ def test_non_heterogeneous_ray_cluster(slurm_scheduler, temp_dir):
# Verify run_as_group was NOT set
assert not hasattr(executor, "run_as_group") or not executor.run_as_group
assert isinstance(dryrun_info.request, SlurmRayRequest)


# ---------------------------------------------------------------------------
# Tests for start-time polling feature
# ---------------------------------------------------------------------------


def test_poll_job_start_time_prints_while_pending(slurm_scheduler, mocker):
job_id = "12345"
stop_event = threading.Event()
mock_tunnel = mock.MagicMock()
mock_tunnel.run.return_value.stdout = f"{job_id}|2026-03-14T15:30:00|PENDING\n"
mock_tunnel.run.return_value.return_code = 0

mock_print = mocker.patch("builtins.print")

# Stop after first iteration by setting the event inside wait
def wait_once(timeout=None):
stop_event.set()
return True

stop_event.wait = wait_once

slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event)
mock_print.assert_called_once()
printed = mock_print.call_args[0][0]
assert job_id in printed
assert "PENDING" in printed
assert "2026-03-14T15:30:00" in printed
assert "Current time:" in printed


def test_poll_job_start_time_stops_when_job_starts(slurm_scheduler, mocker):
job_id = "12345"
stop_event = threading.Event()
mock_tunnel = mock.MagicMock()
mock_tunnel.run.return_value.stdout = f"{job_id}|2026-03-14T15:30:00|RUNNING\n"
mock_tunnel.run.return_value.return_code = 0

mocker.patch("builtins.print")
wait_called = []
original_wait = stop_event.wait
stop_event.wait = lambda t=None: wait_called.append(t) or original_wait(0)

slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event)
assert len(wait_called) == 0 # returned immediately, no wait


def test_poll_job_start_time_stops_when_queue_empty(slurm_scheduler, mocker):
job_id = "12345"
stop_event = threading.Event()
mock_tunnel = mock.MagicMock()
mock_tunnel.run.return_value.stdout = ""
mock_tunnel.run.return_value.return_code = 0

mock_print = mocker.patch("builtins.print")
slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event)

mock_print.assert_called_once()
assert "no longer pending" in mock_print.call_args[0][0]


def test_poll_job_start_time_continues_on_exception(slurm_scheduler, mocker):
job_id = "12345"
stop_event = threading.Event()
mock_tunnel = mock.MagicMock()
# First call raises, second call returns empty to stop the loop
second_result = mock.MagicMock()
second_result.stdout = ""
mock_tunnel.run.side_effect = [
Exception("squeue failed"),
second_result,
]

mocker.patch("builtins.print")
# Patch wait so the inter-poll sleep doesn't block the test (edge case #1)
stop_event.wait = mock.MagicMock(return_value=False)

slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event)
assert mock_tunnel.run.call_count == 2
stop_event.wait.assert_called_once_with(30)


def test_poll_job_start_time_handles_none_stdout(slurm_scheduler, mocker):
job_id = "12345"
stop_event = threading.Event()
mock_tunnel = mock.MagicMock()
mock_tunnel.run.return_value.stdout = None

mock_print = mocker.patch("builtins.print")
slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event)

mock_print.assert_called_once()
assert "no longer pending" in mock_print.call_args[0][0]


def test_poll_job_start_time_skips_nonzero_return_code(slurm_scheduler, mocker):
job_id = "12345"
stop_event = threading.Event()
mock_tunnel = mock.MagicMock()
mock_tunnel.run.return_value.stdout = "slurm_load_jobs error: Invalid job id specified"
mock_tunnel.run.return_value.return_code = 1

mock_print = mocker.patch("builtins.print")
slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event)

mock_print.assert_called_once()
assert "no longer pending" in mock_print.call_args[0][0]


def test_poll_job_start_time_deduplicates_array_job_lines(slurm_scheduler, mocker):
job_id = "12345"
stop_event = threading.Event()
mock_tunnel = mock.MagicMock()
mock_tunnel.run.return_value.stdout = (
f"{job_id}_1|2026-03-14T15:30:00|PENDING\n{job_id}_2|2026-03-14T15:30:00|PENDING\n"
)
mock_tunnel.run.return_value.return_code = 0

mock_print = mocker.patch("builtins.print")

def wait_once(timeout=None):
stop_event.set()
return True

stop_event.wait = wait_once

slurm_scheduler._poll_job_start_time(job_id, mock_tunnel, stop_event)
assert mock_print.call_count == 1


def test_schedule_starts_start_time_polling_thread(slurm_scheduler, mocker):
job_id = "99999"
dryrun_info = mock.MagicMock()

mock_tunnel = mock.MagicMock()
mock_tunnel.run.return_value.stdout = job_id
slurm_scheduler.tunnel = mock_tunnel

mocker.patch.object(SlurmTunnelScheduler, "_initialize_tunnel")
mocker.patch("nemo_run.run.torchx_backend.schedulers.slurm._save_job_dir")

# Block the polling thread so is_alive() is True when we check
started = threading.Event()

def blocking_poll(poll_job_id, poll_tunnel, stop_event):
started.set()
stop_event.wait()

mocker.patch.object(SlurmTunnelScheduler, "_poll_job_start_time", side_effect=blocking_poll)

slurm_scheduler.schedule(dryrun_info)

started.wait(timeout=2)
assert job_id in slurm_scheduler._start_time_threads
thread = slurm_scheduler._start_time_threads[job_id]
assert thread.daemon
assert thread.is_alive()
assert job_id in slurm_scheduler._start_time_stop_events

# Cleanup
slurm_scheduler._start_time_stop_events[job_id].set()


def test_schedule_stops_existing_thread_on_duplicate_job_id(slurm_scheduler, mocker):
job_id = "99999"
old_ev = threading.Event()
slurm_scheduler._start_time_stop_events[job_id] = old_ev
slurm_scheduler._start_time_threads[job_id] = mock.MagicMock()

dryrun_info = mock.MagicMock()
mock_tunnel = mock.MagicMock()
mock_tunnel.run.return_value.stdout = job_id
slurm_scheduler.tunnel = mock_tunnel

mocker.patch.object(SlurmTunnelScheduler, "_initialize_tunnel")
mocker.patch("nemo_run.run.torchx_backend.schedulers.slurm._save_job_dir")
mocker.patch.object(SlurmTunnelScheduler, "_poll_job_start_time")

slurm_scheduler.schedule(dryrun_info)

assert old_ev.is_set()
assert slurm_scheduler._start_time_stop_events[job_id] is not old_ev # new event

# Cleanup
slurm_scheduler._start_time_stop_events[job_id].set()


def test_close_stops_all_polling_threads(slurm_scheduler):
ev1, ev2 = threading.Event(), threading.Event()
slurm_scheduler._start_time_stop_events = {"1": ev1, "2": ev2}
slurm_scheduler.close()
assert ev1.is_set()
assert ev2.is_set()
assert slurm_scheduler._start_time_threads == {}
assert slurm_scheduler._start_time_stop_events == {}


def test_cancel_stops_polling_thread_for_job(slurm_scheduler, mocker):
job_id = "12345"
ev = threading.Event()
slurm_scheduler._start_time_stop_events[job_id] = ev
slurm_scheduler._start_time_threads[job_id] = mock.MagicMock()
mocker.patch(
"nemo_run.run.torchx_backend.schedulers.slurm._get_job_dirs",
return_value={job_id: ("dir", mock.MagicMock(), "")},
)
slurm_scheduler.tunnel = mock.MagicMock()

slurm_scheduler._cancel_existing(job_id)

assert ev.is_set()
assert job_id not in slurm_scheduler._start_time_stop_events
Loading