diff --git a/AGENTS.md b/AGENTS.md index bb1a7aac0f..3e1f85eaf5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -18,6 +18,8 @@ - Python targets 3.9+ with 4-space indentation and max line length of 99 (see `ruff.toml`; `E501` is ignored but keep lines readable). - Imports are sorted via Ruff’s isort settings (`dstack` treated as first-party). - Keep primary/public functions before local helper functions in a module section. +- Roughly keep function definitions in the order they are referenced within a file so call flow stays easy to follow. +- Prefer early returns over nested `if`/`else` blocks when they make the control flow simpler. - Keep private classes, exceptions, and similar implementation-specific types close to the private functions that use them unless they are shared more broadly in the module. - Prefer pydantic-style models in `core/models`. - Document attributes when the note adds behavior, compatibility, or semantic context that is not obvious from the name and type. Use attribute docstrings without leading newline. diff --git a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py index 5916c9054a..ea3d539734 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py @@ -1,15 +1,15 @@ import asyncio -import re +import enum import uuid from collections.abc import Iterable +from dataclasses import dataclass from datetime import timedelta -from typing import Dict, List, Optional +from typing import Dict, List, Literal, Optional, Union from sqlalchemy import and_, func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import aliased, contains_eager, joinedload, load_only -from dstack._internal import settings from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT from dstack._internal.core.errors import GatewayError from dstack._internal.core.models.backends.base import BackendType @@ -30,6 +30,7 @@ JobRuntimeData, JobSpec, JobStatus, + JobSubmission, JobTerminationReason, ProbeSpec, Run, @@ -49,10 +50,15 @@ RunModel, UserModel, ) -from dstack._internal.server.schemas.runner import GPUDevice, TaskStatus +from dstack._internal.server.schemas.runner import TaskStatus from dstack._internal.server.services import events, services from dstack._internal.server.services import files as files_services from dstack._internal.server.services import logs as logs_services +from dstack._internal.server.services.backends.provisioning import ( + get_instance_specific_gpu_devices, + get_instance_specific_mounts, + resolve_provisioning_image_name, +) from dstack._internal.server.services.instances import ( get_instance_remote_connection_info, get_instance_ssh_private_keys, @@ -140,273 +146,167 @@ async def _process_next_running_job(): lockset.difference_update([job_model_id]) -async def _process_running_job(session: AsyncSession, job_model: JobModel): - job_model = await _refetch_job_model(session, job_model) - run_model = await _fetch_run_model(session, job_model.run_id) - repo_model = run_model.repo - project = run_model.project - run = run_model_to_run(run_model, include_sensitive=True) - job_submission = job_model_to_job_submission(job_model) - job_provisioning_data = job_submission.job_provisioning_data - if job_provisioning_data is None: - logger.error("%s: job_provisioning_data of an active job is None", fmt(job_model)) - job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER - job_model.termination_reason_message = ( - "Unexpected server error: job_provisioning_data of an active job is None" - ) - switch_job_status(session, job_model, JobStatus.TERMINATING) - job_model.last_processed_at = common_utils.get_current_datetime() - await session.commit() - return +@dataclass +class _RunningJobContext: + job_model: JobModel + run_model: RunModel + repo_model: RepoModel + project: ProjectModel + run: Run + job: Job + job_submission: JobSubmission + job_provisioning_data: Optional[JobProvisioningData] + initial_status: JobStatus + server_ssh_private_keys: Optional[tuple[str, Optional[str]]] = None - job = find_job(run.jobs, job_model.replica_num, job_model.job_num) - volumes = [] - secrets = {} - cluster_info = None - repo_creds = None +@dataclass +class _RunningJobStartupContext: + cluster_info: ClusterInfo + volumes: list[Volume] + secrets: dict[str, str] + repo_creds: Optional[RemoteRepoCreds] - initial_status = job_model.status - if initial_status in [JobStatus.PROVISIONING, JobStatus.PULLING]: - for other_job in run.jobs: - if ( - other_job.job_spec.replica_num == job.job_spec.replica_num - and other_job.job_submissions[-1].status == JobStatus.SUBMITTED - ): - logger.debug( - "%s: waiting for all jobs in the replica to be provisioned", - fmt(job_model), - ) - job_model.last_processed_at = common_utils.get_current_datetime() - await session.commit() - return - cluster_info = _get_cluster_info( - jobs=run.jobs, - replica_num=job.job_spec.replica_num, - job_provisioning_data=job_provisioning_data, - job_runtime_data=job_submission.job_runtime_data, +async def _process_running_job(session: AsyncSession, job_model: JobModel): + context = await _load_running_job_context(session=session, job_model=job_model) + if context.job_provisioning_data is None: + logger.error("%s: job_provisioning_data of an active job is None", fmt(context.job_model)) + await _terminate_running_job( + session=session, + job_model=context.job_model, + termination_reason=JobTerminationReason.TERMINATED_BY_SERVER, + termination_reason_message="Unexpected server error: job_provisioning_data of an active job is None", ) + return - volumes = await get_job_attached_volumes( + startup_context = None + if context.initial_status in [JobStatus.PROVISIONING, JobStatus.PULLING]: + startup_context = await _prepare_running_job_startup_context( session=session, - project=project, - run_spec=run.run_spec, - job_num=job.job_spec.job_num, - job_provisioning_data=job_provisioning_data, + context=context, ) + if startup_context is None: + return - repo_creds_model = await get_repo_creds( - session=session, repo=repo_model, user=run_model.user + if context.initial_status == JobStatus.PROVISIONING: + await _process_running_job_provisioning_state( + session=session, + context=context, + startup_context=common_utils.get_or_error(startup_context), + ) + elif context.initial_status == JobStatus.PULLING: + await _process_running_job_pulling_state( + session=session, + context=context, + startup_context=common_utils.get_or_error(startup_context), + ) + else: + await _process_running_job_running_state( + session=session, + context=context, ) - repo_creds = repo_model_to_repo_head_with_creds(repo_model, repo_creds_model).repo_creds - secrets = await get_project_secrets_mapping(session=session, project=project) - try: - _interpolate_secrets(secrets, job.job_spec) - except InterpolatorError as e: - job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER - job_model.termination_reason_message = f"Secrets interpolation error: {e.args[0]}" - switch_job_status(session, job_model, JobStatus.TERMINATING) - job_model.last_processed_at = common_utils.get_current_datetime() - await session.commit() - return + if context.job_model.status == JobStatus.RUNNING: + if context.initial_status != JobStatus.RUNNING: + _initialize_running_job_probes(job_model=context.job_model, job=context.job) + await _maybe_register_replica( + session, + run_model=context.run_model, + run=context.run, + job_model=context.job_model, + probe_specs=context.job.job_spec.probes, + ) + await _check_gpu_utilization(session, job_model=context.job_model, job=context.job) + await _mark_job_processed(session=session, job_model=context.job_model) + + +async def _load_running_job_context( + session: AsyncSession, job_model: JobModel +) -> _RunningJobContext: + job_model = await _refetch_job_model(session, job_model) + run_model = await _fetch_run_model(session, job_model.run_id) + run = run_model_to_run(run_model, include_sensitive=True) + job_submission = job_model_to_job_submission(job_model) server_ssh_private_keys = get_instance_ssh_private_keys( common_utils.get_or_error(job_model.instance) ) + return _RunningJobContext( + job_model=job_model, + run_model=run_model, + repo_model=run_model.repo, + project=run_model.project, + run=run, + job=find_job(run.jobs, job_model.replica_num, job_model.job_num), + job_submission=job_submission, + job_provisioning_data=job_submission.job_provisioning_data, + initial_status=job_model.status, + server_ssh_private_keys=server_ssh_private_keys, + ) - if initial_status == JobStatus.PROVISIONING: - if job_provisioning_data.hostname is None: - await _wait_for_instance_provisioning_data(session, job_model) - job_model.last_processed_at = common_utils.get_current_datetime() - await session.commit() - return - if _should_wait_for_other_nodes(run, job, job_model): - job_model.last_processed_at = common_utils.get_current_datetime() - await session.commit() - return - - # fails are acceptable until timeout is exceeded - if job_provisioning_data.dockerized: - logger.debug( - "%s: process provisioning job with shim, age=%s", - fmt(job_model), - job_submission.age, - ) - ssh_user = job_provisioning_data.username - assert run.run_spec.ssh_key_pub is not None - user_ssh_key = run.run_spec.ssh_key_pub.strip() - public_keys = [project.ssh_public_key.strip(), user_ssh_key] - if job_provisioning_data.backend == BackendType.LOCAL: - # No need to update ~/.ssh/authorized_keys when running shim locally - user_ssh_key = "" - success = await common_utils.run_async( - _process_provisioning_with_shim, - server_ssh_private_keys, - job_provisioning_data, - None, - session, - run, - job_model, - job_provisioning_data, - volumes, - job.job_spec.registry_auth, - public_keys, - ssh_user, - user_ssh_key, - ) - else: - assert cluster_info is not None - logger.debug( - "%s: process provisioning job without shim, age=%s", - fmt(job_model), - job_submission.age, - ) - # FIXME: downloading file archives and code here is a waste of time if - # the runner is not ready yet - file_archives = await _get_job_file_archives( - session=session, - archive_mappings=job.job_spec.file_archives, - user=run_model.user, - ) - code = await _get_job_code( - session=session, - project=project, - repo=repo_model, - code_hash=_get_repo_code_hash(run, job), - ) - success = await common_utils.run_async( - _submit_job_to_runner, - server_ssh_private_keys, - job_provisioning_data, - None, - session, - run, - job_model, - job, - cluster_info, - code, - file_archives, - secrets, - repo_creds, - success_if_not_available=False, - ) - if not success: - # check timeout - provisioning_timeout = get_provisioning_timeout( - backend_type=job_provisioning_data.get_base_backend(), - instance_type_name=job_provisioning_data.instance_type.name, - ) - if job_submission.age > provisioning_timeout: - job_model.termination_reason = JobTerminationReason.WAITING_RUNNER_LIMIT_EXCEEDED - job_model.termination_reason_message = ( - f"Runner did not become available within {provisioning_timeout.total_seconds()}s." - f" Job submission age: {job_submission.age.total_seconds()}s)" - ) - switch_job_status(session, job_model, JobStatus.TERMINATING) - # instance will be emptied by process_terminating_jobs +async def _prepare_running_job_startup_context( + session: AsyncSession, + context: _RunningJobContext, +) -> Optional[_RunningJobStartupContext]: + job_provisioning_data = common_utils.get_or_error(context.job_provisioning_data) - else: # fails are not acceptable - if initial_status == JobStatus.PULLING: - assert cluster_info is not None + for other_job in context.run.jobs: + if ( + other_job.job_spec.replica_num == context.job.job_spec.replica_num + and other_job.job_submissions[-1].status == JobStatus.SUBMITTED + ): logger.debug( - "%s: process pulling job with shim, age=%s", fmt(job_model), job_submission.age - ) - # FIXME: downloading file archives and code here is a waste of time if - # the runner is not ready yet - file_archives = await _get_job_file_archives( - session=session, - archive_mappings=job.job_spec.file_archives, - user=run_model.user, - ) - code = await _get_job_code( - session=session, - project=project, - repo=repo_model, - code_hash=_get_repo_code_hash(run, job), - ) - success = await common_utils.run_async( - _process_pulling_with_shim, - server_ssh_private_keys, - job_provisioning_data, - None, - session, - run, - job_model, - job, - cluster_info, - code, - file_archives, - secrets, - repo_creds, - server_ssh_private_keys, - job_provisioning_data, - ) - else: - logger.debug("%s: process running job, age=%s", fmt(job_model), job_submission.age) - success = await common_utils.run_async( - _process_running, - server_ssh_private_keys, - job_provisioning_data, - job_submission.job_runtime_data, - session, - run_model, - job_model, + "%s: waiting for all jobs in the replica to be provisioned", + fmt(context.job_model), ) + await _mark_job_processed(session=session, job_model=context.job_model) + return None + + cluster_info = _get_cluster_info( + jobs=context.run.jobs, + replica_num=context.job.job_spec.replica_num, + job_provisioning_data=job_provisioning_data, + job_runtime_data=context.job_submission.job_runtime_data, + ) - if success: - _reset_disconnected_at(session, job_model) - else: - if job_model.termination_reason: - logger.warning( - "%s: failed due to %s, age=%s", - fmt(job_model), - job_model.termination_reason.value, - job_submission.age, - ) - switch_job_status(session, job_model, JobStatus.TERMINATING) - # job will be terminated and instance will be emptied by process_terminating_jobs - else: - # No job_model.termination_reason set means ssh connection failed - _set_disconnected_at_now(session, job_model) - if _should_terminate_job_due_to_disconnect(job_model): - if job_provisioning_data.instance_type.resources.spot: - job_model.termination_reason = ( - JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY - ) - else: - job_model.termination_reason = JobTerminationReason.INSTANCE_UNREACHABLE - job_model.termination_reason_message = "Instance is unreachable" - switch_job_status(session, job_model, JobStatus.TERMINATING) - else: - logger.warning( - "%s: is unreachable, waiting for the instance to become reachable again, age=%s", - fmt(job_model), - job_submission.age, - ) - - if initial_status != job_model.status and job_model.status == JobStatus.RUNNING: - job_model.probes = [] - for probe_num in range(len(job.job_spec.probes)): - job_model.probes.append( - ProbeModel( - name=f"{job_model.job_name}-{probe_num}", - probe_num=probe_num, - due=common_utils.get_current_datetime(), - success_streak=0, - active=True, - ) - ) + volumes = await get_job_attached_volumes( + session=session, + project=context.project, + run_spec=context.run.run_spec, + job_num=context.job.job_spec.job_num, + job_provisioning_data=job_provisioning_data, + ) - if job_model.status == JobStatus.RUNNING: - await _maybe_register_replica(session, run_model, run, job_model, job.job_spec.probes) - if job_model.status == JobStatus.RUNNING: - await _check_gpu_utilization(session, job_model, job) + repo_creds_model = await get_repo_creds( + session=session, + repo=context.repo_model, + user=context.run_model.user, + ) + repo_creds = repo_model_to_repo_head_with_creds( + context.repo_model, + repo_creds_model, + ).repo_creds - job_model.last_processed_at = common_utils.get_current_datetime() - await session.commit() + secrets = await get_project_secrets_mapping(session=session, project=context.project) + try: + _interpolate_secrets(secrets, context.job.job_spec) + except InterpolatorError as e: + await _terminate_running_job( + session=session, + job_model=context.job_model, + termination_reason=JobTerminationReason.TERMINATED_BY_SERVER, + termination_reason_message=f"Secrets interpolation error: {e.args[0]}", + ) + return None + + return _RunningJobStartupContext( + cluster_info=cluster_info, + volumes=volumes, + secrets=secrets, + repo_creds=repo_creds, + ) async def _refetch_job_model(session: AsyncSession, job_model: JobModel) -> JobModel: @@ -456,6 +356,292 @@ async def _fetch_run_model(session: AsyncSession, run_id: uuid.UUID) -> RunModel return res.unique().scalar_one() +async def _process_running_job_provisioning_state( + session: AsyncSession, + context: _RunningJobContext, + startup_context: _RunningJobStartupContext, +) -> None: + job_provisioning_data = common_utils.get_or_error(context.job_provisioning_data) + server_ssh_private_keys = common_utils.get_or_error(context.server_ssh_private_keys) + + if job_provisioning_data.hostname is None: + await _wait_for_instance_provisioning_data(session, context.job_model) + return + if _should_wait_for_other_nodes(context.run, context.job, context.job_model): + return + + # fails are acceptable until timeout is exceeded + success = False + if job_provisioning_data.dockerized: + logger.debug( + "%s: process provisioning job with shim, age=%s", + fmt(context.job_model), + context.job_submission.age, + ) + ssh_user = job_provisioning_data.username + assert context.run.run_spec.ssh_key_pub is not None + user_ssh_key = context.run.run_spec.ssh_key_pub.strip() + public_keys = [context.project.ssh_public_key.strip(), user_ssh_key] + if job_provisioning_data.backend == BackendType.LOCAL: + # No need to update ~/.ssh/authorized_keys when running shim locally + user_ssh_key = "" + success = await common_utils.run_async( + _process_provisioning_with_shim, + server_ssh_private_keys, + job_provisioning_data, + None, + session=session, + run=context.run, + job_model=context.job_model, + jpd=job_provisioning_data, + volumes=startup_context.volumes, + registry_auth=context.job.job_spec.registry_auth, + public_keys=public_keys, + ssh_user=ssh_user, + ssh_key=user_ssh_key, + ) + else: + logger.debug( + "%s: process provisioning job without shim, age=%s", + fmt(context.job_model), + context.job_submission.age, + ) + runner_availability = await common_utils.run_async( + _get_runner_availability, + server_ssh_private_keys, + job_provisioning_data, + None, + ) + if runner_availability == _RunnerAvailability.AVAILABLE: + file_archives = await _get_job_file_archives( + session=session, + archive_mappings=context.job.job_spec.file_archives, + user=context.run_model.user, + ) + code = await _get_job_code( + session=session, + project=context.project, + repo=context.repo_model, + code_hash=_get_repo_code_hash(context.run, context.job), + ) + success = await common_utils.run_async( + _submit_job_to_runner, + server_ssh_private_keys, + job_provisioning_data, + None, + session=session, + run=context.run, + job_model=context.job_model, + job=context.job, + cluster_info=startup_context.cluster_info, + code=code, + file_archives=file_archives, + secrets=startup_context.secrets, + repo_credentials=startup_context.repo_creds, + success_if_not_available=False, + ) + + if success: + return + + # check timeout + provisioning_timeout = get_provisioning_timeout( + backend_type=job_provisioning_data.get_base_backend(), + instance_type_name=job_provisioning_data.instance_type.name, + ) + if context.job_submission.age > provisioning_timeout: + context.job_model.termination_reason = JobTerminationReason.WAITING_RUNNER_LIMIT_EXCEEDED + context.job_model.termination_reason_message = ( + f"Runner did not become available within {provisioning_timeout.total_seconds()}s." + f" Job submission age: {context.job_submission.age.total_seconds()}s)" + ) + switch_job_status(session, context.job_model, JobStatus.TERMINATING) + # instance will be emptied by process_terminating_jobs + + +async def _process_running_job_pulling_state( + session: AsyncSession, + context: _RunningJobContext, + startup_context: _RunningJobStartupContext, +) -> None: + job_provisioning_data = common_utils.get_or_error(context.job_provisioning_data) + server_ssh_private_keys = common_utils.get_or_error(context.server_ssh_private_keys) + + logger.debug( + "%s: process pulling job with shim, age=%s", + fmt(context.job_model), + context.job_submission.age, + ) + shim_state = await common_utils.run_async( + _sync_shim_pulling_state, + server_ssh_private_keys, + job_provisioning_data, + None, + job_model=context.job_model, + ) + if shim_state == _ShimPullingState.WAITING: + _reset_disconnected_at(session, context.job_model) + return + + if shim_state == _ShimPullingState.READY: + job_runtime_data = get_job_runtime_data(context.job_model) + runner_availability = await common_utils.run_async( + _get_runner_availability, + server_ssh_private_keys, + job_provisioning_data, + job_runtime_data, + ) + if runner_availability == _RunnerAvailability.UNAVAILABLE: + _reset_disconnected_at(session, context.job_model) + return + + if runner_availability == _RunnerAvailability.AVAILABLE: + file_archives = await _get_job_file_archives( + session=session, + archive_mappings=context.job.job_spec.file_archives, + user=context.run_model.user, + ) + code = await _get_job_code( + session=session, + project=context.project, + repo=context.repo_model, + code_hash=_get_repo_code_hash(context.run, context.job), + ) + success = await common_utils.run_async( + _submit_job_to_runner, + server_ssh_private_keys, + job_provisioning_data, + job_runtime_data, + session=session, + run=context.run, + job_model=context.job_model, + job=context.job, + cluster_info=startup_context.cluster_info, + code=code, + file_archives=file_archives, + secrets=startup_context.secrets, + repo_credentials=startup_context.repo_creds, + success_if_not_available=True, + ) + if success: + _reset_disconnected_at(session, context.job_model) + return + + if context.job_model.termination_reason: + logger.warning( + "%s: failed due to %s, age=%s", + fmt(context.job_model), + context.job_model.termination_reason.value, + context.job_submission.age, + ) + switch_job_status(session, context.job_model, JobStatus.TERMINATING) + # job will be terminated and instance will be emptied by process_terminating_jobs + return + + # No job_model.termination_reason set means ssh connection failed + _set_disconnected_at_now(session, context.job_model) + if not _should_terminate_job_due_to_disconnect(context.job_model): + logger.warning( + "%s: is unreachable, waiting for the instance to become reachable again, age=%s", + fmt(context.job_model), + context.job_submission.age, + ) + return + + if job_provisioning_data.instance_type.resources.spot: + context.job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY + else: + context.job_model.termination_reason = JobTerminationReason.INSTANCE_UNREACHABLE + context.job_model.termination_reason_message = "Instance is unreachable" + switch_job_status(session, context.job_model, JobStatus.TERMINATING) + + +async def _process_running_job_running_state( + session: AsyncSession, + context: _RunningJobContext, +) -> None: + job_provisioning_data = common_utils.get_or_error(context.job_provisioning_data) + server_ssh_private_keys = common_utils.get_or_error(context.server_ssh_private_keys) + + logger.debug( + "%s: process running job, age=%s", + fmt(context.job_model), + context.job_submission.age, + ) + success = await common_utils.run_async( + _process_running, + server_ssh_private_keys, + job_provisioning_data, + context.job_submission.job_runtime_data, + session=session, + run_model=context.run_model, + job_model=context.job_model, + ) + + if success: + _reset_disconnected_at(session, context.job_model) + return + + if context.job_model.termination_reason: + logger.warning( + "%s: failed due to %s, age=%s", + fmt(context.job_model), + context.job_model.termination_reason.value, + context.job_submission.age, + ) + switch_job_status(session, context.job_model, JobStatus.TERMINATING) + # job will be terminated and instance will be emptied by process_terminating_jobs + return + + # No job_model.termination_reason set means ssh connection failed + _set_disconnected_at_now(session, context.job_model) + if not _should_terminate_job_due_to_disconnect(context.job_model): + logger.warning( + "%s: is unreachable, waiting for the instance to become reachable again, age=%s", + fmt(context.job_model), + context.job_submission.age, + ) + return + + if job_provisioning_data.instance_type.resources.spot: + context.job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY + else: + context.job_model.termination_reason = JobTerminationReason.INSTANCE_UNREACHABLE + context.job_model.termination_reason_message = "Instance is unreachable" + switch_job_status(session, context.job_model, JobStatus.TERMINATING) + + +async def _mark_job_processed(session: AsyncSession, job_model: JobModel) -> None: + job_model.last_processed_at = common_utils.get_current_datetime() + await session.commit() + + +async def _terminate_running_job( + session: AsyncSession, + job_model: JobModel, + termination_reason: JobTerminationReason, + termination_reason_message: str, +) -> None: + job_model.termination_reason = termination_reason + job_model.termination_reason_message = termination_reason_message + switch_job_status(session, job_model, JobStatus.TERMINATING) + await _mark_job_processed(session=session, job_model=job_model) + + +def _initialize_running_job_probes(job_model: JobModel, job: Job) -> None: + job_model.probes = [] + for probe_num in range(len(job.job_spec.probes)): + job_model.probes.append( + ProbeModel( + name=f"{job_model.job_name}-{probe_num}", + probe_num=probe_num, + due=common_utils.get_current_datetime(), + success_streak=0, + active=True, + ) + ) + + async def _wait_for_instance_provisioning_data(session: AsyncSession, job_model: JobModel): """ This function will be called until instance IP address appears @@ -531,7 +717,7 @@ def _process_provisioning_with_shim( session: AsyncSession, run: Run, job_model: JobModel, - job_provisioning_data: JobProvisioningData, + jpd: JobProvisioningData, volumes: List[Volume], registry_auth: Optional[RegistryAuth], public_keys: List[str], @@ -576,13 +762,9 @@ def _process_provisioning_with_shim( for volume, volume_mount in zip(volumes, volume_mounts): volume_mount.name = volume.name - instance_mounts += _get_instance_specific_mounts( - job_provisioning_data.backend, job_provisioning_data.instance_type.name - ) + instance_mounts += get_instance_specific_mounts(jpd.backend, jpd.instance_type.name) - gpu_devices = _get_instance_specific_gpu_devices( - job_provisioning_data.backend, job_provisioning_data.instance_type.name - ) + gpu_devices = get_instance_specific_gpu_devices(jpd.backend, jpd.instance_type.name) container_user = "root" @@ -599,7 +781,7 @@ def _process_provisioning_with_shim( cpu = None memory = None network_mode = NetworkMode.HOST - image_name = _patch_base_image_for_aws_efa(job_spec, job_provisioning_data) + image_name = resolve_provisioning_image_name(job_spec, jpd) if shim_client.is_api_v2_supported(): shim_client.submit_task( task_id=job_model.id, @@ -621,7 +803,7 @@ def _process_provisioning_with_shim( host_ssh_user=ssh_user, host_ssh_keys=[ssh_key] if ssh_key else [], container_ssh_keys=public_keys, - instance_id=job_provisioning_data.instance_id, + instance_id=jpd.instance_id, ) else: submitted = shim_client.submit( @@ -638,7 +820,7 @@ def _process_provisioning_with_shim( mounts=volume_mounts, volumes=volumes, instance_mounts=instance_mounts, - instance_id=job_provisioning_data.instance_id, + instance_id=jpd.instance_id, ) if not submitted: # This can happen when we lost connection to the runner (e.g., network issues), marked @@ -659,31 +841,30 @@ def _process_provisioning_with_shim( return True +class _RunnerAvailability(enum.Enum): + AVAILABLE = "available" + UNAVAILABLE = "unavailable" + + +class _ShimPullingState(enum.Enum): + WAITING = "waiting" + READY = "ready" + + +@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1) +def _get_runner_availability(ports: Dict[int, int]) -> _RunnerAvailability: + runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT]) + if runner_client.healthcheck() is None: + return _RunnerAvailability.UNAVAILABLE + return _RunnerAvailability.AVAILABLE + + @runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT]) -def _process_pulling_with_shim( +def _sync_shim_pulling_state( ports: Dict[int, int], - session: AsyncSession, - run: Run, job_model: JobModel, - job: Job, - cluster_info: ClusterInfo, - code: bytes, - file_archives: Iterable[tuple[uuid.UUID, bytes]], - secrets: Dict[str, str], - repo_credentials: Optional[RemoteRepoCreds], - server_ssh_private_keys: tuple[str, Optional[str]], - job_provisioning_data: JobProvisioningData, -) -> bool: - """ - Possible next states: - - JobStatus.RUNNING if runner is available - - JobStatus.TERMINATING if shim is not available - - Returns: - is successful - """ +) -> Union[Literal[False], _ShimPullingState]: shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) - job_runtime_data = None if shim_client.is_api_v2_supported(): # raises error if shim is down, causes retry task = shim_client.get_task(job_model.id) @@ -701,7 +882,7 @@ def _process_pulling_with_shim( return False if task.status != TaskStatus.RUNNING: - return True + return _ShimPullingState.WAITING job_runtime_data = get_job_runtime_data(job_model) # should check for None, as there may be older jobs submitted before @@ -709,10 +890,9 @@ def _process_pulling_with_shim( if job_runtime_data is not None: # port mapping is not yet available, waiting if task.ports is None: - return True + return _ShimPullingState.WAITING job_runtime_data.ports = {pm.container: pm.host for pm in task.ports} job_model.job_runtime_data = job_runtime_data.json() - else: shim_status = shim_client.pull() # raises error if shim is down, causes retry @@ -734,23 +914,9 @@ def _process_pulling_with_shim( return False if shim_status.state in ("pulling", "creating"): - return True + return _ShimPullingState.WAITING - return _submit_job_to_runner( - server_ssh_private_keys, - job_provisioning_data, - job_runtime_data, - session=session, - run=run, - job_model=job_model, - job=job, - cluster_info=cluster_info, - code=code, - file_archives=file_archives, - secrets=secrets, - repo_credentials=repo_credentials, - success_if_not_available=True, - ) + return _ShimPullingState.READY @runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT]) @@ -1138,105 +1304,3 @@ def _interpolate_secrets(secrets: Dict[str, str], job_spec: JobSpec): username=interpolate(job_spec.registry_auth.username), password=interpolate(job_spec.registry_auth.password), ) - - -def _get_instance_specific_mounts( - backend_type: BackendType, instance_type_name: str -) -> List[InstanceMountPoint]: - if backend_type == BackendType.GCP: - if instance_type_name == "a3-megagpu-8g": - return [ - InstanceMountPoint( - instance_path="/dev/aperture_devices", - path="/dev/aperture_devices", - ), - InstanceMountPoint( - instance_path="/var/lib/tcpxo/lib64", - path="/var/lib/tcpxo/lib64", - ), - InstanceMountPoint( - instance_path="/var/lib/fastrak/lib64", - path="/var/lib/fastrak/lib64", - ), - ] - if instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]: - return [ - InstanceMountPoint( - instance_path="/var/lib/nvidia/lib64", - path="/usr/local/nvidia/lib64", - ), - InstanceMountPoint( - instance_path="/var/lib/nvidia/bin", - path="/usr/local/nvidia/bin", - ), - InstanceMountPoint( - instance_path="/var/lib/tcpx/lib64", - path="/usr/local/tcpx/lib64", - ), - InstanceMountPoint( - instance_path="/run/tcpx", - path="/run/tcpx", - ), - ] - return [] - - -def _get_instance_specific_gpu_devices( - backend_type: BackendType, instance_type_name: str -) -> List[GPUDevice]: - gpu_devices = [] - if backend_type == BackendType.GCP and instance_type_name in [ - "a3-edgegpu-8g", - "a3-highgpu-8g", - ]: - for i in range(8): - gpu_devices.append( - GPUDevice(path_on_host=f"/dev/nvidia{i}", path_in_container=f"/dev/nvidia{i}") - ) - gpu_devices.append( - GPUDevice(path_on_host="/dev/nvidia-uvm", path_in_container="/dev/nvidia-uvm") - ) - gpu_devices.append( - GPUDevice(path_on_host="/dev/nvidiactl", path_in_container="/dev/nvidiactl") - ) - return gpu_devices - - -def _patch_base_image_for_aws_efa( - job_spec: JobSpec, job_provisioning_data: JobProvisioningData -) -> str: - image_name = job_spec.image_name - - if job_provisioning_data.backend != BackendType.AWS: - return image_name - - instance_type = job_provisioning_data.instance_type.name - efa_enabled_patterns = [ - # TODO: p6-b200 isn't supported yet in gpuhunt - r"^p6-b200\.(48xlarge)$", - r"^p5\.(4xlarge|48xlarge)$", - r"^p5e\.(48xlarge)$", - r"^p5en\.(48xlarge)$", - r"^p4d\.(24xlarge)$", - r"^p4de\.(24xlarge)$", - r"^g6\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$", - r"^g6e\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$", - r"^gr6\.8xlarge$", - r"^g5\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$", - r"^g4dn\.(8xlarge|12xlarge|16xlarge|metal)$", - r"^p3dn\.(24xlarge)$", - ] - - is_efa_enabled = any(re.match(pattern, instance_type) for pattern in efa_enabled_patterns) - if not is_efa_enabled: - return image_name - - if not image_name.startswith(f"{settings.DSTACK_BASE_IMAGE}:"): - return image_name - - if image_name.endswith(f"-base-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"): - return image_name[:-17] + f"-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" - elif image_name.endswith(f"-devel-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"): - return image_name[:-18] + f"-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" - - return image_name diff --git a/src/dstack/_internal/server/services/backends/provisioning.py b/src/dstack/_internal/server/services/backends/provisioning.py new file mode 100644 index 0000000000..13fd362b34 --- /dev/null +++ b/src/dstack/_internal/server/services/backends/provisioning.py @@ -0,0 +1,122 @@ +import re + +from dstack._internal import settings +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.runs import JobProvisioningData, JobSpec +from dstack._internal.core.models.volumes import InstanceMountPoint +from dstack._internal.server.schemas.runner import GPUDevice + +_AWS_EFA_ENABLED_INSTANCE_TYPE_PATTERNS = [ + # TODO: p6-b200 isn't supported yet in gpuhunt + r"^p6-b200\.(48xlarge)$", + r"^p5\.(4xlarge|48xlarge)$", + r"^p5e\.(48xlarge)$", + r"^p5en\.(48xlarge)$", + r"^p4d\.(24xlarge)$", + r"^p4de\.(24xlarge)$", + r"^g6\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$", + r"^g6e\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$", + r"^gr6\.8xlarge$", + r"^g5\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$", + r"^g4dn\.(8xlarge|12xlarge|16xlarge|metal)$", + r"^p3dn\.(24xlarge)$", +] + + +def get_instance_specific_mounts( + backend_type: BackendType, + instance_type_name: str, +) -> list[InstanceMountPoint]: + if backend_type == BackendType.GCP: + if instance_type_name == "a3-megagpu-8g": + return [ + InstanceMountPoint( + instance_path="/dev/aperture_devices", + path="/dev/aperture_devices", + ), + InstanceMountPoint( + instance_path="/var/lib/tcpxo/lib64", + path="/var/lib/tcpxo/lib64", + ), + InstanceMountPoint( + instance_path="/var/lib/fastrak/lib64", + path="/var/lib/fastrak/lib64", + ), + ] + if instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]: + return [ + InstanceMountPoint( + instance_path="/var/lib/nvidia/lib64", + path="/usr/local/nvidia/lib64", + ), + InstanceMountPoint( + instance_path="/var/lib/nvidia/bin", + path="/usr/local/nvidia/bin", + ), + InstanceMountPoint( + instance_path="/var/lib/tcpx/lib64", + path="/usr/local/tcpx/lib64", + ), + InstanceMountPoint( + instance_path="/run/tcpx", + path="/run/tcpx", + ), + ] + return [] + + +def get_instance_specific_gpu_devices( + backend_type: BackendType, + instance_type_name: str, +) -> list[GPUDevice]: + gpu_devices = [] + if backend_type == BackendType.GCP and instance_type_name in [ + "a3-edgegpu-8g", + "a3-highgpu-8g", + ]: + for i in range(8): + gpu_devices.append( + GPUDevice(path_on_host=f"/dev/nvidia{i}", path_in_container=f"/dev/nvidia{i}") + ) + gpu_devices.append( + GPUDevice(path_on_host="/dev/nvidia-uvm", path_in_container="/dev/nvidia-uvm") + ) + gpu_devices.append( + GPUDevice(path_on_host="/dev/nvidiactl", path_in_container="/dev/nvidiactl") + ) + return gpu_devices + + +def resolve_provisioning_image_name( + job_spec: JobSpec, + job_provisioning_data: JobProvisioningData, +) -> str: + image_name = job_spec.image_name + if job_provisioning_data.backend == BackendType.AWS: + return _patch_base_image_for_aws_efa( + image_name, + job_provisioning_data.instance_type.name, + ) + return image_name + + +def _patch_base_image_for_aws_efa( + image_name: str, + instance_type_name: str, +) -> str: + is_efa_enabled = any( + re.match(pattern, instance_type_name) + for pattern in _AWS_EFA_ENABLED_INSTANCE_TYPE_PATTERNS + ) + if not is_efa_enabled: + return image_name + + if not image_name.startswith(f"{settings.DSTACK_BASE_IMAGE}:"): + return image_name + + if image_name.endswith(f"-base-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"): + return image_name[:-17] + f"-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" + if image_name.endswith(f"-devel-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"): + return image_name[:-18] + f"-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" + + return image_name diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py index 3310cda996..b73c9bbe62 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/base.py +++ b/src/dstack/_internal/server/services/jobs/configurators/base.py @@ -77,7 +77,8 @@ def get_default_python_verison() -> str: def get_default_image(nvcc: bool = False) -> str: """ Note: May be overridden by dstack (e.g., EFA-enabled version for AWS EFA-capable instances). - See `dstack._internal.server.background.scheduled_tasks.running_jobs._patch_base_image_for_aws_efa` for details. + See `dstack._internal.server.services.backends.provisioning.resolve_provisioning_image_name` + for details. Args: nvcc: If True, returns 'devel' variant, otherwise 'base'. diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py index aad8615bf3..80a18dc11d 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py @@ -2,7 +2,7 @@ from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Optional -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest from freezegun import freeze_time @@ -19,16 +19,12 @@ ProbeConfig, ServiceConfiguration, ) -from dstack._internal.core.models.instances import InstanceStatus, InstanceType +from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.core.models.profiles import StartupOrder, UtilizationPolicy -from dstack._internal.core.models.resources import ResourcesSpec from dstack._internal.core.models.runs import ( - JobProvisioningData, JobRuntimeData, - JobSpec, JobStatus, JobTerminationReason, - Requirements, RunStatus, ) from dstack._internal.core.models.volumes import ( @@ -38,7 +34,7 @@ ) from dstack._internal.server import settings as server_settings from dstack._internal.server.background.scheduled_tasks.running_jobs import ( - _patch_base_image_for_aws_efa, + _RunnerAvailability, process_running_jobs, ) from dstack._internal.server.models import JobModel @@ -150,6 +146,14 @@ async def test_leaves_provisioning_job_unchanged_if_runner_not_alive( patch( "dstack._internal.server.services.runner.client.RunnerClient" ) as RunnerClientMock, + patch( + "dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_file_archives", + new_callable=AsyncMock, + ) as get_job_file_archives_mock, + patch( + "dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_code", + new_callable=AsyncMock, + ) as get_job_code_mock, patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock, ): datetime_mock.return_value = datetime(2023, 1, 2, 5, 12, 30, 10, tzinfo=timezone.utc) @@ -159,6 +163,8 @@ async def test_leaves_provisioning_job_unchanged_if_runner_not_alive( await process_running_jobs() SSHTunnelMock.assert_called_once() runner_client_mock.healthcheck.assert_called_once() + get_job_file_archives_mock.assert_not_awaited() + get_job_code_mock.assert_not_awaited() await session.refresh(job) assert job is not None assert job.status == JobStatus.PROVISIONING @@ -207,8 +213,8 @@ async def test_runs_provisioning_job(self, test_db, session: AsyncSession): working_dir="/dstack/run", username="dstack" ) await process_running_jobs() - SSHTunnelMock.assert_called_once() - runner_client_mock.healthcheck.assert_called_once() + assert SSHTunnelMock.call_count == 2 + assert runner_client_mock.healthcheck.call_count == 2 runner_client_mock.submit_job.assert_called_once() runner_client_mock.upload_code.assert_called_once() runner_client_mock.run_job.assert_called_once() @@ -430,9 +436,9 @@ async def test_pulling_shim( await process_running_jobs() - assert ssh_tunnel_mock.call_count == 2 + assert ssh_tunnel_mock.call_count == 3 shim_client_mock.get_task.assert_called_once() - runner_client_mock.healthcheck.assert_called_once() + assert runner_client_mock.healthcheck.call_count == 2 runner_client_mock.submit_job.assert_called_once() runner_client_mock.upload_code.assert_called_once() runner_client_mock.run_job.assert_called_once() @@ -484,15 +490,175 @@ async def test_pulling_shim_port_mapping_not_ready( shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING shim_client_mock.get_task.return_value.ports = None - await process_running_jobs() + with ( + patch( + "dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_file_archives", + new_callable=AsyncMock, + ) as get_job_file_archives_mock, + patch( + "dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_code", + new_callable=AsyncMock, + ) as get_job_code_mock, + ): + await process_running_jobs() + + ssh_tunnel_mock.assert_called_once() + shim_client_mock.get_task.assert_called_once() + runner_client_mock.healthcheck.assert_not_called() + runner_client_mock.submit_job.assert_not_called() + get_job_file_archives_mock.assert_not_awaited() + get_job_code_mock.assert_not_awaited() + await session.refresh(job) + assert job is not None + assert job.status == JobStatus.PULLING + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_pulling_shim_runner_not_ready( + self, + test_db, + session: AsyncSession, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + runner_client_mock: Mock, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + job_provisioning_data=get_job_provisioning_data(dockerized=True), + job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + shim_client_mock.get_task.return_value.ports = [ + PortMapping(container=10022, host=32771), + PortMapping(container=10999, host=32772), + ] + runner_client_mock.healthcheck.return_value = None + + with ( + patch( + "dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_file_archives", + new_callable=AsyncMock, + ) as get_job_file_archives_mock, + patch( + "dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_code", + new_callable=AsyncMock, + ) as get_job_code_mock, + ): + await process_running_jobs() + + assert ssh_tunnel_mock.call_count == 2 + shim_client_mock.get_task.assert_called_once() + runner_client_mock.healthcheck.assert_called_once() + runner_client_mock.submit_job.assert_not_called() + get_job_file_archives_mock.assert_not_awaited() + get_job_code_mock.assert_not_awaited() + + await session.refresh(job) + assert job is not None + assert job.status == JobStatus.PULLING + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_pulling_shim_uses_runtime_port_mapping_for_runner_calls( + self, + test_db, + session: AsyncSession, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + job_provisioning_data=get_job_provisioning_data(dockerized=True), + job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + shim_client_mock.get_task.return_value.ports = [ + PortMapping(container=10022, host=32771), + PortMapping(container=10999, host=32772), + ] + + expected_ports = { + 10022: 32771, + 10999: 32772, + } + + def assert_runner_availability(_, __, job_runtime_data): + assert job_runtime_data is not None + assert job_runtime_data.ports == expected_ports + return _RunnerAvailability.AVAILABLE + + def assert_submit_job_to_runner(_, __, job_runtime_data, **kwargs): + assert job_runtime_data is not None + assert job_runtime_data.ports == expected_ports + return True + + with ( + patch( + "dstack._internal.server.background.scheduled_tasks.running_jobs._get_runner_availability", + side_effect=assert_runner_availability, + ) as get_runner_availability_mock, + patch( + "dstack._internal.server.background.scheduled_tasks.running_jobs._submit_job_to_runner", + side_effect=assert_submit_job_to_runner, + ) as submit_job_to_runner_mock, + patch( + "dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_file_archives", + new_callable=AsyncMock, + return_value=[], + ), + patch( + "dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_code", + new_callable=AsyncMock, + return_value=b"", + ), + ): + await process_running_jobs() + + ssh_tunnel_mock.assert_called_once() + get_runner_availability_mock.assert_called_once() + submit_job_to_runner_mock.assert_called_once() - ssh_tunnel_mock.assert_called_once() - shim_client_mock.get_task.assert_called_once() - runner_client_mock.healthcheck.assert_not_called() - runner_client_mock.submit_job.assert_not_called() await session.refresh(job) assert job is not None assert job.status == JobStatus.PULLING + jrd = JobRuntimeData.__response__.parse_raw(job.job_runtime_data) + assert jrd.ports == expected_ports @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @@ -1131,129 +1297,3 @@ async def test_registers_service_replica_only_after_probes_pass( else: assert not job.registered assert not events - - -class TestPatchBaseImageForAwsEfa: - @staticmethod - def _create_job_spec(image_name: str) -> "JobSpec": - return JobSpec( - job_num=0, - job_name="test-job", - commands=["echo hello"], - env={}, - image_name=image_name, - requirements=Requirements(resources=ResourcesSpec()), - ) - - @staticmethod - def _create_job_provisioning_data_with_instance_type( - backend: BackendType, instance_type: str - ) -> JobProvisioningData: - job_provisioning_data = get_job_provisioning_data(backend=backend) - job_provisioning_data.instance_type = InstanceType( - name=instance_type, - resources=job_provisioning_data.instance_type.resources, - ) - return job_provisioning_data - - @staticmethod - def _call_patch_base_image_for_aws_efa( - image_name: str, backend: BackendType, instance_type: str - ) -> str: - job_spec = TestPatchBaseImageForAwsEfa._create_job_spec(image_name) - job_provisioning_data = ( - TestPatchBaseImageForAwsEfa._create_job_provisioning_data_with_instance_type( - backend, instance_type - ) - ) - return _patch_base_image_for_aws_efa(job_spec, job_provisioning_data) - - @pytest.mark.parametrize( - "suffix,instance_type", - [ - ("-base", "p6-b200.48xlarge"), - ("-devel", "p5.48xlarge"), - ], - ) - def test_patch_aws_efa_instance_with_suffix(self, suffix: str, instance_type: str): - image_name = f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}{suffix}-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" - result = self._call_patch_base_image_for_aws_efa( - image_name, BackendType.AWS, instance_type - ) - expected = f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" - assert result == expected - - @pytest.mark.parametrize("suffix", ["-base", "-devel"]) - @pytest.mark.parametrize( - "instance_type", - [ - "p5.48xlarge", - "p5e.48xlarge", - "p4d.24xlarge", - "p4de.24xlarge", - "g6.8xlarge", - "g6e.8xlarge", - ], - ) - def test_patch_all_efa_instance_types(self, instance_type: str, suffix: str): - image_name = f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}{suffix}-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" - result = self._call_patch_base_image_for_aws_efa( - image_name, BackendType.AWS, instance_type - ) - expected = f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" - assert result == expected - - @pytest.mark.parametrize("suffix", ["-base", "-devel"]) - @pytest.mark.parametrize( - "backend", - [BackendType.GCP, BackendType.AZURE, BackendType.LAMBDA, BackendType.LOCAL], - ) - @pytest.mark.parametrize( - "instance_type", - [ - "standard-4", - "p5.xlarge", - "p6.2xlarge", - "g6.xlarge", - ], # Mix of generic and EFA-named types - ) - def test_no_patch_non_aws_backends( - self, backend: BackendType, suffix: str, instance_type: str - ): - image_name = f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}{suffix}-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" - result = self._call_patch_base_image_for_aws_efa(image_name, backend, instance_type) - assert result == image_name - - @pytest.mark.parametrize("suffix", ["-base", "-devel"]) - @pytest.mark.parametrize( - "instance_type", - ["t3.micro", "m5.large", "c5.xlarge", "r5.2xlarge", "m6i.large", "g6.xlarge"], - ) - def test_no_patch_non_efa_aws_instances(self, instance_type: str, suffix: str): - image_name = f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}{suffix}" - result = self._call_patch_base_image_for_aws_efa( - image_name, BackendType.AWS, instance_type - ) - assert result == image_name - - @pytest.mark.parametrize( - "instance_type", - ["p5.xlarge", "p6.2xlarge", "t3.micro", "m5.large"], # Mix of EFA and non-EFA instances - ) - @pytest.mark.parametrize( - "image_name", - [ - "ubuntu:20.04", - "nvidia/cuda:11.8-runtime-ubuntu20.04", - "python:3.9-slim", - "custom/image:latest", - f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}-custom", - f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}-devel-efa", - f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}", - ], - ) - def test_no_patch_other_images(self, instance_type: str, image_name: str): - result = self._call_patch_base_image_for_aws_efa( - image_name, BackendType.AWS, instance_type - ) - assert result == image_name diff --git a/src/tests/_internal/server/services/backends/test_provisioning.py b/src/tests/_internal/server/services/backends/test_provisioning.py new file mode 100644 index 0000000000..3145a56afa --- /dev/null +++ b/src/tests/_internal/server/services/backends/test_provisioning.py @@ -0,0 +1,168 @@ +import pytest + +from dstack._internal import settings +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import InstanceType +from dstack._internal.core.models.resources import ResourcesSpec +from dstack._internal.core.models.runs import JobProvisioningData, JobSpec, Requirements +from dstack._internal.server.services.backends.provisioning import ( + resolve_provisioning_image_name, +) +from dstack._internal.server.testing.common import get_job_provisioning_data + + +class TestResolveProvisioningImageName: + @staticmethod + def _create_job_spec(image_name: str) -> JobSpec: + return JobSpec( + job_num=0, + job_name="test-job", + app_specs=None, + commands=["echo hello"], + env={}, + home_dir=None, + image_name=image_name, + max_duration=None, + registry_auth=None, + requirements=Requirements(resources=ResourcesSpec()), + retry=None, + working_dir=None, + ) + + @staticmethod + def _create_job_provisioning_data_with_instance_type( + backend: BackendType, + instance_type: str, + ) -> JobProvisioningData: + job_provisioning_data = get_job_provisioning_data(backend=backend) + job_provisioning_data.instance_type = InstanceType( + name=instance_type, + resources=job_provisioning_data.instance_type.resources, + ) + return job_provisioning_data + + @staticmethod + def _call_resolve_provisioning_image_name( + image_name: str, + backend: BackendType, + instance_type: str, + ) -> str: + job_spec = TestResolveProvisioningImageName._create_job_spec(image_name) + job_provisioning_data = ( + TestResolveProvisioningImageName._create_job_provisioning_data_with_instance_type( + backend, + instance_type, + ) + ) + return resolve_provisioning_image_name(job_spec, job_provisioning_data) + + @pytest.mark.parametrize( + ("suffix", "instance_type"), + [ + ("-base", "p6-b200.48xlarge"), + ("-devel", "p5.48xlarge"), + ], + ) + def test_patch_aws_efa_instance_with_suffix(self, suffix: str, instance_type: str) -> None: + image_name = ( + f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}{suffix}" + f"-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" + ) + result = self._call_resolve_provisioning_image_name( + image_name, + BackendType.AWS, + instance_type, + ) + expected = ( + f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}" + f"-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" + ) + assert result == expected + + @pytest.mark.parametrize("suffix", ["-base", "-devel"]) + @pytest.mark.parametrize( + "instance_type", + [ + "p5.48xlarge", + "p5e.48xlarge", + "p4d.24xlarge", + "p4de.24xlarge", + "g6.8xlarge", + "g6e.8xlarge", + ], + ) + def test_patch_all_efa_instance_types(self, instance_type: str, suffix: str) -> None: + image_name = ( + f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}{suffix}" + f"-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" + ) + result = self._call_resolve_provisioning_image_name( + image_name, + BackendType.AWS, + instance_type, + ) + expected = ( + f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}" + f"-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" + ) + assert result == expected + + @pytest.mark.parametrize("suffix", ["-base", "-devel"]) + @pytest.mark.parametrize( + "backend", + [BackendType.GCP, BackendType.AZURE, BackendType.LAMBDA, BackendType.LOCAL], + ) + @pytest.mark.parametrize( + "instance_type", + ["standard-4", "p5.xlarge", "p6.2xlarge", "g6.xlarge"], + ) + def test_no_patch_non_aws_backends( + self, + backend: BackendType, + suffix: str, + instance_type: str, + ) -> None: + image_name = ( + f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}{suffix}" + f"-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" + ) + result = self._call_resolve_provisioning_image_name(image_name, backend, instance_type) + assert result == image_name + + @pytest.mark.parametrize("suffix", ["-base", "-devel"]) + @pytest.mark.parametrize( + "instance_type", + ["t3.micro", "m5.large", "c5.xlarge", "r5.2xlarge", "m6i.large", "g6.xlarge"], + ) + def test_no_patch_non_efa_aws_instances(self, instance_type: str, suffix: str) -> None: + image_name = f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}{suffix}" + result = self._call_resolve_provisioning_image_name( + image_name, + BackendType.AWS, + instance_type, + ) + assert result == image_name + + @pytest.mark.parametrize( + "instance_type", + ["p5.xlarge", "p6.2xlarge", "t3.micro", "m5.large"], + ) + @pytest.mark.parametrize( + "image_name", + [ + "ubuntu:20.04", + "nvidia/cuda:11.8-runtime-ubuntu20.04", + "python:3.9-slim", + "custom/image:latest", + f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}-custom", + f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}-devel-efa", + f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}", + ], + ) + def test_no_patch_other_images(self, instance_type: str, image_name: str) -> None: + result = self._call_resolve_provisioning_image_name( + image_name, + BackendType.AWS, + instance_type, + ) + assert result == image_name