55from concurrent .futures import ThreadPoolExecutor
66from contextlib import asynccontextmanager
77from pathlib import Path
8- from typing import Awaitable , Callable , List , Optional
8+ from typing import Annotated , Awaitable , Callable , List , Optional
99
1010import sentry_sdk
11- from fastapi import FastAPI , Request , Response , status
11+ from fastapi import Depends , FastAPI , Request , Response , status
1212from fastapi .datastructures import URL
1313from fastapi .responses import HTMLResponse , RedirectResponse
1414from fastapi .staticfiles import StaticFiles
15+ from packaging .version import Version
1516from prometheus_client import Counter , Histogram
1617from sentry_sdk .types import SamplingContext
1718
19+ from dstack ._internal import settings as core_settings
1820from dstack ._internal .cli .utils .common import console
1921from dstack ._internal .core .errors import ForbiddenError , ServerClientError
2022from dstack ._internal .core .services .configs import update_default_project
6870 get_client_version ,
6971 get_server_client_error_details ,
7072)
71- from dstack ._internal .settings import DSTACK_VERSION
7273from dstack ._internal .utils .logging import get_logger
7374from dstack ._internal .utils .ssh import check_required_ssh_version
7475
@@ -91,6 +92,9 @@ def create_app() -> FastAPI:
9192 app = FastAPI (
9293 docs_url = "/api/docs" ,
9394 lifespan = lifespan ,
95+ dependencies = [
96+ Depends (_check_client_version ),
97+ ],
9498 )
9599 app .state .proxy_dependency_injector = ServerProxyDependencyInjector ()
96100 return app
@@ -102,7 +106,7 @@ async def lifespan(app: FastAPI):
102106 if settings .SENTRY_DSN is not None :
103107 sentry_sdk .init (
104108 dsn = settings .SENTRY_DSN ,
105- release = DSTACK_VERSION ,
109+ release = core_settings . DSTACK_VERSION ,
106110 environment = settings .SERVER_ENVIRONMENT ,
107111 enable_tracing = True ,
108112 traces_sampler = _sentry_traces_sampler ,
@@ -164,7 +168,9 @@ async def lifespan(app: FastAPI):
164168 else :
165169 logger .info ("Background processing is disabled" )
166170 PROBES_SCHEDULER .start ()
167- dstack_version = DSTACK_VERSION if DSTACK_VERSION else "(no version)"
171+ dstack_version = (
172+ core_settings .DSTACK_VERSION if core_settings .DSTACK_VERSION else "(no version)"
173+ )
168174 job_network_mode_log = (
169175 logger .info
170176 if settings .JOB_NETWORK_MODE != settings .DEFAULT_JOB_NETWORK_MODE
@@ -336,32 +342,6 @@ def _extract_endpoint_label(request: Request, response: Response) -> str:
336342 ).inc ()
337343 return response
338344
339- @app .middleware ("http" )
340- async def check_client_version (request : Request , call_next ):
341- if (
342- not request .url .path .startswith ("/api/" )
343- or request .url .path in _NO_API_VERSION_CHECK_ROUTES
344- ):
345- return await call_next (request )
346- try :
347- client_version = get_client_version (request )
348- except ValueError as e :
349- return CustomORJSONResponse (
350- status_code = status .HTTP_400_BAD_REQUEST ,
351- content = {"detail" : [error_detail (str (e ))]},
352- )
353- client_release : Optional [tuple [int , ...]] = None
354- if client_version is not None :
355- client_release = client_version .release
356- request .state .client_release = client_release
357- response = check_client_server_compatibility (
358- client_version = client_version ,
359- server_version = DSTACK_VERSION ,
360- )
361- if response is not None :
362- return response
363- return await call_next (request )
364-
365345 @app .get ("/healthcheck" )
366346 async def healthcheck ():
367347 return CustomORJSONResponse (content = {"status" : "running" })
@@ -396,6 +376,19 @@ async def index():
396376 return RedirectResponse ("/api/docs" )
397377
398378
379+ def _check_client_version (
380+ request : Request , client_version : Annotated [Optional [Version ], Depends (get_client_version )]
381+ ) -> None :
382+ if (
383+ request .url .path .startswith ("/api/" )
384+ and request .url .path not in _NO_API_VERSION_CHECK_ROUTES
385+ ):
386+ check_client_server_compatibility (
387+ client_version = client_version ,
388+ server_version = core_settings .DSTACK_VERSION ,
389+ )
390+
391+
399392def _is_proxy_request (request : Request ) -> bool :
400393 if request .url .path .startswith ("/proxy" ):
401394 return True
0 commit comments