Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/blueapi/client/rest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections.abc import Callable, Mapping
from typing import Any, Literal, TypeVar

Expand All @@ -10,6 +11,7 @@
)
from pydantic import BaseModel, TypeAdapter, ValidationError

from blueapi import __version__
from blueapi.config import RestConfig
from blueapi.service.authentication import JWTAuth, SessionManager
from blueapi.service.model import (
Expand All @@ -32,6 +34,8 @@

TRACER = get_tracer("rest")

LOGGER = logging.getLogger(__name__)


class UnauthorisedAccessError(Exception):
pass
Expand Down Expand Up @@ -271,6 +275,17 @@ def _request_and_deserialize(
raise exception
if response.status_code == status.HTTP_204_NO_CONTENT:
raise NoContentError(target_type)
if (server_version := response.headers.get("x-blueapi-version")) is not None:
from packaging.version import Version

if (server_version := Version(server_version).release) != (
client_version := Version(__version__).release
):
LOGGER.warning(
f"Version mismatch: Blueapi server version is {server_version}"
f"but client version is {client_version}."
f"Some features may not work as expected."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote this as a draft.

It can be some thing like

"Version mismatch: server is , client is {Version(version).release}"

Suggested change
)
if (server_version := Version(server_version).release) != (client_version:= Version(__version__).release):
LOGGER.warning(
f"Version mismatch : Blueapi server version is {server_version}"
f" and client version is {client_version}.Some features may not work as expected."
)

deserialized = TypeAdapter(target_type).validate_python(response.json())
return deserialized

Expand Down
7 changes: 5 additions & 2 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from starlette.responses import JSONResponse
from super_state_machine.errors import TransitionError

import blueapi
import blueapi.cli
Comment on lines 36 to +38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import blueapi
import blueapi.cli
from blueapi import __version__

from blueapi.config import ApplicationConfig, OIDCConfig, Tag
from blueapi.service import interface
from blueapi.worker import TrackableTask, WorkerState
Expand Down Expand Up @@ -123,7 +125,7 @@ def get_app(config: ApplicationConfig):
app.include_router(secure_router, dependencies=dependencies)
app.add_exception_handler(KeyError, on_key_error_404)
app.add_exception_handler(jwt.PyJWTError, on_token_error_401)
app.middleware("http")(add_api_version_header)
app.middleware("http")(add_version_headers)
app.middleware("http")(inject_propagated_observability_context)
app.middleware("http")(log_request_details)
if config.api.cors:
Expand Down Expand Up @@ -568,11 +570,12 @@ def start(config: ApplicationConfig):
)


async def add_api_version_header(
async def add_version_headers(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
):
response = await call_next(request)
response.headers["X-API-Version"] = ApplicationConfig.REST_API_VERSION
response.headers["X-BlueAPI-Version"] = blueapi.__version__
return response


Expand Down
37 changes: 36 additions & 1 deletion tests/unit_tests/client/test_rest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import uuid
from pathlib import Path
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch

import pytest
import requests
import responses
from packaging.version import Version

from blueapi import __version__
from blueapi.client.rest import (
BlueapiRestClient,
BlueskyRemoteControlError,
Expand Down Expand Up @@ -196,3 +198,36 @@ def test_parameter_error_other_string():
input=34,
)
assert str(p1) == "Invalid value 34 for field field_one.0: error_message"


@pytest.mark.parametrize(
"server_version,logging_warning_present",
[(__version__, False), ("0.0.1", True), (None, False)],
)
@patch("blueapi.client.rest.TypeAdapter")
@patch("blueapi.client.rest.requests.Session.request")
@patch("blueapi.client.rest.LOGGER")
def test_server_and_client_versions(
mock_logger: MagicMock,
mock_request: Mock,
mock_type_adapter: Mock,
rest: BlueapiRestClient,
server_version: str,
logging_warning_present: bool,
):
response = Mock(spec=requests.Response)
response.status_code = 200
response.headers = {"x-blueapi-version": server_version}
mock_request.return_value = response

rest.get_plans()

if logging_warning_present:
mock_logger.warning.assert_called_once_with(
f"Version mismatch: Blueapi server version is"
f"{Version(server_version).release}"
f"but client version is {Version(__version__).release}."
f"Some features may not work as expected."
)
else:
mock_logger.assert_not_called()
23 changes: 22 additions & 1 deletion tests/unit_tests/service/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,28 @@
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient

from blueapi.service.main import get_passthrough_headers, log_request_details
from blueapi import __version__
from blueapi.config import ApplicationConfig
from blueapi.service.main import (
add_version_headers,
get_passthrough_headers,
log_request_details,
)


async def test_add_version_header():
app = FastAPI()
app.middleware("http")(add_version_headers)

@app.get("/")
async def root():
return {"message": "Hello World"}

client = TestClient(app)
response = client.get("/")

assert response.headers["X-API-VERSION"] == ApplicationConfig.REST_API_VERSION
assert response.headers["X-BlueAPI-VERSION"] == __version__


async def test_log_request_details():
Expand Down
Loading