From fe6cd36b11799f48abe106dc29538420b52adc1f Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 26 Feb 2026 11:06:29 +0500 Subject: [PATCH 01/21] Load only fleets active runs in apply_plan --- src/dstack/_internal/server/services/fleets.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 8febbd126..469bcb7c7 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -485,7 +485,12 @@ async def apply_plan( .joinedload(InstanceModel.jobs) .load_only(JobModel.id) ) - .options(selectinload(FleetModel.runs)) + # `is_fleet_in_use()` only needs active run presence/status. + .options( + selectinload( + FleetModel.runs.and_(RunModel.status.not_in(RunStatus.finished_statuses())) + ).load_only(RunModel.id, RunModel.status) + ) .execution_options(populate_existing=True) .order_by(FleetModel.id) # take locks in order .with_for_update(key_share=True) From f6827f97924660bbc1afc4ce669175ee606ae095 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 26 Feb 2026 11:16:54 +0500 Subject: [PATCH 02/21] Replace fleets many-to-many joinedloads with selectinloads --- src/dstack/_internal/server/services/fleets.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 469bcb7c7..008b48a03 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -225,7 +225,7 @@ async def list_projects_fleet_models( .where(*filters) .order_by(*order_by) .limit(limit) - .options(joinedload(FleetModel.instances.and_(InstanceModel.deleted == False))) + .options(selectinload(FleetModel.instances.and_(InstanceModel.deleted == False))) ) fleet_models = list(res.unique().scalars().all()) return fleet_models @@ -256,7 +256,7 @@ async def list_project_fleet_models( res = await session.execute( select(FleetModel) .where(*filters) - .options(joinedload(FleetModel.instances.and_(InstanceModel.deleted == False))) + .options(selectinload(FleetModel.instances.and_(InstanceModel.deleted == False))) ) return list(res.unique().scalars().all()) @@ -661,12 +661,12 @@ async def delete_fleets( select(FleetModel) .where(FleetModel.id.in_(fleets_ids)) .options( - joinedload(FleetModel.instances.and_(InstanceModel.id.in_(instances_ids))) - .joinedload(InstanceModel.jobs) + selectinload(FleetModel.instances.and_(InstanceModel.id.in_(instances_ids))) + .selectinload(InstanceModel.jobs) .load_only(JobModel.id) ) .options( - joinedload( + selectinload( FleetModel.runs.and_(RunModel.status.not_in(RunStatus.finished_statuses())) ) ) From 55041af47268dbd8d438fc100d9254377d8241d3 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 26 Feb 2026 11:30:47 +0500 Subject: [PATCH 03/21] Optimize selects --- .../background/scheduled_tasks/fleets.py | 25 +++++++++++-------- .../scheduled_tasks/submitted_jobs.py | 11 +++----- .../_internal/server/services/fleets.py | 2 +- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/fleets.py b/src/dstack/_internal/server/background/scheduled_tasks/fleets.py index a758f86ad..c19d5c516 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/fleets.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/fleets.py @@ -5,10 +5,11 @@ from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload, load_only, selectinload, with_loader_criteria +from sqlalchemy.orm import joinedload, load_only, selectinload from dstack._internal.core.models.fleets import FleetSpec, FleetStatus from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason +from dstack._internal.core.models.runs import RunStatus from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ( FleetModel, @@ -59,10 +60,9 @@ async def process_fleets(): ) .options( load_only(FleetModel.id, FleetModel.name), - selectinload(FleetModel.instances).load_only(InstanceModel.id), - with_loader_criteria( - InstanceModel, InstanceModel.deleted == False, include_aliases=True - ), + selectinload( + FleetModel.instances.and_(InstanceModel.deleted == False) + ).load_only(InstanceModel.id), ) .order_by(FleetModel.last_processed_at.asc()) .limit(BATCH_SIZE) @@ -115,14 +115,17 @@ async def _process_fleets(session: AsyncSession, fleet_models: List[FleetModel]) res = await session.execute( select(FleetModel) .where(FleetModel.id.in_(fleet_ids)) + .options(joinedload(FleetModel.project)) .options( - joinedload(FleetModel.instances).joinedload(InstanceModel.jobs).load_only(JobModel.id), - with_loader_criteria( - InstanceModel, InstanceModel.deleted == False, include_aliases=True - ), + selectinload(FleetModel.instances.and_(InstanceModel.deleted == False)) + .joinedload(InstanceModel.jobs) + .load_only(JobModel.id), + ) + .options( + selectinload( + FleetModel.runs.and_(RunModel.status.not_in(RunStatus.finished_statuses())) + ).load_only(RunModel.status) ) - .options(joinedload(FleetModel.project)) - .options(joinedload(FleetModel.runs).load_only(RunModel.status)) .execution_options(populate_existing=True) ) fleet_models = list(res.unique().scalars().all()) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py index 5d1b2e1a7..932db5294 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py @@ -13,7 +13,6 @@ load_only, noload, selectinload, - with_loader_criteria, ) from dstack._internal.core.backends.base.backend import Backend @@ -223,9 +222,8 @@ async def _process_submitted_job( .where(JobModel.id == job_model.id) .options(joinedload(JobModel.instance)) .options( - joinedload(JobModel.fleet).joinedload(FleetModel.instances), - with_loader_criteria( - InstanceModel, InstanceModel.deleted == False, include_aliases=True + joinedload(JobModel.fleet).selectinload( + FleetModel.instances.and_(InstanceModel.deleted == False) ), ) ) @@ -236,9 +234,8 @@ async def _process_submitted_job( .options(joinedload(RunModel.project).joinedload(ProjectModel.backends)) .options(joinedload(RunModel.user).load_only(UserModel.name)) .options( - joinedload(RunModel.fleet).joinedload(FleetModel.instances), - with_loader_criteria( - InstanceModel, InstanceModel.deleted == False, include_aliases=True + joinedload(RunModel.fleet).selectinload( + FleetModel.instances.and_(InstanceModel.deleted == False) ), ) ) diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 008b48a03..b557fc4b5 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -668,7 +668,7 @@ async def delete_fleets( .options( selectinload( FleetModel.runs.and_(RunModel.status.not_in(RunStatus.finished_statuses())) - ) + ).load_only(RunModel.status) ) .execution_options(populate_existing=True) ) From 2e3e0bc74cef521bf12cf07a5a5f8e24a3e6213f Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 26 Feb 2026 15:23:00 +0500 Subject: [PATCH 04/21] WIP: FleetPipeline --- .../background/pipeline_tasks/fleets.py | 519 ++++++++++++++++++ src/dstack/_internal/server/models.py | 2 +- .../_internal/server/services/fleets.py | 36 +- 3 files changed, 555 insertions(+), 2 deletions(-) create mode 100644 src/dstack/_internal/server/background/pipeline_tasks/fleets.py diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py new file mode 100644 index 000000000..2417f514c --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -0,0 +1,519 @@ +import asyncio +import uuid +from dataclasses import dataclass, field +from datetime import timedelta +from typing import Sequence + +from sqlalchemy import or_, select, update +from sqlalchemy.orm import joinedload, load_only, selectinload + +from dstack._internal.core.models.fleets import FleetSpec, FleetStatus +from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason +from dstack._internal.core.models.runs import RunStatus +from dstack._internal.server.background.pipeline_tasks.base import ( + Fetcher, + Heartbeater, + Pipeline, + PipelineItem, + UpdateMap, + Worker, + get_processed_update_map, + get_unlock_update_map, +) +from dstack._internal.server.db import get_db, get_session_ctx +from dstack._internal.server.models import ( + FleetModel, + InstanceModel, + JobModel, + PlacementGroupModel, + RunModel, +) +from dstack._internal.server.services import events +from dstack._internal.server.services.fleets import ( + create_fleet_instance_model, + emit_fleet_status_change_event, + get_fleet_spec, + get_next_instance_num, + is_fleet_empty, + is_fleet_in_use, +) +from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.utils import sentry_utils +from dstack._internal.utils.common import get_current_datetime +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +@dataclass +class FleetPipelineItem(PipelineItem): + status: FleetStatus + + +class FleetPipeline(Pipeline[FleetPipelineItem]): + def __init__( + self, + workers_num: int = 10, + queue_lower_limit_factor: float = 0.5, + queue_upper_limit_factor: float = 2.0, + min_processing_interval: timedelta = timedelta(seconds=30), + lock_timeout: timedelta = timedelta(seconds=30), + heartbeat_trigger: timedelta = timedelta(seconds=15), + ) -> None: + super().__init__( + workers_num=workers_num, + queue_lower_limit_factor=queue_lower_limit_factor, + queue_upper_limit_factor=queue_upper_limit_factor, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeat_trigger=heartbeat_trigger, + ) + self.__heartbeater = Heartbeater[FleetPipelineItem]( + model_type=FleetModel, + lock_timeout=self._lock_timeout, + heartbeat_trigger=self._heartbeat_trigger, + ) + self.__fetcher = FleetFetcher( + queue=self._queue, + queue_desired_minsize=self._queue_desired_minsize, + min_processing_interval=self._min_processing_interval, + lock_timeout=self._lock_timeout, + heartbeater=self._heartbeater, + ) + self.__workers = [ + FleetWorker(queue=self._queue, heartbeater=self._heartbeater) + for _ in range(self._workers_num) + ] + + @property + def hint_fetch_model_name(self) -> str: + return FleetModel.__name__ + + @property + def _heartbeater(self) -> Heartbeater[FleetPipelineItem]: + return self.__heartbeater + + @property + def _fetcher(self) -> Fetcher[FleetPipelineItem]: + return self.__fetcher + + @property + def _workers(self) -> Sequence["FleetWorker"]: + return self.__workers + + +class FleetFetcher(Fetcher[FleetPipelineItem]): + def __init__( + self, + queue: asyncio.Queue[FleetPipelineItem], + queue_desired_minsize: int, + min_processing_interval: timedelta, + lock_timeout: timedelta, + heartbeater: Heartbeater[FleetPipelineItem], + queue_check_delay: float = 1.0, + ) -> None: + super().__init__( + queue=queue, + queue_desired_minsize=queue_desired_minsize, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeater=heartbeater, + queue_check_delay=queue_check_delay, + ) + + @sentry_utils.instrument_named_task("pipeline_tasks.FleetFetcher.fetch") + async def fetch(self, limit: int) -> list[FleetPipelineItem]: + fleet_lock, _ = get_locker(get_db().dialect_name).get_lockset(FleetModel.__tablename__) + async with fleet_lock: + async with get_session_ctx() as session: + now = get_current_datetime() + res = await session.execute( + select(FleetModel) + .where( + FleetModel.deleted == False, + or_( + FleetModel.last_processed_at <= now - self._min_processing_interval, + FleetModel.last_processed_at == FleetModel.created_at, + ), + or_( + FleetModel.lock_expires_at.is_(None), + FleetModel.lock_expires_at < now, + ), + or_( + FleetModel.lock_owner.is_(None), + FleetModel.lock_owner == FleetPipeline.__name__, + ), + ) + .order_by(FleetModel.last_processed_at.asc()) + .limit(limit) + .with_for_update(skip_locked=True, key_share=True) + .options( + load_only( + FleetModel.id, + FleetModel.lock_token, + FleetModel.lock_expires_at, + FleetModel.status, + ) + ) + ) + fleet_models = list(res.scalars().all()) + lock_expires_at = get_current_datetime() + self._lock_timeout + lock_token = uuid.uuid4() + items = [] + for fleet_model in fleet_models: + prev_lock_expired = fleet_model.lock_expires_at is not None + fleet_model.lock_expires_at = lock_expires_at + fleet_model.lock_token = lock_token + fleet_model.lock_owner = FleetPipeline.__name__ + items.append( + FleetPipelineItem( + __tablename__=FleetModel.__tablename__, + id=fleet_model.id, + lock_expires_at=lock_expires_at, + lock_token=lock_token, + prev_lock_expired=prev_lock_expired, + status=fleet_model.status, + ) + ) + await session.commit() + return items + + +class FleetWorker(Worker[FleetPipelineItem]): + def __init__( + self, + queue: asyncio.Queue[FleetPipelineItem], + heartbeater: Heartbeater[FleetPipelineItem], + ) -> None: + super().__init__( + queue=queue, + heartbeater=heartbeater, + ) + + @sentry_utils.instrument_named_task("pipeline_tasks.FleetWorker.process") + async def process(self, item: FleetPipelineItem): + async with get_session_ctx() as session: + res = await session.execute( + select(FleetModel) + .where( + FleetModel.id == item.id, + FleetModel.lock_token == item.lock_token, + ) + .options(joinedload(FleetModel.project)) + .options( + selectinload(FleetModel.instances.and_(InstanceModel.deleted == False)) + .joinedload(InstanceModel.jobs) + .load_only(JobModel.id), + ) + .options( + selectinload( + FleetModel.runs.and_(RunModel.status.not_in(RunStatus.finished_statuses())) + ).load_only(RunModel.status) + ) + ) + fleet_model = res.unique().scalar_one_or_none() + if fleet_model is None: + logger.warning( + "Failed to process %s item %s: lock_token mismatch." + " The item is expected to be processed and updated on another fetch iteration.", + item.__tablename__, + item.id, + ) + return + + instance_lock, _ = get_locker(get_db().dialect_name).get_lockset( + InstanceModel.__tablename__ + ) + async with instance_lock: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.fleet_id == item.id, + InstanceModel.deleted != False, + # TODO: Lock instance models in the DB + # or_( + # InstanceModel.lock_expires_at.is_(None), + # InstanceModel.lock_expires_at < get_current_datetime(), + # ), + # or_( + # InstanceModel.lock_owner.is_(None), + # InstanceModel.lock_owner == FleetPipeline.__name__, + # ), + ) + .with_for_update(skip_locked=True, key_share=True) + ) + locked_instance_models = res.scalars().all() + if len(fleet_model.instances) != len(locked_instance_models): + logger.debug( + "Failed to lock fleet %s instances. The fleet will be processed later.", + item.id, + ) + now = get_current_datetime() + # Keep `lock_owner` so that `InstancePipeline` sees that the fleet is being locked + # but unset `lock_expires_at` to process the item again ASAP (after `min_processing_interval`). + # Unset `lock_token` so that heartbeater can no longer update the item. + res = await session.execute( + update(FleetModel) + .where( + FleetModel.id == item.id, + FleetModel.lock_token == item.lock_token, + ) + .values( + lock_expires_at=None, + lock_token=None, + last_processed_at=now, + ) + ) + if res.rowcount == 0: # pyright: ignore[reportAttributeAccessIssue] + logger.warning( + "Failed to reset lock: lock_token changed." + " The item is expected to be processed and updated on another fetch iteration." + ) + return + + # TODO: Lock instance models in the DB + # for instance_model in locked_instance_models: + # instance_model.lock_expires_at = item.lock_expires_at + # instance_model.lock_token = item.lock_token + # instance_model.lock_owner = FleetPipeline.__name__ + # await session.commit() + + result = await _process_fleet(fleet_model) + fleet_update_map = ( + result.fleet_update_map | get_processed_update_map() | get_unlock_update_map() + ) + async with get_session_ctx() as session: + res = await session.execute( + update(FleetModel) + .where( + FleetModel.id == fleet_model.id, + FleetModel.lock_token == fleet_model.lock_token, + ) + .values(**fleet_update_map) + .returning(FleetModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + logger.warning( + "Failed to update %s item %s after processing: lock_token changed." + " The item is expected to be processed and updated on another fetch iteration.", + item.__tablename__, + item.id, + ) + # TODO: Clean up fleet. + return + + if fleet_update_map.get("deleted"): + await session.execute( + update(PlacementGroupModel) + .where(PlacementGroupModel.fleet_id == item.id) + .values(fleet_deleted=True) + ) + + instance_update_rows = [] + for instance_id, instance_update_map in result.instance_id_to_update_map.items(): + update_row = dict(instance_update_map) + update_row["id"] = instance_id + instance_update_rows.append(update_row) + if instance_update_rows: + await session.execute( + update(InstanceModel).execution_options(synchronize_session=False), + instance_update_rows, + ) + + if result.new_instances_count > 0: + fleet_spec = get_fleet_spec(fleet_model) + res = await session.execute( + select(InstanceModel.instance_num).where( + InstanceModel.fleet_id == fleet_model.id, + InstanceModel.deleted == False, + ) + ) + taken_instance_nums = set(res.scalars().all()) + for _ in range(result.new_instances_count): + instance_num = get_next_instance_num(taken_instance_nums) + instance_model = create_fleet_instance_model( + session=session, + project=fleet_model.project, + # TODO: Store fleet.user and pass it instead of the project owner. + username=fleet_model.project.owner.name, + spec=fleet_spec, + instance_num=instance_num, + ) + instance_model.fleet_id = fleet_model.id + taken_instance_nums.add(instance_num) + events.emit( + session=session, + message=( + "Instance created to meet target fleet node count." + f" Status: {instance_model.status.upper()}" + ), + actor=events.SystemActor(), + targets=[events.Target.from_model(instance_model)], + ) + logger.info( + "Added %d instances to fleet %s", + result.new_instances_count, + fleet_model.name, + ) + + emit_fleet_status_change_event( + session=session, + fleet_model=fleet_model, + old_status=fleet_model.status, + new_status=fleet_update_map.get("status", fleet_model.status), + status_message=fleet_update_map.get("status_message", fleet_model.status_message), + ) + + +@dataclass +class _ProcessResult: + fleet_update_map: UpdateMap = field(default_factory=dict) + instance_id_to_update_map: dict[uuid.UUID, UpdateMap] = field(default_factory=dict) + new_instances_count: int = 0 + + +@dataclass +class _MaintainNodexResult: + instance_id_to_update_map: dict[uuid.UUID, UpdateMap] = field(default_factory=dict) + new_instances_count: int = 0 + + @property + def has_changes(self) -> bool: + return len(self.instance_id_to_update_map) > 0 or self.new_instances_count > 0 + + +async def _process_fleet(fleet_model: FleetModel) -> _ProcessResult: + result = _consolidate_fleet_state_with_spec(fleet_model) + if result.new_instances_count > 0: + # Avoid auto-deleting empty fleets that are about to receive new instances. + return result + # TODO: Drop fleets auto-deletion after dropping fleets auto-creation. + deleted = _autodelete_fleet(fleet_model) + if deleted: + result.fleet_update_map["deleted"] = True + result.fleet_update_map["deleted_at"] = get_current_datetime() + return result + + +def _consolidate_fleet_state_with_spec(fleet_model: FleetModel) -> _ProcessResult: + result = _ProcessResult() + if fleet_model.status == FleetStatus.TERMINATING: + return result + fleet_spec = get_fleet_spec(fleet_model) + if fleet_spec.configuration.nodes is None or fleet_spec.autocreated: + # Only explicitly created cloud fleets are consolidated. + return result + if not _is_fleet_ready_for_consolidation(fleet_model): + return result + maintain_nodes_result = _maintain_fleet_nodes_in_min_max_range(fleet_model, fleet_spec) + if maintain_nodes_result.has_changes: + result.instance_id_to_update_map = maintain_nodes_result.instance_id_to_update_map + result.new_instances_count = maintain_nodes_result.new_instances_count + result.fleet_update_map["consolidation_attempt"] = fleet_model.consolidation_attempt + 1 + else: + # The fleet is already consolidated or consolidation is in progress. + # We reset consolidation_attempt in both cases for simplicity. + # The second case does not need reset but is ok to do since + # it means consolidation is longer than delay, so it won't happen too often. + # TODO: Reset consolidation_attempt on fleet in-place update. + result.fleet_update_map["consolidation_attempt"] = 0 + result.fleet_update_map["last_consolidated_at"] = get_current_datetime() + return result + + +def _is_fleet_ready_for_consolidation(fleet_model: FleetModel) -> bool: + consolidation_retry_delay = _get_consolidation_retry_delay(fleet_model.consolidation_attempt) + last_consolidated_at = fleet_model.last_consolidated_at or fleet_model.last_processed_at + duration_since_last_consolidation = get_current_datetime() - last_consolidated_at + return duration_since_last_consolidation >= consolidation_retry_delay + + +# We use exponentially increasing consolidation retry delays so that +# consolidation does not happen too often. In particular, this prevents +# retrying instance provisioning constantly in case of no offers. +# TODO: Adjust delays. +_CONSOLIDATION_RETRY_DELAYS = [ + timedelta(seconds=30), + timedelta(minutes=1), + timedelta(minutes=2), + timedelta(minutes=5), + timedelta(minutes=10), +] + + +def _get_consolidation_retry_delay(consolidation_attempt: int) -> timedelta: + if consolidation_attempt < len(_CONSOLIDATION_RETRY_DELAYS): + return _CONSOLIDATION_RETRY_DELAYS[consolidation_attempt] + return _CONSOLIDATION_RETRY_DELAYS[-1] + + +def _maintain_fleet_nodes_in_min_max_range( + fleet_model: FleetModel, + fleet_spec: FleetSpec, +) -> _MaintainNodexResult: + """ + Ensures the fleet has at least `nodes.min` and at most `nodes.max` instances. + """ + assert fleet_spec.configuration.nodes is not None + result = _MaintainNodexResult() + for instance in fleet_model.instances: + # Delete terminated but not deleted instances since + # they are going to be replaced with new pending instances. + if instance.status == InstanceStatus.TERMINATED and not instance.deleted: + result.instance_id_to_update_map[instance.id] = { + "deleted": True, + "deleted_at": get_current_datetime(), + } + active_instances = [ + i for i in fleet_model.instances if i.status != InstanceStatus.TERMINATED and not i.deleted + ] + active_instances_num = len(active_instances) + if active_instances_num < fleet_spec.configuration.nodes.min: + nodes_missing = fleet_spec.configuration.nodes.min - active_instances_num + result.new_instances_count = nodes_missing + return result + if ( + fleet_spec.configuration.nodes.max is None + or active_instances_num <= fleet_spec.configuration.nodes.max + ): + return result + # Fleet has more instances than allowed by nodes.max. + # This is possible due to race conditions (e.g. provisioning jobs in a fleet concurrently) + # or if nodes.max is updated. + nodes_redundant = active_instances_num - fleet_spec.configuration.nodes.max + for instance in fleet_model.instances: + if nodes_redundant == 0: + break + if instance.status in [InstanceStatus.IDLE]: + result.instance_id_to_update_map[instance.id] = { + "termination_reason": InstanceTerminationReason.MAX_INSTANCES_LIMIT, + "termination_reason_message": "Fleet has too many instances", + "deleted": True, + "deleted_at": get_current_datetime(), + } + nodes_redundant -= 1 + return result + + +def _autodelete_fleet(fleet_model: FleetModel) -> bool: + if fleet_model.project.deleted: + # It used to be possible to delete project with active resources: + # https://github.com/dstackai/dstack/issues/3077 + logger.info("Fleet %s deleted due to deleted project", fleet_model.name) + return True + + if is_fleet_in_use(fleet_model) or not is_fleet_empty(fleet_model): + return False + + fleet_spec = get_fleet_spec(fleet_model) + if ( + fleet_model.status != FleetStatus.TERMINATING + and fleet_spec.configuration.nodes is not None + and fleet_spec.configuration.nodes.min == 0 + ): + # Empty fleets that allow 0 nodes should not be auto-deleted + return False + + logger.info("Automatic cleanup of an empty fleet %s", fleet_model.name) + return True diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index a7a8ec0bd..1a8a3aa86 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -576,7 +576,7 @@ class PoolModel(BaseModel): instances: Mapped[List["InstanceModel"]] = relationship(back_populates="pool", lazy="selectin") -class FleetModel(BaseModel): +class FleetModel(PipelineModelMixin, BaseModel): __tablename__ = "fleets" id: Mapped[uuid.UUID] = mapped_column( diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index b557fc4b5..db35331fd 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -102,11 +102,45 @@ def switch_fleet_status( return fleet_model.status = new_status + emit_fleet_status_change_event( + session=session, + fleet_model=fleet_model, + old_status=old_status, + new_status=new_status, + status_message=fleet_model.status_message, + actor=actor, + ) - msg = f"Fleet status changed {old_status.upper()} -> {new_status.upper()}" + +def emit_fleet_status_change_event( + session: AsyncSession, + fleet_model: FleetModel, + old_status: FleetStatus, + new_status: FleetStatus, + status_message: Optional[str], + actor: events.AnyActor = events.SystemActor(), +) -> None: + if old_status == new_status: + return + msg = get_fleet_status_change_message( + old_status=old_status, + new_status=new_status, + status_message=status_message, + ) events.emit(session, msg, actor=actor, targets=[events.Target.from_model(fleet_model)]) +def get_fleet_status_change_message( + old_status: FleetStatus, + new_status: FleetStatus, + status_message: Optional[str], +) -> str: + msg = f"Fleet status changed {old_status.upper()} -> {new_status.upper()}" + if status_message is not None: + msg += f" ({status_message})" + return msg + + async def list_projects_with_no_active_fleets( session: AsyncSession, user: UserModel, From d171238fa4a788c496dca1f4cee46d51ceda82c0 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 27 Feb 2026 10:40:19 +0500 Subject: [PATCH 05/21] Fixes --- .../_internal/server/background/pipeline_tasks/fleets.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index 2417f514c..0f5300b66 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -229,7 +229,7 @@ async def process(self, item: FleetPipelineItem): select(InstanceModel) .where( InstanceModel.fleet_id == item.id, - InstanceModel.deleted != False, + InstanceModel.deleted == False, # TODO: Lock instance models in the DB # or_( # InstanceModel.lock_expires_at.is_(None), @@ -314,6 +314,7 @@ async def process(self, item: FleetPipelineItem): for instance_id, instance_update_map in result.instance_id_to_update_map.items(): update_row = dict(instance_update_map) update_row["id"] = instance_id + update_row |= get_processed_update_map() # | get_unlock_update_map() instance_update_rows.append(update_row) if instance_update_rows: await session.execute( @@ -391,6 +392,7 @@ async def _process_fleet(fleet_model: FleetModel) -> _ProcessResult: # TODO: Drop fleets auto-deletion after dropping fleets auto-creation. deleted = _autodelete_fleet(fleet_model) if deleted: + result.fleet_update_map["status"] = FleetStatus.TERMINATED result.fleet_update_map["deleted"] = True result.fleet_update_map["deleted_at"] = get_current_datetime() return result @@ -489,8 +491,7 @@ def _maintain_fleet_nodes_in_min_max_range( result.instance_id_to_update_map[instance.id] = { "termination_reason": InstanceTerminationReason.MAX_INSTANCES_LIMIT, "termination_reason_message": "Fleet has too many instances", - "deleted": True, - "deleted_at": get_current_datetime(), + "status": InstanceStatus.TERMINATING, } nodes_redundant -= 1 return result From a3672886e094baf96df557ac4341af44e8afa602 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 27 Feb 2026 10:59:16 +0500 Subject: [PATCH 06/21] Add TestFleetWorker --- .../background/pipeline_tasks/test_fleets.py | 357 ++++++++++++++++++ .../background/scheduled_tasks/test_fleets.py | 9 + 2 files changed, 366 insertions(+) create mode 100644 src/tests/_internal/server/background/pipeline_tasks/test_fleets.py diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py new file mode 100644 index 000000000..09b8a3400 --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py @@ -0,0 +1,357 @@ +import uuid +from datetime import datetime, timezone +from unittest.mock import Mock + +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.fleets import FleetNodesSpec, FleetStatus +from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.core.models.runs import RunStatus +from dstack._internal.core.models.users import GlobalRole, ProjectRole +from dstack._internal.server.background.pipeline_tasks.fleets import ( + FleetPipelineItem, + FleetWorker, +) +from dstack._internal.server.models import FleetModel, InstanceModel +from dstack._internal.server.services.projects import add_project_member +from dstack._internal.server.testing.common import ( + create_fleet, + create_instance, + create_placement_group, + create_project, + create_repo, + create_run, + create_user, + get_fleet_spec, +) + + +@pytest.fixture +def worker() -> FleetWorker: + return FleetWorker(queue=Mock(), heartbeater=Mock()) + + +def _fleet_to_pipeline_item(fleet: FleetModel) -> FleetPipelineItem: + assert fleet.lock_token is not None + assert fleet.lock_expires_at is not None + return FleetPipelineItem( + __tablename__=fleet.__tablename__, + id=fleet.id, + lock_token=fleet.lock_token, + lock_expires_at=fleet.lock_expires_at, + prev_lock_expired=False, + status=fleet.status, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestFleetWorker: + async def test_deletes_empty_autocreated_fleet( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.autocreated = True + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + assert fleet.deleted + + async def test_deletes_terminating_user_fleet( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.autocreated = False + fleet = await create_fleet( + session=session, + project=project, + status=FleetStatus.TERMINATING, + ) + + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + assert fleet.deleted + + async def test_does_not_delete_fleet_with_active_run( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + ) + user = await create_user(session=session, global_role=GlobalRole.USER) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + repo = await create_repo( + session=session, + project_id=project.id, + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.RUNNING, + ) + fleet.runs.append(run) + + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + assert not fleet.deleted + + async def test_does_not_delete_fleet_with_instance( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + ) + user = await create_user(session=session, global_role=GlobalRole.USER) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + ) + fleet.instances.append(instance) + + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + assert not fleet.deleted + + async def test_consolidation_creates_missing_instances( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.configuration.nodes = FleetNodesSpec(min=2, target=2, max=2) + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + instance_num=1, + ) + + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + instances = (await session.execute(select(InstanceModel))).scalars().all() + assert len(instances) == 2 + assert {i.instance_num for i in instances} == {0, 1} + assert fleet.consolidation_attempt == 1 + + async def test_consolidation_terminates_redundant_instances( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=1) + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + instance1 = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.BUSY, + instance_num=0, + ) + instance2 = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + instance_num=1, + ) + instance3 = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.TERMINATED, + instance_num=2, + ) + + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + await session.refresh(instance1) + await session.refresh(instance2) + await session.refresh(instance3) + assert instance1.status == InstanceStatus.BUSY + assert instance2.status == InstanceStatus.TERMINATING + assert instance3.deleted + assert fleet.consolidation_attempt == 1 + + async def test_marks_placement_groups_fleet_deleted_on_fleet_delete( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + status=FleetStatus.TERMINATING, + ) + placement_group1 = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="test-pg-1", + ) + placement_group2 = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="test-pg-2", + ) + + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + await session.refresh(placement_group1) + await session.refresh(placement_group2) + assert fleet.deleted + assert placement_group1.fleet_deleted + assert placement_group2.fleet_deleted + + async def test_consolidation_respects_retry_delay( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.configuration.nodes = FleetNodesSpec(min=2, target=2, max=2) + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + instance_num=0, + ) + fleet.consolidation_attempt = 1 + fleet.last_consolidated_at = datetime.now(timezone.utc) + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + instances = ( + ( + await session.execute( + select(InstanceModel).where( + InstanceModel.fleet_id == fleet.id, + InstanceModel.deleted == False, + ) + ) + ) + .scalars() + .all() + ) + assert len(instances) == 1 + assert fleet.consolidation_attempt == 1 + assert not fleet.deleted + + async def test_consolidation_attempt_resets_when_no_changes( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=1) + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + instance_num=0, + ) + fleet.consolidation_attempt = 3 + previous_last_consolidated_at = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc) + fleet.last_consolidated_at = previous_last_consolidated_at + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + instances = ( + ( + await session.execute( + select(InstanceModel).where( + InstanceModel.fleet_id == fleet.id, + InstanceModel.deleted == False, + ) + ) + ) + .scalars() + .all() + ) + assert len(instances) == 1 + assert fleet.consolidation_attempt == 0 + assert fleet.last_consolidated_at is not None + assert fleet.last_consolidated_at > previous_last_consolidated_at diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_fleets.py b/src/tests/_internal/server/background/scheduled_tasks/test_fleets.py index 2ef1b27ab..2136a2c96 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_fleets.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_fleets.py @@ -154,8 +154,17 @@ async def test_consolidation_terminates_redundant_instances( status=InstanceStatus.IDLE, instance_num=1, ) + instance3 = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.TERMINATED, + instance_num=2, + ) await process_fleets() await session.refresh(instance1) await session.refresh(instance2) + await session.refresh(instance3) assert instance1.status == InstanceStatus.BUSY assert instance2.status == InstanceStatus.TERMINATING + assert instance3.deleted From 8cad574579b7e424c5593984bad3d789f9426c19 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 27 Feb 2026 11:40:16 +0500 Subject: [PATCH 07/21] Fix consolidation_attempt reset --- .../background/pipeline_tasks/fleets.py | 23 +++++------ src/dstack/_internal/server/models.py | 2 + .../background/pipeline_tasks/test_fleets.py | 41 ++++++++++++++++++- 3 files changed, 53 insertions(+), 13 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index 0f5300b66..e906cabc3 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -47,7 +47,7 @@ @dataclass class FleetPipelineItem(PipelineItem): - status: FleetStatus + pass class FleetPipeline(Pipeline[FleetPipelineItem]): @@ -152,7 +152,6 @@ async def fetch(self, limit: int) -> list[FleetPipelineItem]: FleetModel.id, FleetModel.lock_token, FleetModel.lock_expires_at, - FleetModel.status, ) ) ) @@ -172,7 +171,6 @@ async def fetch(self, limit: int) -> list[FleetPipelineItem]: lock_expires_at=lock_expires_at, lock_token=lock_token, prev_lock_expired=prev_lock_expired, - status=fleet_model.status, ) ) await session.commit() @@ -375,9 +373,10 @@ class _ProcessResult: @dataclass -class _MaintainNodexResult: +class _MaintainNodesResult: instance_id_to_update_map: dict[uuid.UUID, UpdateMap] = field(default_factory=dict) new_instances_count: int = 0 + changes_required: bool = False @property def has_changes(self) -> bool: @@ -412,13 +411,10 @@ def _consolidate_fleet_state_with_spec(fleet_model: FleetModel) -> _ProcessResul if maintain_nodes_result.has_changes: result.instance_id_to_update_map = maintain_nodes_result.instance_id_to_update_map result.new_instances_count = maintain_nodes_result.new_instances_count + if maintain_nodes_result.changes_required: result.fleet_update_map["consolidation_attempt"] = fleet_model.consolidation_attempt + 1 else: - # The fleet is already consolidated or consolidation is in progress. - # We reset consolidation_attempt in both cases for simplicity. - # The second case does not need reset but is ok to do since - # it means consolidation is longer than delay, so it won't happen too often. - # TODO: Reset consolidation_attempt on fleet in-place update. + # The fleet is consolidated with respect to nodes min/max. result.fleet_update_map["consolidation_attempt"] = 0 result.fleet_update_map["last_consolidated_at"] = get_current_datetime() return result @@ -453,16 +449,17 @@ def _get_consolidation_retry_delay(consolidation_attempt: int) -> timedelta: def _maintain_fleet_nodes_in_min_max_range( fleet_model: FleetModel, fleet_spec: FleetSpec, -) -> _MaintainNodexResult: +) -> _MaintainNodesResult: """ Ensures the fleet has at least `nodes.min` and at most `nodes.max` instances. """ assert fleet_spec.configuration.nodes is not None - result = _MaintainNodexResult() + result = _MaintainNodesResult() for instance in fleet_model.instances: # Delete terminated but not deleted instances since # they are going to be replaced with new pending instances. if instance.status == InstanceStatus.TERMINATED and not instance.deleted: + result.changes_required = True result.instance_id_to_update_map[instance.id] = { "deleted": True, "deleted_at": get_current_datetime(), @@ -472,6 +469,7 @@ def _maintain_fleet_nodes_in_min_max_range( ] active_instances_num = len(active_instances) if active_instances_num < fleet_spec.configuration.nodes.min: + result.changes_required = True nodes_missing = fleet_spec.configuration.nodes.min - active_instances_num result.new_instances_count = nodes_missing return result @@ -483,11 +481,12 @@ def _maintain_fleet_nodes_in_min_max_range( # Fleet has more instances than allowed by nodes.max. # This is possible due to race conditions (e.g. provisioning jobs in a fleet concurrently) # or if nodes.max is updated. + result.changes_required = True nodes_redundant = active_instances_num - fleet_spec.configuration.nodes.max for instance in fleet_model.instances: if nodes_redundant == 0: break - if instance.status in [InstanceStatus.IDLE]: + if instance.status == InstanceStatus.IDLE: result.instance_id_to_update_map[instance.id] = { "termination_reason": InstanceTerminationReason.MAX_INSTANCES_LIMIT, "termination_reason_message": "Fleet has too many instances", diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 1a8a3aa86..35f7eb4ec 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -604,6 +604,8 @@ class FleetModel(PipelineModelMixin, BaseModel): jobs: Mapped[List["JobModel"]] = relationship(back_populates="fleet") instances: Mapped[List["InstanceModel"]] = relationship(back_populates="fleet") + # `consolidation_attempt` counts how many times in a row fleet needed consolidation. + # Allows increasing delays between attempts. consolidation_attempt: Mapped[int] = mapped_column(Integer, server_default="0") last_consolidated_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py index 09b8a3400..f6c7acd50 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py @@ -42,7 +42,6 @@ def _fleet_to_pipeline_item(fleet: FleetModel) -> FleetPipelineItem: lock_token=fleet.lock_token, lock_expires_at=fleet.lock_expires_at, prev_lock_expired=False, - status=fleet.status, ) @@ -233,6 +232,46 @@ async def test_consolidation_terminates_redundant_instances( assert instance3.deleted assert fleet.consolidation_attempt == 1 + async def test_consolidation_attempt_increments_when_over_max_and_no_idle_instances( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=1) + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + instance1 = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.BUSY, + instance_num=0, + ) + instance2 = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.BUSY, + instance_num=1, + ) + + fleet.consolidation_attempt = 2 + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + await session.refresh(instance1) + await session.refresh(instance2) + assert instance1.status == InstanceStatus.BUSY + assert instance2.status == InstanceStatus.BUSY + assert fleet.consolidation_attempt == 3 + async def test_marks_placement_groups_fleet_deleted_on_fleet_delete( self, test_db, session: AsyncSession, worker: FleetWorker ): From bc0689aa90990a0d38b850fdc9b5b25643db6711 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 27 Feb 2026 13:43:04 +0500 Subject: [PATCH 08/21] Use typed dicts for update maps --- .../server/background/pipeline_tasks/base.py | 37 ++++++--- .../pipeline_tasks/compute_groups.py | 55 +++++++++---- .../background/pipeline_tasks/fleets.py | 78 +++++++++++-------- .../background/pipeline_tasks/gateways.py | 46 ++++++++--- .../pipeline_tasks/placement_groups.py | 35 +++++---- .../background/pipeline_tasks/volumes.py | 29 +++++-- .../background/pipeline_tasks/test_fleets.py | 6 +- 7 files changed, 193 insertions(+), 93 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/base.py b/src/dstack/_internal/server/background/pipeline_tasks/base.py index 9d016934c..779e8eab6 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/base.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/base.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any, ClassVar, Generic, Optional, Protocol, Sequence, TypeVar +from typing import Any, ClassVar, Generic, Optional, Protocol, Sequence, TypedDict, TypeVar from sqlalchemy import and_, or_, update from sqlalchemy.orm import Mapped @@ -337,16 +337,33 @@ async def process(self, item: ItemT): pass -UpdateMap = dict[str, Any] +class UnlockUpdateMap(TypedDict, total=False): + lock_expires_at: Optional[datetime] + lock_token: Optional[uuid.UUID] + lock_owner: Optional[str] -def get_unlock_update_map() -> UpdateMap: - return { - "lock_expires_at": None, - "lock_token": None, - "lock_owner": None, - } +class ProcessedUpdateMap(TypedDict, total=False): + last_processed_at: datetime -def get_processed_update_map() -> UpdateMap: - return {"last_processed_at": get_current_datetime()} +class ItemUpdateMap(UnlockUpdateMap, ProcessedUpdateMap, total=False): + lock_expires_at: Optional[datetime] + lock_token: Optional[uuid.UUID] + lock_owner: Optional[str] + last_processed_at: datetime + + +def set_unlock_update_map_fields(update_map: UnlockUpdateMap) -> None: + update_map["lock_expires_at"] = None + update_map["lock_token"] = None + update_map["lock_owner"] = None + + +def set_processed_update_map_fields( + update_map: ProcessedUpdateMap, + processed_at: Optional[datetime] = None, +) -> None: + if processed_at is None: + processed_at = get_current_datetime() + update_map["last_processed_at"] = processed_at diff --git a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py index 33e839b8b..cf14dfbd1 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py @@ -2,7 +2,7 @@ import uuid from dataclasses import dataclass, field from datetime import datetime, timedelta -from typing import Sequence +from typing import Sequence, TypedDict from sqlalchemy import or_, select, update from sqlalchemy.orm import joinedload, load_only @@ -14,12 +14,12 @@ from dstack._internal.server.background.pipeline_tasks.base import ( Fetcher, Heartbeater, + ItemUpdateMap, Pipeline, PipelineItem, - UpdateMap, Worker, - get_processed_update_map, - get_unlock_update_map, + set_processed_update_map_fields, + set_unlock_update_map_fields, ) from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ComputeGroupModel, InstanceModel, ProjectModel @@ -206,9 +206,9 @@ async def process(self, item: PipelineItem): if terminate_result.compute_group_update_map: logger.info("Terminated compute group %s", compute_group_model.id) else: - terminate_result.compute_group_update_map = get_processed_update_map() + set_processed_update_map_fields(terminate_result.compute_group_update_map) - terminate_result.compute_group_update_map |= get_unlock_update_map() + set_unlock_update_map_fields(terminate_result.compute_group_update_map) async with get_session_ctx() as session: res = await session.execute( @@ -246,10 +246,28 @@ async def process(self, item: PipelineItem): ) +class _ComputeGroupUpdateMap(ItemUpdateMap, total=False): + status: ComputeGroupStatus + deleted: bool + deleted_at: datetime + first_termination_retry_at: datetime + last_termination_retry_at: datetime + + +class _InstanceBulkUpdateMap(TypedDict, total=False): + last_processed_at: datetime + deleted: bool + deleted_at: datetime + finished_at: datetime + status: InstanceStatus + + @dataclass class _TerminateResult: - compute_group_update_map: UpdateMap = field(default_factory=dict) - instances_update_map: UpdateMap = field(default_factory=dict) + compute_group_update_map: _ComputeGroupUpdateMap = field( + default_factory=_ComputeGroupUpdateMap + ) + instances_update_map: _InstanceBulkUpdateMap = field(default_factory=_InstanceBulkUpdateMap) async def _terminate_compute_group(compute_group_model: ComputeGroupModel) -> _TerminateResult: @@ -286,13 +304,13 @@ async def _terminate_compute_group(compute_group_model: ComputeGroupModel) -> _T if compute_group_model.first_termination_retry_at is None: result.compute_group_update_map["first_termination_retry_at"] = get_current_datetime() result.compute_group_update_map["last_termination_retry_at"] = get_current_datetime() + first_termination_retry_at = result.compute_group_update_map.get( + "first_termination_retry_at", compute_group_model.first_termination_retry_at + ) + assert first_termination_retry_at is not None if _next_termination_retry_at( result.compute_group_update_map["last_termination_retry_at"] - ) < _get_termination_deadline( - result.compute_group_update_map.get( - "first_termination_retry_at", compute_group_model.first_termination_retry_at - ) - ): + ) < _get_termination_deadline(first_termination_retry_at): logger.warning( "Failed to terminate compute group %s. Will retry. Error: %r", compute_group.name, @@ -309,10 +327,15 @@ async def _terminate_compute_group(compute_group_model: ComputeGroupModel) -> _T exc_info=not isinstance(e, BackendError), ) terminated_result = _get_terminated_result() + compute_group_update_map = _ComputeGroupUpdateMap() + compute_group_update_map.update(result.compute_group_update_map) + compute_group_update_map.update(terminated_result.compute_group_update_map) + instances_update_map = _InstanceBulkUpdateMap() + instances_update_map.update(result.instances_update_map) + instances_update_map.update(terminated_result.instances_update_map) return _TerminateResult( - compute_group_update_map=result.compute_group_update_map - | terminated_result.compute_group_update_map, - instances_update_map=result.instances_update_map | terminated_result.instances_update_map, + compute_group_update_map=compute_group_update_map, + instances_update_map=instances_update_map, ) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index e906cabc3..9ded1bf7d 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -1,8 +1,8 @@ import asyncio import uuid from dataclasses import dataclass, field -from datetime import timedelta -from typing import Sequence +from datetime import datetime, timedelta +from typing import Sequence, TypedDict from sqlalchemy import or_, select, update from sqlalchemy.orm import joinedload, load_only, selectinload @@ -13,12 +13,12 @@ from dstack._internal.server.background.pipeline_tasks.base import ( Fetcher, Heartbeater, + ItemUpdateMap, Pipeline, PipelineItem, - UpdateMap, Worker, - get_processed_update_map, - get_unlock_update_map, + set_processed_update_map_fields, + set_unlock_update_map_fields, ) from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ( @@ -45,12 +45,7 @@ logger = get_logger(__name__) -@dataclass -class FleetPipelineItem(PipelineItem): - pass - - -class FleetPipeline(Pipeline[FleetPipelineItem]): +class FleetPipeline(Pipeline[PipelineItem]): def __init__( self, workers_num: int = 10, @@ -68,7 +63,7 @@ def __init__( lock_timeout=lock_timeout, heartbeat_trigger=heartbeat_trigger, ) - self.__heartbeater = Heartbeater[FleetPipelineItem]( + self.__heartbeater = Heartbeater[PipelineItem]( model_type=FleetModel, lock_timeout=self._lock_timeout, heartbeat_trigger=self._heartbeat_trigger, @@ -90,11 +85,11 @@ def hint_fetch_model_name(self) -> str: return FleetModel.__name__ @property - def _heartbeater(self) -> Heartbeater[FleetPipelineItem]: + def _heartbeater(self) -> Heartbeater[PipelineItem]: return self.__heartbeater @property - def _fetcher(self) -> Fetcher[FleetPipelineItem]: + def _fetcher(self) -> Fetcher[PipelineItem]: return self.__fetcher @property @@ -102,14 +97,14 @@ def _workers(self) -> Sequence["FleetWorker"]: return self.__workers -class FleetFetcher(Fetcher[FleetPipelineItem]): +class FleetFetcher(Fetcher[PipelineItem]): def __init__( self, - queue: asyncio.Queue[FleetPipelineItem], + queue: asyncio.Queue[PipelineItem], queue_desired_minsize: int, min_processing_interval: timedelta, lock_timeout: timedelta, - heartbeater: Heartbeater[FleetPipelineItem], + heartbeater: Heartbeater[PipelineItem], queue_check_delay: float = 1.0, ) -> None: super().__init__( @@ -122,7 +117,7 @@ def __init__( ) @sentry_utils.instrument_named_task("pipeline_tasks.FleetFetcher.fetch") - async def fetch(self, limit: int) -> list[FleetPipelineItem]: + async def fetch(self, limit: int) -> list[PipelineItem]: fleet_lock, _ = get_locker(get_db().dialect_name).get_lockset(FleetModel.__tablename__) async with fleet_lock: async with get_session_ctx() as session: @@ -165,7 +160,7 @@ async def fetch(self, limit: int) -> list[FleetPipelineItem]: fleet_model.lock_token = lock_token fleet_model.lock_owner = FleetPipeline.__name__ items.append( - FleetPipelineItem( + PipelineItem( __tablename__=FleetModel.__tablename__, id=fleet_model.id, lock_expires_at=lock_expires_at, @@ -177,11 +172,11 @@ async def fetch(self, limit: int) -> list[FleetPipelineItem]: return items -class FleetWorker(Worker[FleetPipelineItem]): +class FleetWorker(Worker[PipelineItem]): def __init__( self, - queue: asyncio.Queue[FleetPipelineItem], - heartbeater: Heartbeater[FleetPipelineItem], + queue: asyncio.Queue[PipelineItem], + heartbeater: Heartbeater[PipelineItem], ) -> None: super().__init__( queue=queue, @@ -189,7 +184,7 @@ def __init__( ) @sentry_utils.instrument_named_task("pipeline_tasks.FleetWorker.process") - async def process(self, item: FleetPipelineItem): + async def process(self, item: PipelineItem): async with get_session_ctx() as session: res = await session.execute( select(FleetModel) @@ -277,9 +272,10 @@ async def process(self, item: FleetPipelineItem): # await session.commit() result = await _process_fleet(fleet_model) - fleet_update_map = ( - result.fleet_update_map | get_processed_update_map() | get_unlock_update_map() - ) + fleet_update_map = _FleetUpdateMap() + fleet_update_map.update(result.fleet_update_map) + set_processed_update_map_fields(fleet_update_map) + set_unlock_update_map_fields(fleet_update_map) async with get_session_ctx() as session: res = await session.execute( update(FleetModel) @@ -310,9 +306,10 @@ async def process(self, item: FleetPipelineItem): instance_update_rows = [] for instance_id, instance_update_map in result.instance_id_to_update_map.items(): - update_row = dict(instance_update_map) + update_row = _InstanceUpdateMap() + update_row.update(instance_update_map) update_row["id"] = instance_id - update_row |= get_processed_update_map() # | get_unlock_update_map() + set_processed_update_map_fields(update_row) instance_update_rows.append(update_row) if instance_update_rows: await session.execute( @@ -365,16 +362,35 @@ async def process(self, item: FleetPipelineItem): ) +class _FleetUpdateMap(ItemUpdateMap, total=False): + status: FleetStatus + status_message: str + deleted: bool + deleted_at: datetime + consolidation_attempt: int + last_consolidated_at: datetime + + +class _InstanceUpdateMap(TypedDict, total=False): + status: InstanceStatus + termination_reason: InstanceTerminationReason + termination_reason_message: str + deleted: bool + deleted_at: datetime + last_processed_at: datetime + id: uuid.UUID + + @dataclass class _ProcessResult: - fleet_update_map: UpdateMap = field(default_factory=dict) - instance_id_to_update_map: dict[uuid.UUID, UpdateMap] = field(default_factory=dict) + fleet_update_map: _FleetUpdateMap = field(default_factory=_FleetUpdateMap) + instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap] = field(default_factory=dict) new_instances_count: int = 0 @dataclass class _MaintainNodesResult: - instance_id_to_update_map: dict[uuid.UUID, UpdateMap] = field(default_factory=dict) + instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap] = field(default_factory=dict) new_instances_count: int = 0 changes_required: bool = False diff --git a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py index cdd0904e1..e448a1cb6 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py @@ -2,7 +2,7 @@ import uuid from dataclasses import dataclass, field from datetime import timedelta -from typing import Optional, Sequence +from typing import Optional, Sequence, TypedDict from sqlalchemy import delete, or_, select, update from sqlalchemy.orm import joinedload, load_only @@ -14,12 +14,13 @@ from dstack._internal.server.background.pipeline_tasks.base import ( Fetcher, Heartbeater, + ItemUpdateMap, Pipeline, PipelineItem, - UpdateMap, + ProcessedUpdateMap, Worker, - get_processed_update_map, - get_unlock_update_map, + set_processed_update_map_fields, + set_unlock_update_map_fields, ) from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ( @@ -227,7 +228,10 @@ async def _process_submitted_item(item: GatewayPipelineItem): return result = await _process_submitted_gateway(gateway_model) - update_map = result.update_map | get_processed_update_map() | get_unlock_update_map() + update_map = _GatewayUpdateMap() + update_map.update(result.update_map) + set_processed_update_map_fields(update_map) + set_unlock_update_map_fields(update_map) async with get_session_ctx() as session: gateway_compute_model = result.gateway_compute_model if gateway_compute_model is not None: @@ -262,9 +266,20 @@ async def _process_submitted_item(item: GatewayPipelineItem): ) +class _GatewayUpdateMap(ItemUpdateMap, total=False): + status: GatewayStatus + status_message: str + gateway_compute_id: uuid.UUID + + +class _GatewayComputeUpdateMap(TypedDict, total=False): + active: bool + deleted: bool + + @dataclass class _SubmittedResult: - update_map: UpdateMap = field(default_factory=dict) + update_map: _GatewayUpdateMap = field(default_factory=_GatewayUpdateMap) gateway_compute_model: Optional[GatewayComputeModel] = None @@ -337,7 +352,10 @@ async def _process_provisioning_item(item: GatewayPipelineItem): return result = await _process_provisioning_gateway(gateway_model) - update_map = result.gateway_update_map | get_processed_update_map() | get_unlock_update_map() + update_map = _GatewayUpdateMap() + update_map.update(result.gateway_update_map) + set_processed_update_map_fields(update_map) + set_unlock_update_map_fields(update_map) async with get_session_ctx() as session: res = await session.execute( update(GatewayModel) @@ -383,8 +401,10 @@ async def _process_provisioning_item(item: GatewayPipelineItem): @dataclass class _ProvisioningResult: - gateway_update_map: UpdateMap = field(default_factory=dict) - gateway_compute_update_map: UpdateMap = field(default_factory=dict) + gateway_update_map: _GatewayUpdateMap = field(default_factory=_GatewayUpdateMap) + gateway_compute_update_map: _GatewayComputeUpdateMap = field( + default_factory=_GatewayComputeUpdateMap + ) async def _process_provisioning_gateway(gateway_model: GatewayModel) -> _ProvisioningResult: @@ -475,13 +495,15 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem): targets=[events.Target.from_model(gateway_model)], ) else: + processed_update_map: ProcessedUpdateMap = {} + set_processed_update_map_fields(processed_update_map) res = await session.execute( update(GatewayModel) .where( GatewayModel.id == gateway_model.id, GatewayModel.lock_token == gateway_model.lock_token, ) - .values(**get_processed_update_map()) + .values(**processed_update_map) .returning(GatewayModel.id) ) updated_ids = list(res.scalars().all()) @@ -515,7 +537,9 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem): @dataclass class _DeletedResult: delete_gateway: bool - gateway_compute_update_map: UpdateMap = field(default_factory=dict) + gateway_compute_update_map: _GatewayComputeUpdateMap = field( + default_factory=_GatewayComputeUpdateMap + ) async def _process_to_be_deleted_gateway(gateway_model: GatewayModel) -> _DeletedResult: diff --git a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py index 193358ec0..70028fe67 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py @@ -1,6 +1,6 @@ import asyncio import uuid -from datetime import timedelta +from datetime import datetime, timedelta from typing import Sequence from sqlalchemy import or_, select, update @@ -11,12 +11,12 @@ from dstack._internal.server.background.pipeline_tasks.base import ( Fetcher, Heartbeater, + ItemUpdateMap, Pipeline, PipelineItem, - UpdateMap, Worker, - get_processed_update_map, - get_unlock_update_map, + set_processed_update_map_fields, + set_unlock_update_map_fields, ) from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ( @@ -197,9 +197,9 @@ async def process(self, item: PipelineItem): if update_map: logger.info("Deleted placement group %s", placement_group_model.name) else: - update_map = get_processed_update_map() + set_processed_update_map_fields(update_map) - update_map |= get_unlock_update_map() + set_unlock_update_map_fields(update_map) async with get_session_ctx() as session: res = await session.execute( @@ -221,7 +221,14 @@ async def process(self, item: PipelineItem): ) -async def _delete_placement_group(placement_group_model: PlacementGroupModel) -> UpdateMap: +class _PlacementGroupUpdateMap(ItemUpdateMap, total=False): + deleted: bool + deleted_at: datetime + + +async def _delete_placement_group( + placement_group_model: PlacementGroupModel, +) -> _PlacementGroupUpdateMap: placement_group = placement_group_model_to_placement_group(placement_group_model) if placement_group.provisioning_data is None: logger.error( @@ -247,7 +254,7 @@ async def _delete_placement_group(placement_group_model: PlacementGroupModel) -> logger.info( "Placement group %s is still in use. Skipping deletion for now.", placement_group.name ) - return {} + return _PlacementGroupUpdateMap() except Exception: # TODO: Retry deletion logger.exception( @@ -259,10 +266,10 @@ async def _delete_placement_group(placement_group_model: PlacementGroupModel) -> return _get_deleted_update_map() -def _get_deleted_update_map() -> UpdateMap: +def _get_deleted_update_map() -> _PlacementGroupUpdateMap: now = get_current_datetime() - return { - "last_processed_at": now, - "deleted": True, - "deleted_at": now, - } + update_map = _PlacementGroupUpdateMap() + update_map["last_processed_at"] = now + update_map["deleted"] = True + update_map["deleted_at"] = now + return update_map diff --git a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py index 578fe8423..e4f4f5a15 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py @@ -1,7 +1,7 @@ import asyncio import uuid from dataclasses import dataclass, field -from datetime import timedelta +from datetime import datetime, timedelta from typing import Sequence from sqlalchemy import or_, select, update @@ -13,12 +13,12 @@ from dstack._internal.server.background.pipeline_tasks.base import ( Fetcher, Heartbeater, + ItemUpdateMap, Pipeline, PipelineItem, - UpdateMap, Worker, - get_processed_update_map, - get_unlock_update_map, + set_processed_update_map_fields, + set_unlock_update_map_fields, ) from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ( @@ -233,7 +233,10 @@ async def _process_submitted_item(item: VolumePipelineItem): return result = await _process_submitted_volume(volume_model) - update_map = result.update_map | get_processed_update_map() | get_unlock_update_map() + update_map = _VolumeUpdateMap() + update_map.update(result.update_map) + set_processed_update_map_fields(update_map) + set_unlock_update_map_fields(update_map) async with get_session_ctx() as session: res = await session.execute( update(VolumeModel) @@ -263,9 +266,17 @@ async def _process_submitted_item(item: VolumePipelineItem): ) +class _VolumeUpdateMap(ItemUpdateMap, total=False): + status: VolumeStatus + status_message: str + volume_provisioning_data: str + deleted: bool + deleted_at: datetime + + @dataclass class _SubmittedResult: - update_map: UpdateMap = field(default_factory=dict) + update_map: _VolumeUpdateMap = field(default_factory=_VolumeUpdateMap) async def _process_submitted_volume(volume_model: VolumeModel) -> _SubmittedResult: @@ -363,7 +374,9 @@ async def _process_to_be_deleted_item(item: VolumePipelineItem): return result = await _process_to_be_deleted_volume(volume_model) - update_map = result.update_map | get_unlock_update_map() + update_map = _VolumeUpdateMap() + update_map.update(result.update_map) + set_unlock_update_map_fields(update_map) async with get_session_ctx() as session: res = await session.execute( update(VolumeModel) @@ -393,7 +406,7 @@ async def _process_to_be_deleted_item(item: VolumePipelineItem): @dataclass class _DeletedResult: - update_map: UpdateMap = field(default_factory=dict) + update_map: _VolumeUpdateMap = field(default_factory=_VolumeUpdateMap) async def _process_to_be_deleted_volume(volume_model: VolumeModel) -> _DeletedResult: diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py index f6c7acd50..133e9084e 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py @@ -10,8 +10,8 @@ from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.core.models.runs import RunStatus from dstack._internal.core.models.users import GlobalRole, ProjectRole +from dstack._internal.server.background.pipeline_tasks.base import PipelineItem from dstack._internal.server.background.pipeline_tasks.fleets import ( - FleetPipelineItem, FleetWorker, ) from dstack._internal.server.models import FleetModel, InstanceModel @@ -33,10 +33,10 @@ def worker() -> FleetWorker: return FleetWorker(queue=Mock(), heartbeater=Mock()) -def _fleet_to_pipeline_item(fleet: FleetModel) -> FleetPipelineItem: +def _fleet_to_pipeline_item(fleet: FleetModel) -> PipelineItem: assert fleet.lock_token is not None assert fleet.lock_expires_at is not None - return FleetPipelineItem( + return PipelineItem( __tablename__=fleet.__tablename__, id=fleet.id, lock_token=fleet.lock_token, From e2a27933d1ab15a3358e1a2ad7c7326a84871b3d Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 27 Feb 2026 14:36:51 +0500 Subject: [PATCH 09/21] Unify processing result classes --- .../pipeline_tasks/compute_groups.py | 7 +++- .../background/pipeline_tasks/gateways.py | 8 ++-- .../pipeline_tasks/placement_groups.py | 24 ++++++++---- .../pipeline_tasks/test_gateways.py | 4 ++ .../pipeline_tasks/test_placement_groups.py | 39 +++++++++++++++++++ 5 files changed, 69 insertions(+), 13 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py index cf14dfbd1..c7facb739 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py @@ -203,7 +203,12 @@ async def process(self, item: PipelineItem): # TODO: Fetch only compute groups with all instances terminating. if all(i.status == InstanceStatus.TERMINATING for i in compute_group_model.instances): terminate_result = await _terminate_compute_group(compute_group_model) - if terminate_result.compute_group_update_map: + terminated = terminate_result.compute_group_update_map.get( + "status" + ) == ComputeGroupStatus.TERMINATED or terminate_result.compute_group_update_map.get( + "deleted", False + ) + if terminated: logger.info("Terminated compute group %s", compute_group_model.id) else: set_processed_update_map_fields(terminate_result.compute_group_update_map) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py index e448a1cb6..e802043c0 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py @@ -17,7 +17,6 @@ ItemUpdateMap, Pipeline, PipelineItem, - ProcessedUpdateMap, Worker, set_processed_update_map_fields, set_unlock_update_map_fields, @@ -495,15 +494,16 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem): targets=[events.Target.from_model(gateway_model)], ) else: - processed_update_map: ProcessedUpdateMap = {} - set_processed_update_map_fields(processed_update_map) + update_map = _GatewayUpdateMap() + set_processed_update_map_fields(update_map) + set_unlock_update_map_fields(update_map) res = await session.execute( update(GatewayModel) .where( GatewayModel.id == gateway_model.id, GatewayModel.lock_token == gateway_model.lock_token, ) - .values(**processed_update_map) + .values(**update_map) .returning(GatewayModel.id) ) updated_ids = list(res.scalars().all()) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py index 70028fe67..a5965b015 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py @@ -1,5 +1,6 @@ import asyncio import uuid +from dataclasses import dataclass, field from datetime import datetime, timedelta from typing import Sequence @@ -193,8 +194,10 @@ async def process(self, item: PipelineItem): ) return - update_map = await _delete_placement_group(placement_group_model) - if update_map: + result = await _delete_placement_group(placement_group_model) + update_map = _PlacementGroupUpdateMap() + update_map.update(result.update_map) + if update_map.get("deleted", False): logger.info("Deleted placement group %s", placement_group_model.name) else: set_processed_update_map_fields(update_map) @@ -226,15 +229,20 @@ class _PlacementGroupUpdateMap(ItemUpdateMap, total=False): deleted_at: datetime +@dataclass +class _DeletePlacementGroupResult: + update_map: _PlacementGroupUpdateMap = field(default_factory=_PlacementGroupUpdateMap) + + async def _delete_placement_group( placement_group_model: PlacementGroupModel, -) -> _PlacementGroupUpdateMap: +) -> _DeletePlacementGroupResult: placement_group = placement_group_model_to_placement_group(placement_group_model) if placement_group.provisioning_data is None: logger.error( "Failed to delete placement group %s. provisioning_data is None.", placement_group.name ) - return _get_deleted_update_map() + return _DeletePlacementGroupResult(update_map=_get_deleted_update_map()) backend = await backends_services.get_project_backend_by_type( project=placement_group_model.project, backend_type=placement_group.provisioning_data.backend, @@ -245,7 +253,7 @@ async def _delete_placement_group( "Failed to delete placement group %s. Backend not available. Please delete it manually.", placement_group.name, ) - return _get_deleted_update_map() + return _DeletePlacementGroupResult(update_map=_get_deleted_update_map()) compute = backend.compute() assert isinstance(compute, ComputeWithPlacementGroupSupport) try: @@ -254,16 +262,16 @@ async def _delete_placement_group( logger.info( "Placement group %s is still in use. Skipping deletion for now.", placement_group.name ) - return _PlacementGroupUpdateMap() + return _DeletePlacementGroupResult() except Exception: # TODO: Retry deletion logger.exception( "Got exception when deleting placement group %s. Please delete it manually.", placement_group.name, ) - return _get_deleted_update_map() + return _DeletePlacementGroupResult(update_map=_get_deleted_update_map()) - return _get_deleted_update_map() + return _DeletePlacementGroupResult(update_map=_get_deleted_update_map()) def _get_deleted_update_map() -> _PlacementGroupUpdateMap: diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py b/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py index 9628451bd..59cbd370e 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py @@ -257,6 +257,7 @@ async def test_keeps_gateway_if_terminate_fails( ) gateway.lock_token = uuid.uuid4() gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + gateway.lock_owner = "GatewayPipeline" gateway.to_be_deleted = True original_last_processed_at = gateway.last_processed_at await session.commit() @@ -286,6 +287,9 @@ async def test_keeps_gateway_if_terminate_fails( await session.refresh(gateway_compute) assert gateway.to_be_deleted is True assert gateway.last_processed_at > original_last_processed_at + assert gateway.lock_token is None + assert gateway.lock_expires_at is None + assert gateway.lock_owner is None assert gateway_compute.active is True assert gateway_compute.deleted is False events = await list_events(session) diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py b/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py index 7baed58b6..c23d5e604 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py @@ -5,6 +5,7 @@ import pytest from sqlalchemy.ext.asyncio import AsyncSession +from dstack._internal.core.errors import PlacementGroupInUseError from dstack._internal.server.background.pipeline_tasks.base import PipelineItem from dstack._internal.server.background.pipeline_tasks.placement_groups import PlacementGroupWorker from dstack._internal.server.models import PlacementGroupModel @@ -62,3 +63,41 @@ async def test_deletes_placement_group( aws_mock.compute.return_value.delete_placement_group.assert_called_once() await session.refresh(placement_group) assert placement_group.deleted + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_retries_placement_group_deletion_if_still_in_use( + self, test_db, session: AsyncSession, worker: PlacementGroupWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + ) + placement_group = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="test2-pg", + fleet_deleted=True, + ) + placement_group.lock_token = uuid.uuid4() + placement_group.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + placement_group.lock_owner = "PlacementGroupPipeline" + original_last_processed_at = placement_group.last_processed_at + await session.commit() + with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: + aws_mock = Mock() + m.return_value = aws_mock + aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) + aws_mock.compute.return_value.delete_placement_group.side_effect = ( + PlacementGroupInUseError() + ) + await worker.process(_placement_group_to_pipeline_item(placement_group)) + aws_mock.compute.return_value.delete_placement_group.assert_called_once() + await session.refresh(placement_group) + assert not placement_group.deleted + assert placement_group.last_processed_at > original_last_processed_at + assert placement_group.lock_token is None + assert placement_group.lock_expires_at is None + assert placement_group.lock_owner is None From 55d98811864fbb59879145d768902471897a1484 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 27 Feb 2026 15:38:04 +0500 Subject: [PATCH 10/21] Centralize last_processed_at setting --- .../server/background/pipeline_tasks/base.py | 47 ++++++++++--- .../pipeline_tasks/compute_groups.py | 68 +++++++++---------- .../background/pipeline_tasks/fleets.py | 37 +++++----- .../background/pipeline_tasks/gateways.py | 20 ++++-- .../pipeline_tasks/placement_groups.py | 21 +++--- .../background/pipeline_tasks/volumes.py | 19 ++++-- 6 files changed, 127 insertions(+), 85 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/base.py b/src/dstack/_internal/server/background/pipeline_tasks/base.py index 779e8eab6..59a372760 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/base.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/base.py @@ -5,7 +5,18 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any, ClassVar, Generic, Optional, Protocol, Sequence, TypedDict, TypeVar +from typing import ( + Any, + ClassVar, + Final, + Generic, + Optional, + Protocol, + Sequence, + TypedDict, + TypeVar, + Union, +) from sqlalchemy import and_, or_, update from sqlalchemy.orm import Mapped @@ -337,6 +348,16 @@ async def process(self, item: ItemT): pass +class _NowPlaceholder: + pass + + +NOW_PLACEHOLDER: Final = _NowPlaceholder() + +# Timestamp value stored in update maps before being resolved to current time. +UpdateMapDateTime = Union[datetime, _NowPlaceholder] + + class UnlockUpdateMap(TypedDict, total=False): lock_expires_at: Optional[datetime] lock_token: Optional[uuid.UUID] @@ -344,17 +365,17 @@ class UnlockUpdateMap(TypedDict, total=False): class ProcessedUpdateMap(TypedDict, total=False): - last_processed_at: datetime + last_processed_at: UpdateMapDateTime class ItemUpdateMap(UnlockUpdateMap, ProcessedUpdateMap, total=False): lock_expires_at: Optional[datetime] lock_token: Optional[uuid.UUID] lock_owner: Optional[str] - last_processed_at: datetime + last_processed_at: UpdateMapDateTime -def set_unlock_update_map_fields(update_map: UnlockUpdateMap) -> None: +def set_unlock_update_map_fields(update_map: UnlockUpdateMap): update_map["lock_expires_at"] = None update_map["lock_token"] = None update_map["lock_owner"] = None @@ -362,8 +383,16 @@ def set_unlock_update_map_fields(update_map: UnlockUpdateMap) -> None: def set_processed_update_map_fields( update_map: ProcessedUpdateMap, - processed_at: Optional[datetime] = None, -) -> None: - if processed_at is None: - processed_at = get_current_datetime() - update_map["last_processed_at"] = processed_at + now: UpdateMapDateTime = NOW_PLACEHOLDER, +): + update_map["last_processed_at"] = now + + +def resolve_now_placeholders(update_values: Any, now: datetime): + if isinstance(update_values, list): + for update_row in update_values: + resolve_now_placeholders(update_row, now) + return + for key, value in update_values.items(): + if value is NOW_PLACEHOLDER: + update_values[key] = now diff --git a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py index c7facb739..822f3adec 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py @@ -12,12 +12,15 @@ from dstack._internal.core.models.compute_groups import ComputeGroupStatus from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.server.background.pipeline_tasks.base import ( + NOW_PLACEHOLDER, Fetcher, Heartbeater, ItemUpdateMap, Pipeline, PipelineItem, + UpdateMapDateTime, Worker, + resolve_now_placeholders, set_processed_update_map_fields, set_unlock_update_map_fields, ) @@ -199,30 +202,28 @@ async def process(self, item: PipelineItem): ) return - terminate_result = _TerminateResult() + result = _TerminateResult() # TODO: Fetch only compute groups with all instances terminating. if all(i.status == InstanceStatus.TERMINATING for i in compute_group_model.instances): - terminate_result = await _terminate_compute_group(compute_group_model) - terminated = terminate_result.compute_group_update_map.get( - "status" - ) == ComputeGroupStatus.TERMINATED or terminate_result.compute_group_update_map.get( - "deleted", False - ) - if terminated: + result = await _terminate_compute_group(compute_group_model) + set_processed_update_map_fields(result.compute_group_update_map) + if result.instances_update_map: + set_processed_update_map_fields(result.instances_update_map) + set_unlock_update_map_fields(result.compute_group_update_map) + if result.compute_group_update_map.get("deleted", False): logger.info("Terminated compute group %s", compute_group_model.id) - else: - set_processed_update_map_fields(terminate_result.compute_group_update_map) - - set_unlock_update_map_fields(terminate_result.compute_group_update_map) async with get_session_ctx() as session: + now = get_current_datetime() + resolve_now_placeholders(result.compute_group_update_map, now=now) + resolve_now_placeholders(result.instances_update_map, now=now) res = await session.execute( update(ComputeGroupModel) .where( ComputeGroupModel.id == compute_group_model.id, ComputeGroupModel.lock_token == compute_group_model.lock_token, ) - .values(**terminate_result.compute_group_update_map) + .values(**result.compute_group_update_map) .returning(ComputeGroupModel.id) ) updated_ids = list(res.scalars().all()) @@ -234,13 +235,13 @@ async def process(self, item: PipelineItem): item.id, ) return - if not terminate_result.instances_update_map: + if not result.instances_update_map: return instances_ids = [i.id for i in compute_group_model.instances] res = await session.execute( update(InstanceModel) .where(InstanceModel.id.in_(instances_ids)) - .values(**terminate_result.instances_update_map) + .values(**result.instances_update_map) ) for instance_model in compute_group_model.instances: emit_instance_status_change_event( @@ -254,16 +255,16 @@ async def process(self, item: PipelineItem): class _ComputeGroupUpdateMap(ItemUpdateMap, total=False): status: ComputeGroupStatus deleted: bool - deleted_at: datetime - first_termination_retry_at: datetime - last_termination_retry_at: datetime + deleted_at: UpdateMapDateTime + first_termination_retry_at: UpdateMapDateTime + last_termination_retry_at: UpdateMapDateTime class _InstanceBulkUpdateMap(TypedDict, total=False): - last_processed_at: datetime + last_processed_at: UpdateMapDateTime deleted: bool - deleted_at: datetime - finished_at: datetime + deleted_at: UpdateMapDateTime + finished_at: UpdateMapDateTime status: InstanceStatus @@ -306,16 +307,16 @@ async def _terminate_compute_group(compute_group_model: ComputeGroupModel) -> _T compute_group, ) except Exception as e: + retry_at = get_current_datetime() + first_termination_retry_at = compute_group_model.first_termination_retry_at if compute_group_model.first_termination_retry_at is None: - result.compute_group_update_map["first_termination_retry_at"] = get_current_datetime() - result.compute_group_update_map["last_termination_retry_at"] = get_current_datetime() - first_termination_retry_at = result.compute_group_update_map.get( - "first_termination_retry_at", compute_group_model.first_termination_retry_at - ) + result.compute_group_update_map["first_termination_retry_at"] = NOW_PLACEHOLDER + first_termination_retry_at = retry_at assert first_termination_retry_at is not None - if _next_termination_retry_at( - result.compute_group_update_map["last_termination_retry_at"] - ) < _get_termination_deadline(first_termination_retry_at): + result.compute_group_update_map["last_termination_retry_at"] = NOW_PLACEHOLDER + if _next_termination_retry_at(retry_at) < _get_termination_deadline( + first_termination_retry_at + ): logger.warning( "Failed to terminate compute group %s. Will retry. Error: %r", compute_group.name, @@ -353,19 +354,16 @@ def _get_termination_deadline(first_termination_retry_at: datetime) -> datetime: def _get_terminated_result() -> _TerminateResult: - now = get_current_datetime() return _TerminateResult( compute_group_update_map={ - "last_processed_at": now, "deleted": True, - "deleted_at": now, + "deleted_at": NOW_PLACEHOLDER, "status": ComputeGroupStatus.TERMINATED, }, instances_update_map={ - "last_processed_at": now, "deleted": True, - "deleted_at": now, - "finished_at": now, + "deleted_at": NOW_PLACEHOLDER, + "finished_at": NOW_PLACEHOLDER, "status": InstanceStatus.TERMINATED, }, ) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index 9ded1bf7d..945c7fa76 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -1,7 +1,7 @@ import asyncio import uuid from dataclasses import dataclass, field -from datetime import datetime, timedelta +from datetime import timedelta from typing import Sequence, TypedDict from sqlalchemy import or_, select, update @@ -11,12 +11,15 @@ from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason from dstack._internal.core.models.runs import RunStatus from dstack._internal.server.background.pipeline_tasks.base import ( + NOW_PLACEHOLDER, Fetcher, Heartbeater, ItemUpdateMap, Pipeline, PipelineItem, + UpdateMapDateTime, Worker, + resolve_now_placeholders, set_processed_update_map_fields, set_unlock_update_map_fields, ) @@ -276,7 +279,17 @@ async def process(self, item: PipelineItem): fleet_update_map.update(result.fleet_update_map) set_processed_update_map_fields(fleet_update_map) set_unlock_update_map_fields(fleet_update_map) + instance_update_rows = [] + for instance_id, instance_update_map in result.instance_id_to_update_map.items(): + update_row = _InstanceUpdateMap() + update_row.update(instance_update_map) + update_row["id"] = instance_id + set_processed_update_map_fields(update_row) + instance_update_rows.append(update_row) async with get_session_ctx() as session: + now = get_current_datetime() + resolve_now_placeholders(fleet_update_map, now=now) + resolve_now_placeholders(instance_update_rows, now=now) res = await session.execute( update(FleetModel) .where( @@ -303,14 +316,6 @@ async def process(self, item: PipelineItem): .where(PlacementGroupModel.fleet_id == item.id) .values(fleet_deleted=True) ) - - instance_update_rows = [] - for instance_id, instance_update_map in result.instance_id_to_update_map.items(): - update_row = _InstanceUpdateMap() - update_row.update(instance_update_map) - update_row["id"] = instance_id - set_processed_update_map_fields(update_row) - instance_update_rows.append(update_row) if instance_update_rows: await session.execute( update(InstanceModel).execution_options(synchronize_session=False), @@ -366,9 +371,9 @@ class _FleetUpdateMap(ItemUpdateMap, total=False): status: FleetStatus status_message: str deleted: bool - deleted_at: datetime + deleted_at: UpdateMapDateTime consolidation_attempt: int - last_consolidated_at: datetime + last_consolidated_at: UpdateMapDateTime class _InstanceUpdateMap(TypedDict, total=False): @@ -376,8 +381,8 @@ class _InstanceUpdateMap(TypedDict, total=False): termination_reason: InstanceTerminationReason termination_reason_message: str deleted: bool - deleted_at: datetime - last_processed_at: datetime + deleted_at: UpdateMapDateTime + last_processed_at: UpdateMapDateTime id: uuid.UUID @@ -409,7 +414,7 @@ async def _process_fleet(fleet_model: FleetModel) -> _ProcessResult: if deleted: result.fleet_update_map["status"] = FleetStatus.TERMINATED result.fleet_update_map["deleted"] = True - result.fleet_update_map["deleted_at"] = get_current_datetime() + result.fleet_update_map["deleted_at"] = NOW_PLACEHOLDER return result @@ -432,7 +437,7 @@ def _consolidate_fleet_state_with_spec(fleet_model: FleetModel) -> _ProcessResul else: # The fleet is consolidated with respect to nodes min/max. result.fleet_update_map["consolidation_attempt"] = 0 - result.fleet_update_map["last_consolidated_at"] = get_current_datetime() + result.fleet_update_map["last_consolidated_at"] = NOW_PLACEHOLDER return result @@ -478,7 +483,7 @@ def _maintain_fleet_nodes_in_min_max_range( result.changes_required = True result.instance_id_to_update_map[instance.id] = { "deleted": True, - "deleted_at": get_current_datetime(), + "deleted_at": NOW_PLACEHOLDER, } active_instances = [ i for i in fleet_model.instances if i.status != InstanceStatus.TERMINATED and not i.deleted diff --git a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py index e802043c0..05b5cc03a 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py @@ -18,6 +18,7 @@ Pipeline, PipelineItem, Worker, + resolve_now_placeholders, set_processed_update_map_fields, set_unlock_update_map_fields, ) @@ -237,6 +238,8 @@ async def _process_submitted_item(item: GatewayPipelineItem): session.add(gateway_compute_model) await session.flush() update_map["gateway_compute_id"] = gateway_compute_model.id + now = get_current_datetime() + resolve_now_placeholders(update_map, now=now) res = await session.execute( update(GatewayModel) .where( @@ -351,18 +354,20 @@ async def _process_provisioning_item(item: GatewayPipelineItem): return result = await _process_provisioning_gateway(gateway_model) - update_map = _GatewayUpdateMap() - update_map.update(result.gateway_update_map) - set_processed_update_map_fields(update_map) - set_unlock_update_map_fields(update_map) + gateway_update_map = result.gateway_update_map + set_processed_update_map_fields(gateway_update_map) + set_unlock_update_map_fields(gateway_update_map) + async with get_session_ctx() as session: + now = get_current_datetime() + resolve_now_placeholders(gateway_update_map, now=now) res = await session.execute( update(GatewayModel) .where( GatewayModel.id == gateway_model.id, GatewayModel.lock_token == gateway_model.lock_token, ) - .values(**update_map) + .values(**gateway_update_map) .returning(GatewayModel.id) ) updated_ids = list(res.scalars().all()) @@ -378,8 +383,8 @@ async def _process_provisioning_item(item: GatewayPipelineItem): session=session, gateway_model=gateway_model, old_status=gateway_model.status, - new_status=update_map.get("status", gateway_model.status), - status_message=update_map.get("status_message", gateway_model.status_message), + new_status=gateway_update_map.get("status", gateway_model.status), + status_message=gateway_update_map.get("status_message", gateway_model.status_message), ) if result.gateway_compute_update_map: res = await session.execute( @@ -497,6 +502,7 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem): update_map = _GatewayUpdateMap() set_processed_update_map_fields(update_map) set_unlock_update_map_fields(update_map) + resolve_now_placeholders(update_map, now=get_current_datetime()) res = await session.execute( update(GatewayModel) .where( diff --git a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py index a5965b015..41f32873c 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py @@ -1,7 +1,7 @@ import asyncio import uuid from dataclasses import dataclass, field -from datetime import datetime, timedelta +from datetime import timedelta from typing import Sequence from sqlalchemy import or_, select, update @@ -10,12 +10,15 @@ from dstack._internal.core.backends.base.compute import ComputeWithPlacementGroupSupport from dstack._internal.core.errors import PlacementGroupInUseError from dstack._internal.server.background.pipeline_tasks.base import ( + NOW_PLACEHOLDER, Fetcher, Heartbeater, ItemUpdateMap, Pipeline, PipelineItem, + UpdateMapDateTime, Worker, + resolve_now_placeholders, set_processed_update_map_fields, set_unlock_update_map_fields, ) @@ -195,16 +198,14 @@ async def process(self, item: PipelineItem): return result = await _delete_placement_group(placement_group_model) - update_map = _PlacementGroupUpdateMap() - update_map.update(result.update_map) + update_map = result.update_map + set_processed_update_map_fields(update_map) + set_unlock_update_map_fields(update_map) if update_map.get("deleted", False): logger.info("Deleted placement group %s", placement_group_model.name) - else: - set_processed_update_map_fields(update_map) - - set_unlock_update_map_fields(update_map) async with get_session_ctx() as session: + resolve_now_placeholders(update_map, now=get_current_datetime()) res = await session.execute( update(PlacementGroupModel) .where( @@ -226,7 +227,7 @@ async def process(self, item: PipelineItem): class _PlacementGroupUpdateMap(ItemUpdateMap, total=False): deleted: bool - deleted_at: datetime + deleted_at: UpdateMapDateTime @dataclass @@ -275,9 +276,7 @@ async def _delete_placement_group( def _get_deleted_update_map() -> _PlacementGroupUpdateMap: - now = get_current_datetime() update_map = _PlacementGroupUpdateMap() - update_map["last_processed_at"] = now update_map["deleted"] = True - update_map["deleted_at"] = now + update_map["deleted_at"] = NOW_PLACEHOLDER return update_map diff --git a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py index e4f4f5a15..1c9461a87 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py @@ -1,7 +1,7 @@ import asyncio import uuid from dataclasses import dataclass, field -from datetime import datetime, timedelta +from datetime import timedelta from typing import Sequence from sqlalchemy import or_, select, update @@ -11,12 +11,15 @@ from dstack._internal.core.errors import BackendError, BackendNotAvailable from dstack._internal.core.models.volumes import VolumeStatus from dstack._internal.server.background.pipeline_tasks.base import ( + NOW_PLACEHOLDER, Fetcher, Heartbeater, ItemUpdateMap, Pipeline, PipelineItem, + UpdateMapDateTime, Worker, + resolve_now_placeholders, set_processed_update_map_fields, set_unlock_update_map_fields, ) @@ -233,11 +236,12 @@ async def _process_submitted_item(item: VolumePipelineItem): return result = await _process_submitted_volume(volume_model) - update_map = _VolumeUpdateMap() - update_map.update(result.update_map) + update_map = result.update_map set_processed_update_map_fields(update_map) set_unlock_update_map_fields(update_map) + async with get_session_ctx() as session: + resolve_now_placeholders(update_map, now=get_current_datetime()) res = await session.execute( update(VolumeModel) .where( @@ -271,7 +275,7 @@ class _VolumeUpdateMap(ItemUpdateMap, total=False): status_message: str volume_provisioning_data: str deleted: bool - deleted_at: datetime + deleted_at: UpdateMapDateTime @dataclass @@ -376,8 +380,11 @@ async def _process_to_be_deleted_item(item: VolumePipelineItem): result = await _process_to_be_deleted_volume(volume_model) update_map = _VolumeUpdateMap() update_map.update(result.update_map) + set_processed_update_map_fields(update_map) set_unlock_update_map_fields(update_map) async with get_session_ctx() as session: + now = get_current_datetime() + resolve_now_placeholders(update_map, now=now) res = await session.execute( update(VolumeModel) .where( @@ -451,11 +458,9 @@ async def _process_to_be_deleted_volume(volume_model: VolumeModel) -> _DeletedRe def _get_deleted_result() -> _DeletedResult: - now = get_current_datetime() return _DeletedResult( update_map={ - "last_processed_at": now, "deleted": True, - "deleted_at": now, + "deleted_at": NOW_PLACEHOLDER, } ) From 818eda73a44e2a31736e1ae7e89e1415266ea53c Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 27 Feb 2026 15:43:41 +0500 Subject: [PATCH 11/21] Refactor _build_instance_update_rows() --- .../background/pipeline_tasks/fleets.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index 945c7fa76..40bc5a860 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -279,13 +279,8 @@ async def process(self, item: PipelineItem): fleet_update_map.update(result.fleet_update_map) set_processed_update_map_fields(fleet_update_map) set_unlock_update_map_fields(fleet_update_map) - instance_update_rows = [] - for instance_id, instance_update_map in result.instance_id_to_update_map.items(): - update_row = _InstanceUpdateMap() - update_row.update(instance_update_map) - update_row["id"] = instance_id - set_processed_update_map_fields(update_row) - instance_update_rows.append(update_row) + instance_update_rows = _build_instance_update_rows(result.instance_id_to_update_map) + async with get_session_ctx() as session: now = get_current_datetime() resolve_now_placeholders(fleet_update_map, now=now) @@ -538,3 +533,16 @@ def _autodelete_fleet(fleet_model: FleetModel) -> bool: logger.info("Automatic cleanup of an empty fleet %s", fleet_model.name) return True + + +def _build_instance_update_rows( + instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap], +) -> list[_InstanceUpdateMap]: + instance_update_rows = [] + for instance_id, instance_update_map in instance_id_to_update_map.items(): + update_row = _InstanceUpdateMap() + update_row.update(instance_update_map) + update_row["id"] = instance_id + set_processed_update_map_fields(update_row) + instance_update_rows.append(update_row) + return instance_update_rows From f20342c8f82ace055e355a24542d92d58f55f0f0 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 27 Feb 2026 15:56:57 +0500 Subject: [PATCH 12/21] Make result naming consistent --- .../pipeline_tasks/compute_groups.py | 13 +++---------- .../background/pipeline_tasks/gateways.py | 8 ++++---- .../pipeline_tasks/placement_groups.py | 18 ++++++++---------- .../background/pipeline_tasks/volumes.py | 8 ++++---- 4 files changed, 19 insertions(+), 28 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py index 822f3adec..0ee2975eb 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py @@ -333,16 +333,9 @@ async def _terminate_compute_group(compute_group_model: ComputeGroupModel) -> _T exc_info=not isinstance(e, BackendError), ) terminated_result = _get_terminated_result() - compute_group_update_map = _ComputeGroupUpdateMap() - compute_group_update_map.update(result.compute_group_update_map) - compute_group_update_map.update(terminated_result.compute_group_update_map) - instances_update_map = _InstanceBulkUpdateMap() - instances_update_map.update(result.instances_update_map) - instances_update_map.update(terminated_result.instances_update_map) - return _TerminateResult( - compute_group_update_map=compute_group_update_map, - instances_update_map=instances_update_map, - ) + terminated_result.compute_group_update_map.update(result.compute_group_update_map) + terminated_result.instances_update_map.update(result.instances_update_map) + return terminated_result def _next_termination_retry_at(last_termination_retry_at: datetime) -> datetime: diff --git a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py index 05b5cc03a..2d5f0a947 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py @@ -541,14 +541,14 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem): @dataclass -class _DeletedResult: +class _ProcessToBeDeletedResult: delete_gateway: bool gateway_compute_update_map: _GatewayComputeUpdateMap = field( default_factory=_GatewayComputeUpdateMap ) -async def _process_to_be_deleted_gateway(gateway_model: GatewayModel) -> _DeletedResult: +async def _process_to_be_deleted_gateway(gateway_model: GatewayModel) -> _ProcessToBeDeletedResult: assert gateway_model.backend.type != BackendType.DSTACK backend = await backends_services.get_project_backend_by_type_or_error( project=gateway_model.project, backend_type=gateway_model.backend.type @@ -572,9 +572,9 @@ async def _process_to_be_deleted_gateway(gateway_model: GatewayModel) -> _Delete "Error when deleting gateway compute for %s", gateway_model.name, ) - return _DeletedResult(delete_gateway=False) + return _ProcessToBeDeletedResult(delete_gateway=False) logger.info("Deleted gateway compute for %s", gateway_model.name) - result = _DeletedResult(delete_gateway=True) + result = _ProcessToBeDeletedResult(delete_gateway=True) if gateway_model.gateway_compute is not None: await gateway_connections_pool.remove(gateway_model.gateway_compute.ip_address) result.gateway_compute_update_map = {"active": False, "deleted": True} diff --git a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py index 41f32873c..703cfe154 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py @@ -231,19 +231,19 @@ class _PlacementGroupUpdateMap(ItemUpdateMap, total=False): @dataclass -class _DeletePlacementGroupResult: +class _DeleteResult: update_map: _PlacementGroupUpdateMap = field(default_factory=_PlacementGroupUpdateMap) async def _delete_placement_group( placement_group_model: PlacementGroupModel, -) -> _DeletePlacementGroupResult: +) -> _DeleteResult: placement_group = placement_group_model_to_placement_group(placement_group_model) if placement_group.provisioning_data is None: logger.error( "Failed to delete placement group %s. provisioning_data is None.", placement_group.name ) - return _DeletePlacementGroupResult(update_map=_get_deleted_update_map()) + return _get_deleted_result() backend = await backends_services.get_project_backend_by_type( project=placement_group_model.project, backend_type=placement_group.provisioning_data.backend, @@ -254,7 +254,7 @@ async def _delete_placement_group( "Failed to delete placement group %s. Backend not available. Please delete it manually.", placement_group.name, ) - return _DeletePlacementGroupResult(update_map=_get_deleted_update_map()) + return _get_deleted_result() compute = backend.compute() assert isinstance(compute, ComputeWithPlacementGroupSupport) try: @@ -263,20 +263,18 @@ async def _delete_placement_group( logger.info( "Placement group %s is still in use. Skipping deletion for now.", placement_group.name ) - return _DeletePlacementGroupResult() + return _DeleteResult() except Exception: # TODO: Retry deletion logger.exception( "Got exception when deleting placement group %s. Please delete it manually.", placement_group.name, ) - return _DeletePlacementGroupResult(update_map=_get_deleted_update_map()) + return _get_deleted_result() - return _DeletePlacementGroupResult(update_map=_get_deleted_update_map()) - -def _get_deleted_update_map() -> _PlacementGroupUpdateMap: +def _get_deleted_result() -> _DeleteResult: update_map = _PlacementGroupUpdateMap() update_map["deleted"] = True update_map["deleted_at"] = NOW_PLACEHOLDER - return update_map + return _DeleteResult(update_map=update_map) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py index 1c9461a87..c7a8f5761 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py @@ -412,11 +412,11 @@ async def _process_to_be_deleted_item(item: VolumePipelineItem): @dataclass -class _DeletedResult: +class _ProcessToBeDeletedResult: update_map: _VolumeUpdateMap = field(default_factory=_VolumeUpdateMap) -async def _process_to_be_deleted_volume(volume_model: VolumeModel) -> _DeletedResult: +async def _process_to_be_deleted_volume(volume_model: VolumeModel) -> _ProcessToBeDeletedResult: volume = volume_model_to_volume(volume_model) if volume.external: return _get_deleted_result() @@ -457,8 +457,8 @@ async def _process_to_be_deleted_volume(volume_model: VolumeModel) -> _DeletedRe return _get_deleted_result() -def _get_deleted_result() -> _DeletedResult: - return _DeletedResult( +def _get_deleted_result() -> _ProcessToBeDeletedResult: + return _ProcessToBeDeletedResult( update_map={ "deleted": True, "deleted_at": NOW_PLACEHOLDER, From 5f0c1d834377ca1b954836de308d40ea3675da26 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 27 Feb 2026 16:29:37 +0500 Subject: [PATCH 13/21] Refactor _create_missing_fleet_instances() --- .../server/background/pipeline_tasks/base.py | 36 +++++++-- .../background/pipeline_tasks/fleets.py | 81 +++++++++++-------- 2 files changed, 74 insertions(+), 43 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/base.py b/src/dstack/_internal/server/background/pipeline_tasks/base.py index 59a372760..ff86ba288 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/base.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/base.py @@ -3,6 +3,7 @@ import random import uuid from abc import ABC, abstractmethod +from collections.abc import Iterable, Sequence from dataclasses import dataclass from datetime import datetime, timedelta from typing import ( @@ -12,7 +13,6 @@ Generic, Optional, Protocol, - Sequence, TypedDict, TypeVar, Union, @@ -353,46 +353,66 @@ class _NowPlaceholder: NOW_PLACEHOLDER: Final = _NowPlaceholder() +""" +Use `NOW_PLACEHOLDER` together with `resolve_now_placeholders()` in pipeline update maps +instead of `get_current_time()` to have the same current time for all updates in the transaction. +""" # Timestamp value stored in update maps before being resolved to current time. UpdateMapDateTime = Union[datetime, _NowPlaceholder] -class UnlockUpdateMap(TypedDict, total=False): +class _UnlockUpdateMap(TypedDict, total=False): lock_expires_at: Optional[datetime] lock_token: Optional[uuid.UUID] lock_owner: Optional[str] -class ProcessedUpdateMap(TypedDict, total=False): +class _ProcessedUpdateMap(TypedDict, total=False): last_processed_at: UpdateMapDateTime -class ItemUpdateMap(UnlockUpdateMap, ProcessedUpdateMap, total=False): +class ItemUpdateMap(_UnlockUpdateMap, _ProcessedUpdateMap, total=False): lock_expires_at: Optional[datetime] lock_token: Optional[uuid.UUID] lock_owner: Optional[str] last_processed_at: UpdateMapDateTime -def set_unlock_update_map_fields(update_map: UnlockUpdateMap): +def set_unlock_update_map_fields(update_map: _UnlockUpdateMap): update_map["lock_expires_at"] = None update_map["lock_token"] = None update_map["lock_owner"] = None def set_processed_update_map_fields( - update_map: ProcessedUpdateMap, + update_map: _ProcessedUpdateMap, now: UpdateMapDateTime = NOW_PLACEHOLDER, ): update_map["last_processed_at"] = now -def resolve_now_placeholders(update_values: Any, now: datetime): - if isinstance(update_values, list): +class _ResolveNowUpdateMap(Protocol): + def items(self) -> Iterable[tuple[str, object]]: ... + + +_ResolveNowInput = Union[_ResolveNowUpdateMap, Sequence[_ResolveNowUpdateMap]] + + +def resolve_now_placeholders(update_values: _ResolveNowInput, now: datetime): + """ + Replaces `NOW_PLACEHOLDER` with `now` in an update map or a sequence of update rows. + """ + if isinstance(update_values, Sequence): for update_row in update_values: resolve_now_placeholders(update_row, now) return + # Runtime dict narrowing is required here: pyright doesn't model TypedDicts as + # supporting generic dynamic-key mutation via protocol methods. + if not isinstance(update_values, dict): + raise TypeError( + "resolve_now_placeholders() expects update maps or sequences of update maps" + ) for key, value in update_values.items(): if value is NOW_PLACEHOLDER: update_values[key] = now diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index 40bc5a860..824749df8 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -5,6 +5,7 @@ from typing import Sequence, TypedDict from sqlalchemy import or_, select, update +from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.orm import joinedload, load_only, selectinload from dstack._internal.core.models.fleets import FleetSpec, FleetStatus @@ -316,43 +317,12 @@ async def process(self, item: PipelineItem): update(InstanceModel).execution_options(synchronize_session=False), instance_update_rows, ) - if result.new_instances_count > 0: - fleet_spec = get_fleet_spec(fleet_model) - res = await session.execute( - select(InstanceModel.instance_num).where( - InstanceModel.fleet_id == fleet_model.id, - InstanceModel.deleted == False, - ) - ) - taken_instance_nums = set(res.scalars().all()) - for _ in range(result.new_instances_count): - instance_num = get_next_instance_num(taken_instance_nums) - instance_model = create_fleet_instance_model( - session=session, - project=fleet_model.project, - # TODO: Store fleet.user and pass it instead of the project owner. - username=fleet_model.project.owner.name, - spec=fleet_spec, - instance_num=instance_num, - ) - instance_model.fleet_id = fleet_model.id - taken_instance_nums.add(instance_num) - events.emit( - session=session, - message=( - "Instance created to meet target fleet node count." - f" Status: {instance_model.status.upper()}" - ), - actor=events.SystemActor(), - targets=[events.Target.from_model(instance_model)], - ) - logger.info( - "Added %d instances to fleet %s", - result.new_instances_count, - fleet_model.name, + await _create_missing_fleet_instances( + session=session, + fleet_model=fleet_model, + new_instances_count=result.new_instances_count, ) - emit_fleet_status_change_event( session=session, fleet_model=fleet_model, @@ -546,3 +516,44 @@ def _build_instance_update_rows( set_processed_update_map_fields(update_row) instance_update_rows.append(update_row) return instance_update_rows + + +async def _create_missing_fleet_instances( + session: AsyncSession, + fleet_model: FleetModel, + new_instances_count: int, +): + fleet_spec = get_fleet_spec(fleet_model) + res = await session.execute( + select(InstanceModel.instance_num).where( + InstanceModel.fleet_id == fleet_model.id, + InstanceModel.deleted == False, + ) + ) + taken_instance_nums = set(res.scalars().all()) + for _ in range(new_instances_count): + instance_num = get_next_instance_num(taken_instance_nums) + instance_model = create_fleet_instance_model( + session=session, + project=fleet_model.project, + # TODO: Store fleet.user and pass it instead of the project owner. + username=fleet_model.project.owner.name, + spec=fleet_spec, + instance_num=instance_num, + ) + instance_model.fleet_id = fleet_model.id + taken_instance_nums.add(instance_num) + events.emit( + session=session, + message=( + "Instance created to meet target fleet node count." + f" Status: {instance_model.status.upper()}" + ), + actor=events.SystemActor(), + targets=[events.Target.from_model(instance_model)], + ) + logger.info( + "Added %d instances to fleet %s", + new_instances_count, + fleet_model.name, + ) From 486b89b91a8c006e293cc9a384819286f7c46532 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 27 Feb 2026 16:31:34 +0500 Subject: [PATCH 14/21] Enable FleetPipeline --- .../server/background/pipeline_tasks/__init__.py | 2 ++ .../server/background/pipeline_tasks/fleets.py | 4 ++-- .../server/background/scheduled_tasks/__init__.py | 10 +++++----- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py index d9f67680c..6b3762419 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py @@ -2,6 +2,7 @@ from dstack._internal.server.background.pipeline_tasks.base import Pipeline from dstack._internal.server.background.pipeline_tasks.compute_groups import ComputeGroupPipeline +from dstack._internal.server.background.pipeline_tasks.fleets import FleetPipeline from dstack._internal.server.background.pipeline_tasks.gateways import GatewayPipeline from dstack._internal.server.background.pipeline_tasks.placement_groups import ( PlacementGroupPipeline, @@ -16,6 +17,7 @@ class PipelineManager: def __init__(self) -> None: self._pipelines: list[Pipeline] = [ ComputeGroupPipeline(), + FleetPipeline(), GatewayPipeline(), PlacementGroupPipeline(), VolumePipeline(), diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index 824749df8..ed41cafa1 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -56,8 +56,8 @@ def __init__( queue_lower_limit_factor: float = 0.5, queue_upper_limit_factor: float = 2.0, min_processing_interval: timedelta = timedelta(seconds=30), - lock_timeout: timedelta = timedelta(seconds=30), - heartbeat_trigger: timedelta = timedelta(seconds=15), + lock_timeout: timedelta = timedelta(seconds=20), + heartbeat_trigger: timedelta = timedelta(seconds=10), ) -> None: super().__init__( workers_num=workers_num, diff --git a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py index 45ae8ec7f..9c7cd6ac1 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py @@ -102,13 +102,13 @@ def start_scheduled_tasks() -> AsyncIOScheduler: _scheduler.add_job( process_idle_volumes, IntervalTrigger(seconds=60, jitter=10), max_instances=1 ) - _scheduler.add_job( - process_fleets, - IntervalTrigger(seconds=10, jitter=2), - max_instances=1, - ) _scheduler.add_job(delete_instance_health_checks, IntervalTrigger(minutes=5), max_instances=1) if not FeatureFlags.PIPELINE_PROCESSING_ENABLED: + _scheduler.add_job( + process_fleets, + IntervalTrigger(seconds=10, jitter=2), + max_instances=1, + ) _scheduler.add_job( process_gateways, IntervalTrigger(seconds=10, jitter=2), max_instances=5 ) From 05003466837a4207d59caa94e5260292a741c995 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 27 Feb 2026 17:16:27 +0500 Subject: [PATCH 15/21] Respect fleet locks in the API endpoints --- .../scheduled_tasks/submitted_jobs.py | 2 + .../_internal/server/services/fleets.py | 52 ++++++++++++--- .../server/services/gateways/__init__.py | 2 +- .../_internal/server/services/volumes.py | 2 +- .../_internal/server/routers/test_fleets.py | 63 +++++++++++++++++++ 5 files changed, 112 insertions(+), 9 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py index 932db5294..151f07dee 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py @@ -581,6 +581,8 @@ async def _fetch_fleet_with_master_instance_provisioning_data( # To avoid violating fleet placement cluster during master provisioning, # we must lock empty fleets and respect existing instances in non-empty fleets. # On SQLite always take the lock during master provisioning for simplicity. + # It's fine to lock fleets currently locked by pipelines (with lock_* fields set) + # since we won't update fleets – we only need to ensure there is no parallel provisioning. await exit_stack.enter_async_context( get_locker(get_db().dialect_name).lock_ctx( FleetModel.__tablename__, [fleet_model.id] diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index db35331fd..5321f04f5 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -531,6 +531,12 @@ async def apply_plan( ) fleet_model = res.scalars().unique().one_or_none() if fleet_model is not None: + if fleet_model.lock_expires_at is not None: + # TODO: Make the endpoint fully async so we don't need to lock and error: + # put the request in queue and process in the background. + raise ServerClientError( + "Failed to update fleet: fleet is being processed currently. Try again later." + ) return await _update_fleet( session=session, user=user, @@ -668,8 +674,7 @@ async def delete_fleets( FleetModel.name.in_(names), FleetModel.deleted == False, ) - .order_by(FleetModel.id) # take locks in order - .with_for_update(key_share=True) + .order_by(FleetModel.id) ) fleets_ids = list(res.scalars().unique().all()) res = await session.execute( @@ -678,8 +683,7 @@ async def delete_fleets( InstanceModel.fleet_id.in_(fleets_ids), InstanceModel.deleted == False, ) - .order_by(InstanceModel.id) # take locks in order - .with_for_update(key_share=True) + .order_by(InstanceModel.id) ) instances_ids = list(res.scalars().unique().all()) if is_db_sqlite(): @@ -693,7 +697,12 @@ async def delete_fleets( # TODO: Do not lock fleet when deleting only instances. res = await session.execute( select(FleetModel) - .where(FleetModel.id.in_(fleets_ids)) + .where( + FleetModel.project_id == project.id, + FleetModel.id.in_(fleets_ids), + FleetModel.deleted == False, + FleetModel.lock_expires_at.is_(None), + ) .options( selectinload(FleetModel.instances.and_(InstanceModel.id.in_(instances_ids))) .selectinload(InstanceModel.jobs) @@ -705,10 +714,39 @@ async def delete_fleets( ).load_only(RunModel.status) ) .execution_options(populate_existing=True) + .order_by(FleetModel.id) # take locks in order + .with_for_update(key_share=True, of=FleetModel) ) fleet_models = res.scalars().unique().all() - fleets = [fleet_model_to_fleet(m) for m in fleet_models] - for fleet in fleets: + if len(fleet_models) != len(fleets_ids): + # TODO: Make the endpoint fully async so we don't need to lock and error: + # put the request in queue and process in the background. + msg = ( + "Failed to delete fleets: fleets are being processed currently. Try again later." + if instance_nums is None + else "Failed to delete fleet instances: fleets are being processed currently. Try again later." + ) + raise ServerClientError(msg) + res = await session.execute( + select(InstanceModel.id) + .where( + InstanceModel.id.in_(instances_ids), + InstanceModel.deleted == False, + ) + .order_by(InstanceModel.id) # take locks in order + .with_for_update(key_share=True, of=InstanceModel) + .execution_options(populate_existing=True) + ) + instance_models_ids = list(res.scalars().unique().all()) + if len(instance_models_ids) != len(instances_ids): + msg = ( + "Failed to delete fleets: fleet instances are being processed currently. Try again later." + if instance_nums is None + else "Failed to delete fleet instances: fleet instances are being processed currently. Try again later." + ) + raise ServerClientError(msg) + for fleet_model in fleet_models: + fleet = fleet_model_to_fleet(fleet_model) if fleet.spec.configuration.ssh_config is not None: _check_can_manage_ssh_fleets(user=user, project=project) if instance_nums is None: diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index 762af8bef..ddc3d64c4 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -356,7 +356,7 @@ async def _delete_gateways_pipeline( ) gateway_models = res.scalars().all() if len(gateway_models) != len(gateways_ids): - # TODO: Make the delete endpoint fully async so we don't need to lock and error: + # TODO: Make the endpoint fully async so we don't need to lock and error: # put the request in queue and process in the background. raise ServerClientError( "Failed to delete gateways: gateways are being processed currently. Try again later." diff --git a/src/dstack/_internal/server/services/volumes.py b/src/dstack/_internal/server/services/volumes.py index f0d2fc703..1c846c724 100644 --- a/src/dstack/_internal/server/services/volumes.py +++ b/src/dstack/_internal/server/services/volumes.py @@ -369,7 +369,7 @@ async def _delete_volumes_pipeline( ) volume_models = res.scalars().unique().all() if len(volume_models) != len(volumes_ids): - # TODO: Make the delete endpoint fully async so we don't need to lock and error: + # TODO: Make the endpoint fully async so we don't need to lock and error: # put the request in queue and process in the background. raise ServerClientError( "Failed to delete volumes: volumes are being processed currently. Try again later." diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index fef712aca..4121acb26 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -927,6 +927,37 @@ async def test_returns_400_when_fleets_in_use( assert not fleet.deleted assert instance.status == InstanceStatus.BUSY + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_400_when_fleet_locked( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + fleet = await create_fleet(session=session, project=project) + instance = await create_instance( + session=session, + project=project, + ) + fleet.instances.append(instance) + fleet.lock_expires_at = datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc) + await session.commit() + + response = await client.post( + f"/api/project/{project.name}/fleets/delete", + headers=get_auth_headers(user.token), + json={"names": [fleet.name]}, + ) + assert response.status_code == 400 + + await session.refresh(fleet) + await session.refresh(instance) + assert fleet.status != FleetStatus.TERMINATING + assert instance.status != InstanceStatus.TERMINATING + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_forbids_if_no_permission_to_manage_ssh_fleets( @@ -1053,6 +1084,38 @@ async def test_returns_400_when_deleting_busy_instances( assert instance.status != InstanceStatus.TERMINATING assert fleet.status != FleetStatus.TERMINATING + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_400_when_fleet_locked( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + fleet = await create_fleet(session=session, project=project) + instance = await create_instance( + session=session, + project=project, + instance_num=1, + ) + fleet.instances.append(instance) + fleet.lock_expires_at = datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc) + await session.commit() + + response = await client.post( + f"/api/project/{project.name}/fleets/delete_instances", + headers=get_auth_headers(user.token), + json={"name": fleet.name, "instance_nums": [1]}, + ) + assert response.status_code == 400 + + await session.refresh(fleet) + await session.refresh(instance) + assert fleet.status != FleetStatus.TERMINATING + assert instance.status != InstanceStatus.TERMINATING + class TestGetPlan: @pytest.mark.asyncio From d64072474e3349e0aca3b1adaf5aaae50d2e266e Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 27 Feb 2026 17:26:26 +0500 Subject: [PATCH 16/21] Add FleetModel pipeline migration --- ...e61de27_add_fleetmodel_pipeline_columns.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 src/dstack/_internal/server/migrations/versions/2026/02_27_1218_d21d3e61de27_add_fleetmodel_pipeline_columns.py diff --git a/src/dstack/_internal/server/migrations/versions/2026/02_27_1218_d21d3e61de27_add_fleetmodel_pipeline_columns.py b/src/dstack/_internal/server/migrations/versions/2026/02_27_1218_d21d3e61de27_add_fleetmodel_pipeline_columns.py new file mode 100644 index 000000000..fad3da790 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/02_27_1218_d21d3e61de27_add_fleetmodel_pipeline_columns.py @@ -0,0 +1,47 @@ +"""Add FleetModel pipeline columns + +Revision ID: d21d3e61de27 +Revises: 9a363c3cbe04 +Create Date: 2026-02-27 12:18:01.768776+00:00 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +import dstack._internal.server.models + +# revision identifiers, used by Alembic. +revision = "d21d3e61de27" +down_revision = "9a363c3cbe04" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("fleets", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "lock_expires_at", dstack._internal.server.models.NaiveDateTime(), nullable=True + ) + ) + batch_op.add_column( + sa.Column( + "lock_token", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True + ) + ) + batch_op.add_column(sa.Column("lock_owner", sa.String(length=100), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("fleets", schema=None) as batch_op: + batch_op.drop_column("lock_owner") + batch_op.drop_column("lock_token") + batch_op.drop_column("lock_expires_at") + + # ### end Alembic commands ### From 557d1a5ac25bfbfa09607a52d3ee59ea27316c67 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 27 Feb 2026 17:37:04 +0500 Subject: [PATCH 17/21] Adjust CONSOLIDATION_RETRY_DELAYS --- .../_internal/server/background/pipeline_tasks/fleets.py | 5 ++--- .../server/background/pipeline_tasks/test_fleets.py | 6 ++++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index ed41cafa1..075e2f1b9 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -55,7 +55,7 @@ def __init__( workers_num: int = 10, queue_lower_limit_factor: float = 0.5, queue_upper_limit_factor: float = 2.0, - min_processing_interval: timedelta = timedelta(seconds=30), + min_processing_interval: timedelta = timedelta(seconds=60), lock_timeout: timedelta = timedelta(seconds=20), heartbeat_trigger: timedelta = timedelta(seconds=10), ) -> None: @@ -416,13 +416,12 @@ def _is_fleet_ready_for_consolidation(fleet_model: FleetModel) -> bool: # We use exponentially increasing consolidation retry delays so that # consolidation does not happen too often. In particular, this prevents # retrying instance provisioning constantly in case of no offers. -# TODO: Adjust delays. _CONSOLIDATION_RETRY_DELAYS = [ - timedelta(seconds=30), timedelta(minutes=1), timedelta(minutes=2), timedelta(minutes=5), timedelta(minutes=10), + timedelta(minutes=30), ] diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py index 133e9084e..746ddf2ea 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py @@ -392,5 +392,7 @@ async def test_consolidation_attempt_resets_when_no_changes( ) assert len(instances) == 1 assert fleet.consolidation_attempt == 0 - assert fleet.last_consolidated_at is not None - assert fleet.last_consolidated_at > previous_last_consolidated_at + assert ( + fleet.last_consolidated_at is not None + and fleet.last_consolidated_at > previous_last_consolidated_at + ) From bd5b998e8cbe4c4caf3ede19781e36329022b0b3 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 27 Feb 2026 17:48:33 +0500 Subject: [PATCH 18/21] Fix fleet autodelete comments --- .../server/background/pipeline_tasks/fleets.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index 075e2f1b9..55ffcd7f9 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -372,11 +372,10 @@ def has_changes(self) -> bool: async def _process_fleet(fleet_model: FleetModel) -> _ProcessResult: result = _consolidate_fleet_state_with_spec(fleet_model) if result.new_instances_count > 0: - # Avoid auto-deleting empty fleets that are about to receive new instances. + # Avoid deleting fleets that are about to provision new instances. return result - # TODO: Drop fleets auto-deletion after dropping fleets auto-creation. - deleted = _autodelete_fleet(fleet_model) - if deleted: + delete = _should_delete_fleet(fleet_model) + if delete: result.fleet_update_map["status"] = FleetStatus.TERMINATED result.fleet_update_map["deleted"] = True result.fleet_update_map["deleted_at"] = NOW_PLACEHOLDER @@ -481,7 +480,7 @@ def _maintain_fleet_nodes_in_min_max_range( return result -def _autodelete_fleet(fleet_model: FleetModel) -> bool: +def _should_delete_fleet(fleet_model: FleetModel) -> bool: if fleet_model.project.deleted: # It used to be possible to delete project with active resources: # https://github.com/dstackai/dstack/issues/3077 @@ -491,6 +490,7 @@ def _autodelete_fleet(fleet_model: FleetModel) -> bool: if is_fleet_in_use(fleet_model) or not is_fleet_empty(fleet_model): return False + # TODO: Drop non-terminating fleets auto-deletion after dropping fleets auto-creation. fleet_spec = get_fleet_spec(fleet_model) if ( fleet_model.status != FleetStatus.TERMINATING From 25aa3b057d17503e22c00a84de124d90aab04813 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 27 Feb 2026 17:57:44 +0500 Subject: [PATCH 19/21] Add scheduled tasks deprecated note --- .../server/background/scheduled_tasks/compute_groups.py | 2 ++ .../_internal/server/background/scheduled_tasks/fleets.py | 2 ++ .../_internal/server/background/scheduled_tasks/gateways.py | 2 ++ .../server/background/scheduled_tasks/placement_groups.py | 2 ++ .../_internal/server/background/scheduled_tasks/volumes.py | 2 ++ 5 files changed, 10 insertions(+) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/compute_groups.py b/src/dstack/_internal/server/background/scheduled_tasks/compute_groups.py index feb1cc507..58d6b2c8b 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/compute_groups.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/compute_groups.py @@ -32,6 +32,8 @@ TERMINATION_RETRY_MAX_DURATION = timedelta(minutes=15) +# NOTE: This scheduled task is going to be deprecated in favor of `ComputeGroupPipeline`. +# If this logic changes before removal, keep `pipeline_tasks/compute_groups.py` in sync. async def process_compute_groups(batch_size: int = 1): tasks = [] for _ in range(batch_size): diff --git a/src/dstack/_internal/server/background/scheduled_tasks/fleets.py b/src/dstack/_internal/server/background/scheduled_tasks/fleets.py index c19d5c516..6b1ba7667 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/fleets.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/fleets.py @@ -40,6 +40,8 @@ MIN_PROCESSING_INTERVAL = timedelta(seconds=30) +# NOTE: This scheduled task is going to be deprecated in favor of `FleetPipeline`. +# If this logic changes before removal, keep `pipeline_tasks/fleets.py` in sync. @sentry_utils.instrument_scheduled_task async def process_fleets(): fleet_lock, fleet_lockset = get_locker(get_db().dialect_name).get_lockset( diff --git a/src/dstack/_internal/server/background/scheduled_tasks/gateways.py b/src/dstack/_internal/server/background/scheduled_tasks/gateways.py index fc12e8e3b..262f45a18 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/gateways.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/gateways.py @@ -35,6 +35,8 @@ async def process_gateways_connections(): await _process_active_connections() +# NOTE: This scheduled task is going to be deprecated in favor of `GatewayPipeline`. +# If this logic changes before removal, keep `pipeline_tasks/gateways.py` in sync. @sentry_utils.instrument_scheduled_task async def process_gateways(): lock, lockset = get_locker(get_db().dialect_name).get_lockset(GatewayModel.__tablename__) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/placement_groups.py b/src/dstack/_internal/server/background/scheduled_tasks/placement_groups.py index 71ab51b07..1106ce491 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/placement_groups.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/placement_groups.py @@ -19,6 +19,8 @@ logger = get_logger(__name__) +# NOTE: This scheduled task is going to be deprecated in favor of `PlacementGroupPipeline`. +# If this logic changes before removal, keep `pipeline_tasks/placement_groups.py` in sync. @sentry_utils.instrument_scheduled_task async def process_placement_groups(): lock, lockset = get_locker(get_db().dialect_name).get_lockset( diff --git a/src/dstack/_internal/server/background/scheduled_tasks/volumes.py b/src/dstack/_internal/server/background/scheduled_tasks/volumes.py index a61f79694..11e6f3c59 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/volumes.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/volumes.py @@ -24,6 +24,8 @@ logger = get_logger(__name__) +# NOTE: This scheduled task is going to be deprecated in favor of `VolumePipeline`. +# If this logic changes before removal, keep `pipeline_tasks/volumes.py` in sync. @sentry_utils.instrument_scheduled_task async def process_submitted_volumes(): lock, lockset = get_locker(get_db().dialect_name).get_lockset(VolumeModel.__tablename__) From 0499b0e4b59a8503c24599f706929aad5a716c10 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 27 Feb 2026 18:09:28 +0500 Subject: [PATCH 20/21] Cleanup comment --- src/dstack/_internal/server/background/pipeline_tasks/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/base.py b/src/dstack/_internal/server/background/pipeline_tasks/base.py index ff86ba288..aa5af9a4a 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/base.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/base.py @@ -358,7 +358,7 @@ class _NowPlaceholder: instead of `get_current_time()` to have the same current time for all updates in the transaction. """ -# Timestamp value stored in update maps before being resolved to current time. + UpdateMapDateTime = Union[datetime, _NowPlaceholder] From 0b821394680a79f3da9131258f260d1c7cef016b Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 2 Mar 2026 10:32:35 +0500 Subject: [PATCH 21/21] Add ix_fleets_pipeline_fetch_q index --- ...ec_add_ix_fleets_pipeline_fetch_q_index.py | 49 +++++++++++++++++++ src/dstack/_internal/server/models.py | 9 ++++ 2 files changed, 58 insertions(+) create mode 100644 src/dstack/_internal/server/migrations/versions/2026/03_02_0530_46150101edec_add_ix_fleets_pipeline_fetch_q_index.py diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_02_0530_46150101edec_add_ix_fleets_pipeline_fetch_q_index.py b/src/dstack/_internal/server/migrations/versions/2026/03_02_0530_46150101edec_add_ix_fleets_pipeline_fetch_q_index.py new file mode 100644 index 000000000..365aac41c --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/03_02_0530_46150101edec_add_ix_fleets_pipeline_fetch_q_index.py @@ -0,0 +1,49 @@ +"""Add ix_fleets_pipeline_fetch_q index + +Revision ID: 46150101edec +Revises: d21d3e61de27 +Create Date: 2026-03-02 05:30:07.196407+00:00 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "46150101edec" +down_revision = "d21d3e61de27" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.get_context().autocommit_block(): + op.drop_index( + "ix_fleets_pipeline_fetch_q", + table_name="fleets", + if_exists=True, + postgresql_concurrently=True, + ) + op.create_index( + "ix_fleets_pipeline_fetch_q", + "fleets", + [sa.literal_column("last_processed_at ASC")], + unique=False, + sqlite_where=sa.text("deleted = 0"), + postgresql_where=sa.text("deleted IS FALSE"), + postgresql_concurrently=True, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.get_context().autocommit_block(): + op.drop_index( + "ix_fleets_pipeline_fetch_q", + table_name="fleets", + if_exists=True, + postgresql_concurrently=True, + ) + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 35f7eb4ec..15c5488da 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -609,6 +609,15 @@ class FleetModel(PipelineModelMixin, BaseModel): consolidation_attempt: Mapped[int] = mapped_column(Integer, server_default="0") last_consolidated_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) + __table_args__ = ( + Index( + "ix_fleets_pipeline_fetch_q", + last_processed_at.asc(), + postgresql_where=deleted == false(), + sqlite_where=deleted == false(), + ), + ) + class InstanceModel(BaseModel): __tablename__ = "instances"