Skip to content

Commit 4432cdf

Browse files
authored
Do not return NO_BALANCE to older clients (#3462)
Since only newer CLIs can correctly display `InstanceAvailability.NO_BALANCE`, replace `NO_BALANCE` with `NOT_AVAILABLE` in server responses for older clients for the following API methods: - `/api/project/{project_name}/fleets/get_plan` - `/api/project/{project_name}/runs/get_plan` - `/api/project/{project_name}/gpus/list` Additionally, refactor the code to make it easy to retrieve the client version using FastAPI dependencies. ```python client_version: Annotated[Optional[Version], Depends(get_client_version)] ```
1 parent c90cdf1 commit 4432cdf

13 files changed

Lines changed: 399 additions & 104 deletions

File tree

src/dstack/_internal/server/app.py

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,18 @@
55
from concurrent.futures import ThreadPoolExecutor
66
from contextlib import asynccontextmanager
77
from pathlib import Path
8-
from typing import Awaitable, Callable, List, Optional
8+
from typing import Annotated, Awaitable, Callable, List, Optional
99

1010
import sentry_sdk
11-
from fastapi import FastAPI, Request, Response, status
11+
from fastapi import Depends, FastAPI, Request, Response, status
1212
from fastapi.datastructures import URL
1313
from fastapi.responses import HTMLResponse, RedirectResponse
1414
from fastapi.staticfiles import StaticFiles
15+
from packaging.version import Version
1516
from prometheus_client import Counter, Histogram
1617
from sentry_sdk.types import SamplingContext
1718

19+
from dstack._internal import settings as core_settings
1820
from dstack._internal.cli.utils.common import console
1921
from dstack._internal.core.errors import ForbiddenError, ServerClientError
2022
from dstack._internal.core.services.configs import update_default_project
@@ -68,7 +70,6 @@
6870
get_client_version,
6971
get_server_client_error_details,
7072
)
71-
from dstack._internal.settings import DSTACK_VERSION
7273
from dstack._internal.utils.logging import get_logger
7374
from dstack._internal.utils.ssh import check_required_ssh_version
7475

@@ -91,6 +92,9 @@ def create_app() -> FastAPI:
9192
app = FastAPI(
9293
docs_url="/api/docs",
9394
lifespan=lifespan,
95+
dependencies=[
96+
Depends(_check_client_version),
97+
],
9498
)
9599
app.state.proxy_dependency_injector = ServerProxyDependencyInjector()
96100
return app
@@ -102,7 +106,7 @@ async def lifespan(app: FastAPI):
102106
if settings.SENTRY_DSN is not None:
103107
sentry_sdk.init(
104108
dsn=settings.SENTRY_DSN,
105-
release=DSTACK_VERSION,
109+
release=core_settings.DSTACK_VERSION,
106110
environment=settings.SERVER_ENVIRONMENT,
107111
enable_tracing=True,
108112
traces_sampler=_sentry_traces_sampler,
@@ -164,7 +168,9 @@ async def lifespan(app: FastAPI):
164168
else:
165169
logger.info("Background processing is disabled")
166170
PROBES_SCHEDULER.start()
167-
dstack_version = DSTACK_VERSION if DSTACK_VERSION else "(no version)"
171+
dstack_version = (
172+
core_settings.DSTACK_VERSION if core_settings.DSTACK_VERSION else "(no version)"
173+
)
168174
job_network_mode_log = (
169175
logger.info
170176
if settings.JOB_NETWORK_MODE != settings.DEFAULT_JOB_NETWORK_MODE
@@ -336,32 +342,6 @@ def _extract_endpoint_label(request: Request, response: Response) -> str:
336342
).inc()
337343
return response
338344

339-
@app.middleware("http")
340-
async def check_client_version(request: Request, call_next):
341-
if (
342-
not request.url.path.startswith("/api/")
343-
or request.url.path in _NO_API_VERSION_CHECK_ROUTES
344-
):
345-
return await call_next(request)
346-
try:
347-
client_version = get_client_version(request)
348-
except ValueError as e:
349-
return CustomORJSONResponse(
350-
status_code=status.HTTP_400_BAD_REQUEST,
351-
content={"detail": [error_detail(str(e))]},
352-
)
353-
client_release: Optional[tuple[int, ...]] = None
354-
if client_version is not None:
355-
client_release = client_version.release
356-
request.state.client_release = client_release
357-
response = check_client_server_compatibility(
358-
client_version=client_version,
359-
server_version=DSTACK_VERSION,
360-
)
361-
if response is not None:
362-
return response
363-
return await call_next(request)
364-
365345
@app.get("/healthcheck")
366346
async def healthcheck():
367347
return CustomORJSONResponse(content={"status": "running"})
@@ -396,6 +376,19 @@ async def index():
396376
return RedirectResponse("/api/docs")
397377

398378

379+
def _check_client_version(
380+
request: Request, client_version: Annotated[Optional[Version], Depends(get_client_version)]
381+
) -> None:
382+
if (
383+
request.url.path.startswith("/api/")
384+
and request.url.path not in _NO_API_VERSION_CHECK_ROUTES
385+
):
386+
check_client_server_compatibility(
387+
client_version=client_version,
388+
server_version=core_settings.DSTACK_VERSION,
389+
)
390+
391+
399392
def _is_proxy_request(request: Request) -> bool:
400393
if request.url.path.startswith("/proxy"):
401394
return True

src/dstack/_internal/server/compatibility/__init__.py

Whitespace-only changes.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from typing import Optional
2+
3+
from packaging.version import Version
4+
5+
from dstack._internal.core.models.instances import (
6+
InstanceAvailability,
7+
InstanceOfferWithAvailability,
8+
)
9+
10+
11+
def patch_offers_list(
12+
offers: list[InstanceOfferWithAvailability], client_version: Optional[Version]
13+
) -> None:
14+
if client_version is None:
15+
return
16+
# CLIs prior to 0.20.4 incorrectly display the `no_balance` availability in the run/fleet plan
17+
if client_version < Version("0.20.4"):
18+
for offer in offers:
19+
if offer.availability == InstanceAvailability.NO_BALANCE:
20+
offer.availability = InstanceAvailability.NOT_AVAILABLE
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from typing import Optional
2+
3+
from packaging.version import Version
4+
5+
from dstack._internal.core.models.instances import InstanceAvailability
6+
from dstack._internal.server.schemas.gpus import ListGpusResponse
7+
8+
9+
def patch_list_gpus_response(
10+
response: ListGpusResponse, client_version: Optional[Version]
11+
) -> None:
12+
if client_version is None:
13+
return
14+
# CLIs prior to 0.20.4 incorrectly display the `no_balance` availability in `dstack offer --group-by gpu`
15+
if client_version < Version("0.20.4"):
16+
for gpu in response.gpus:
17+
if InstanceAvailability.NO_BALANCE in gpu.availability:
18+
gpu.availability = [
19+
a for a in gpu.availability if a != InstanceAvailability.NO_BALANCE
20+
]
21+
if InstanceAvailability.NOT_AVAILABLE not in gpu.availability:
22+
gpu.availability.append(InstanceAvailability.NOT_AVAILABLE)

src/dstack/_internal/server/routers/fleets.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
from typing import List, Tuple
1+
from typing import List, Optional, Tuple
22

33
from fastapi import APIRouter, Depends
4+
from packaging.version import Version
45
from sqlalchemy.ext.asyncio import AsyncSession
56

67
import dstack._internal.server.services.fleets as fleets_services
78
from dstack._internal.core.errors import ResourceNotExistsError
89
from dstack._internal.core.models.fleets import Fleet, FleetPlan
10+
from dstack._internal.server.compatibility.common import patch_offers_list
911
from dstack._internal.server.db import get_session
1012
from dstack._internal.server.models import ProjectModel, UserModel
1113
from dstack._internal.server.schemas.fleets import (
@@ -21,6 +23,7 @@
2123
from dstack._internal.server.utils.routers import (
2224
CustomORJSONResponse,
2325
get_base_api_additional_responses,
26+
get_client_version,
2427
)
2528

2629
root_router = APIRouter(
@@ -101,6 +104,7 @@ async def get_plan(
101104
body: GetFleetPlanRequest,
102105
session: AsyncSession = Depends(get_session),
103106
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
107+
client_version: Optional[Version] = Depends(get_client_version),
104108
):
105109
"""
106110
Returns a fleet plan for the given fleet configuration.
@@ -112,6 +116,7 @@ async def get_plan(
112116
user=user,
113117
spec=body.spec,
114118
)
119+
patch_offers_list(plan.offers, client_version)
115120
return CustomORJSONResponse(plan)
116121

117122

src/dstack/_internal/server/routers/gpus.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1-
from typing import Tuple
1+
from typing import Annotated, Optional, Tuple
22

33
from fastapi import APIRouter, Depends
4+
from packaging.version import Version
45

6+
from dstack._internal.server.compatibility.gpus import patch_list_gpus_response
57
from dstack._internal.server.models import ProjectModel, UserModel
68
from dstack._internal.server.schemas.gpus import ListGpusRequest, ListGpusResponse
79
from dstack._internal.server.security.permissions import ProjectMember
810
from dstack._internal.server.services.gpus import list_gpus_grouped
9-
from dstack._internal.server.utils.routers import get_base_api_additional_responses
11+
from dstack._internal.server.utils.routers import (
12+
get_base_api_additional_responses,
13+
get_client_version,
14+
)
1015

1116
project_router = APIRouter(
1217
prefix="/api/project/{project_name}/gpus",
@@ -18,7 +23,10 @@
1823
@project_router.post("/list", response_model=ListGpusResponse, response_model_exclude_none=True)
1924
async def list_gpus(
2025
body: ListGpusRequest,
26+
client_version: Annotated[Optional[Version], Depends(get_client_version)],
2127
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
2228
) -> ListGpusResponse:
2329
_, project = user_project
24-
return await list_gpus_grouped(project=project, run_spec=body.run_spec, group_by=body.group_by)
30+
resp = await list_gpus_grouped(project=project, run_spec=body.run_spec, group_by=body.group_by)
31+
patch_list_gpus_response(resp, client_version)
32+
return resp

src/dstack/_internal/server/routers/runs.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
from typing import Annotated, List, Optional, Tuple, cast
1+
from typing import Annotated, List, Optional, Tuple
22

3-
from fastapi import APIRouter, Depends, Request
3+
from fastapi import APIRouter, Depends
4+
from packaging.version import Version
45
from sqlalchemy.ext.asyncio import AsyncSession
56

67
from dstack._internal.core.errors import ResourceNotExistsError
78
from dstack._internal.core.models.runs import Run, RunPlan
9+
from dstack._internal.server.compatibility.common import patch_offers_list
810
from dstack._internal.server.db import get_session
911
from dstack._internal.server.models import ProjectModel, UserModel
1012
from dstack._internal.server.schemas.runs import (
@@ -21,6 +23,7 @@
2123
from dstack._internal.server.utils.routers import (
2224
CustomORJSONResponse,
2325
get_base_api_additional_responses,
26+
get_client_version,
2427
)
2528

2629
root_router = APIRouter(
@@ -35,9 +38,10 @@
3538
)
3639

3740

38-
def use_legacy_repo_dir(request: Request) -> bool:
39-
client_release = cast(Optional[tuple[int, ...]], request.state.client_release)
40-
return client_release is not None and client_release < (0, 19, 27)
41+
def use_legacy_repo_dir(
42+
client_version: Annotated[Optional[Version], Depends(get_client_version)],
43+
) -> bool:
44+
return client_version is not None and client_version < Version("0.19.27")
4145

4246

4347
@root_router.post(
@@ -110,6 +114,7 @@ async def get_plan(
110114
body: GetRunPlanRequest,
111115
session: Annotated[AsyncSession, Depends(get_session)],
112116
user_project: Annotated[tuple[UserModel, ProjectModel], Depends(ProjectMember())],
117+
client_version: Annotated[Optional[Version], Depends(get_client_version)],
113118
legacy_repo_dir: Annotated[bool, Depends(use_legacy_repo_dir)],
114119
):
115120
"""
@@ -127,6 +132,8 @@ async def get_plan(
127132
max_offers=body.max_offers,
128133
legacy_repo_dir=legacy_repo_dir,
129134
)
135+
for job_plan in run_plan.job_plans:
136+
patch_offers_list(job_plan.offers, client_version)
130137
return CustomORJSONResponse(run_plan)
131138

132139

src/dstack/_internal/server/utils/routers.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -124,19 +124,28 @@ def get_request_size(request: Request) -> int:
124124

125125

126126
def get_client_version(request: Request) -> Optional[packaging.version.Version]:
127+
"""
128+
FastAPI dependency that returns the dstack client version or None if the version is latest/dev.
129+
"""
130+
127131
version = request.headers.get("x-api-version")
128132
if version is None:
129133
return None
130-
return parse_version(version)
134+
try:
135+
return parse_version(version)
136+
except ValueError as e:
137+
raise HTTPException(
138+
status_code=status.HTTP_400_BAD_REQUEST,
139+
detail=[error_detail(str(e))],
140+
)
131141

132142

133143
def check_client_server_compatibility(
134144
client_version: Optional[packaging.version.Version],
135145
server_version: Optional[str],
136-
) -> Optional[CustomORJSONResponse]:
146+
) -> None:
137147
"""
138-
Returns `JSONResponse` with error if client/server versions are incompatible.
139-
Returns `None` otherwise.
148+
Raise HTTP exception if the client is incompatible with the server.
140149
"""
141150
if client_version is None or server_version is None:
142151
return None
@@ -149,21 +158,9 @@ def check_client_server_compatibility(
149158
client_version.major > parsed_server_version.major
150159
or client_version.minor > parsed_server_version.minor
151160
):
152-
return error_incompatible_versions(
153-
str(client_version), server_version, ask_cli_update=False
161+
msg = f"The client/CLI version ({client_version}) is incompatible with the server version ({server_version})."
162+
raise HTTPException(
163+
status_code=status.HTTP_400_BAD_REQUEST,
164+
detail=get_server_client_error_details(ServerClientError(msg=msg)),
154165
)
155166
return None
156-
157-
158-
def error_incompatible_versions(
159-
client_version: Optional[str],
160-
server_version: str,
161-
ask_cli_update: bool,
162-
) -> CustomORJSONResponse:
163-
msg = f"The client/CLI version ({client_version}) is incompatible with the server version ({server_version})."
164-
if ask_cli_update:
165-
msg += f" Update the dstack CLI: `pip install dstack=={server_version}`."
166-
return CustomORJSONResponse(
167-
status_code=status.HTTP_400_BAD_REQUEST,
168-
content={"detail": get_server_client_error_details(ServerClientError(msg=msg))},
169-
)

0 commit comments

Comments
 (0)