diff --git a/TASK.md b/TASK.md index 83a4d77..edf3d66 100644 --- a/TASK.md +++ b/TASK.md @@ -91,4 +91,5 @@ ## 2025-10-06 - [x] Add async EmergencyManagement controller handlers for new CRUD endpoints. - [x] Expand EmergencyManagement OpenAPI specification with notifications streaming and updated schemas. +- [x] Resolve dataclass JSON decoding for postponed annotations in nested payloads. diff --git a/reticulum_openapi/model.py b/reticulum_openapi/model.py index c5e57f1..ab3f743 100644 --- a/reticulum_openapi/model.py +++ b/reticulum_openapi/model.py @@ -5,6 +5,7 @@ from dataclasses import is_dataclass import json import zlib +import sys from typing import List from typing import Optional from typing import Type @@ -12,6 +13,7 @@ from typing import Union from typing import get_args from typing import get_origin +from typing import get_type_hints from .codec_msgpack import from_bytes as msgpack_from_bytes from .codec_msgpack import to_canonical_bytes @@ -85,24 +87,55 @@ def dataclass_to_json(data_obj: T, *, compress: bool = True) -> bytes: return compress_json(json_bytes, enabled=compress) -def _construct(tp, value): +def _construct(tp, value, *, module_globals=None): + if isinstance(tp, str): + if module_globals is None: + return value + try: + tp = eval(tp, module_globals) + except NameError: + return value origin = get_origin(tp) if origin is Union: + none_type = type(None) for sub in get_args(tp): + if sub is none_type: + if value is None: + return None + continue try: - return _construct(sub, value) + return _construct(sub, value, module_globals=module_globals) except Exception: continue + if value is None: + return None raise ValueError(f"No matching type for Union {tp}") if is_dataclass(tp): + tp_module = sys.modules.get(tp.__module__) + tp_globals = module_globals + if isinstance(tp_module, type(sys)): + tp_globals = vars(tp_module) kwargs = {} - for f in fields(tp): - if isinstance(value, dict) and f.name in value: - kwargs[f.name] = _construct(f.type, value[f.name]) + if isinstance(value, dict): + type_hints = get_type_hints(tp, globalns=tp_globals) + for f in fields(tp): + if f.name in value: + field_type = type_hints.get(f.name, f.type) + kwargs[f.name] = _construct( + field_type, + value[f.name], + module_globals=tp_globals, + ) return tp(**kwargs) # type: ignore if origin is list and isinstance(value, list): - item_type = get_args(tp)[0] - return [_construct(item_type, v) for v in value] + args = get_args(tp) + if not args: + return value + item_type = args[0] + return [ + _construct(item_type, v, module_globals=module_globals) + for v in value + ] return value @@ -124,7 +157,8 @@ def dataclass_from_json(cls: Type[T], data: bytes) -> T: else: json_bytes = data obj_dict = json.loads(json_bytes.decode("utf-8")) - return _construct(cls, obj_dict) + module_globals = vars(sys.modules.get(cls.__module__, {})) + return _construct(cls, obj_dict, module_globals=module_globals) def dataclass_to_msgpack(data_obj: T) -> bytes: @@ -153,7 +187,8 @@ def dataclass_from_msgpack(cls: Type[T], data: bytes) -> T: T: Deserialised dataclass instance. """ obj_dict = msgpack_from_bytes(data) - return _construct(cls, obj_dict) + module_globals = vars(sys.modules.get(cls.__module__, {})) + return _construct(cls, obj_dict, module_globals=module_globals) @dataclass diff --git a/tests/test_model.py b/tests/test_model.py index 2ce38bf..6751178 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -8,10 +8,16 @@ dataclass_to_json_bytes, dataclass_to_msgpack, ) -from typing import List, Union +from typing import List +from typing import Optional +from typing import Union +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import async_sessionmaker +from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.orm import declarative_base -from sqlalchemy import Column, Integer, String -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine import pytest @@ -26,6 +32,17 @@ class ItemList: items: List[Item] +@dataclass +class Location: + lat: float + lon: float + + +@dataclass +class OptionalLocationRecord: + location: Optional[Location] + + def test_serialization_roundtrip(): item = Item(name="foo", value=42) data = dataclass_to_msgpack(item) @@ -49,6 +66,15 @@ def test_json_roundtrip_with_compression(): assert compress_json(json_bytes, enabled=False) == json_bytes +def test_json_roundtrip_with_optional_nested_dataclass(): + record = OptionalLocationRecord(location=Location(lat=1.0, lon=2.0)) + compressed = compress_json(dataclass_to_json_bytes(record)) + obj = dataclass_from_json(OptionalLocationRecord, compressed) + assert isinstance(obj.location, Location) + assert obj.location.lat == pytest.approx(1.0) + assert obj.location.lon == pytest.approx(2.0) + + def test_list_of_items_roundtrip(): obj = ItemList(items=[Item(name="a", value=1), Item(name="b", value=2)]) data = dataclass_to_msgpack(obj)