Skip to content

Commit c9be23a

Browse files
authored
Implement gateway pipeline (#3599)
* Make pipelines generic over ItemT * Update AGENTS.md * Refactor instance emit events * Implement GatewayPipeline * Refactor * Make gateway deletion async * Add TestGatewayWorkerDeleted * Drop deleted_by_user cols * Handle to_be_deleted gateways in API * Fix delete gateways API tests * Merge migrations * Hint gateway pipeline * Restore sync gateways delete API * Do not run pipelines withou feature flag * Fix server_default in migration * Prevent dstack Sky gateway deletion * Remove implicit gateway_compute load * Remove implicit gateway.backend load
1 parent 768bdc1 commit c9be23a

21 files changed

Lines changed: 1334 additions & 88 deletions

File tree

AGENTS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
## Coding Style & Naming Conventions
1818
- Python targets 3.9+ with 4-space indentation and max line length of 99 (see `ruff.toml`; `E501` is ignored but keep lines readable).
1919
- Imports are sorted via Ruff’s isort settings (`dstack` treated as first-party).
20+
- Keep primary/public functions before local helper functions in a module section.
2021
- Prefer pydantic-style models in `core/models`.
2122
- Tests use `test_*.py` modules and `test_*` functions; fixtures live near usage.
2223

src/dstack/_internal/server/app.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,9 @@ async def lifespan(app: FastAPI):
167167
pipeline_manager = None
168168
if settings.SERVER_BACKGROUND_PROCESSING_ENABLED:
169169
scheduler = start_scheduled_tasks()
170-
pipeline_manager = start_pipeline_tasks()
171-
app.state.pipeline_manager = pipeline_manager
170+
if core_settings.FeatureFlags.PIPELINE_PROCESSING_ENABLED:
171+
pipeline_manager = start_pipeline_tasks()
172+
app.state.pipeline_manager = pipeline_manager
172173
else:
173174
logger.info("Background processing is disabled")
174175
PROBES_SCHEDULER.start()

src/dstack/_internal/server/background/pipeline_tasks/__init__.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,22 @@
22

33
from dstack._internal.server.background.pipeline_tasks.base import Pipeline
44
from dstack._internal.server.background.pipeline_tasks.compute_groups import ComputeGroupPipeline
5+
from dstack._internal.server.background.pipeline_tasks.gateways import GatewayPipeline
56
from dstack._internal.server.background.pipeline_tasks.placement_groups import (
67
PlacementGroupPipeline,
78
)
8-
from dstack._internal.settings import FeatureFlags
99
from dstack._internal.utils.logging import get_logger
1010

1111
logger = get_logger(__name__)
1212

1313

1414
class PipelineManager:
1515
def __init__(self) -> None:
16-
self._pipelines: list[Pipeline] = []
17-
if FeatureFlags.PIPELINE_PROCESSING_ENABLED:
18-
self._pipelines += [
19-
ComputeGroupPipeline(),
20-
PlacementGroupPipeline(),
21-
]
16+
self._pipelines: list[Pipeline] = [
17+
ComputeGroupPipeline(),
18+
GatewayPipeline(),
19+
PlacementGroupPipeline(),
20+
]
2221
self._hinter = PipelineHinter(self._pipelines)
2322

2423
def start(self):

src/dstack/_internal/server/background/pipeline_tasks/base.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,25 @@
1919

2020
@dataclass
2121
class PipelineItem:
22+
"""
23+
Pipelines can work with this class or its subclass if the worker needs to access extra attributes.
24+
"""
25+
2226
__tablename__: str
2327
id: uuid.UUID
2428
lock_expires_at: datetime
2529
lock_token: uuid.UUID
2630
prev_lock_expired: bool
2731

2832

33+
ItemT = TypeVar("ItemT", bound=PipelineItem)
34+
35+
2936
class PipelineModel(Protocol):
37+
"""
38+
Heartbeater can work with any DB model implementing this protocol.
39+
"""
40+
3041
__tablename__: str
3142
__mapper__: ClassVar[Any]
3243
__table__: ClassVar[Any]
@@ -39,7 +50,7 @@ class PipelineError(Exception):
3950
pass
4051

4152

42-
class Pipeline(ABC):
53+
class Pipeline(Generic[ItemT], ABC):
4354
def __init__(
4455
self,
4556
workers_num: int,
@@ -57,7 +68,7 @@ def __init__(
5768
self._min_processing_interval = min_processing_interval
5869
self._lock_timeout = lock_timeout
5970
self._heartbeat_trigger = heartbeat_trigger
60-
self._queue = asyncio.Queue[PipelineItem](maxsize=self._queue_maxsize)
71+
self._queue = asyncio.Queue[ItemT](maxsize=self._queue_maxsize)
6172
self._tasks: list[asyncio.Task] = []
6273
self._running = False
6374
self._shutdown = False
@@ -119,35 +130,32 @@ def hint_fetch_model_name(self) -> str:
119130

120131
@property
121132
@abstractmethod
122-
def _heartbeater(self) -> "Heartbeater":
133+
def _heartbeater(self) -> "Heartbeater[ItemT]":
123134
pass
124135

125136
@property
126137
@abstractmethod
127-
def _fetcher(self) -> "Fetcher":
138+
def _fetcher(self) -> "Fetcher[ItemT]":
128139
pass
129140

130141
@property
131142
@abstractmethod
132-
def _workers(self) -> Sequence["Worker"]:
143+
def _workers(self) -> Sequence["Worker[ItemT]"]:
133144
pass
134145

135146

136-
ModelT = TypeVar("ModelT", bound=PipelineModel)
137-
138-
139-
class Heartbeater(Generic[ModelT]):
147+
class Heartbeater(Generic[ItemT]):
140148
def __init__(
141149
self,
142-
model_type: type[ModelT],
150+
model_type: type[PipelineModel],
143151
lock_timeout: timedelta,
144152
heartbeat_trigger: timedelta,
145153
heartbeat_delay: float = 1.0,
146154
) -> None:
147155
self._model_type = model_type
148156
self._lock_timeout = lock_timeout
149157
self._hearbeat_margin = heartbeat_trigger
150-
self._items: dict[uuid.UUID, PipelineItem] = {}
158+
self._items: dict[uuid.UUID, ItemT] = {}
151159
self._untrack_lock = asyncio.Lock()
152160
self._heartbeat_delay = heartbeat_delay
153161
self._running = False
@@ -164,18 +172,18 @@ async def start(self):
164172
def stop(self):
165173
self._running = False
166174

167-
async def track(self, item: PipelineItem):
175+
async def track(self, item: ItemT):
168176
self._items[item.id] = item
169177

170-
async def untrack(self, item: PipelineItem):
178+
async def untrack(self, item: ItemT):
171179
async with self._untrack_lock:
172180
tracked = self._items.get(item.id)
173181
# Prevent expired fetch iteration to unlock item processed by new iteration.
174182
if tracked is not None and tracked.lock_token == item.lock_token:
175183
del self._items[item.id]
176184

177185
async def heartbeat(self):
178-
items_to_update: list[PipelineItem] = []
186+
items_to_update: list[ItemT] = []
179187
now = get_current_datetime()
180188
items = list(self._items.values())
181189
failed_to_heartbeat_count = 0
@@ -227,16 +235,16 @@ async def heartbeat(self):
227235
)
228236

229237

230-
class Fetcher(ABC):
238+
class Fetcher(Generic[ItemT], ABC):
231239
_DEFAULT_FETCH_DELAYS = [0.5, 1, 2, 5]
232240

233241
def __init__(
234242
self,
235-
queue: asyncio.Queue[PipelineItem],
243+
queue: asyncio.Queue[ItemT],
236244
queue_desired_minsize: int,
237245
min_processing_interval: timedelta,
238246
lock_timeout: timedelta,
239-
heartbeater: Heartbeater,
247+
heartbeater: Heartbeater[ItemT],
240248
queue_check_delay: float = 1.0,
241249
fetch_delays: Optional[list[float]] = None,
242250
) -> None:
@@ -289,7 +297,7 @@ def hint(self):
289297
self._fetch_event.set()
290298

291299
@abstractmethod
292-
async def fetch(self, limit: int) -> list[PipelineItem]:
300+
async def fetch(self, limit: int) -> list[ItemT]:
293301
pass
294302

295303
def _next_fetch_delay(self, empty_fetch_count: int) -> float:
@@ -298,11 +306,11 @@ def _next_fetch_delay(self, empty_fetch_count: int) -> float:
298306
return next_delay * (1 + jitter)
299307

300308

301-
class Worker(ABC):
309+
class Worker(Generic[ItemT], ABC):
302310
def __init__(
303311
self,
304-
queue: asyncio.Queue[PipelineItem],
305-
heartbeater: Heartbeater,
312+
queue: asyncio.Queue[ItemT],
313+
heartbeater: Heartbeater[ItemT],
306314
) -> None:
307315
self._queue = queue
308316
self._heartbeater = heartbeater
@@ -325,7 +333,7 @@ def stop(self):
325333
self._running = False
326334

327335
@abstractmethod
328-
async def process(self, item: PipelineItem):
336+
async def process(self, item: ItemT):
329337
pass
330338

331339

src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from dstack._internal.server.models import ComputeGroupModel, InstanceModel, ProjectModel
2626
from dstack._internal.server.services import backends as backends_services
2727
from dstack._internal.server.services.compute_groups import compute_group_model_to_compute_group
28-
from dstack._internal.server.services.instances import switch_instance_status
28+
from dstack._internal.server.services.instances import emit_instance_status_change_event
2929
from dstack._internal.server.services.locking import get_locker
3030
from dstack._internal.utils.common import get_current_datetime, run_async
3131
from dstack._internal.utils.logging import get_logger
@@ -36,7 +36,7 @@
3636
TERMINATION_RETRY_MAX_DURATION = timedelta(minutes=15)
3737

3838

39-
class ComputeGroupPipeline(Pipeline):
39+
class ComputeGroupPipeline(Pipeline[PipelineItem]):
4040
def __init__(
4141
self,
4242
workers_num: int = 10,
@@ -54,7 +54,7 @@ def __init__(
5454
lock_timeout=lock_timeout,
5555
heartbeat_trigger=heartbeat_trigger,
5656
)
57-
self.__heartbeater = Heartbeater[ComputeGroupModel](
57+
self.__heartbeater = Heartbeater[PipelineItem](
5858
model_type=ComputeGroupModel,
5959
lock_timeout=self._lock_timeout,
6060
heartbeat_trigger=self._heartbeat_trigger,
@@ -76,26 +76,26 @@ def hint_fetch_model_name(self) -> str:
7676
return ComputeGroupModel.__name__
7777

7878
@property
79-
def _heartbeater(self) -> Heartbeater:
79+
def _heartbeater(self) -> Heartbeater[PipelineItem]:
8080
return self.__heartbeater
8181

8282
@property
83-
def _fetcher(self) -> Fetcher:
83+
def _fetcher(self) -> Fetcher[PipelineItem]:
8484
return self.__fetcher
8585

8686
@property
8787
def _workers(self) -> Sequence["ComputeGroupWorker"]:
8888
return self.__workers
8989

9090

91-
class ComputeGroupFetcher(Fetcher):
91+
class ComputeGroupFetcher(Fetcher[PipelineItem]):
9292
def __init__(
9393
self,
9494
queue: asyncio.Queue[PipelineItem],
9595
queue_desired_minsize: int,
9696
min_processing_interval: timedelta,
9797
lock_timeout: timedelta,
98-
heartbeater: Heartbeater[ComputeGroupModel],
98+
heartbeater: Heartbeater[PipelineItem],
9999
queue_check_delay: float = 1.0,
100100
) -> None:
101101
super().__init__(
@@ -161,11 +161,11 @@ async def fetch(self, limit: int) -> list[PipelineItem]:
161161
return items
162162

163163

164-
class ComputeGroupWorker(Worker):
164+
class ComputeGroupWorker(Worker[PipelineItem]):
165165
def __init__(
166166
self,
167167
queue: asyncio.Queue[PipelineItem],
168-
heartbeater: Heartbeater[ComputeGroupModel],
168+
heartbeater: Heartbeater[PipelineItem],
169169
) -> None:
170170
super().__init__(
171171
queue=queue,
@@ -235,7 +235,12 @@ async def process(self, item: PipelineItem):
235235
.values(**terminate_result.instances_update_map)
236236
)
237237
for instance_model in compute_group_model.instances:
238-
switch_instance_status(session, instance_model, InstanceStatus.TERMINATED)
238+
emit_instance_status_change_event(
239+
session=session,
240+
instance_model=instance_model,
241+
old_status=instance_model.status,
242+
new_status=InstanceStatus.TERMINATED,
243+
)
239244

240245

241246
@dataclass

0 commit comments

Comments
 (0)