diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index 52150d36fd..4e5db254ae 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Callable, Mapping from typing import Any, Literal, TypeVar @@ -10,6 +11,7 @@ ) from pydantic import BaseModel, TypeAdapter, ValidationError +from blueapi import __version__ from blueapi.config import RestConfig from blueapi.service.authentication import JWTAuth, SessionManager from blueapi.service.model import ( @@ -32,6 +34,8 @@ TRACER = get_tracer("rest") +LOGGER = logging.getLogger(__name__) + class UnauthorisedAccessError(Exception): pass @@ -271,6 +275,17 @@ def _request_and_deserialize( raise exception if response.status_code == status.HTTP_204_NO_CONTENT: raise NoContentError(target_type) + if (server_version := response.headers.get("x-blueapi-version")) is not None: + from packaging.version import Version + + if (server_version := Version(server_version).release) != ( + client_version := Version(__version__).release + ): + LOGGER.warning( + f"Version mismatch: Blueapi server version is {server_version}" + f"but client version is {client_version}." + f"Some features may not work as expected." + ) deserialized = TypeAdapter(target_type).validate_python(response.json()) return deserialized diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index c79dd3df34..50332374d8 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -34,6 +34,8 @@ from starlette.responses import JSONResponse from super_state_machine.errors import TransitionError +import blueapi +import blueapi.cli from blueapi.config import ApplicationConfig, OIDCConfig, Tag from blueapi.service import interface from blueapi.worker import TrackableTask, WorkerState @@ -123,7 +125,7 @@ def get_app(config: ApplicationConfig): app.include_router(secure_router, dependencies=dependencies) app.add_exception_handler(KeyError, on_key_error_404) app.add_exception_handler(jwt.PyJWTError, on_token_error_401) - app.middleware("http")(add_api_version_header) + app.middleware("http")(add_version_headers) app.middleware("http")(inject_propagated_observability_context) app.middleware("http")(log_request_details) if config.api.cors: @@ -568,11 +570,12 @@ def start(config: ApplicationConfig): ) -async def add_api_version_header( +async def add_version_headers( request: Request, call_next: Callable[[Request], Awaitable[Response]] ): response = await call_next(request) response.headers["X-API-Version"] = ApplicationConfig.REST_API_VERSION + response.headers["X-BlueAPI-Version"] = blueapi.__version__ return response diff --git a/tests/unit_tests/client/test_rest.py b/tests/unit_tests/client/test_rest.py index 2ddcdd3800..e34ee4f8ca 100644 --- a/tests/unit_tests/client/test_rest.py +++ b/tests/unit_tests/client/test_rest.py @@ -1,11 +1,13 @@ import uuid from pathlib import Path -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pytest import requests import responses +from packaging.version import Version +from blueapi import __version__ from blueapi.client.rest import ( BlueapiRestClient, BlueskyRemoteControlError, @@ -196,3 +198,36 @@ def test_parameter_error_other_string(): input=34, ) assert str(p1) == "Invalid value 34 for field field_one.0: error_message" + + +@pytest.mark.parametrize( + "server_version,logging_warning_present", + [(__version__, False), ("0.0.1", True), (None, False)], +) +@patch("blueapi.client.rest.TypeAdapter") +@patch("blueapi.client.rest.requests.Session.request") +@patch("blueapi.client.rest.LOGGER") +def test_server_and_client_versions( + mock_logger: MagicMock, + mock_request: Mock, + mock_type_adapter: Mock, + rest: BlueapiRestClient, + server_version: str, + logging_warning_present: bool, +): + response = Mock(spec=requests.Response) + response.status_code = 200 + response.headers = {"x-blueapi-version": server_version} + mock_request.return_value = response + + rest.get_plans() + + if logging_warning_present: + mock_logger.warning.assert_called_once_with( + f"Version mismatch: Blueapi server version is" + f"{Version(server_version).release}" + f"but client version is {Version(__version__).release}." + f"Some features may not work as expected." + ) + else: + mock_logger.assert_not_called() diff --git a/tests/unit_tests/service/test_main.py b/tests/unit_tests/service/test_main.py index 2e109d38c6..4a4bcca634 100644 --- a/tests/unit_tests/service/test_main.py +++ b/tests/unit_tests/service/test_main.py @@ -5,7 +5,28 @@ from fastapi import FastAPI, Request from fastapi.testclient import TestClient -from blueapi.service.main import get_passthrough_headers, log_request_details +from blueapi import __version__ +from blueapi.config import ApplicationConfig +from blueapi.service.main import ( + add_version_headers, + get_passthrough_headers, + log_request_details, +) + + +async def test_add_version_header(): + app = FastAPI() + app.middleware("http")(add_version_headers) + + @app.get("/") + async def root(): + return {"message": "Hello World"} + + client = TestClient(app) + response = client.get("/") + + assert response.headers["X-API-VERSION"] == ApplicationConfig.REST_API_VERSION + assert response.headers["X-BlueAPI-VERSION"] == __version__ async def test_log_request_details():