From 8fe04c83f7bb7cd35ac3aa2c64d3ee66e2f996be Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 10 Mar 2026 16:08:28 +0500 Subject: [PATCH 1/8] Extract running job helper functions --- AGENTS.md | 1 + .../scheduled_tasks/running_jobs.py | 77 +++++++++++-------- 2 files changed, 48 insertions(+), 30 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index bb1a7aac0..84b778a61 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -18,6 +18,7 @@ - Python targets 3.9+ with 4-space indentation and max line length of 99 (see `ruff.toml`; `E501` is ignored but keep lines readable). - Imports are sorted via Ruff’s isort settings (`dstack` treated as first-party). - Keep primary/public functions before local helper functions in a module section. +- Roughly keep function definitions in the order they are referenced within a file so call flow stays easy to follow. - Keep private classes, exceptions, and similar implementation-specific types close to the private functions that use them unless they are shared more broadly in the module. - Prefer pydantic-style models in `core/models`. - Document attributes when the note adds behavior, compatibility, or semantic context that is not obvious from the name and type. Use attribute docstrings without leading newline. diff --git a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py index 5916c9054..f26f9b72a 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py @@ -150,13 +150,12 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): job_provisioning_data = job_submission.job_provisioning_data if job_provisioning_data is None: logger.error("%s: job_provisioning_data of an active job is None", fmt(job_model)) - job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER - job_model.termination_reason_message = ( - "Unexpected server error: job_provisioning_data of an active job is None" + await _terminate_running_job( + session=session, + job_model=job_model, + termination_reason=JobTerminationReason.TERMINATED_BY_SERVER, + termination_reason_message="Unexpected server error: job_provisioning_data of an active job is None", ) - switch_job_status(session, job_model, JobStatus.TERMINATING) - job_model.last_processed_at = common_utils.get_current_datetime() - await session.commit() return job = find_job(run.jobs, job_model.replica_num, job_model.job_num) @@ -177,8 +176,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): "%s: waiting for all jobs in the replica to be provisioned", fmt(job_model), ) - job_model.last_processed_at = common_utils.get_current_datetime() - await session.commit() + await _mark_job_processed(session=session, job_model=job_model) return cluster_info = _get_cluster_info( @@ -205,11 +203,12 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): try: _interpolate_secrets(secrets, job.job_spec) except InterpolatorError as e: - job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER - job_model.termination_reason_message = f"Secrets interpolation error: {e.args[0]}" - switch_job_status(session, job_model, JobStatus.TERMINATING) - job_model.last_processed_at = common_utils.get_current_datetime() - await session.commit() + await _terminate_running_job( + session=session, + job_model=job_model, + termination_reason=JobTerminationReason.TERMINATED_BY_SERVER, + termination_reason_message=f"Secrets interpolation error: {e.args[0]}", + ) return server_ssh_private_keys = get_instance_ssh_private_keys( @@ -219,12 +218,10 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): if initial_status == JobStatus.PROVISIONING: if job_provisioning_data.hostname is None: await _wait_for_instance_provisioning_data(session, job_model) - job_model.last_processed_at = common_utils.get_current_datetime() - await session.commit() + await _mark_job_processed(session=session, job_model=job_model) return if _should_wait_for_other_nodes(run, job, job_model): - job_model.last_processed_at = common_utils.get_current_datetime() - await session.commit() + await _mark_job_processed(session=session, job_model=job_model) return # fails are acceptable until timeout is exceeded @@ -388,25 +385,14 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): ) if initial_status != job_model.status and job_model.status == JobStatus.RUNNING: - job_model.probes = [] - for probe_num in range(len(job.job_spec.probes)): - job_model.probes.append( - ProbeModel( - name=f"{job_model.job_name}-{probe_num}", - probe_num=probe_num, - due=common_utils.get_current_datetime(), - success_streak=0, - active=True, - ) - ) + _initialize_running_job_probes(job_model=job_model, job=job) if job_model.status == JobStatus.RUNNING: await _maybe_register_replica(session, run_model, run, job_model, job.job_spec.probes) if job_model.status == JobStatus.RUNNING: await _check_gpu_utilization(session, job_model, job) - job_model.last_processed_at = common_utils.get_current_datetime() - await session.commit() + await _mark_job_processed(session=session, job_model=job_model) async def _refetch_job_model(session: AsyncSession, job_model: JobModel) -> JobModel: @@ -456,6 +442,37 @@ async def _fetch_run_model(session: AsyncSession, run_id: uuid.UUID) -> RunModel return res.unique().scalar_one() +async def _mark_job_processed(session: AsyncSession, job_model: JobModel) -> None: + job_model.last_processed_at = common_utils.get_current_datetime() + await session.commit() + + +async def _terminate_running_job( + session: AsyncSession, + job_model: JobModel, + termination_reason: JobTerminationReason, + termination_reason_message: str, +) -> None: + job_model.termination_reason = termination_reason + job_model.termination_reason_message = termination_reason_message + switch_job_status(session, job_model, JobStatus.TERMINATING) + await _mark_job_processed(session=session, job_model=job_model) + + +def _initialize_running_job_probes(job_model: JobModel, job: Job) -> None: + job_model.probes = [] + for probe_num in range(len(job.job_spec.probes)): + job_model.probes.append( + ProbeModel( + name=f"{job_model.job_name}-{probe_num}", + probe_num=probe_num, + due=common_utils.get_current_datetime(), + success_streak=0, + active=True, + ) + ) + + async def _wait_for_instance_provisioning_data(session: AsyncSession, job_model: JobModel): """ This function will be called until instance IP address appears From e4dc8b785f3d64a793760e660b73c3c85cf2137a Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 10 Mar 2026 16:38:55 +0500 Subject: [PATCH 2/8] Split running job state handlers --- AGENTS.md | 1 + .../scheduled_tasks/running_jobs.py | 446 +++++++++++------- 2 files changed, 280 insertions(+), 167 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 84b778a61..3e1f85eaf 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -19,6 +19,7 @@ - Imports are sorted via Ruff’s isort settings (`dstack` treated as first-party). - Keep primary/public functions before local helper functions in a module section. - Roughly keep function definitions in the order they are referenced within a file so call flow stays easy to follow. +- Prefer early returns over nested `if`/`else` blocks when they make the control flow simpler. - Keep private classes, exceptions, and similar implementation-specific types close to the private functions that use them unless they are shared more broadly in the module. - Prefer pydantic-style models in `core/models`. - Document attributes when the note adds behavior, compatibility, or semantic context that is not obvious from the name and type. Use attribute docstrings without leading newline. diff --git a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py index f26f9b72a..2a5f5dc90 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py @@ -30,6 +30,7 @@ JobRuntimeData, JobSpec, JobStatus, + JobSubmission, JobTerminationReason, ProbeSpec, Run, @@ -216,173 +217,47 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): ) if initial_status == JobStatus.PROVISIONING: - if job_provisioning_data.hostname is None: - await _wait_for_instance_provisioning_data(session, job_model) - await _mark_job_processed(session=session, job_model=job_model) - return - if _should_wait_for_other_nodes(run, job, job_model): - await _mark_job_processed(session=session, job_model=job_model) - return - - # fails are acceptable until timeout is exceeded - if job_provisioning_data.dockerized: - logger.debug( - "%s: process provisioning job with shim, age=%s", - fmt(job_model), - job_submission.age, - ) - ssh_user = job_provisioning_data.username - assert run.run_spec.ssh_key_pub is not None - user_ssh_key = run.run_spec.ssh_key_pub.strip() - public_keys = [project.ssh_public_key.strip(), user_ssh_key] - if job_provisioning_data.backend == BackendType.LOCAL: - # No need to update ~/.ssh/authorized_keys when running shim locally - user_ssh_key = "" - success = await common_utils.run_async( - _process_provisioning_with_shim, - server_ssh_private_keys, - job_provisioning_data, - None, - session, - run, - job_model, - job_provisioning_data, - volumes, - job.job_spec.registry_auth, - public_keys, - ssh_user, - user_ssh_key, - ) - else: - assert cluster_info is not None - logger.debug( - "%s: process provisioning job without shim, age=%s", - fmt(job_model), - job_submission.age, - ) - # FIXME: downloading file archives and code here is a waste of time if - # the runner is not ready yet - file_archives = await _get_job_file_archives( - session=session, - archive_mappings=job.job_spec.file_archives, - user=run_model.user, - ) - code = await _get_job_code( - session=session, - project=project, - repo=repo_model, - code_hash=_get_repo_code_hash(run, job), - ) - success = await common_utils.run_async( - _submit_job_to_runner, - server_ssh_private_keys, - job_provisioning_data, - None, - session, - run, - job_model, - job, - cluster_info, - code, - file_archives, - secrets, - repo_creds, - success_if_not_available=False, - ) - - if not success: - # check timeout - provisioning_timeout = get_provisioning_timeout( - backend_type=job_provisioning_data.get_base_backend(), - instance_type_name=job_provisioning_data.instance_type.name, - ) - if job_submission.age > provisioning_timeout: - job_model.termination_reason = JobTerminationReason.WAITING_RUNNER_LIMIT_EXCEEDED - job_model.termination_reason_message = ( - f"Runner did not become available within {provisioning_timeout.total_seconds()}s." - f" Job submission age: {job_submission.age.total_seconds()}s)" - ) - switch_job_status(session, job_model, JobStatus.TERMINATING) - # instance will be emptied by process_terminating_jobs - - else: # fails are not acceptable - if initial_status == JobStatus.PULLING: - assert cluster_info is not None - logger.debug( - "%s: process pulling job with shim, age=%s", fmt(job_model), job_submission.age - ) - # FIXME: downloading file archives and code here is a waste of time if - # the runner is not ready yet - file_archives = await _get_job_file_archives( - session=session, - archive_mappings=job.job_spec.file_archives, - user=run_model.user, - ) - code = await _get_job_code( - session=session, - project=project, - repo=repo_model, - code_hash=_get_repo_code_hash(run, job), - ) - success = await common_utils.run_async( - _process_pulling_with_shim, - server_ssh_private_keys, - job_provisioning_data, - None, - session, - run, - job_model, - job, - cluster_info, - code, - file_archives, - secrets, - repo_creds, - server_ssh_private_keys, - job_provisioning_data, - ) - else: - logger.debug("%s: process running job, age=%s", fmt(job_model), job_submission.age) - success = await common_utils.run_async( - _process_running, - server_ssh_private_keys, - job_provisioning_data, - job_submission.job_runtime_data, - session, - run_model, - job_model, - ) - - if success: - _reset_disconnected_at(session, job_model) - else: - if job_model.termination_reason: - logger.warning( - "%s: failed due to %s, age=%s", - fmt(job_model), - job_model.termination_reason.value, - job_submission.age, - ) - switch_job_status(session, job_model, JobStatus.TERMINATING) - # job will be terminated and instance will be emptied by process_terminating_jobs - else: - # No job_model.termination_reason set means ssh connection failed - _set_disconnected_at_now(session, job_model) - if _should_terminate_job_due_to_disconnect(job_model): - if job_provisioning_data.instance_type.resources.spot: - job_model.termination_reason = ( - JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY - ) - else: - job_model.termination_reason = JobTerminationReason.INSTANCE_UNREACHABLE - job_model.termination_reason_message = "Instance is unreachable" - switch_job_status(session, job_model, JobStatus.TERMINATING) - else: - logger.warning( - "%s: is unreachable, waiting for the instance to become reachable again, age=%s", - fmt(job_model), - job_submission.age, - ) + await _process_running_job_provisioning_state( + session=session, + run=run, + run_model=run_model, + repo_model=repo_model, + project=project, + job_model=job_model, + job=job, + job_submission=job_submission, + job_provisioning_data=job_provisioning_data, + volumes=volumes, + cluster_info=cluster_info, + secrets=secrets, + repo_creds=repo_creds, + server_ssh_private_keys=server_ssh_private_keys, + ) + elif initial_status == JobStatus.PULLING: + await _process_running_job_pulling_state( + session=session, + run=run, + run_model=run_model, + repo_model=repo_model, + project=project, + job_model=job_model, + job=job, + job_submission=job_submission, + job_provisioning_data=job_provisioning_data, + cluster_info=cluster_info, + secrets=secrets, + repo_creds=repo_creds, + server_ssh_private_keys=server_ssh_private_keys, + ) + else: + await _process_running_job_running_state( + session=session, + run_model=run_model, + job_model=job_model, + job_submission=job_submission, + job_provisioning_data=job_provisioning_data, + server_ssh_private_keys=server_ssh_private_keys, + ) if initial_status != job_model.status and job_model.status == JobStatus.RUNNING: _initialize_running_job_probes(job_model=job_model, job=job) @@ -442,6 +317,243 @@ async def _fetch_run_model(session: AsyncSession, run_id: uuid.UUID) -> RunModel return res.unique().scalar_one() +async def _process_running_job_provisioning_state( + session: AsyncSession, + run: Run, + run_model: RunModel, + repo_model: RepoModel, + project: ProjectModel, + job_model: JobModel, + job: Job, + job_submission: JobSubmission, + job_provisioning_data: JobProvisioningData, + volumes: list[Volume], + cluster_info: Optional[ClusterInfo], + secrets: dict[str, str], + repo_creds: Optional[RemoteRepoCreds], + server_ssh_private_keys: tuple[str, Optional[str]], +) -> None: + if job_provisioning_data.hostname is None: + await _wait_for_instance_provisioning_data(session, job_model) + return + if _should_wait_for_other_nodes(run, job, job_model): + return + + # fails are acceptable until timeout is exceeded + if job_provisioning_data.dockerized: + logger.debug( + "%s: process provisioning job with shim, age=%s", + fmt(job_model), + job_submission.age, + ) + ssh_user = job_provisioning_data.username + assert run.run_spec.ssh_key_pub is not None + user_ssh_key = run.run_spec.ssh_key_pub.strip() + public_keys = [project.ssh_public_key.strip(), user_ssh_key] + if job_provisioning_data.backend == BackendType.LOCAL: + # No need to update ~/.ssh/authorized_keys when running shim locally + user_ssh_key = "" + success = await common_utils.run_async( + _process_provisioning_with_shim, + server_ssh_private_keys, + job_provisioning_data, + None, + session, + run, + job_model, + job_provisioning_data, + volumes, + job.job_spec.registry_auth, + public_keys, + ssh_user, + user_ssh_key, + ) + else: + assert cluster_info is not None + logger.debug( + "%s: process provisioning job without shim, age=%s", + fmt(job_model), + job_submission.age, + ) + # FIXME: downloading file archives and code here is a waste of time if + # the runner is not ready yet + file_archives = await _get_job_file_archives( + session=session, + archive_mappings=job.job_spec.file_archives, + user=run_model.user, + ) + code = await _get_job_code( + session=session, + project=project, + repo=repo_model, + code_hash=_get_repo_code_hash(run, job), + ) + success = await common_utils.run_async( + _submit_job_to_runner, + server_ssh_private_keys, + job_provisioning_data, + None, + session, + run, + job_model, + job, + cluster_info, + code, + file_archives, + secrets, + repo_creds, + success_if_not_available=False, + ) + + if success: + return + + # check timeout + provisioning_timeout = get_provisioning_timeout( + backend_type=job_provisioning_data.get_base_backend(), + instance_type_name=job_provisioning_data.instance_type.name, + ) + if job_submission.age > provisioning_timeout: + job_model.termination_reason = JobTerminationReason.WAITING_RUNNER_LIMIT_EXCEEDED + job_model.termination_reason_message = ( + f"Runner did not become available within {provisioning_timeout.total_seconds()}s." + f" Job submission age: {job_submission.age.total_seconds()}s)" + ) + switch_job_status(session, job_model, JobStatus.TERMINATING) + # instance will be emptied by process_terminating_jobs + + +async def _process_running_job_pulling_state( + session: AsyncSession, + run: Run, + run_model: RunModel, + repo_model: RepoModel, + project: ProjectModel, + job_model: JobModel, + job: Job, + job_submission: JobSubmission, + job_provisioning_data: JobProvisioningData, + cluster_info: Optional[ClusterInfo], + secrets: dict[str, str], + repo_creds: Optional[RemoteRepoCreds], + server_ssh_private_keys: tuple[str, Optional[str]], +) -> None: + assert cluster_info is not None + logger.debug("%s: process pulling job with shim, age=%s", fmt(job_model), job_submission.age) + # FIXME: downloading file archives and code here is a waste of time if + # the runner is not ready yet + file_archives = await _get_job_file_archives( + session=session, + archive_mappings=job.job_spec.file_archives, + user=run_model.user, + ) + code = await _get_job_code( + session=session, + project=project, + repo=repo_model, + code_hash=_get_repo_code_hash(run, job), + ) + success = await common_utils.run_async( + _process_pulling_with_shim, + server_ssh_private_keys, + job_provisioning_data, + None, + session, + run, + job_model, + job, + cluster_info, + code, + file_archives, + secrets, + repo_creds, + server_ssh_private_keys, + job_provisioning_data, + ) + + if success: + _reset_disconnected_at(session, job_model) + return + + if job_model.termination_reason: + logger.warning( + "%s: failed due to %s, age=%s", + fmt(job_model), + job_model.termination_reason.value, + job_submission.age, + ) + switch_job_status(session, job_model, JobStatus.TERMINATING) + # job will be terminated and instance will be emptied by process_terminating_jobs + return + + # No job_model.termination_reason set means ssh connection failed + _set_disconnected_at_now(session, job_model) + if not _should_terminate_job_due_to_disconnect(job_model): + logger.warning( + "%s: is unreachable, waiting for the instance to become reachable again, age=%s", + fmt(job_model), + job_submission.age, + ) + return + + if job_provisioning_data.instance_type.resources.spot: + job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY + else: + job_model.termination_reason = JobTerminationReason.INSTANCE_UNREACHABLE + job_model.termination_reason_message = "Instance is unreachable" + switch_job_status(session, job_model, JobStatus.TERMINATING) + + +async def _process_running_job_running_state( + session: AsyncSession, + run_model: RunModel, + job_model: JobModel, + job_submission: JobSubmission, + job_provisioning_data: JobProvisioningData, + server_ssh_private_keys: tuple[str, Optional[str]], +) -> None: + logger.debug("%s: process running job, age=%s", fmt(job_model), job_submission.age) + success = await common_utils.run_async( + _process_running, + server_ssh_private_keys, + job_provisioning_data, + job_submission.job_runtime_data, + session, + run_model, + job_model, + ) + + if success: + _reset_disconnected_at(session, job_model) + return + + if job_model.termination_reason: + logger.warning( + "%s: failed due to %s, age=%s", + fmt(job_model), + job_model.termination_reason.value, + job_submission.age, + ) + switch_job_status(session, job_model, JobStatus.TERMINATING) + # job will be terminated and instance will be emptied by process_terminating_jobs + return + # No job_model.termination_reason set means ssh connection failed + _set_disconnected_at_now(session, job_model) + if not _should_terminate_job_due_to_disconnect(job_model): + logger.warning( + "%s: is unreachable, waiting for the instance to become reachable again, age=%s", + fmt(job_model), + job_submission.age, + ) + return + if job_provisioning_data.instance_type.resources.spot: + job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY + else: + job_model.termination_reason = JobTerminationReason.INSTANCE_UNREACHABLE + job_model.termination_reason_message = "Instance is unreachable" + switch_job_status(session, job_model, JobStatus.TERMINATING) + + async def _mark_job_processed(session: AsyncSession, job_model: JobModel) -> None: job_model.last_processed_at = common_utils.get_current_datetime() await session.commit() From d92078c2d7f6d4dc80b4044e0b084c2c423ffa3c Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 10 Mar 2026 17:13:26 +0500 Subject: [PATCH 3/8] Introduce running job context --- .../scheduled_tasks/running_jobs.py | 318 +++++++++--------- 1 file changed, 167 insertions(+), 151 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py index 2a5f5dc90..83a1e3fc7 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py @@ -2,6 +2,7 @@ import re import uuid from collections.abc import Iterable +from dataclasses import dataclass from datetime import timedelta from typing import Dict, List, Optional @@ -141,133 +142,151 @@ async def _process_next_running_job(): lockset.difference_update([job_model_id]) +@dataclass +class _RunningJobContext: + job_model: JobModel + run_model: RunModel + repo_model: RepoModel + project: ProjectModel + run: Run + job: Job + job_submission: JobSubmission + job_provisioning_data: Optional[JobProvisioningData] + initial_status: JobStatus + server_ssh_private_keys: Optional[tuple[str, Optional[str]]] = None + + async def _process_running_job(session: AsyncSession, job_model: JobModel): - job_model = await _refetch_job_model(session, job_model) - run_model = await _fetch_run_model(session, job_model.run_id) - repo_model = run_model.repo - project = run_model.project - run = run_model_to_run(run_model, include_sensitive=True) - job_submission = job_model_to_job_submission(job_model) - job_provisioning_data = job_submission.job_provisioning_data - if job_provisioning_data is None: - logger.error("%s: job_provisioning_data of an active job is None", fmt(job_model)) + context = await _load_running_job_context(session=session, job_model=job_model) + if context.job_provisioning_data is None: + logger.error("%s: job_provisioning_data of an active job is None", fmt(context.job_model)) await _terminate_running_job( session=session, - job_model=job_model, + job_model=context.job_model, termination_reason=JobTerminationReason.TERMINATED_BY_SERVER, termination_reason_message="Unexpected server error: job_provisioning_data of an active job is None", ) return - job = find_job(run.jobs, job_model.replica_num, job_model.job_num) + job_provisioning_data = common_utils.get_or_error(context.job_provisioning_data) volumes = [] secrets = {} cluster_info = None repo_creds = None - initial_status = job_model.status - if initial_status in [JobStatus.PROVISIONING, JobStatus.PULLING]: - for other_job in run.jobs: + if context.initial_status in [JobStatus.PROVISIONING, JobStatus.PULLING]: + for other_job in context.run.jobs: if ( - other_job.job_spec.replica_num == job.job_spec.replica_num + other_job.job_spec.replica_num == context.job.job_spec.replica_num and other_job.job_submissions[-1].status == JobStatus.SUBMITTED ): logger.debug( "%s: waiting for all jobs in the replica to be provisioned", - fmt(job_model), + fmt(context.job_model), ) - await _mark_job_processed(session=session, job_model=job_model) + await _mark_job_processed(session=session, job_model=context.job_model) return cluster_info = _get_cluster_info( - jobs=run.jobs, - replica_num=job.job_spec.replica_num, + jobs=context.run.jobs, + replica_num=context.job.job_spec.replica_num, job_provisioning_data=job_provisioning_data, - job_runtime_data=job_submission.job_runtime_data, + job_runtime_data=context.job_submission.job_runtime_data, ) volumes = await get_job_attached_volumes( session=session, - project=project, - run_spec=run.run_spec, - job_num=job.job_spec.job_num, + project=context.project, + run_spec=context.run.run_spec, + job_num=context.job.job_spec.job_num, job_provisioning_data=job_provisioning_data, ) repo_creds_model = await get_repo_creds( - session=session, repo=repo_model, user=run_model.user + session=session, repo=context.repo_model, user=context.run_model.user ) - repo_creds = repo_model_to_repo_head_with_creds(repo_model, repo_creds_model).repo_creds + repo_creds = repo_model_to_repo_head_with_creds( + context.repo_model, repo_creds_model + ).repo_creds - secrets = await get_project_secrets_mapping(session=session, project=project) + secrets = await get_project_secrets_mapping(session=session, project=context.project) try: - _interpolate_secrets(secrets, job.job_spec) + _interpolate_secrets(secrets, context.job.job_spec) except InterpolatorError as e: await _terminate_running_job( session=session, - job_model=job_model, + job_model=context.job_model, termination_reason=JobTerminationReason.TERMINATED_BY_SERVER, termination_reason_message=f"Secrets interpolation error: {e.args[0]}", ) return - server_ssh_private_keys = get_instance_ssh_private_keys( - common_utils.get_or_error(job_model.instance) + context.server_ssh_private_keys = get_instance_ssh_private_keys( + common_utils.get_or_error(context.job_model.instance) ) - if initial_status == JobStatus.PROVISIONING: + if context.initial_status == JobStatus.PROVISIONING: await _process_running_job_provisioning_state( session=session, - run=run, - run_model=run_model, - repo_model=repo_model, - project=project, - job_model=job_model, - job=job, - job_submission=job_submission, - job_provisioning_data=job_provisioning_data, + context=context, volumes=volumes, cluster_info=cluster_info, secrets=secrets, repo_creds=repo_creds, - server_ssh_private_keys=server_ssh_private_keys, ) - elif initial_status == JobStatus.PULLING: + elif context.initial_status == JobStatus.PULLING: await _process_running_job_pulling_state( session=session, - run=run, - run_model=run_model, - repo_model=repo_model, - project=project, - job_model=job_model, - job=job, - job_submission=job_submission, - job_provisioning_data=job_provisioning_data, + context=context, cluster_info=cluster_info, secrets=secrets, repo_creds=repo_creds, - server_ssh_private_keys=server_ssh_private_keys, ) else: await _process_running_job_running_state( session=session, - run_model=run_model, - job_model=job_model, - job_submission=job_submission, - job_provisioning_data=job_provisioning_data, - server_ssh_private_keys=server_ssh_private_keys, + context=context, ) - if initial_status != job_model.status and job_model.status == JobStatus.RUNNING: - _initialize_running_job_probes(job_model=job_model, job=job) + if ( + context.initial_status != context.job_model.status + and context.job_model.status == JobStatus.RUNNING + ): + _initialize_running_job_probes(job_model=context.job_model, job=context.job) - if job_model.status == JobStatus.RUNNING: - await _maybe_register_replica(session, run_model, run, job_model, job.job_spec.probes) - if job_model.status == JobStatus.RUNNING: - await _check_gpu_utilization(session, job_model, job) + if context.job_model.status == JobStatus.RUNNING: + await _maybe_register_replica( + session, + context.run_model, + context.run, + context.job_model, + context.job.job_spec.probes, + ) + if context.job_model.status == JobStatus.RUNNING: + await _check_gpu_utilization(session, context.job_model, context.job) - await _mark_job_processed(session=session, job_model=job_model) + await _mark_job_processed(session=session, job_model=context.job_model) + + +async def _load_running_job_context( + session: AsyncSession, job_model: JobModel +) -> _RunningJobContext: + job_model = await _refetch_job_model(session, job_model) + run_model = await _fetch_run_model(session, job_model.run_id) + run = run_model_to_run(run_model, include_sensitive=True) + job_submission = job_model_to_job_submission(job_model) + return _RunningJobContext( + job_model=job_model, + run_model=run_model, + repo_model=run_model.repo, + project=run_model.project, + run=run, + job=find_job(run.jobs, job_model.replica_num, job_model.job_num), + job_submission=job_submission, + job_provisioning_data=job_submission.job_provisioning_data, + initial_status=job_model.status, + ) async def _refetch_job_model(session: AsyncSession, job_model: JobModel) -> JobModel: @@ -319,37 +338,32 @@ async def _fetch_run_model(session: AsyncSession, run_id: uuid.UUID) -> RunModel async def _process_running_job_provisioning_state( session: AsyncSession, - run: Run, - run_model: RunModel, - repo_model: RepoModel, - project: ProjectModel, - job_model: JobModel, - job: Job, - job_submission: JobSubmission, - job_provisioning_data: JobProvisioningData, + context: _RunningJobContext, volumes: list[Volume], cluster_info: Optional[ClusterInfo], secrets: dict[str, str], repo_creds: Optional[RemoteRepoCreds], - server_ssh_private_keys: tuple[str, Optional[str]], ) -> None: + job_provisioning_data = common_utils.get_or_error(context.job_provisioning_data) + server_ssh_private_keys = common_utils.get_or_error(context.server_ssh_private_keys) + if job_provisioning_data.hostname is None: - await _wait_for_instance_provisioning_data(session, job_model) + await _wait_for_instance_provisioning_data(session, context.job_model) return - if _should_wait_for_other_nodes(run, job, job_model): + if _should_wait_for_other_nodes(context.run, context.job, context.job_model): return # fails are acceptable until timeout is exceeded if job_provisioning_data.dockerized: logger.debug( "%s: process provisioning job with shim, age=%s", - fmt(job_model), - job_submission.age, + fmt(context.job_model), + context.job_submission.age, ) ssh_user = job_provisioning_data.username - assert run.run_spec.ssh_key_pub is not None - user_ssh_key = run.run_spec.ssh_key_pub.strip() - public_keys = [project.ssh_public_key.strip(), user_ssh_key] + assert context.run.run_spec.ssh_key_pub is not None + user_ssh_key = context.run.run_spec.ssh_key_pub.strip() + public_keys = [context.project.ssh_public_key.strip(), user_ssh_key] if job_provisioning_data.backend == BackendType.LOCAL: # No need to update ~/.ssh/authorized_keys when running shim locally user_ssh_key = "" @@ -359,11 +373,11 @@ async def _process_running_job_provisioning_state( job_provisioning_data, None, session, - run, - job_model, + context.run, + context.job_model, job_provisioning_data, volumes, - job.job_spec.registry_auth, + context.job.job_spec.registry_auth, public_keys, ssh_user, user_ssh_key, @@ -372,21 +386,21 @@ async def _process_running_job_provisioning_state( assert cluster_info is not None logger.debug( "%s: process provisioning job without shim, age=%s", - fmt(job_model), - job_submission.age, + fmt(context.job_model), + context.job_submission.age, ) # FIXME: downloading file archives and code here is a waste of time if # the runner is not ready yet file_archives = await _get_job_file_archives( session=session, - archive_mappings=job.job_spec.file_archives, - user=run_model.user, + archive_mappings=context.job.job_spec.file_archives, + user=context.run_model.user, ) code = await _get_job_code( session=session, - project=project, - repo=repo_model, - code_hash=_get_repo_code_hash(run, job), + project=context.project, + repo=context.repo_model, + code_hash=_get_repo_code_hash(context.run, context.job), ) success = await common_utils.run_async( _submit_job_to_runner, @@ -394,9 +408,9 @@ async def _process_running_job_provisioning_state( job_provisioning_data, None, session, - run, - job_model, - job, + context.run, + context.job_model, + context.job, cluster_info, code, file_archives, @@ -413,45 +427,44 @@ async def _process_running_job_provisioning_state( backend_type=job_provisioning_data.get_base_backend(), instance_type_name=job_provisioning_data.instance_type.name, ) - if job_submission.age > provisioning_timeout: - job_model.termination_reason = JobTerminationReason.WAITING_RUNNER_LIMIT_EXCEEDED - job_model.termination_reason_message = ( + if context.job_submission.age > provisioning_timeout: + context.job_model.termination_reason = JobTerminationReason.WAITING_RUNNER_LIMIT_EXCEEDED + context.job_model.termination_reason_message = ( f"Runner did not become available within {provisioning_timeout.total_seconds()}s." - f" Job submission age: {job_submission.age.total_seconds()}s)" + f" Job submission age: {context.job_submission.age.total_seconds()}s)" ) - switch_job_status(session, job_model, JobStatus.TERMINATING) + switch_job_status(session, context.job_model, JobStatus.TERMINATING) # instance will be emptied by process_terminating_jobs async def _process_running_job_pulling_state( session: AsyncSession, - run: Run, - run_model: RunModel, - repo_model: RepoModel, - project: ProjectModel, - job_model: JobModel, - job: Job, - job_submission: JobSubmission, - job_provisioning_data: JobProvisioningData, + context: _RunningJobContext, cluster_info: Optional[ClusterInfo], secrets: dict[str, str], repo_creds: Optional[RemoteRepoCreds], - server_ssh_private_keys: tuple[str, Optional[str]], ) -> None: + job_provisioning_data = common_utils.get_or_error(context.job_provisioning_data) + server_ssh_private_keys = common_utils.get_or_error(context.server_ssh_private_keys) + assert cluster_info is not None - logger.debug("%s: process pulling job with shim, age=%s", fmt(job_model), job_submission.age) + logger.debug( + "%s: process pulling job with shim, age=%s", + fmt(context.job_model), + context.job_submission.age, + ) # FIXME: downloading file archives and code here is a waste of time if # the runner is not ready yet file_archives = await _get_job_file_archives( session=session, - archive_mappings=job.job_spec.file_archives, - user=run_model.user, + archive_mappings=context.job.job_spec.file_archives, + user=context.run_model.user, ) code = await _get_job_code( session=session, - project=project, - repo=repo_model, - code_hash=_get_repo_code_hash(run, job), + project=context.project, + repo=context.repo_model, + code_hash=_get_repo_code_hash(context.run, context.job), ) success = await common_utils.run_async( _process_pulling_with_shim, @@ -459,9 +472,9 @@ async def _process_running_job_pulling_state( job_provisioning_data, None, session, - run, - job_model, - job, + context.run, + context.job_model, + context.job, cluster_info, code, file_archives, @@ -472,86 +485,89 @@ async def _process_running_job_pulling_state( ) if success: - _reset_disconnected_at(session, job_model) + _reset_disconnected_at(session, context.job_model) return - if job_model.termination_reason: + if context.job_model.termination_reason: logger.warning( "%s: failed due to %s, age=%s", - fmt(job_model), - job_model.termination_reason.value, - job_submission.age, + fmt(context.job_model), + context.job_model.termination_reason.value, + context.job_submission.age, ) - switch_job_status(session, job_model, JobStatus.TERMINATING) + switch_job_status(session, context.job_model, JobStatus.TERMINATING) # job will be terminated and instance will be emptied by process_terminating_jobs return # No job_model.termination_reason set means ssh connection failed - _set_disconnected_at_now(session, job_model) - if not _should_terminate_job_due_to_disconnect(job_model): + _set_disconnected_at_now(session, context.job_model) + if not _should_terminate_job_due_to_disconnect(context.job_model): logger.warning( "%s: is unreachable, waiting for the instance to become reachable again, age=%s", - fmt(job_model), - job_submission.age, + fmt(context.job_model), + context.job_submission.age, ) return if job_provisioning_data.instance_type.resources.spot: - job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY + context.job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY else: - job_model.termination_reason = JobTerminationReason.INSTANCE_UNREACHABLE - job_model.termination_reason_message = "Instance is unreachable" - switch_job_status(session, job_model, JobStatus.TERMINATING) + context.job_model.termination_reason = JobTerminationReason.INSTANCE_UNREACHABLE + context.job_model.termination_reason_message = "Instance is unreachable" + switch_job_status(session, context.job_model, JobStatus.TERMINATING) async def _process_running_job_running_state( session: AsyncSession, - run_model: RunModel, - job_model: JobModel, - job_submission: JobSubmission, - job_provisioning_data: JobProvisioningData, - server_ssh_private_keys: tuple[str, Optional[str]], + context: _RunningJobContext, ) -> None: - logger.debug("%s: process running job, age=%s", fmt(job_model), job_submission.age) + job_provisioning_data = common_utils.get_or_error(context.job_provisioning_data) + server_ssh_private_keys = common_utils.get_or_error(context.server_ssh_private_keys) + + logger.debug( + "%s: process running job, age=%s", + fmt(context.job_model), + context.job_submission.age, + ) success = await common_utils.run_async( _process_running, server_ssh_private_keys, job_provisioning_data, - job_submission.job_runtime_data, + context.job_submission.job_runtime_data, session, - run_model, - job_model, + context.run_model, + context.job_model, ) if success: - _reset_disconnected_at(session, job_model) + _reset_disconnected_at(session, context.job_model) return - if job_model.termination_reason: + if context.job_model.termination_reason: logger.warning( "%s: failed due to %s, age=%s", - fmt(job_model), - job_model.termination_reason.value, - job_submission.age, + fmt(context.job_model), + context.job_model.termination_reason.value, + context.job_submission.age, ) - switch_job_status(session, job_model, JobStatus.TERMINATING) + switch_job_status(session, context.job_model, JobStatus.TERMINATING) # job will be terminated and instance will be emptied by process_terminating_jobs return # No job_model.termination_reason set means ssh connection failed - _set_disconnected_at_now(session, job_model) - if not _should_terminate_job_due_to_disconnect(job_model): + _set_disconnected_at_now(session, context.job_model) + if not _should_terminate_job_due_to_disconnect(context.job_model): logger.warning( "%s: is unreachable, waiting for the instance to become reachable again, age=%s", - fmt(job_model), - job_submission.age, + fmt(context.job_model), + context.job_submission.age, ) return if job_provisioning_data.instance_type.resources.spot: - job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY + context.job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY else: - job_model.termination_reason = JobTerminationReason.INSTANCE_UNREACHABLE - job_model.termination_reason_message = "Instance is unreachable" - switch_job_status(session, job_model, JobStatus.TERMINATING) + context.job_model.termination_reason = JobTerminationReason.INSTANCE_UNREACHABLE + context.job_model.termination_reason_message = "Instance is unreachable" + switch_job_status(session, context.job_model, JobStatus.TERMINATING) async def _mark_job_processed(session: AsyncSession, job_model: JobModel) -> None: From 17cf5e8ee4b7e8c08087d76eb5693ff908f597fa Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 11 Mar 2026 10:35:12 +0500 Subject: [PATCH 4/8] Extract running job startup context --- .../scheduled_tasks/running_jobs.py | 185 +++++++++--------- 1 file changed, 97 insertions(+), 88 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py index 83a1e3fc7..edeee46eb 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py @@ -156,6 +156,14 @@ class _RunningJobContext: server_ssh_private_keys: Optional[tuple[str, Optional[str]]] = None +@dataclass +class _RunningJobStartupContext: + cluster_info: ClusterInfo + volumes: list[Volume] + secrets: dict[str, str] + repo_creds: Optional[RemoteRepoCreds] + + async def _process_running_job(session: AsyncSession, job_model: JobModel): context = await _load_running_job_context(session=session, job_model=job_model) if context.job_provisioning_data is None: @@ -168,80 +176,26 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): ) return - job_provisioning_data = common_utils.get_or_error(context.job_provisioning_data) - - volumes = [] - secrets = {} - cluster_info = None - repo_creds = None - + startup_context = None if context.initial_status in [JobStatus.PROVISIONING, JobStatus.PULLING]: - for other_job in context.run.jobs: - if ( - other_job.job_spec.replica_num == context.job.job_spec.replica_num - and other_job.job_submissions[-1].status == JobStatus.SUBMITTED - ): - logger.debug( - "%s: waiting for all jobs in the replica to be provisioned", - fmt(context.job_model), - ) - await _mark_job_processed(session=session, job_model=context.job_model) - return - - cluster_info = _get_cluster_info( - jobs=context.run.jobs, - replica_num=context.job.job_spec.replica_num, - job_provisioning_data=job_provisioning_data, - job_runtime_data=context.job_submission.job_runtime_data, - ) - - volumes = await get_job_attached_volumes( + startup_context = await _prepare_running_job_startup_context( session=session, - project=context.project, - run_spec=context.run.run_spec, - job_num=context.job.job_spec.job_num, - job_provisioning_data=job_provisioning_data, - ) - - repo_creds_model = await get_repo_creds( - session=session, repo=context.repo_model, user=context.run_model.user + context=context, ) - repo_creds = repo_model_to_repo_head_with_creds( - context.repo_model, repo_creds_model - ).repo_creds - - secrets = await get_project_secrets_mapping(session=session, project=context.project) - try: - _interpolate_secrets(secrets, context.job.job_spec) - except InterpolatorError as e: - await _terminate_running_job( - session=session, - job_model=context.job_model, - termination_reason=JobTerminationReason.TERMINATED_BY_SERVER, - termination_reason_message=f"Secrets interpolation error: {e.args[0]}", - ) + if startup_context is None: return - context.server_ssh_private_keys = get_instance_ssh_private_keys( - common_utils.get_or_error(context.job_model.instance) - ) - if context.initial_status == JobStatus.PROVISIONING: await _process_running_job_provisioning_state( session=session, context=context, - volumes=volumes, - cluster_info=cluster_info, - secrets=secrets, - repo_creds=repo_creds, + startup_context=common_utils.get_or_error(startup_context), ) elif context.initial_status == JobStatus.PULLING: await _process_running_job_pulling_state( session=session, context=context, - cluster_info=cluster_info, - secrets=secrets, - repo_creds=repo_creds, + startup_context=common_utils.get_or_error(startup_context), ) else: await _process_running_job_running_state( @@ -249,22 +203,17 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): context=context, ) - if ( - context.initial_status != context.job_model.status - and context.job_model.status == JobStatus.RUNNING - ): - _initialize_running_job_probes(job_model=context.job_model, job=context.job) - if context.job_model.status == JobStatus.RUNNING: + if context.initial_status != JobStatus.RUNNING: + _initialize_running_job_probes(job_model=context.job_model, job=context.job) await _maybe_register_replica( session, - context.run_model, - context.run, - context.job_model, - context.job.job_spec.probes, + run_model=context.run_model, + run=context.run, + job_model=context.job_model, + probe_specs=context.job.job_spec.probes, ) - if context.job_model.status == JobStatus.RUNNING: - await _check_gpu_utilization(session, context.job_model, context.job) + await _check_gpu_utilization(session, job_model=context.job_model, job=context.job) await _mark_job_processed(session=session, job_model=context.job_model) @@ -276,6 +225,9 @@ async def _load_running_job_context( run_model = await _fetch_run_model(session, job_model.run_id) run = run_model_to_run(run_model, include_sensitive=True) job_submission = job_model_to_job_submission(job_model) + server_ssh_private_keys = get_instance_ssh_private_keys( + common_utils.get_or_error(job_model.instance) + ) return _RunningJobContext( job_model=job_model, run_model=run_model, @@ -286,6 +238,70 @@ async def _load_running_job_context( job_submission=job_submission, job_provisioning_data=job_submission.job_provisioning_data, initial_status=job_model.status, + server_ssh_private_keys=server_ssh_private_keys, + ) + + +async def _prepare_running_job_startup_context( + session: AsyncSession, + context: _RunningJobContext, +) -> Optional[_RunningJobStartupContext]: + job_provisioning_data = common_utils.get_or_error(context.job_provisioning_data) + + for other_job in context.run.jobs: + if ( + other_job.job_spec.replica_num == context.job.job_spec.replica_num + and other_job.job_submissions[-1].status == JobStatus.SUBMITTED + ): + logger.debug( + "%s: waiting for all jobs in the replica to be provisioned", + fmt(context.job_model), + ) + await _mark_job_processed(session=session, job_model=context.job_model) + return None + + cluster_info = _get_cluster_info( + jobs=context.run.jobs, + replica_num=context.job.job_spec.replica_num, + job_provisioning_data=job_provisioning_data, + job_runtime_data=context.job_submission.job_runtime_data, + ) + + volumes = await get_job_attached_volumes( + session=session, + project=context.project, + run_spec=context.run.run_spec, + job_num=context.job.job_spec.job_num, + job_provisioning_data=job_provisioning_data, + ) + + repo_creds_model = await get_repo_creds( + session=session, + repo=context.repo_model, + user=context.run_model.user, + ) + repo_creds = repo_model_to_repo_head_with_creds( + context.repo_model, + repo_creds_model, + ).repo_creds + + secrets = await get_project_secrets_mapping(session=session, project=context.project) + try: + _interpolate_secrets(secrets, context.job.job_spec) + except InterpolatorError as e: + await _terminate_running_job( + session=session, + job_model=context.job_model, + termination_reason=JobTerminationReason.TERMINATED_BY_SERVER, + termination_reason_message=f"Secrets interpolation error: {e.args[0]}", + ) + return None + + return _RunningJobStartupContext( + cluster_info=cluster_info, + volumes=volumes, + secrets=secrets, + repo_creds=repo_creds, ) @@ -339,10 +355,7 @@ async def _fetch_run_model(session: AsyncSession, run_id: uuid.UUID) -> RunModel async def _process_running_job_provisioning_state( session: AsyncSession, context: _RunningJobContext, - volumes: list[Volume], - cluster_info: Optional[ClusterInfo], - secrets: dict[str, str], - repo_creds: Optional[RemoteRepoCreds], + startup_context: _RunningJobStartupContext, ) -> None: job_provisioning_data = common_utils.get_or_error(context.job_provisioning_data) server_ssh_private_keys = common_utils.get_or_error(context.server_ssh_private_keys) @@ -376,14 +389,13 @@ async def _process_running_job_provisioning_state( context.run, context.job_model, job_provisioning_data, - volumes, + startup_context.volumes, context.job.job_spec.registry_auth, public_keys, ssh_user, user_ssh_key, ) else: - assert cluster_info is not None logger.debug( "%s: process provisioning job without shim, age=%s", fmt(context.job_model), @@ -411,11 +423,11 @@ async def _process_running_job_provisioning_state( context.run, context.job_model, context.job, - cluster_info, + startup_context.cluster_info, code, file_archives, - secrets, - repo_creds, + startup_context.secrets, + startup_context.repo_creds, success_if_not_available=False, ) @@ -440,14 +452,11 @@ async def _process_running_job_provisioning_state( async def _process_running_job_pulling_state( session: AsyncSession, context: _RunningJobContext, - cluster_info: Optional[ClusterInfo], - secrets: dict[str, str], - repo_creds: Optional[RemoteRepoCreds], + startup_context: _RunningJobStartupContext, ) -> None: job_provisioning_data = common_utils.get_or_error(context.job_provisioning_data) server_ssh_private_keys = common_utils.get_or_error(context.server_ssh_private_keys) - assert cluster_info is not None logger.debug( "%s: process pulling job with shim, age=%s", fmt(context.job_model), @@ -475,11 +484,11 @@ async def _process_running_job_pulling_state( context.run, context.job_model, context.job, - cluster_info, + startup_context.cluster_info, code, file_archives, - secrets, - repo_creds, + startup_context.secrets, + startup_context.repo_creds, server_ssh_private_keys, job_provisioning_data, ) From cdc67fe1120a9b7ed2010ffe9550c831566df0e3 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 11 Mar 2026 11:25:12 +0500 Subject: [PATCH 5/8] Use keyword args for running job RPCs --- .../scheduled_tasks/running_jobs.py | 84 +++++++++---------- 1 file changed, 40 insertions(+), 44 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py index edeee46eb..831069649 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py @@ -385,15 +385,15 @@ async def _process_running_job_provisioning_state( server_ssh_private_keys, job_provisioning_data, None, - session, - context.run, - context.job_model, - job_provisioning_data, - startup_context.volumes, - context.job.job_spec.registry_auth, - public_keys, - ssh_user, - user_ssh_key, + session=session, + run=context.run, + job_model=context.job_model, + jpd=job_provisioning_data, + volumes=startup_context.volumes, + registry_auth=context.job.job_spec.registry_auth, + public_keys=public_keys, + ssh_user=ssh_user, + ssh_key=user_ssh_key, ) else: logger.debug( @@ -419,15 +419,15 @@ async def _process_running_job_provisioning_state( server_ssh_private_keys, job_provisioning_data, None, - session, - context.run, - context.job_model, - context.job, - startup_context.cluster_info, - code, - file_archives, - startup_context.secrets, - startup_context.repo_creds, + session=session, + run=context.run, + job_model=context.job_model, + job=context.job, + cluster_info=startup_context.cluster_info, + code=code, + file_archives=file_archives, + secrets=startup_context.secrets, + repo_credentials=startup_context.repo_creds, success_if_not_available=False, ) @@ -480,17 +480,17 @@ async def _process_running_job_pulling_state( server_ssh_private_keys, job_provisioning_data, None, - session, - context.run, - context.job_model, - context.job, - startup_context.cluster_info, - code, - file_archives, - startup_context.secrets, - startup_context.repo_creds, - server_ssh_private_keys, - job_provisioning_data, + session=session, + run=context.run, + job_model=context.job_model, + job=context.job, + cluster_info=startup_context.cluster_info, + code=code, + file_archives=file_archives, + secrets=startup_context.secrets, + repo_credentials=startup_context.repo_creds, + server_ssh_private_keys=server_ssh_private_keys, + jpd=job_provisioning_data, ) if success: @@ -543,9 +543,9 @@ async def _process_running_job_running_state( server_ssh_private_keys, job_provisioning_data, context.job_submission.job_runtime_data, - session, - context.run_model, - context.job_model, + session=session, + run_model=context.run_model, + job_model=context.job_model, ) if success: @@ -685,7 +685,7 @@ def _process_provisioning_with_shim( session: AsyncSession, run: Run, job_model: JobModel, - job_provisioning_data: JobProvisioningData, + jpd: JobProvisioningData, volumes: List[Volume], registry_auth: Optional[RegistryAuth], public_keys: List[str], @@ -730,13 +730,9 @@ def _process_provisioning_with_shim( for volume, volume_mount in zip(volumes, volume_mounts): volume_mount.name = volume.name - instance_mounts += _get_instance_specific_mounts( - job_provisioning_data.backend, job_provisioning_data.instance_type.name - ) + instance_mounts += _get_instance_specific_mounts(jpd.backend, jpd.instance_type.name) - gpu_devices = _get_instance_specific_gpu_devices( - job_provisioning_data.backend, job_provisioning_data.instance_type.name - ) + gpu_devices = _get_instance_specific_gpu_devices(jpd.backend, jpd.instance_type.name) container_user = "root" @@ -753,7 +749,7 @@ def _process_provisioning_with_shim( cpu = None memory = None network_mode = NetworkMode.HOST - image_name = _patch_base_image_for_aws_efa(job_spec, job_provisioning_data) + image_name = _patch_base_image_for_aws_efa(job_spec, jpd) if shim_client.is_api_v2_supported(): shim_client.submit_task( task_id=job_model.id, @@ -775,7 +771,7 @@ def _process_provisioning_with_shim( host_ssh_user=ssh_user, host_ssh_keys=[ssh_key] if ssh_key else [], container_ssh_keys=public_keys, - instance_id=job_provisioning_data.instance_id, + instance_id=jpd.instance_id, ) else: submitted = shim_client.submit( @@ -792,7 +788,7 @@ def _process_provisioning_with_shim( mounts=volume_mounts, volumes=volumes, instance_mounts=instance_mounts, - instance_id=job_provisioning_data.instance_id, + instance_id=jpd.instance_id, ) if not submitted: # This can happen when we lost connection to the runner (e.g., network issues), marked @@ -826,7 +822,7 @@ def _process_pulling_with_shim( secrets: Dict[str, str], repo_credentials: Optional[RemoteRepoCreds], server_ssh_private_keys: tuple[str, Optional[str]], - job_provisioning_data: JobProvisioningData, + jpd: JobProvisioningData, ) -> bool: """ Possible next states: @@ -892,7 +888,7 @@ def _process_pulling_with_shim( return _submit_job_to_runner( server_ssh_private_keys, - job_provisioning_data, + jpd, job_runtime_data, session=session, run=run, From f975c2d2fbb462ef601eadbd0da989eaa3472e91 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 11 Mar 2026 12:24:11 +0500 Subject: [PATCH 6/8] Defer running job payload loading --- .../scheduled_tasks/running_jobs.py | 198 ++++++++++-------- .../scheduled_tasks/test_running_jobs.py | 105 +++++++++- 2 files changed, 200 insertions(+), 103 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py index 831069649..80328445e 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py @@ -1,10 +1,11 @@ import asyncio +import enum import re import uuid from collections.abc import Iterable from dataclasses import dataclass from datetime import timedelta -from typing import Dict, List, Optional +from typing import Dict, List, Literal, Optional, Union from sqlalchemy import and_, func, select from sqlalchemy.ext.asyncio import AsyncSession @@ -367,6 +368,7 @@ async def _process_running_job_provisioning_state( return # fails are acceptable until timeout is exceeded + success = False if job_provisioning_data.dockerized: logger.debug( "%s: process provisioning job with shim, age=%s", @@ -401,35 +403,40 @@ async def _process_running_job_provisioning_state( fmt(context.job_model), context.job_submission.age, ) - # FIXME: downloading file archives and code here is a waste of time if - # the runner is not ready yet - file_archives = await _get_job_file_archives( - session=session, - archive_mappings=context.job.job_spec.file_archives, - user=context.run_model.user, - ) - code = await _get_job_code( - session=session, - project=context.project, - repo=context.repo_model, - code_hash=_get_repo_code_hash(context.run, context.job), - ) - success = await common_utils.run_async( - _submit_job_to_runner, + runner_availability = await common_utils.run_async( + _get_runner_availability, server_ssh_private_keys, job_provisioning_data, None, - session=session, - run=context.run, - job_model=context.job_model, - job=context.job, - cluster_info=startup_context.cluster_info, - code=code, - file_archives=file_archives, - secrets=startup_context.secrets, - repo_credentials=startup_context.repo_creds, - success_if_not_available=False, ) + if runner_availability == _RunnerAvailability.AVAILABLE: + file_archives = await _get_job_file_archives( + session=session, + archive_mappings=context.job.job_spec.file_archives, + user=context.run_model.user, + ) + code = await _get_job_code( + session=session, + project=context.project, + repo=context.repo_model, + code_hash=_get_repo_code_hash(context.run, context.job), + ) + success = await common_utils.run_async( + _submit_job_to_runner, + server_ssh_private_keys, + job_provisioning_data, + None, + session=session, + run=context.run, + job_model=context.job_model, + job=context.job, + cluster_info=startup_context.cluster_info, + code=code, + file_archives=file_archives, + secrets=startup_context.secrets, + repo_credentials=startup_context.repo_creds, + success_if_not_available=False, + ) if success: return @@ -462,41 +469,60 @@ async def _process_running_job_pulling_state( fmt(context.job_model), context.job_submission.age, ) - # FIXME: downloading file archives and code here is a waste of time if - # the runner is not ready yet - file_archives = await _get_job_file_archives( - session=session, - archive_mappings=context.job.job_spec.file_archives, - user=context.run_model.user, - ) - code = await _get_job_code( - session=session, - project=context.project, - repo=context.repo_model, - code_hash=_get_repo_code_hash(context.run, context.job), - ) - success = await common_utils.run_async( - _process_pulling_with_shim, + shim_state = await common_utils.run_async( + _get_shim_pulling_state, server_ssh_private_keys, job_provisioning_data, None, - session=session, - run=context.run, job_model=context.job_model, - job=context.job, - cluster_info=startup_context.cluster_info, - code=code, - file_archives=file_archives, - secrets=startup_context.secrets, - repo_credentials=startup_context.repo_creds, - server_ssh_private_keys=server_ssh_private_keys, - jpd=job_provisioning_data, ) - - if success: + if shim_state == _ShimPullingState.WAITING: _reset_disconnected_at(session, context.job_model) return + if shim_state == _ShimPullingState.READY: + runner_availability = await common_utils.run_async( + _get_runner_availability, + server_ssh_private_keys, + job_provisioning_data, + None, + ) + if runner_availability == _RunnerAvailability.UNAVAILABLE: + _reset_disconnected_at(session, context.job_model) + return + + if runner_availability == _RunnerAvailability.AVAILABLE: + file_archives = await _get_job_file_archives( + session=session, + archive_mappings=context.job.job_spec.file_archives, + user=context.run_model.user, + ) + code = await _get_job_code( + session=session, + project=context.project, + repo=context.repo_model, + code_hash=_get_repo_code_hash(context.run, context.job), + ) + success = await common_utils.run_async( + _submit_job_to_runner, + server_ssh_private_keys, + job_provisioning_data, + None, + session=session, + run=context.run, + job_model=context.job_model, + job=context.job, + cluster_info=startup_context.cluster_info, + code=code, + file_archives=file_archives, + secrets=startup_context.secrets, + repo_credentials=startup_context.repo_creds, + success_if_not_available=True, + ) + if success: + _reset_disconnected_at(session, context.job_model) + return + if context.job_model.termination_reason: logger.warning( "%s: failed due to %s, age=%s", @@ -562,6 +588,7 @@ async def _process_running_job_running_state( switch_job_status(session, context.job_model, JobStatus.TERMINATING) # job will be terminated and instance will be emptied by process_terminating_jobs return + # No job_model.termination_reason set means ssh connection failed _set_disconnected_at_now(session, context.job_model) if not _should_terminate_job_due_to_disconnect(context.job_model): @@ -571,6 +598,7 @@ async def _process_running_job_running_state( context.job_submission.age, ) return + if job_provisioning_data.instance_type.resources.spot: context.job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY else: @@ -809,31 +837,30 @@ def _process_provisioning_with_shim( return True +class _RunnerAvailability(enum.Enum): + AVAILABLE = "available" + UNAVAILABLE = "unavailable" + + +class _ShimPullingState(enum.Enum): + WAITING = "waiting" + READY = "ready" + + +@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1) +def _get_runner_availability(ports: Dict[int, int]) -> _RunnerAvailability: + runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT]) + if runner_client.healthcheck() is None: + return _RunnerAvailability.UNAVAILABLE + return _RunnerAvailability.AVAILABLE + + @runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT]) -def _process_pulling_with_shim( +def _get_shim_pulling_state( ports: Dict[int, int], - session: AsyncSession, - run: Run, job_model: JobModel, - job: Job, - cluster_info: ClusterInfo, - code: bytes, - file_archives: Iterable[tuple[uuid.UUID, bytes]], - secrets: Dict[str, str], - repo_credentials: Optional[RemoteRepoCreds], - server_ssh_private_keys: tuple[str, Optional[str]], - jpd: JobProvisioningData, -) -> bool: - """ - Possible next states: - - JobStatus.RUNNING if runner is available - - JobStatus.TERMINATING if shim is not available - - Returns: - is successful - """ +) -> Union[Literal[False], _ShimPullingState]: shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) - job_runtime_data = None if shim_client.is_api_v2_supported(): # raises error if shim is down, causes retry task = shim_client.get_task(job_model.id) @@ -851,7 +878,7 @@ def _process_pulling_with_shim( return False if task.status != TaskStatus.RUNNING: - return True + return _ShimPullingState.WAITING job_runtime_data = get_job_runtime_data(job_model) # should check for None, as there may be older jobs submitted before @@ -859,10 +886,9 @@ def _process_pulling_with_shim( if job_runtime_data is not None: # port mapping is not yet available, waiting if task.ports is None: - return True + return _ShimPullingState.WAITING job_runtime_data.ports = {pm.container: pm.host for pm in task.ports} job_model.job_runtime_data = job_runtime_data.json() - else: shim_status = shim_client.pull() # raises error if shim is down, causes retry @@ -884,23 +910,9 @@ def _process_pulling_with_shim( return False if shim_status.state in ("pulling", "creating"): - return True + return _ShimPullingState.WAITING - return _submit_job_to_runner( - server_ssh_private_keys, - jpd, - job_runtime_data, - session=session, - run=run, - job_model=job_model, - job=job, - cluster_info=cluster_info, - code=code, - file_archives=file_archives, - secrets=secrets, - repo_credentials=repo_credentials, - success_if_not_available=True, - ) + return _ShimPullingState.READY @runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT]) diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py index aad8615bf..3fa0035af 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py @@ -2,7 +2,7 @@ from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Optional -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest from freezegun import freeze_time @@ -150,6 +150,14 @@ async def test_leaves_provisioning_job_unchanged_if_runner_not_alive( patch( "dstack._internal.server.services.runner.client.RunnerClient" ) as RunnerClientMock, + patch( + "dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_file_archives", + new_callable=AsyncMock, + ) as get_job_file_archives_mock, + patch( + "dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_code", + new_callable=AsyncMock, + ) as get_job_code_mock, patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock, ): datetime_mock.return_value = datetime(2023, 1, 2, 5, 12, 30, 10, tzinfo=timezone.utc) @@ -159,6 +167,8 @@ async def test_leaves_provisioning_job_unchanged_if_runner_not_alive( await process_running_jobs() SSHTunnelMock.assert_called_once() runner_client_mock.healthcheck.assert_called_once() + get_job_file_archives_mock.assert_not_awaited() + get_job_code_mock.assert_not_awaited() await session.refresh(job) assert job is not None assert job.status == JobStatus.PROVISIONING @@ -207,8 +217,8 @@ async def test_runs_provisioning_job(self, test_db, session: AsyncSession): working_dir="/dstack/run", username="dstack" ) await process_running_jobs() - SSHTunnelMock.assert_called_once() - runner_client_mock.healthcheck.assert_called_once() + assert SSHTunnelMock.call_count == 2 + assert runner_client_mock.healthcheck.call_count == 2 runner_client_mock.submit_job.assert_called_once() runner_client_mock.upload_code.assert_called_once() runner_client_mock.run_job.assert_called_once() @@ -430,9 +440,9 @@ async def test_pulling_shim( await process_running_jobs() - assert ssh_tunnel_mock.call_count == 2 + assert ssh_tunnel_mock.call_count == 3 shim_client_mock.get_task.assert_called_once() - runner_client_mock.healthcheck.assert_called_once() + assert runner_client_mock.healthcheck.call_count == 2 runner_client_mock.submit_job.assert_called_once() runner_client_mock.upload_code.assert_called_once() runner_client_mock.run_job.assert_called_once() @@ -484,12 +494,87 @@ async def test_pulling_shim_port_mapping_not_ready( shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING shim_client_mock.get_task.return_value.ports = None - await process_running_jobs() + with ( + patch( + "dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_file_archives", + new_callable=AsyncMock, + ) as get_job_file_archives_mock, + patch( + "dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_code", + new_callable=AsyncMock, + ) as get_job_code_mock, + ): + await process_running_jobs() + + ssh_tunnel_mock.assert_called_once() + shim_client_mock.get_task.assert_called_once() + runner_client_mock.healthcheck.assert_not_called() + runner_client_mock.submit_job.assert_not_called() + get_job_file_archives_mock.assert_not_awaited() + get_job_code_mock.assert_not_awaited() + await session.refresh(job) + assert job is not None + assert job.status == JobStatus.PULLING + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_pulling_shim_runner_not_ready( + self, + test_db, + session: AsyncSession, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + runner_client_mock: Mock, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + job_provisioning_data=get_job_provisioning_data(dockerized=True), + job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + shim_client_mock.get_task.return_value.ports = [ + PortMapping(container=10022, host=32771), + PortMapping(container=10999, host=32772), + ] + runner_client_mock.healthcheck.return_value = None + + with ( + patch( + "dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_file_archives", + new_callable=AsyncMock, + ) as get_job_file_archives_mock, + patch( + "dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_code", + new_callable=AsyncMock, + ) as get_job_code_mock, + ): + await process_running_jobs() + + assert ssh_tunnel_mock.call_count == 2 + shim_client_mock.get_task.assert_called_once() + runner_client_mock.healthcheck.assert_called_once() + runner_client_mock.submit_job.assert_not_called() + get_job_file_archives_mock.assert_not_awaited() + get_job_code_mock.assert_not_awaited() - ssh_tunnel_mock.assert_called_once() - shim_client_mock.get_task.assert_called_once() - runner_client_mock.healthcheck.assert_not_called() - runner_client_mock.submit_job.assert_not_called() await session.refresh(job) assert job is not None assert job.status == JobStatus.PULLING From 9b0a11e6f407f5b083b425d10441e3a33fc27ce4 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 11 Mar 2026 12:56:20 +0500 Subject: [PATCH 7/8] Fix missing jrd update --- .../scheduled_tasks/running_jobs.py | 9 +- .../scheduled_tasks/test_running_jobs.py | 86 +++++++++++++++++++ 2 files changed, 91 insertions(+), 4 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py index 80328445e..127ef3b70 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py @@ -470,7 +470,7 @@ async def _process_running_job_pulling_state( context.job_submission.age, ) shim_state = await common_utils.run_async( - _get_shim_pulling_state, + _sync_shim_pulling_state, server_ssh_private_keys, job_provisioning_data, None, @@ -481,11 +481,12 @@ async def _process_running_job_pulling_state( return if shim_state == _ShimPullingState.READY: + job_runtime_data = get_job_runtime_data(context.job_model) runner_availability = await common_utils.run_async( _get_runner_availability, server_ssh_private_keys, job_provisioning_data, - None, + job_runtime_data, ) if runner_availability == _RunnerAvailability.UNAVAILABLE: _reset_disconnected_at(session, context.job_model) @@ -507,7 +508,7 @@ async def _process_running_job_pulling_state( _submit_job_to_runner, server_ssh_private_keys, job_provisioning_data, - None, + job_runtime_data, session=session, run=context.run, job_model=context.job_model, @@ -856,7 +857,7 @@ def _get_runner_availability(ports: Dict[int, int]) -> _RunnerAvailability: @runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT]) -def _get_shim_pulling_state( +def _sync_shim_pulling_state( ports: Dict[int, int], job_model: JobModel, ) -> Union[Literal[False], _ShimPullingState]: diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py index 3fa0035af..a1ca2a501 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py @@ -39,6 +39,7 @@ from dstack._internal.server import settings as server_settings from dstack._internal.server.background.scheduled_tasks.running_jobs import ( _patch_base_image_for_aws_efa, + _RunnerAvailability, process_running_jobs, ) from dstack._internal.server.models import JobModel @@ -579,6 +580,91 @@ async def test_pulling_shim_runner_not_ready( assert job is not None assert job.status == JobStatus.PULLING + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_pulling_shim_uses_runtime_port_mapping_for_runner_calls( + self, + test_db, + session: AsyncSession, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + job_provisioning_data=get_job_provisioning_data(dockerized=True), + job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + shim_client_mock.get_task.return_value.ports = [ + PortMapping(container=10022, host=32771), + PortMapping(container=10999, host=32772), + ] + + expected_ports = { + 10022: 32771, + 10999: 32772, + } + + def assert_runner_availability(_, __, job_runtime_data): + assert job_runtime_data is not None + assert job_runtime_data.ports == expected_ports + return _RunnerAvailability.AVAILABLE + + def assert_submit_job_to_runner(_, __, job_runtime_data, **kwargs): + assert job_runtime_data is not None + assert job_runtime_data.ports == expected_ports + return True + + with ( + patch( + "dstack._internal.server.background.scheduled_tasks.running_jobs._get_runner_availability", + side_effect=assert_runner_availability, + ) as get_runner_availability_mock, + patch( + "dstack._internal.server.background.scheduled_tasks.running_jobs._submit_job_to_runner", + side_effect=assert_submit_job_to_runner, + ) as submit_job_to_runner_mock, + patch( + "dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_file_archives", + new_callable=AsyncMock, + return_value=[], + ), + patch( + "dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_code", + new_callable=AsyncMock, + return_value=b"", + ), + ): + await process_running_jobs() + + ssh_tunnel_mock.assert_called_once() + get_runner_availability_mock.assert_called_once() + submit_job_to_runner_mock.assert_called_once() + + await session.refresh(job) + assert job is not None + assert job.status == JobStatus.PULLING + jrd = JobRuntimeData.__response__.parse_raw(job.job_runtime_data) + assert jrd.ports == expected_ports + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_pulling_shim_failed(self, test_db, session: AsyncSession): From c051458856eedf072de5468bad35e4bac343bb6d Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 11 Mar 2026 14:12:36 +0500 Subject: [PATCH 8/8] Extract backend provisioning helpers --- .../scheduled_tasks/running_jobs.py | 117 +----------- .../server/services/backends/provisioning.py | 122 +++++++++++++ .../services/jobs/configurators/base.py | 3 +- .../scheduled_tasks/test_running_jobs.py | 133 +------------- .../services/backends/test_provisioning.py | 168 ++++++++++++++++++ 5 files changed, 302 insertions(+), 241 deletions(-) create mode 100644 src/dstack/_internal/server/services/backends/provisioning.py create mode 100644 src/tests/_internal/server/services/backends/test_provisioning.py diff --git a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py index 127ef3b70..ea3d53973 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py @@ -1,6 +1,5 @@ import asyncio import enum -import re import uuid from collections.abc import Iterable from dataclasses import dataclass @@ -11,7 +10,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import aliased, contains_eager, joinedload, load_only -from dstack._internal import settings from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT from dstack._internal.core.errors import GatewayError from dstack._internal.core.models.backends.base import BackendType @@ -52,10 +50,15 @@ RunModel, UserModel, ) -from dstack._internal.server.schemas.runner import GPUDevice, TaskStatus +from dstack._internal.server.schemas.runner import TaskStatus from dstack._internal.server.services import events, services from dstack._internal.server.services import files as files_services from dstack._internal.server.services import logs as logs_services +from dstack._internal.server.services.backends.provisioning import ( + get_instance_specific_gpu_devices, + get_instance_specific_mounts, + resolve_provisioning_image_name, +) from dstack._internal.server.services.instances import ( get_instance_remote_connection_info, get_instance_ssh_private_keys, @@ -759,9 +762,9 @@ def _process_provisioning_with_shim( for volume, volume_mount in zip(volumes, volume_mounts): volume_mount.name = volume.name - instance_mounts += _get_instance_specific_mounts(jpd.backend, jpd.instance_type.name) + instance_mounts += get_instance_specific_mounts(jpd.backend, jpd.instance_type.name) - gpu_devices = _get_instance_specific_gpu_devices(jpd.backend, jpd.instance_type.name) + gpu_devices = get_instance_specific_gpu_devices(jpd.backend, jpd.instance_type.name) container_user = "root" @@ -778,7 +781,7 @@ def _process_provisioning_with_shim( cpu = None memory = None network_mode = NetworkMode.HOST - image_name = _patch_base_image_for_aws_efa(job_spec, jpd) + image_name = resolve_provisioning_image_name(job_spec, jpd) if shim_client.is_api_v2_supported(): shim_client.submit_task( task_id=job_model.id, @@ -1301,105 +1304,3 @@ def _interpolate_secrets(secrets: Dict[str, str], job_spec: JobSpec): username=interpolate(job_spec.registry_auth.username), password=interpolate(job_spec.registry_auth.password), ) - - -def _get_instance_specific_mounts( - backend_type: BackendType, instance_type_name: str -) -> List[InstanceMountPoint]: - if backend_type == BackendType.GCP: - if instance_type_name == "a3-megagpu-8g": - return [ - InstanceMountPoint( - instance_path="/dev/aperture_devices", - path="/dev/aperture_devices", - ), - InstanceMountPoint( - instance_path="/var/lib/tcpxo/lib64", - path="/var/lib/tcpxo/lib64", - ), - InstanceMountPoint( - instance_path="/var/lib/fastrak/lib64", - path="/var/lib/fastrak/lib64", - ), - ] - if instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]: - return [ - InstanceMountPoint( - instance_path="/var/lib/nvidia/lib64", - path="/usr/local/nvidia/lib64", - ), - InstanceMountPoint( - instance_path="/var/lib/nvidia/bin", - path="/usr/local/nvidia/bin", - ), - InstanceMountPoint( - instance_path="/var/lib/tcpx/lib64", - path="/usr/local/tcpx/lib64", - ), - InstanceMountPoint( - instance_path="/run/tcpx", - path="/run/tcpx", - ), - ] - return [] - - -def _get_instance_specific_gpu_devices( - backend_type: BackendType, instance_type_name: str -) -> List[GPUDevice]: - gpu_devices = [] - if backend_type == BackendType.GCP and instance_type_name in [ - "a3-edgegpu-8g", - "a3-highgpu-8g", - ]: - for i in range(8): - gpu_devices.append( - GPUDevice(path_on_host=f"/dev/nvidia{i}", path_in_container=f"/dev/nvidia{i}") - ) - gpu_devices.append( - GPUDevice(path_on_host="/dev/nvidia-uvm", path_in_container="/dev/nvidia-uvm") - ) - gpu_devices.append( - GPUDevice(path_on_host="/dev/nvidiactl", path_in_container="/dev/nvidiactl") - ) - return gpu_devices - - -def _patch_base_image_for_aws_efa( - job_spec: JobSpec, job_provisioning_data: JobProvisioningData -) -> str: - image_name = job_spec.image_name - - if job_provisioning_data.backend != BackendType.AWS: - return image_name - - instance_type = job_provisioning_data.instance_type.name - efa_enabled_patterns = [ - # TODO: p6-b200 isn't supported yet in gpuhunt - r"^p6-b200\.(48xlarge)$", - r"^p5\.(4xlarge|48xlarge)$", - r"^p5e\.(48xlarge)$", - r"^p5en\.(48xlarge)$", - r"^p4d\.(24xlarge)$", - r"^p4de\.(24xlarge)$", - r"^g6\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$", - r"^g6e\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$", - r"^gr6\.8xlarge$", - r"^g5\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$", - r"^g4dn\.(8xlarge|12xlarge|16xlarge|metal)$", - r"^p3dn\.(24xlarge)$", - ] - - is_efa_enabled = any(re.match(pattern, instance_type) for pattern in efa_enabled_patterns) - if not is_efa_enabled: - return image_name - - if not image_name.startswith(f"{settings.DSTACK_BASE_IMAGE}:"): - return image_name - - if image_name.endswith(f"-base-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"): - return image_name[:-17] + f"-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" - elif image_name.endswith(f"-devel-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"): - return image_name[:-18] + f"-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" - - return image_name diff --git a/src/dstack/_internal/server/services/backends/provisioning.py b/src/dstack/_internal/server/services/backends/provisioning.py new file mode 100644 index 000000000..13fd362b3 --- /dev/null +++ b/src/dstack/_internal/server/services/backends/provisioning.py @@ -0,0 +1,122 @@ +import re + +from dstack._internal import settings +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.runs import JobProvisioningData, JobSpec +from dstack._internal.core.models.volumes import InstanceMountPoint +from dstack._internal.server.schemas.runner import GPUDevice + +_AWS_EFA_ENABLED_INSTANCE_TYPE_PATTERNS = [ + # TODO: p6-b200 isn't supported yet in gpuhunt + r"^p6-b200\.(48xlarge)$", + r"^p5\.(4xlarge|48xlarge)$", + r"^p5e\.(48xlarge)$", + r"^p5en\.(48xlarge)$", + r"^p4d\.(24xlarge)$", + r"^p4de\.(24xlarge)$", + r"^g6\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$", + r"^g6e\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$", + r"^gr6\.8xlarge$", + r"^g5\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$", + r"^g4dn\.(8xlarge|12xlarge|16xlarge|metal)$", + r"^p3dn\.(24xlarge)$", +] + + +def get_instance_specific_mounts( + backend_type: BackendType, + instance_type_name: str, +) -> list[InstanceMountPoint]: + if backend_type == BackendType.GCP: + if instance_type_name == "a3-megagpu-8g": + return [ + InstanceMountPoint( + instance_path="/dev/aperture_devices", + path="/dev/aperture_devices", + ), + InstanceMountPoint( + instance_path="/var/lib/tcpxo/lib64", + path="/var/lib/tcpxo/lib64", + ), + InstanceMountPoint( + instance_path="/var/lib/fastrak/lib64", + path="/var/lib/fastrak/lib64", + ), + ] + if instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]: + return [ + InstanceMountPoint( + instance_path="/var/lib/nvidia/lib64", + path="/usr/local/nvidia/lib64", + ), + InstanceMountPoint( + instance_path="/var/lib/nvidia/bin", + path="/usr/local/nvidia/bin", + ), + InstanceMountPoint( + instance_path="/var/lib/tcpx/lib64", + path="/usr/local/tcpx/lib64", + ), + InstanceMountPoint( + instance_path="/run/tcpx", + path="/run/tcpx", + ), + ] + return [] + + +def get_instance_specific_gpu_devices( + backend_type: BackendType, + instance_type_name: str, +) -> list[GPUDevice]: + gpu_devices = [] + if backend_type == BackendType.GCP and instance_type_name in [ + "a3-edgegpu-8g", + "a3-highgpu-8g", + ]: + for i in range(8): + gpu_devices.append( + GPUDevice(path_on_host=f"/dev/nvidia{i}", path_in_container=f"/dev/nvidia{i}") + ) + gpu_devices.append( + GPUDevice(path_on_host="/dev/nvidia-uvm", path_in_container="/dev/nvidia-uvm") + ) + gpu_devices.append( + GPUDevice(path_on_host="/dev/nvidiactl", path_in_container="/dev/nvidiactl") + ) + return gpu_devices + + +def resolve_provisioning_image_name( + job_spec: JobSpec, + job_provisioning_data: JobProvisioningData, +) -> str: + image_name = job_spec.image_name + if job_provisioning_data.backend == BackendType.AWS: + return _patch_base_image_for_aws_efa( + image_name, + job_provisioning_data.instance_type.name, + ) + return image_name + + +def _patch_base_image_for_aws_efa( + image_name: str, + instance_type_name: str, +) -> str: + is_efa_enabled = any( + re.match(pattern, instance_type_name) + for pattern in _AWS_EFA_ENABLED_INSTANCE_TYPE_PATTERNS + ) + if not is_efa_enabled: + return image_name + + if not image_name.startswith(f"{settings.DSTACK_BASE_IMAGE}:"): + return image_name + + if image_name.endswith(f"-base-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"): + return image_name[:-17] + f"-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" + if image_name.endswith(f"-devel-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"): + return image_name[:-18] + f"-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" + + return image_name diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py index 3310cda99..b73c9bbe6 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/base.py +++ b/src/dstack/_internal/server/services/jobs/configurators/base.py @@ -77,7 +77,8 @@ def get_default_python_verison() -> str: def get_default_image(nvcc: bool = False) -> str: """ Note: May be overridden by dstack (e.g., EFA-enabled version for AWS EFA-capable instances). - See `dstack._internal.server.background.scheduled_tasks.running_jobs._patch_base_image_for_aws_efa` for details. + See `dstack._internal.server.services.backends.provisioning.resolve_provisioning_image_name` + for details. Args: nvcc: If True, returns 'devel' variant, otherwise 'base'. diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py index a1ca2a501..80a18dc11 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py @@ -19,16 +19,12 @@ ProbeConfig, ServiceConfiguration, ) -from dstack._internal.core.models.instances import InstanceStatus, InstanceType +from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.core.models.profiles import StartupOrder, UtilizationPolicy -from dstack._internal.core.models.resources import ResourcesSpec from dstack._internal.core.models.runs import ( - JobProvisioningData, JobRuntimeData, - JobSpec, JobStatus, JobTerminationReason, - Requirements, RunStatus, ) from dstack._internal.core.models.volumes import ( @@ -38,7 +34,6 @@ ) from dstack._internal.server import settings as server_settings from dstack._internal.server.background.scheduled_tasks.running_jobs import ( - _patch_base_image_for_aws_efa, _RunnerAvailability, process_running_jobs, ) @@ -1302,129 +1297,3 @@ async def test_registers_service_replica_only_after_probes_pass( else: assert not job.registered assert not events - - -class TestPatchBaseImageForAwsEfa: - @staticmethod - def _create_job_spec(image_name: str) -> "JobSpec": - return JobSpec( - job_num=0, - job_name="test-job", - commands=["echo hello"], - env={}, - image_name=image_name, - requirements=Requirements(resources=ResourcesSpec()), - ) - - @staticmethod - def _create_job_provisioning_data_with_instance_type( - backend: BackendType, instance_type: str - ) -> JobProvisioningData: - job_provisioning_data = get_job_provisioning_data(backend=backend) - job_provisioning_data.instance_type = InstanceType( - name=instance_type, - resources=job_provisioning_data.instance_type.resources, - ) - return job_provisioning_data - - @staticmethod - def _call_patch_base_image_for_aws_efa( - image_name: str, backend: BackendType, instance_type: str - ) -> str: - job_spec = TestPatchBaseImageForAwsEfa._create_job_spec(image_name) - job_provisioning_data = ( - TestPatchBaseImageForAwsEfa._create_job_provisioning_data_with_instance_type( - backend, instance_type - ) - ) - return _patch_base_image_for_aws_efa(job_spec, job_provisioning_data) - - @pytest.mark.parametrize( - "suffix,instance_type", - [ - ("-base", "p6-b200.48xlarge"), - ("-devel", "p5.48xlarge"), - ], - ) - def test_patch_aws_efa_instance_with_suffix(self, suffix: str, instance_type: str): - image_name = f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}{suffix}-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" - result = self._call_patch_base_image_for_aws_efa( - image_name, BackendType.AWS, instance_type - ) - expected = f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" - assert result == expected - - @pytest.mark.parametrize("suffix", ["-base", "-devel"]) - @pytest.mark.parametrize( - "instance_type", - [ - "p5.48xlarge", - "p5e.48xlarge", - "p4d.24xlarge", - "p4de.24xlarge", - "g6.8xlarge", - "g6e.8xlarge", - ], - ) - def test_patch_all_efa_instance_types(self, instance_type: str, suffix: str): - image_name = f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}{suffix}-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" - result = self._call_patch_base_image_for_aws_efa( - image_name, BackendType.AWS, instance_type - ) - expected = f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" - assert result == expected - - @pytest.mark.parametrize("suffix", ["-base", "-devel"]) - @pytest.mark.parametrize( - "backend", - [BackendType.GCP, BackendType.AZURE, BackendType.LAMBDA, BackendType.LOCAL], - ) - @pytest.mark.parametrize( - "instance_type", - [ - "standard-4", - "p5.xlarge", - "p6.2xlarge", - "g6.xlarge", - ], # Mix of generic and EFA-named types - ) - def test_no_patch_non_aws_backends( - self, backend: BackendType, suffix: str, instance_type: str - ): - image_name = f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}{suffix}-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" - result = self._call_patch_base_image_for_aws_efa(image_name, backend, instance_type) - assert result == image_name - - @pytest.mark.parametrize("suffix", ["-base", "-devel"]) - @pytest.mark.parametrize( - "instance_type", - ["t3.micro", "m5.large", "c5.xlarge", "r5.2xlarge", "m6i.large", "g6.xlarge"], - ) - def test_no_patch_non_efa_aws_instances(self, instance_type: str, suffix: str): - image_name = f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}{suffix}" - result = self._call_patch_base_image_for_aws_efa( - image_name, BackendType.AWS, instance_type - ) - assert result == image_name - - @pytest.mark.parametrize( - "instance_type", - ["p5.xlarge", "p6.2xlarge", "t3.micro", "m5.large"], # Mix of EFA and non-EFA instances - ) - @pytest.mark.parametrize( - "image_name", - [ - "ubuntu:20.04", - "nvidia/cuda:11.8-runtime-ubuntu20.04", - "python:3.9-slim", - "custom/image:latest", - f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}-custom", - f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}-devel-efa", - f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}", - ], - ) - def test_no_patch_other_images(self, instance_type: str, image_name: str): - result = self._call_patch_base_image_for_aws_efa( - image_name, BackendType.AWS, instance_type - ) - assert result == image_name diff --git a/src/tests/_internal/server/services/backends/test_provisioning.py b/src/tests/_internal/server/services/backends/test_provisioning.py new file mode 100644 index 000000000..3145a56af --- /dev/null +++ b/src/tests/_internal/server/services/backends/test_provisioning.py @@ -0,0 +1,168 @@ +import pytest + +from dstack._internal import settings +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import InstanceType +from dstack._internal.core.models.resources import ResourcesSpec +from dstack._internal.core.models.runs import JobProvisioningData, JobSpec, Requirements +from dstack._internal.server.services.backends.provisioning import ( + resolve_provisioning_image_name, +) +from dstack._internal.server.testing.common import get_job_provisioning_data + + +class TestResolveProvisioningImageName: + @staticmethod + def _create_job_spec(image_name: str) -> JobSpec: + return JobSpec( + job_num=0, + job_name="test-job", + app_specs=None, + commands=["echo hello"], + env={}, + home_dir=None, + image_name=image_name, + max_duration=None, + registry_auth=None, + requirements=Requirements(resources=ResourcesSpec()), + retry=None, + working_dir=None, + ) + + @staticmethod + def _create_job_provisioning_data_with_instance_type( + backend: BackendType, + instance_type: str, + ) -> JobProvisioningData: + job_provisioning_data = get_job_provisioning_data(backend=backend) + job_provisioning_data.instance_type = InstanceType( + name=instance_type, + resources=job_provisioning_data.instance_type.resources, + ) + return job_provisioning_data + + @staticmethod + def _call_resolve_provisioning_image_name( + image_name: str, + backend: BackendType, + instance_type: str, + ) -> str: + job_spec = TestResolveProvisioningImageName._create_job_spec(image_name) + job_provisioning_data = ( + TestResolveProvisioningImageName._create_job_provisioning_data_with_instance_type( + backend, + instance_type, + ) + ) + return resolve_provisioning_image_name(job_spec, job_provisioning_data) + + @pytest.mark.parametrize( + ("suffix", "instance_type"), + [ + ("-base", "p6-b200.48xlarge"), + ("-devel", "p5.48xlarge"), + ], + ) + def test_patch_aws_efa_instance_with_suffix(self, suffix: str, instance_type: str) -> None: + image_name = ( + f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}{suffix}" + f"-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" + ) + result = self._call_resolve_provisioning_image_name( + image_name, + BackendType.AWS, + instance_type, + ) + expected = ( + f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}" + f"-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" + ) + assert result == expected + + @pytest.mark.parametrize("suffix", ["-base", "-devel"]) + @pytest.mark.parametrize( + "instance_type", + [ + "p5.48xlarge", + "p5e.48xlarge", + "p4d.24xlarge", + "p4de.24xlarge", + "g6.8xlarge", + "g6e.8xlarge", + ], + ) + def test_patch_all_efa_instance_types(self, instance_type: str, suffix: str) -> None: + image_name = ( + f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}{suffix}" + f"-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" + ) + result = self._call_resolve_provisioning_image_name( + image_name, + BackendType.AWS, + instance_type, + ) + expected = ( + f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}" + f"-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" + ) + assert result == expected + + @pytest.mark.parametrize("suffix", ["-base", "-devel"]) + @pytest.mark.parametrize( + "backend", + [BackendType.GCP, BackendType.AZURE, BackendType.LAMBDA, BackendType.LOCAL], + ) + @pytest.mark.parametrize( + "instance_type", + ["standard-4", "p5.xlarge", "p6.2xlarge", "g6.xlarge"], + ) + def test_no_patch_non_aws_backends( + self, + backend: BackendType, + suffix: str, + instance_type: str, + ) -> None: + image_name = ( + f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}{suffix}" + f"-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" + ) + result = self._call_resolve_provisioning_image_name(image_name, backend, instance_type) + assert result == image_name + + @pytest.mark.parametrize("suffix", ["-base", "-devel"]) + @pytest.mark.parametrize( + "instance_type", + ["t3.micro", "m5.large", "c5.xlarge", "r5.2xlarge", "m6i.large", "g6.xlarge"], + ) + def test_no_patch_non_efa_aws_instances(self, instance_type: str, suffix: str) -> None: + image_name = f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}{suffix}" + result = self._call_resolve_provisioning_image_name( + image_name, + BackendType.AWS, + instance_type, + ) + assert result == image_name + + @pytest.mark.parametrize( + "instance_type", + ["p5.xlarge", "p6.2xlarge", "t3.micro", "m5.large"], + ) + @pytest.mark.parametrize( + "image_name", + [ + "ubuntu:20.04", + "nvidia/cuda:11.8-runtime-ubuntu20.04", + "python:3.9-slim", + "custom/image:latest", + f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}-custom", + f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}-devel-efa", + f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}", + ], + ) + def test_no_patch_other_images(self, instance_type: str, image_name: str) -> None: + result = self._call_resolve_provisioning_image_name( + image_name, + BackendType.AWS, + instance_type, + ) + assert result == image_name