From 4603435a8e9f85558d684de28e837a6ce8418100 Mon Sep 17 00:00:00 2001 From: Corvo <60719165+brothercorvo@users.noreply.github.com> Date: Tue, 11 Nov 2025 09:51:02 -0400 Subject: [PATCH] Add Google-style docstrings to conversion helpers --- TASK.md | 1 + examples/EmergencyManagement/client/client.py | 245 ++-------- .../EmergencyManagement/web_gateway/app.py | 193 +------- reticulum_openapi/__init__.py | 10 + reticulum_openapi/client.py | 119 ++++- reticulum_openapi/conversion.py | 423 ++++++++++++++++++ .../emergency_management/test_web_gateway.py | 58 ++- tests/test_client.py | 119 +++++ tests/test_conversion.py | 57 +++ tests/test_example_emergency_management.py | 81 ++-- tests/test_integration_webui_persistence.py | 11 +- 11 files changed, 858 insertions(+), 459 deletions(-) create mode 100644 reticulum_openapi/conversion.py create mode 100644 tests/test_conversion.py diff --git a/TASK.md b/TASK.md index 139dc9e..5c6d2c7 100644 --- a/TASK.md +++ b/TASK.md @@ -101,3 +101,4 @@ ## 2025-11-11 - [x] Pin Reticulum (RNS) to 1.0.2 and LXMF to 0.9.2 in project dependencies. +- [x] Centralise payload conversion utilities and refactor EmergencyManagement client and gateway to use them. diff --git a/examples/EmergencyManagement/client/client.py b/examples/EmergencyManagement/client/client.py index 3496c9f..c6d4246 100644 --- a/examples/EmergencyManagement/client/client.py +++ b/examples/EmergencyManagement/client/client.py @@ -2,45 +2,18 @@ from __future__ import annotations +from typing import Dict from typing import List from typing import Optional from reticulum_openapi.client import LXMFClient as BaseLXMFClient -from reticulum_openapi.codec_msgpack import from_bytes -from reticulum_openapi.model import dataclass_from_json from examples.EmergencyManagement.Server.models_emergency import ( + DeleteEmergencyActionMessageResult, + DeleteEventResult, EmergencyActionMessage, Event, - DeleteEmergencyActionMessageResult, ) -_JSON_DECODE_FAILED = object() - - -def _decode_json_payload(payload: Optional[bytes], target_type): - """Attempt to decode a compressed JSON payload into ``target_type``. - - Args: - payload (Optional[bytes]): Raw payload returned by the service. - target_type: Dataclass or typing annotation describing the desired - structure. - - Returns: - object: Decoded dataclass instance or iterable when successful. When - the payload does not appear to be compressed JSON, returns the - ``_JSON_DECODE_FAILED`` sentinel value. - """ - - if payload is None: - return _JSON_DECODE_FAILED - if len(payload) < 2 or payload[0] != 0x78: - return _JSON_DECODE_FAILED - try: - return dataclass_from_json(target_type, payload) - except (ValueError, UnicodeDecodeError): - return _JSON_DECODE_FAILED - - COMMAND_CREATE_EMERGENCY_ACTION_MESSAGE = "CreateEmergencyActionMessage" COMMAND_DELETE_EMERGENCY_ACTION_MESSAGE = "DeleteEmergencyActionMessage" COMMAND_LIST_EMERGENCY_ACTION_MESSAGE = "ListEmergencyActionMessage" @@ -55,183 +28,6 @@ def _decode_json_payload(payload: Optional[bytes], target_type): LXMFClient = BaseLXMFClient -def _decode_emergency_action_message( - payload: Optional[bytes], -) -> EmergencyActionMessage: - """Return an :class:`EmergencyActionMessage` decoded from MessagePack bytes. - - Args: - payload (Optional[bytes]): MessagePack payload returned by the service. - - Returns: - EmergencyActionMessage: Dataclass populated from ``payload``. - - Raises: - ValueError: If ``payload`` is ``None`` or not a valid MessagePack document. - """ - - if payload is None: - raise ValueError("Response payload is required") - - json_result = _decode_json_payload(payload, EmergencyActionMessage) - if json_result is not _JSON_DECODE_FAILED: - if json_result is None: - raise ValueError("Decoded payload cannot be null") - return json_result - - data = from_bytes(payload) - if not isinstance(data, dict): - raise ValueError("Decoded payload must be a mapping") - return EmergencyActionMessage(**data) - - -def _decode_optional_emergency_action_message( - payload: Optional[bytes], -) -> Optional[EmergencyActionMessage]: - """Return an optional :class:`EmergencyActionMessage` decoded from bytes.""" - - if payload is None: - return None - - json_result = _decode_json_payload(payload, EmergencyActionMessage) - if json_result is not _JSON_DECODE_FAILED: - return json_result - - data = from_bytes(payload) - if data is None: - return None - if not isinstance(data, dict): - raise ValueError("Decoded payload must be a mapping") - return EmergencyActionMessage(**data) - - -def _decode_emergency_action_message_list( - payload: Optional[bytes], -) -> List[EmergencyActionMessage]: - """Return a list of :class:`EmergencyActionMessage` decoded from bytes.""" - - if payload is None: - return [] - - json_result = _decode_json_payload(payload, List[EmergencyActionMessage]) - if json_result is not _JSON_DECODE_FAILED: - if json_result is None: - return [] - return list(json_result) - - data = from_bytes(payload) - if data is None: - return [] - if not isinstance(data, list): - raise ValueError("Decoded payload must be a list") - - messages: List[EmergencyActionMessage] = [] - for item in data: - if not isinstance(item, dict): - raise ValueError("Each emergency action payload must be a mapping") - messages.append(EmergencyActionMessage(**item)) - return messages - - -def _decode_delete_emergency_action_message_result( - payload: Optional[bytes], -) -> DeleteEmergencyActionMessageResult: - """Return the delete emergency action response decoded from bytes.""" - - if payload is None: - raise ValueError("Response payload is required") - - json_result = _decode_json_payload(payload, DeleteEmergencyActionMessageResult) - if json_result is not _JSON_DECODE_FAILED: - if json_result is None: - raise ValueError("Decoded payload cannot be null") - return json_result - - data = from_bytes(payload) - if not isinstance(data, dict): - raise ValueError("Decoded payload must be a mapping") - return DeleteEmergencyActionMessageResult(**data) - - -def _decode_event(payload: Optional[bytes]) -> Event: - """Return an :class:`Event` decoded from MessagePack bytes.""" - - if payload is None: - raise ValueError("Response payload is required") - - json_result = _decode_json_payload(payload, Event) - if json_result is not _JSON_DECODE_FAILED: - if json_result is None: - raise ValueError("Decoded payload cannot be null") - return json_result - - data = from_bytes(payload) - - if data is None: - raise ValueError("Decoded payload cannot be null") - if not isinstance(data, dict): - raise ValueError("Decoded payload must be a mapping") - return Event(**data) - - -def _decode_optional_event(payload: Optional[bytes]) -> Optional[Event]: - """Return an optional :class:`Event` decoded from MessagePack bytes.""" - - if payload is None: - return None - - json_result = _decode_json_payload(payload, Event) - if json_result is not _JSON_DECODE_FAILED: - return json_result - - data = from_bytes(payload) - - if data is None: - return None - if not isinstance(data, dict): - raise ValueError("Decoded payload must be a mapping") - return Event(**data) - - -def _decode_event_list(payload: Optional[bytes]) -> List[Event]: - """Return a list of :class:`Event` instances decoded from MessagePack.""" - - if payload is None: - return [] - - json_result = _decode_json_payload(payload, List[Event]) - if json_result is not _JSON_DECODE_FAILED: - if json_result is None: - return [] - return list(json_result) - - data = from_bytes(payload) - - if data is None: - return [] - if not isinstance(data, list): - raise ValueError("Decoded payload must be a list") - - events: List[Event] = [] - for item in data: - if not isinstance(item, dict): - raise ValueError("Each event payload must be a mapping") - events.append(Event(**item)) - return events - - -def _decode_delete_event_response(payload: Optional[bytes]) -> dict: - """Return the delete event response decoded from MessagePack bytes.""" - - if payload is None: - raise ValueError("Response payload is required") - - data = from_bytes(payload) - if not isinstance(data, dict): - raise ValueError("Decoded payload must be a mapping") - return data - - async def create_emergency_action_message( client: LXMFClient, server_identity_hash: str, @@ -253,8 +49,9 @@ async def create_emergency_action_message( COMMAND_CREATE_EMERGENCY_ACTION_MESSAGE, message, await_response=True, + response_type=EmergencyActionMessage, ) - return _decode_emergency_action_message(response) + return response async def retrieve_emergency_action_message( @@ -278,8 +75,9 @@ async def retrieve_emergency_action_message( COMMAND_RETRIEVE_EMERGENCY_ACTION_MESSAGE, callsign, await_response=True, + response_type=Optional[EmergencyActionMessage], ) - return _decode_optional_emergency_action_message(response) + return response async def list_emergency_action_messages( @@ -293,8 +91,9 @@ async def list_emergency_action_messages( COMMAND_LIST_EMERGENCY_ACTION_MESSAGE, None, await_response=True, + response_type=List[EmergencyActionMessage], ) - return _decode_emergency_action_message_list(response) + return response async def update_emergency_action_message( @@ -309,8 +108,9 @@ async def update_emergency_action_message( COMMAND_PUT_EMERGENCY_ACTION_MESSAGE, message, await_response=True, + response_type=Optional[EmergencyActionMessage], ) - return _decode_optional_emergency_action_message(response) + return response async def delete_emergency_action_message( @@ -325,8 +125,9 @@ async def delete_emergency_action_message( COMMAND_DELETE_EMERGENCY_ACTION_MESSAGE, callsign, await_response=True, + response_type=DeleteEmergencyActionMessageResult, ) - return _decode_delete_emergency_action_message_result(response) + return response async def create_event( @@ -341,8 +142,9 @@ async def create_event( COMMAND_CREATE_EVENT, event, await_response=True, + response_type=Event, ) - return _decode_event(response) + return response async def retrieve_event( @@ -357,8 +159,9 @@ async def retrieve_event( COMMAND_RETRIEVE_EVENT, str(uid), await_response=True, + response_type=Optional[Event], ) - return _decode_optional_event(response) + return response async def update_event( @@ -373,24 +176,27 @@ async def update_event( COMMAND_PUT_EVENT, event, await_response=True, + response_type=Optional[Event], ) - return _decode_optional_event(response) + return response async def delete_event( client: LXMFClient, server_identity_hash: str, uid: int, -) -> dict: - """Delete an event and return the raw response payload.""" +) -> Dict[str, object]: + """Delete an event and return the normalised response payload.""" response = await client.send_command( server_identity_hash, COMMAND_DELETE_EVENT, str(uid), await_response=True, + response_type=DeleteEventResult, + normalise=True, ) - return _decode_delete_event_response(response) + return response async def list_events( @@ -404,5 +210,6 @@ async def list_events( COMMAND_LIST_EVENT, None, await_response=True, + response_type=List[Event], ) - return _decode_event_list(response) + return response diff --git a/examples/EmergencyManagement/web_gateway/app.py b/examples/EmergencyManagement/web_gateway/app.py index 28411c1..191924c 100644 --- a/examples/EmergencyManagement/web_gateway/app.py +++ b/examples/EmergencyManagement/web_gateway/app.py @@ -6,9 +6,8 @@ import json import logging from contextlib import suppress -from dataclasses import dataclass, fields, is_dataclass +from dataclasses import dataclass from datetime import datetime, timezone -from enum import Enum from pathlib import Path from typing import ( Any, @@ -16,14 +15,7 @@ Callable, Dict, List, - Literal, Optional, - Type, - TypeVar, - Union, - get_args, - get_origin, - get_type_hints, ) from importlib import metadata @@ -57,8 +49,8 @@ attach_client_notifications, router as notifications_router, ) -from reticulum_openapi.codec_msgpack import CodecError -from reticulum_openapi.codec_msgpack import decode_payload_bytes +from reticulum_openapi.conversion import normalise_response +from reticulum_openapi.conversion import prepare_dataclass_payload from examples.EmergencyManagement.web_gateway.interface_status import ( gather_interface_status, @@ -79,7 +71,6 @@ ConfigDict = Dict[str, Any] -T = TypeVar("T") @dataclass(frozen=True) @@ -456,100 +447,6 @@ async def _shutdown() -> None: await unsubscribe() -def _convert_value(expected_type: Type[Any], value: Any) -> Any: - """Recursively convert JSON values to dataclass field types.""" - - if expected_type is Any or expected_type is object: - return value - if expected_type is str: - if isinstance(value, str): - return value - if isinstance(value, (bytes, bytearray, memoryview)): - try: - return bytes(value).decode("utf-8") - except UnicodeDecodeError as exc: - raise ValueError("Unable to decode bytes to string") from exc - raise TypeError(f"Expected string for type {expected_type}") - if expected_type is int: - if isinstance(value, bool): - raise TypeError("Boolean value is not a valid integer") - if isinstance(value, int): - return value - if isinstance(value, float) and value.is_integer(): - return int(value) - if isinstance(value, str): - cleaned = value.strip() - try: - return int(cleaned, 10) - except ValueError as exc: - raise ValueError(f"Unable to convert '{value}' to int") from exc - raise TypeError(f"Expected integer for type {expected_type}") - if expected_type is float: - if isinstance(value, (int, float)): - return float(value) - if isinstance(value, str): - cleaned = value.strip() - try: - return float(cleaned) - except ValueError as exc: - raise ValueError(f"Unable to convert '{value}' to float") from exc - raise TypeError(f"Expected float for type {expected_type}") - if expected_type is bool: - if isinstance(value, bool): - return value - if isinstance(value, str): - lowered = value.strip().lower() - if lowered in {"true", "1", "yes", "on"}: - return True - if lowered in {"false", "0", "no", "off"}: - return False - raise TypeError(f"Expected boolean for type {expected_type}") - origin = get_origin(expected_type) - if origin is Union: - for arg in get_args(expected_type): - if arg is type(None): - if value is None: - return None - continue - try: - return _convert_value(arg, value) - except (TypeError, ValueError): - continue - raise ValueError(f"Unable to match value {value!r} to type {expected_type}") - if origin is Literal: - allowed = get_args(expected_type) - if value in allowed: - return value - raise ValueError( - f"Value {value!r} is not permitted for literal {expected_type}" - ) - if origin in (list, List): - if not isinstance(value, list): - raise TypeError(f"Expected list for type {expected_type}") - item_type = get_args(expected_type)[0] - return [_convert_value(item_type, item) for item in value] - if is_dataclass(expected_type): - if not isinstance(value, dict): - raise TypeError(f"Expected object for dataclass {expected_type.__name__}") - return _build_dataclass(expected_type, value) - return value - - -def _build_dataclass(cls: Type[T], data: Dict[str, Any]) -> T: - """Build a dataclass instance from primitive JSON data.""" - - if not isinstance(data, dict): - raise TypeError("Request payload must be a JSON object") - - kwargs: Dict[str, Any] = {} - type_hints = get_type_hints(cls) - for field in fields(cls): - if field.name in data: - expected_type = type_hints.get(field.name, field.type) - kwargs[field.name] = _convert_value(expected_type, data[field.name]) - return cls(**kwargs) - - async def _send_command( server_identity: str, command: str, @@ -565,6 +462,7 @@ async def _send_command( command, payload, await_response=True, + response_type=response_type, ) except TimeoutError as exc: logger.error( @@ -590,79 +488,10 @@ async def _send_command( if response is None: return JSONResponse(content=None) - try: - data = decode_payload_bytes(response) - except CodecError as exc: - raise HTTPException( - status_code=status.HTTP_502_BAD_GATEWAY, - detail=str(exc), - ) from exc - converted = data - if response_type is not None: - try: - converted = _convert_value(response_type, data) - except (TypeError, ValueError) as exc: - raise HTTPException( - status_code=status.HTTP_502_BAD_GATEWAY, - detail=f"Unable to decode response payload: {exc}", - ) from exc - - normalised = _normalise_response(converted) + normalised = normalise_response(response) return JSONResponse(content=normalised) - -def _prepare_payload( - spec: CommandSpec, - payload: Optional[Dict[str, Any]] = None, - *, - overrides: Optional[Dict[str, Any]] = None, -) -> Any: - """Return the payload shaped for the LXMF command described by ``spec``.""" - - if spec.request_type is None: - if payload is not None and overrides: - merged: Dict[str, Any] = dict(payload) - merged.update(overrides) - return merged - if payload is not None: - return payload - if overrides is not None: - if len(overrides) == 1: - return next(iter(overrides.values())) - return dict(overrides) - return None - - data: Dict[str, Any] = {} - if payload is not None: - data.update(payload) - if overrides is not None: - data.update(overrides) - return _build_dataclass(spec.request_type, data) - - -def _normalise_response(value: Any) -> Any: - """Convert dataclasses, enums, and iterables into JSON-serialisable data.""" - - if value is None: - return None - if is_dataclass(value): - result: Dict[str, Any] = {} - for field in fields(value): - field_value = getattr(value, field.name) - if field_value is None: - continue - result[field.name] = _normalise_response(field_value) - return result - if isinstance(value, Enum): - return value.value - if isinstance(value, dict): - return {str(key): _normalise_response(item) for key, item in value.items()} - if isinstance(value, (list, tuple, set)): - return [_normalise_response(item) for item in value] - return value - - @app.get("/") async def get_gateway_status() -> Dict[str, Any]: """Return gateway metadata and configuration details.""" @@ -721,7 +550,7 @@ async def create_emergency_action_message( """Create a new emergency action message via LXMF.""" spec = _COMMAND_SPECS["eam:create"] - message = _prepare_payload(spec, dict(payload)) + message = prepare_dataclass_payload(spec.request_type, dict(payload)) return await _send_command( server_identity, spec.command, @@ -771,7 +600,9 @@ async def update_emergency_action_message( spec = _COMMAND_SPECS["eam:update"] overrides = {spec.path_field: callsign} if spec.path_field else {"callsign": callsign} - message = _prepare_payload(spec, dict(payload), overrides=overrides) + message = prepare_dataclass_payload( + spec.request_type, dict(payload), overrides=overrides + ) return await _send_command( server_identity, spec.command, @@ -804,7 +635,7 @@ async def create_event( """Create a new event record via LXMF.""" spec = _COMMAND_SPECS["event:create"] - event = _prepare_payload(spec, dict(payload)) + event = prepare_dataclass_payload(spec.request_type, dict(payload)) return await _send_command( server_identity, spec.command, @@ -854,7 +685,9 @@ async def update_event( spec = _COMMAND_SPECS["event:update"] overrides = {spec.path_field: uid} if spec.path_field else {"uid": uid} - event = _prepare_payload(spec, dict(payload), overrides=overrides) + event = prepare_dataclass_payload( + spec.request_type, dict(payload), overrides=overrides + ) return await _send_command( server_identity, spec.command, diff --git a/reticulum_openapi/__init__.py b/reticulum_openapi/__init__.py index dbd5b82..bb19f5b 100644 --- a/reticulum_openapi/__init__.py +++ b/reticulum_openapi/__init__.py @@ -7,6 +7,11 @@ from .controller import APIException from .controller import Controller from .controller import handle_exceptions +from .conversion import build_dataclass +from .conversion import convert_value +from .conversion import decode_payload +from .conversion import normalise_response +from .conversion import prepare_dataclass_payload from .link_client import LinkClient from .link_service import LinkService from .model import BaseModel @@ -25,6 +30,11 @@ "Controller", "APIException", "handle_exceptions", + "convert_value", + "build_dataclass", + "decode_payload", + "prepare_dataclass_payload", + "normalise_response", "BaseModel", "compress_json", "dataclass_from_json", diff --git a/reticulum_openapi/client.py b/reticulum_openapi/client.py index e9992d5..8dfc8bb 100644 --- a/reticulum_openapi/client.py +++ b/reticulum_openapi/client.py @@ -16,6 +16,9 @@ import LXMF import RNS +from .codec_msgpack import decode_payload_bytes +from .conversion import decode_payload +from .conversion import normalise_response from .identity import load_or_create_identity from .logging_config import configure_logging from .model import compress_json @@ -401,7 +404,9 @@ async def send_command( path_timeout: Optional[float] = None, await_response: bool = True, response_title: Optional[str] = None, - ) -> Optional[bytes]: + response_type: Optional[Any] = None, + normalise: bool = False, + ) -> Optional[Any]: """Send a command to a remote LXMF node. Args: @@ -411,9 +416,14 @@ async def send_command( path_timeout (float, optional): Maximum seconds to wait for path discovery. Defaults to ``self.timeout``. await_response (bool, optional): Wait for a response message. Defaults to ``True``. response_title (str, optional): Expected response title. Defaults to ``_response``. + response_type (Any, optional): Dataclass or typing annotation describing the expected + response structure. When provided, the payload is decoded via + :func:`reticulum_openapi.conversion.decode_payload` before being returned. + normalise (bool, optional): When ``True`` the decoded response is normalised to + JSON-serialisable primitives. Defaults to ``False``. Returns: - Optional[bytes]: Response payload if ``await_response`` is ``True``. + Optional[Any]: Response payload decoded according to ``response_type`` and ``normalise``. Raises: TimeoutError: If a transport path cannot be established before ``path_timeout`` elapses. @@ -483,7 +493,12 @@ def _failed_callback(receipt: Any) -> None: timeout=self.timeout, ) try: - return await asyncio.wait_for(response_future, timeout=self.timeout) + raw_response = await asyncio.wait_for( + response_future, timeout=self.timeout + ) + return self._process_response_payload( + raw_response, response_type, normalise + ) except TimeoutError as exc: logger.error( "LXMF command '%s' to %s failed before a response was received: %s", @@ -508,6 +523,104 @@ def _failed_callback(receipt: Any) -> None: link.request(request_path, data=content_bytes, timeout=self.timeout) return None + def _process_response_payload( + self, + payload: Optional[Any], + response_type: Optional[Any], + normalise: bool, + ) -> Optional[Any]: + """Return the decoded response based on ``response_type`` and ``normalise`` flags. + + Args: + payload (Optional[Any]): Raw payload returned by LXMF. + response_type (Optional[Any]): Optional dataclass or typing annotation describing + the desired response structure. + normalise (bool): When ``True`` converts decoded values into JSON-compatible + primitives. + + Returns: + Optional[Any]: Decoded payload that honours ``response_type`` and ``normalise``. + + Raises: + TypeError: Raised when decoding is requested but the payload is not bytes-like. + """ + + if payload is None: + if response_type is None and not normalise: + return None + if response_type is not None: + decoded = decode_payload(None, response_type) + return normalise_response(decoded) if normalise else decoded + return None + + if isinstance(payload, memoryview): + payload = payload.tobytes() + + if isinstance(payload, bytearray): + payload = bytes(payload) + + if not isinstance(payload, bytes): + if response_type is None and not normalise: + return payload + raise TypeError("Response payload must be bytes when decoding is requested") + + if response_type is None and not normalise: + return payload + + decoded: Any + if response_type is not None: + decoded = decode_payload(payload, response_type) + else: + decoded = decode_payload_bytes(payload) + + if normalise: + return normalise_response(decoded) + return decoded + + async def send_command_for_type( + self, + dest_hex: str, + command: str, + payload_obj: object = None, + *, + response_type: Any, + path_timeout: Optional[float] = None, + await_response: bool = True, + response_title: Optional[str] = None, + normalise: bool = False, + ) -> Optional[Any]: + """Convenience wrapper returning decoded responses for ``response_type``. + + Args: + dest_hex (str): Destination identity hash as hex string. + command (str): Command name placed in the LXMF title. + payload_obj (object, optional): Dataclass, dict or bytes payload forwarded to the + service. Defaults to ``None``. + response_type (Any): Dataclass or typing annotation describing the expected response + structure. + path_timeout (Optional[float], optional): Maximum seconds to wait for path discovery. + Defaults to ``self.timeout``. + await_response (bool, optional): Wait for a response message. Defaults to ``True``. + response_title (Optional[str], optional): Expected response title. Defaults to + ``_response``. + normalise (bool, optional): When ``True`` converts decoded values into + JSON-compatible primitives. Defaults to ``False``. + + Returns: + Optional[Any]: Decoded response matching ``response_type``. + """ + + return await self.send_command( + dest_hex, + command, + payload_obj, + path_timeout=path_timeout, + await_response=await_response, + response_title=response_title, + response_type=response_type, + normalise=normalise, + ) + @staticmethod def _format_transport_failure(receipt: Any) -> str: """Return a human-readable description of a transport failure.""" diff --git a/reticulum_openapi/conversion.py b/reticulum_openapi/conversion.py new file mode 100644 index 0000000..c13f334 --- /dev/null +++ b/reticulum_openapi/conversion.py @@ -0,0 +1,423 @@ +"""Utilities for converting LXMF payloads to and from Python types.""" + +from __future__ import annotations + +import inspect +import json +import sys +import zlib +from dataclasses import fields +from dataclasses import is_dataclass +from enum import Enum +from typing import Any +from typing import Dict +from typing import List +from typing import Mapping +from typing import MutableMapping +from typing import MutableSequence +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union +from typing import get_args +from typing import get_origin +from typing import get_type_hints + +from .codec_msgpack import CodecError +from .codec_msgpack import decode_payload_bytes + + +T = TypeVar("T") +_JSON_PREFIX = 0x78 +_SENTINEL = object() + + +def _type_allows_none(expected_type: Any) -> bool: + """Return ``True`` when ``expected_type`` permits ``None`` values. + + Args: + expected_type (Any): Typing annotation to evaluate. + + Returns: + bool: ``True`` when ``None`` is an accepted value for ``expected_type``. + """ + + origin = get_origin(expected_type) + if origin is Union: + return any( + arg is type(None) or _type_allows_none(arg) + for arg in get_args(expected_type) + ) + return expected_type in {Any, object, type(None)} + + +def _default_for_type(expected_type: Any) -> Any: + """Return the default fallback for ``expected_type`` or ``_SENTINEL``. + + Args: + expected_type (Any): Typing annotation to inspect for default behaviour. + + Returns: + Any: Default value compatible with ``expected_type`` when available. + """ + + origin = get_origin(expected_type) + if origin in {list, List, Sequence, MutableSequence, tuple, Tuple, set, frozenset}: + return [] + if origin is Union: + args = get_args(expected_type) + if any(arg is type(None) for arg in args): + return None + if expected_type in { + list, + List, + Sequence, + MutableSequence, + tuple, + Tuple, + set, + frozenset, + }: + return [] + if expected_type in {dict, Dict, Mapping, MutableMapping}: + return {} + if expected_type in {Any, object}: + return None + return _SENTINEL + + +def convert_value(expected_type: Any, value: Any) -> Any: + """Recursively convert ``value`` into the supplied ``expected_type``. + + Args: + expected_type (Any): Dataclass, typing annotation, or primitive type that describes + the desired shape. + value (Any): JSON-compatible payload to convert. + + Returns: + Any: Converted value matching ``expected_type``. + + Raises: + TypeError: If ``value`` cannot be converted to ``expected_type``. + ValueError: When conversion fails due to semantic mismatches, such as invalid + literal choices or failed numeric parsing. + """ + + if expected_type in {Any, object}: + return value + + if value is None: + if _type_allows_none(expected_type): + return None + default = _default_for_type(expected_type) + if default is not _SENTINEL: + return default + raise TypeError(f"Value None is not valid for type {expected_type}") + + origin = get_origin(expected_type) + if origin is not None: + if origin is Union: + last_error: Optional[Exception] = None + for arg in get_args(expected_type): + if arg is type(None): + if value is None: + return None + continue + try: + return convert_value(arg, value) + except (TypeError, ValueError) as exc: + last_error = exc + continue + if last_error is not None: + raise ValueError( + f"Unable to match value {value!r} to type {expected_type}" + ) from last_error + raise ValueError(f"Unable to match value {value!r} to type {expected_type}") + if origin is tuple or origin is Tuple: + item_types = list(get_args(expected_type)) + if not isinstance(value, (list, tuple)): + raise TypeError(f"Expected tuple for type {expected_type}") + if not item_types: + return tuple(value) + if len(item_types) == 2 and item_types[1] is Ellipsis: + item_type = item_types[0] + return tuple(convert_value(item_type, item) for item in value) + if len(value) != len(item_types): + raise ValueError( + f"Expected {len(item_types)} items for tuple {expected_type}, got {len(value)}" + ) + return tuple(convert_value(t, item) for t, item in zip(item_types, value)) + if origin in {list, List, Sequence, MutableSequence, set, frozenset}: + if not isinstance(value, Sequence) or isinstance(value, (str, bytes)): + raise TypeError(f"Expected list for type {expected_type}") + item_types = get_args(expected_type) + item_type = item_types[0] if item_types else Any + converted = [convert_value(item_type, item) for item in value] + if origin in {set, frozenset}: + return set(converted) + return list(converted) + if origin in {dict, Dict, Mapping, MutableMapping}: + if not isinstance(value, Mapping): + raise TypeError(f"Expected mapping for type {expected_type}") + key_type, value_type = (Any, Any) + args = get_args(expected_type) + if len(args) == 2: + key_type, value_type = args + result: Dict[Any, Any] = {} + for raw_key, raw_value in value.items(): + key = ( + convert_value(key_type, raw_key) + if key_type not in {Any, object} + else raw_key + ) + result[str(key)] = convert_value(value_type, raw_value) + return result + from typing import Literal # Local import to avoid circular typing deps + + if origin is Literal: + allowed = get_args(expected_type) + if value in allowed: + return value + raise ValueError( + f"Value {value!r} is not permitted for literal {expected_type}" + ) + # typing.Annotated compatibility + if getattr(origin, "__qualname__", None) == "Annotated": + annotated_args = get_args(expected_type) + if not annotated_args: + return value + return convert_value(annotated_args[0], value) + + if inspect.isclass(expected_type): + if issubclass(expected_type, Enum): + if isinstance(value, expected_type): + return value + try: + return expected_type(value) + except ValueError as exc: + raise ValueError( + f"Value {value!r} is not valid for enum {expected_type.__name__}" + ) from exc + if expected_type is str: + if isinstance(value, str): + return value + if isinstance(value, (bytes, bytearray, memoryview)): + try: + return bytes(value).decode("utf-8") + except UnicodeDecodeError as exc: + raise ValueError("Unable to decode bytes to string") from exc + raise TypeError(f"Expected string for type {expected_type}") + if expected_type is int: + if isinstance(value, bool): + raise TypeError("Boolean value is not a valid integer") + if isinstance(value, int): + return value + if isinstance(value, float) and value.is_integer(): + return int(value) + if isinstance(value, str): + try: + return int(value.strip(), 10) + except ValueError as exc: + raise ValueError(f"Unable to convert {value!r} to int") from exc + raise TypeError(f"Expected integer for type {expected_type}") + if expected_type is float: + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, str): + try: + return float(value.strip()) + except ValueError as exc: + raise ValueError(f"Unable to convert {value!r} to float") from exc + raise TypeError(f"Expected float for type {expected_type}") + if expected_type is bool: + if isinstance(value, bool): + return value + if isinstance(value, str): + lowered = value.strip().lower() + if lowered in {"true", "1", "yes", "on"}: + return True + if lowered in {"false", "0", "no", "off"}: + return False + raise TypeError(f"Expected boolean for type {expected_type}") + if is_dataclass(expected_type): + if isinstance(value, expected_type): + return value + if not isinstance(value, Mapping): + raise TypeError( + f"Expected object for dataclass {expected_type.__name__}" + ) + return build_dataclass(expected_type, value) + + return value + + +def build_dataclass(cls: Type[T], data: Mapping[str, Any]) -> T: + """Construct ``cls`` from ``data`` applying type conversions. + + Args: + cls (Type[T]): Dataclass type to instantiate. + data (Mapping[str, Any]): Mapping containing payload values. + + Returns: + T: Instance of ``cls`` populated with converted values. + + Raises: + TypeError: If ``data`` is not a mapping type. + """ + + if not isinstance(data, Mapping): + raise TypeError("Request payload must be a mapping") + + module = sys.modules.get(cls.__module__) + globalns = vars(module) if module is not None else {} + type_hints = get_type_hints(cls, globalns=globalns) + + kwargs: Dict[str, Any] = {} + for field in fields(cls): + if field.name not in data: + continue + expected_type = type_hints.get(field.name, field.type) + kwargs[field.name] = convert_value(expected_type, data[field.name]) + return cls(**kwargs) + + +def _attempt_json_decode(payload: bytes) -> Any: + """Return decoded JSON when ``payload`` appears to be compressed JSON. + + Args: + payload (bytes): Raw payload returned by LXMF. + + Returns: + Any: Decoded JSON data or ``_SENTINEL`` when decoding fails. + """ + + if len(payload) < 2 or payload[0] != _JSON_PREFIX: + return _SENTINEL + try: + json_bytes = zlib.decompress(payload) + except zlib.error: + json_bytes = payload + try: + text = json_bytes.decode("utf-8") + except UnicodeDecodeError: + return _SENTINEL + try: + return json.loads(text) + except json.JSONDecodeError: + return _SENTINEL + + +def decode_payload(payload: Optional[bytes], expected_type: Any) -> Any: + """Decode ``payload`` into ``expected_type`` using JSON or MessagePack heuristics. + + Args: + payload (Optional[bytes]): Raw payload to decode. + expected_type (Any): Dataclass or typing annotation describing the desired structure. + + Returns: + Any: Decoded payload coerced into ``expected_type``. + + Raises: + ValueError: If the payload cannot be decoded or violates ``expected_type``. + """ + + if not payload: + default = _default_for_type(expected_type) + if default is not _SENTINEL: + return default + raise ValueError("Response payload is required") + + json_candidate = _attempt_json_decode(payload) + data: Any + if json_candidate is not _SENTINEL: + data = json_candidate + else: + try: + data = decode_payload_bytes(payload) + except CodecError as exc: + raise ValueError("Unable to decode payload bytes") from exc + + if data is None: + default = _default_for_type(expected_type) + if default is not _SENTINEL: + return default + raise ValueError("Decoded payload cannot be null") + + return convert_value(expected_type, data) + + +def prepare_dataclass_payload( + expected_type: Optional[Any], + payload: Optional[Mapping[str, Any]] = None, + *, + overrides: Optional[Mapping[str, Any]] = None, +) -> Any: + """Build a dataclass or primitive payload for LXMF commands. + + Args: + expected_type (Optional[Any]): Dataclass or typing annotation describing the payload. + payload (Optional[Mapping[str, Any]]): Base payload values supplied by the caller. + overrides (Optional[Mapping[str, Any]]): Additional values that override ``payload``. + + Returns: + Any: Dataclass instance or primitive structure prepared for transport. + """ + + if expected_type is None: + if payload is not None: + return payload + if overrides is not None: + if len(overrides) == 1: + return next(iter(overrides.values())) + return dict(overrides) + return None + + combined: Dict[str, Any] = {} + if payload is not None: + combined.update(payload) + if overrides is not None: + combined.update(overrides) + + if is_dataclass(expected_type): + return build_dataclass(expected_type, combined) + return convert_value(expected_type, combined) + + +def normalise_response(value: Any) -> Any: + """Convert dataclasses, enums, and iterables into JSON-compatible primitives. + + Args: + value (Any): Object returned from LXMF or service handlers. + + Returns: + Any: JSON-serialisable representation of ``value``. + """ + + if value is None: + return None + if is_dataclass(value): + result: Dict[str, Any] = {} + for field in fields(value): + field_value = getattr(value, field.name) + if field_value is None: + continue + result[field.name] = normalise_response(field_value) + return result + if isinstance(value, Enum): + return value.value + if isinstance(value, Mapping): + return {str(key): normalise_response(item) for key, item in value.items()} + if isinstance(value, (list, tuple, set, frozenset)): + return [normalise_response(item) for item in value] + return value + + +__all__ = [ + "convert_value", + "build_dataclass", + "decode_payload", + "prepare_dataclass_payload", + "normalise_response", +] diff --git a/tests/examples/emergency_management/test_web_gateway.py b/tests/examples/emergency_management/test_web_gateway.py index 976416b..8fd2f8e 100644 --- a/tests/examples/emergency_management/test_web_gateway.py +++ b/tests/examples/emergency_management/test_web_gateway.py @@ -5,7 +5,6 @@ import importlib import json import time -import zlib from typing import List from unittest.mock import AsyncMock @@ -19,7 +18,6 @@ Event, ) from examples.EmergencyManagement.client.client import LXMFClient as RealLXMFClient -from reticulum_openapi.codec_msgpack import to_canonical_bytes SERVER_IDENTITY = "00112233445566778899aabbccddeeff" @@ -139,9 +137,10 @@ def test_create_emergency_action_message_routes_payload(gateway_app) -> None: """Creating an EAM should convert payloads to dataclasses and decode responses.""" module, client, stub = gateway_app - stub.send_command.return_value = to_canonical_bytes( - {"callsign": "Alpha", "groupName": "Team"} - ) + async def fake_send(*args, **kwargs): + return {"callsign": "Alpha", "groupName": "Team"} + + stub.send_command.side_effect = fake_send response = client.post( "/emergency-action-messages", @@ -159,6 +158,7 @@ def test_create_emergency_action_message_routes_payload(gateway_app) -> None: assert isinstance(args[2], EmergencyActionMessage) assert args[2].callsign == "Alpha" assert kwargs["await_response"] is True + assert kwargs["response_type"] == module._COMMAND_SPECS["eam:create"].response_type def test_gateway_status_includes_interface_details(gateway_app) -> None: @@ -181,9 +181,13 @@ def test_list_emergency_action_messages_decodes_messagepack(gateway_app) -> None """Listing EAMs should decode MessagePack arrays to JSON lists.""" module, client, stub = gateway_app - stub.send_command.return_value = to_canonical_bytes( - [{"callsign": "Alpha"}, {"callsign": "Bravo"}] - ) + async def fake_send(*args, **kwargs): + return [ + {"callsign": "Alpha"}, + {"callsign": "Bravo"}, + ] + + stub.send_command.side_effect = fake_send response = client.get( "/emergency-action-messages", @@ -193,19 +197,24 @@ def test_list_emergency_action_messages_decodes_messagepack(gateway_app) -> None assert response.status_code == 200 assert response.json() == [{"callsign": "Alpha"}, {"callsign": "Bravo"}] - args, _ = stub.send_command.await_args + args, kwargs = stub.send_command.await_args assert args[0] == SERVER_IDENTITY assert args[1] == module.COMMAND_LIST_EAM assert args[2] is None + assert kwargs["response_type"] == module._COMMAND_SPECS["eam:list"].response_type def test_create_event_accepts_structured_detail(gateway_app) -> None: """Creating events should forward structured detail payloads.""" module, client, stub = gateway_app - stub.send_command.return_value = to_canonical_bytes( - {"uid": 42, "detail": {"emergencyActionMessage": {"callsign": "Bravo"}}} - ) + async def fake_send(*args, **kwargs): + return { + "uid": 42, + "detail": {"emergencyActionMessage": {"callsign": "Bravo"}}, + } + + stub.send_command.side_effect = fake_send payload = { "uid": 42, @@ -237,6 +246,7 @@ def test_create_event_accepts_structured_detail(gateway_app) -> None: assert isinstance(args[2], Event) assert args[2].uid == 42 assert kwargs["await_response"] is True + assert kwargs["response_type"] == module._COMMAND_SPECS["event:create"].response_type assert args[2].detail is not None message = args[2].detail.emergencyActionMessage @@ -251,7 +261,10 @@ def test_update_event_uses_path_identifier(gateway_app) -> None: """Updating events should merge the path UID into the dataclass payload.""" module, client, stub = gateway_app - stub.send_command.return_value = to_canonical_bytes({"uid": 21, "type": "Updated"}) + async def fake_send(*args, **kwargs): + return {"uid": 21, "type": "Updated"} + + stub.send_command.side_effect = fake_send response = client.put( "/events/21", @@ -268,15 +281,17 @@ def test_update_event_uses_path_identifier(gateway_app) -> None: assert isinstance(args[2], Event) assert args[2].uid == 21 assert kwargs["await_response"] is True + assert kwargs["response_type"] == module._COMMAND_SPECS["event:update"].response_type def test_delete_event_sends_identifier_string(gateway_app) -> None: """Deleting events should forward the identifier as provided.""" module, client, stub = gateway_app - stub.send_command.return_value = to_canonical_bytes( - {"status": "deleted", "uid": "21"} - ) + async def fake_send(*args, **kwargs): + return {"status": "deleted", "uid": 21} + + stub.send_command.side_effect = fake_send response = client.delete( "/events/21", @@ -286,10 +301,11 @@ def test_delete_event_sends_identifier_string(gateway_app) -> None: assert response.status_code == 200 assert response.json() == {"status": "deleted", "uid": 21} - args, _ = stub.send_command.await_args + args, kwargs = stub.send_command.await_args assert args[0] == SERVER_IDENTITY assert args[1] == module.COMMAND_DELETE_EVENT assert args[2] == "21" + assert kwargs["response_type"] == module._COMMAND_SPECS["event:delete"].response_type def test_list_events_decodes_compressed_json(gateway_app) -> None: @@ -297,7 +313,10 @@ def test_list_events_decodes_compressed_json(gateway_app) -> None: _module, client, stub = gateway_app payload = [{"uid": 1, "point": {"lat": 12.5}}] - stub.send_command.return_value = zlib.compress(json.dumps(payload).encode("utf-8")) + async def fake_send(*args, **kwargs): + return payload + + stub.send_command.side_effect = fake_send response = client.get( "/events", @@ -307,6 +326,9 @@ def test_list_events_decodes_compressed_json(gateway_app) -> None: assert response.status_code == 200 assert response.json() == payload + args, kwargs = stub.send_command.await_args + assert kwargs["response_type"] == _module._COMMAND_SPECS["event:list"].response_type + def test_cors_preflight_allows_custom_headers(gateway_app) -> None: """The gateway should allow browser preflight requests from the UI.""" diff --git a/tests/test_client.py b/tests/test_client.py index b767d2b..89ca081 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -5,6 +5,7 @@ from reticulum_openapi import client as client_module from reticulum_openapi.codec_msgpack import from_bytes as msgpack_from_bytes +from reticulum_openapi.model import dataclass_to_msgpack @dataclass @@ -73,6 +74,124 @@ def request( assert isinstance(payload, bytes) +@pytest.mark.asyncio +async def test_send_command_decodes_dataclass_response(monkeypatch): + """Responses can be decoded to dataclasses when ``response_type`` is provided.""" + + loop = asyncio.get_running_loop() + cli = client_module.LXMFClient.__new__(client_module.LXMFClient) + cli._loop = loop + cli.router = SimpleNamespace(handle_outbound=lambda msg: None) + cli.source_identity = object() + cli._futures = {} + cli._link_locks = {} + cli._link_events = {} + cli._links = {} + cli.auth_token = None + cli.timeout = 0.2 + + monkeypatch.setattr( + client_module.RNS.Identity, "recall", lambda h, create=False: object() + ) + + class FakeDestination: + OUT = object() + SINGLE = object() + + def __init__(self, *a, **k): + pass + + monkeypatch.setattr(client_module.RNS, "Destination", FakeDestination) + + class FakeLink: + def __init__(self, _dest, established_callback=None, closed_callback=None): + if established_callback: + loop.call_soon(established_callback, self) + + def request( + self, + path, + data=None, + response_callback=None, + failed_callback=None, + timeout=None, + ): + if response_callback: + payload = dataclass_to_msgpack(Sample(text="response")) + loop.call_soon(response_callback, SimpleNamespace(response=payload)) + + monkeypatch.setattr(client_module.RNS, "Link", FakeLink) + + result = await cli.send_command( + "aa", + "CMD", + Sample(text="hi"), + response_type=Sample, + ) + + assert isinstance(result, Sample) + assert result.text == "response" + + +@pytest.mark.asyncio +async def test_send_command_normalises_decoded_response(monkeypatch): + """Normalised responses are returned as JSON-serialisable primitives.""" + + loop = asyncio.get_running_loop() + cli = client_module.LXMFClient.__new__(client_module.LXMFClient) + cli._loop = loop + cli.router = SimpleNamespace(handle_outbound=lambda msg: None) + cli.source_identity = object() + cli._futures = {} + cli._link_locks = {} + cli._link_events = {} + cli._links = {} + cli.auth_token = None + cli.timeout = 0.2 + + monkeypatch.setattr( + client_module.RNS.Identity, "recall", lambda h, create=False: object() + ) + + class FakeDestination: + OUT = object() + SINGLE = object() + + def __init__(self, *a, **k): + pass + + monkeypatch.setattr(client_module.RNS, "Destination", FakeDestination) + + class FakeLink: + def __init__(self, _dest, established_callback=None, closed_callback=None): + if established_callback: + loop.call_soon(established_callback, self) + + def request( + self, + path, + data=None, + response_callback=None, + failed_callback=None, + timeout=None, + ): + if response_callback: + payload = dataclass_to_msgpack(Sample(text="response")) + loop.call_soon(response_callback, SimpleNamespace(response=payload)) + + monkeypatch.setattr(client_module.RNS, "Link", FakeLink) + + result = await cli.send_command( + "aa", + "CMD", + Sample(text="hi"), + response_type=Sample, + normalise=True, + ) + + assert result == {"text": "response"} + + @pytest.mark.asyncio async def test_send_command_timeout(monkeypatch): loop = asyncio.get_running_loop() diff --git a/tests/test_conversion.py b/tests/test_conversion.py new file mode 100644 index 0000000..2d339d8 --- /dev/null +++ b/tests/test_conversion.py @@ -0,0 +1,57 @@ +"""Tests for the reticulum_openapi.conversion module.""" + +from typing import List + +from examples.EmergencyManagement.Server.models_emergency import Event +from examples.EmergencyManagement.Server.models_emergency import Point +from reticulum_openapi.conversion import decode_payload +from reticulum_openapi.conversion import normalise_response +from reticulum_openapi.conversion import prepare_dataclass_payload +from reticulum_openapi.model import compress_json +from reticulum_openapi.model import dataclass_to_json_bytes +from reticulum_openapi.model import dataclass_to_msgpack + + +def test_decode_payload_returns_default_for_missing_lists() -> None: + """Lists default to empty when no payload is provided.""" + + decoded = decode_payload(None, List[Event]) + assert decoded == [] + + +def test_prepare_dataclass_payload_merges_overrides() -> None: + """Dataclass payload preparation applies overrides with type coercion.""" + + payload = prepare_dataclass_payload( + Event, + {"type": "Exercise", "point": {"lat": 1.0, "lon": 2.0}}, + overrides={"uid": "42"}, + ) + assert isinstance(payload, Event) + assert payload.uid == 42 + assert payload.point is not None + assert payload.point.lat == 1.0 + + +def test_normalise_response_converts_nested_dataclasses() -> None: + """Normalisation flattens dataclasses into JSON-serialisable primitives.""" + + point = Point(lat=3.0, lon=4.0) + event = Event(uid=7, type="Drill", point=point) + payload = normalise_response(event) + assert payload == {"uid": 7, "type": "Drill", "point": {"lat": 3.0, "lon": 4.0}} + + +def test_decode_payload_supports_json_and_messagepack() -> None: + """Decoding handles both MessagePack and compressed JSON payloads.""" + + event = Event(uid=9, type="Alert", point=Point(lat=9, lon=10)) + msgpack_payload = dataclass_to_msgpack(event) + json_payload = compress_json(dataclass_to_json_bytes(event)) + + decoded_msgpack = decode_payload(msgpack_payload, Event) + decoded_json = decode_payload(json_payload, Event) + + assert decoded_msgpack == decoded_json + assert decoded_msgpack.uid == event.uid + assert decoded_msgpack.point == event.point diff --git a/tests/test_example_emergency_management.py b/tests/test_example_emergency_management.py index 52a831c..5971c88 100644 --- a/tests/test_example_emergency_management.py +++ b/tests/test_example_emergency_management.py @@ -6,6 +6,8 @@ import sys from dataclasses import asdict from pathlib import Path +from typing import List +from typing import Optional import pytest import pytest_asyncio @@ -20,6 +22,7 @@ from examples.EmergencyManagement.Server.models_emergency import EAMStatus from examples.EmergencyManagement.Server.models_emergency import Event from examples.EmergencyManagement.Server.models_emergency import Point +from reticulum_openapi.conversion import decode_payload from reticulum_openapi.model import dataclass_to_msgpack from reticulum_openapi.model import dataclass_to_json_bytes from reticulum_openapi.model import compress_json @@ -182,30 +185,26 @@ async def test_event_controller_list_without_session_factory(monkeypatch) -> Non assert result == {"error": "InternalServerError", "code": 500} -def test_decode_event_fallback_handles_messagepack() -> None: - """The client decoder accepts MessagePack event payloads.""" - - from examples.EmergencyManagement.client import client as client_module +def test_decode_payload_handles_messagepack_dataclass() -> None: + """MessagePack payloads decode to dataclass instances.""" event = Event(uid=8, type="Exercise", qos=2) payload = dataclass_to_msgpack(event) - decoded = client_module._decode_event(payload) + decoded = decode_payload(payload, Event) assert isinstance(decoded, Event) assert decoded.uid == event.uid assert decoded.qos == event.qos -def test_decode_event_fallback_handles_compressed_json() -> None: - """The client decoder accepts compressed JSON event payloads.""" - - from examples.EmergencyManagement.client import client as client_module +def test_decode_payload_handles_compressed_json_dataclass() -> None: + """Compressed JSON payloads decode to dataclass instances.""" event = Event(uid=7, type="Drill", point=Point(lat=12.34, lon=56.78)) payload = compress_json(dataclass_to_json_bytes(event)) - decoded = client_module._decode_event(payload) + decoded = decode_payload(payload, Event) assert isinstance(decoded, Event) assert decoded.uid == event.uid @@ -213,30 +212,26 @@ def test_decode_event_fallback_handles_compressed_json() -> None: assert decoded.point.lat == event.point.lat -def test_decode_optional_event_fallback_handles_messagepack() -> None: - """Optional event decoding handles MessagePack payloads.""" - - from examples.EmergencyManagement.client import client as client_module +def test_decode_payload_handles_optional_messagepack() -> None: + """Optional dataclass decoding accepts MessagePack payloads.""" event = Event(uid=12, type="Status", version=3) payload = dataclass_to_msgpack(event) - decoded = client_module._decode_optional_event(payload) + decoded = decode_payload(payload, Optional[Event]) assert isinstance(decoded, Event) assert decoded.uid == event.uid assert decoded.version == event.version -def test_decode_optional_event_fallback_handles_compressed_json() -> None: - """Optional event decoding also supports compressed JSON payloads.""" - - from examples.EmergencyManagement.client import client as client_module +def test_decode_payload_handles_optional_compressed_json() -> None: + """Optional dataclass decoding supports compressed JSON payloads.""" event = Event(uid=11, type="Alert", point=Point(lat=1.5, lon=2.5)) payload = compress_json(dataclass_to_json_bytes(event)) - decoded = client_module._decode_optional_event(payload) + decoded = decode_payload(payload, Optional[Event]) assert isinstance(decoded, Event) assert decoded.uid == event.uid @@ -244,10 +239,8 @@ def test_decode_optional_event_fallback_handles_compressed_json() -> None: assert decoded.point.lon == event.point.lon -def test_decode_event_list_fallback_handles_messagepack() -> None: - """List decoder accepts MessagePack payloads containing mapping entries.""" - - from examples.EmergencyManagement.client import client as client_module +def test_decode_payload_handles_messagepack_list() -> None: + """List decoding accepts MessagePack payloads containing dataclass mappings.""" events = [ Event(uid=31, type="Drill", qos=1), @@ -255,17 +248,15 @@ def test_decode_event_list_fallback_handles_messagepack() -> None: ] payload = dataclass_to_msgpack([asdict(item) for item in events]) - decoded = client_module._decode_event_list(payload) + decoded = decode_payload(payload, List[Event]) assert [item.uid for item in decoded] == [31, 32] assert decoded[0].qos == events[0].qos assert decoded[1].opex == events[1].opex -def test_decode_event_list_fallback_handles_compressed_json() -> None: - """List decoder returns dataclasses when given compressed JSON payloads.""" - - from examples.EmergencyManagement.client import client as client_module +def test_decode_payload_handles_compressed_json_list() -> None: + """List decoding returns dataclasses when given compressed JSON payloads.""" events = [ Event(uid=21, type="Test", point=Point(lat=3.0, lon=4.0)), @@ -273,7 +264,7 @@ def test_decode_event_list_fallback_handles_compressed_json() -> None: ] payload = compress_json(dataclass_to_json_bytes([asdict(item) for item in events])) - decoded = client_module._decode_event_list(payload) + decoded = decode_payload(payload, List[Event]) assert [item.uid for item in decoded] == [21, 22] assert decoded[0].point is not None @@ -705,9 +696,19 @@ class DummyClient: def __init__(self) -> None: self.calls = [] - async def send_command(self, server_id, command, payload, await_response=True): - self.calls.append((server_id, command, payload, await_response)) - return dataclass_to_msgpack(message) + async def send_command( + self, + server_id, + command, + payload, + await_response=True, + response_type=None, + normalise=False, + ): + self.calls.append( + (server_id, command, payload, await_response, response_type, normalise) + ) + return message client = DummyClient() result = await client_lib.create_emergency_action_message( @@ -719,6 +720,8 @@ async def send_command(self, server_id, command, payload, await_response=True): sent = client.calls[0] assert sent[1] == client_lib.COMMAND_CREATE_EMERGENCY_ACTION_MESSAGE assert sent[3] is True + assert sent[4] == EmergencyActionMessage + assert sent[5] is False @pytest.mark.asyncio @@ -728,8 +731,16 @@ async def test_retrieve_helper_raises_for_invalid_payload() -> None: from examples.EmergencyManagement.client import client as client_lib class DummyClient: - async def send_command(self, server_id, command, payload, await_response=True): - return dataclass_to_msgpack("not-a-mapping") + async def send_command( + self, + server_id, + command, + payload, + await_response=True, + response_type=None, + normalise=False, + ): + raise ValueError("Unable to decode payload") client = DummyClient() diff --git a/tests/test_integration_webui_persistence.py b/tests/test_integration_webui_persistence.py index af67cbb..e05fbb1 100644 --- a/tests/test_integration_webui_persistence.py +++ b/tests/test_integration_webui_persistence.py @@ -16,7 +16,6 @@ Base, EmergencyActionMessage, ) -from reticulum_openapi.codec_msgpack import to_canonical_bytes def _to_primitive(value: Any) -> Any: @@ -57,8 +56,10 @@ async def send_command( command: str, payload: Any, await_response: bool = True, - ) -> bytes: - """Execute the mapped controller coroutine and encode the response.""" + response_type: Any = None, + normalise: bool = False, + ) -> Any: + """Execute the mapped controller coroutine and return the raw result.""" if server_identity != self.server_identity: raise AssertionError("Unexpected server identity hash") @@ -66,7 +67,9 @@ async def send_command( if handler is None: raise AssertionError(f"Unhandled command: {command}") result = await handler(payload) - return to_canonical_bytes(_to_primitive(result)) + if normalise: + return _to_primitive(result) + return result @pytest.mark.asyncio