Skip to content
This repository was archived by the owner on May 3, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions TASK.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,5 @@
## 2025-10-06
- [x] Add async EmergencyManagement controller handlers for new CRUD endpoints.
- [x] Expand EmergencyManagement OpenAPI specification with notifications streaming and updated schemas.
- [x] Resolve dataclass JSON decoding for postponed annotations in nested payloads.

53 changes: 44 additions & 9 deletions reticulum_openapi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from dataclasses import is_dataclass
import json
import zlib
import sys
from typing import List
from typing import Optional
from typing import Type
from typing import TypeVar
from typing import Union
from typing import get_args
from typing import get_origin
from typing import get_type_hints

from .codec_msgpack import from_bytes as msgpack_from_bytes
from .codec_msgpack import to_canonical_bytes
Expand Down Expand Up @@ -85,24 +87,55 @@ def dataclass_to_json(data_obj: T, *, compress: bool = True) -> bytes:
return compress_json(json_bytes, enabled=compress)


def _construct(tp, value):
def _construct(tp, value, *, module_globals=None):
if isinstance(tp, str):
if module_globals is None:
return value
try:
tp = eval(tp, module_globals)
except NameError:
return value
origin = get_origin(tp)
if origin is Union:
none_type = type(None)
for sub in get_args(tp):
if sub is none_type:
if value is None:
return None
continue
try:
return _construct(sub, value)
return _construct(sub, value, module_globals=module_globals)
except Exception:
continue
if value is None:
return None
raise ValueError(f"No matching type for Union {tp}")
if is_dataclass(tp):
tp_module = sys.modules.get(tp.__module__)
tp_globals = module_globals
if isinstance(tp_module, type(sys)):
tp_globals = vars(tp_module)
kwargs = {}
for f in fields(tp):
if isinstance(value, dict) and f.name in value:
kwargs[f.name] = _construct(f.type, value[f.name])
if isinstance(value, dict):
type_hints = get_type_hints(tp, globalns=tp_globals)
for f in fields(tp):
if f.name in value:
field_type = type_hints.get(f.name, f.type)
kwargs[f.name] = _construct(
field_type,
value[f.name],
module_globals=tp_globals,
)
return tp(**kwargs) # type: ignore
if origin is list and isinstance(value, list):
item_type = get_args(tp)[0]
return [_construct(item_type, v) for v in value]
args = get_args(tp)
if not args:
return value
item_type = args[0]
return [
_construct(item_type, v, module_globals=module_globals)
for v in value
]
return value


Expand All @@ -124,7 +157,8 @@ def dataclass_from_json(cls: Type[T], data: bytes) -> T:
else:
json_bytes = data
obj_dict = json.loads(json_bytes.decode("utf-8"))
return _construct(cls, obj_dict)
module_globals = vars(sys.modules.get(cls.__module__, {}))
return _construct(cls, obj_dict, module_globals=module_globals)


def dataclass_to_msgpack(data_obj: T) -> bytes:
Expand Down Expand Up @@ -153,7 +187,8 @@ def dataclass_from_msgpack(cls: Type[T], data: bytes) -> T:
T: Deserialised dataclass instance.
"""
obj_dict = msgpack_from_bytes(data)
return _construct(cls, obj_dict)
module_globals = vars(sys.modules.get(cls.__module__, {}))
return _construct(cls, obj_dict, module_globals=module_globals)


@dataclass
Expand Down
32 changes: 29 additions & 3 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,16 @@
dataclass_to_json_bytes,
dataclass_to_msgpack,
)
from typing import List, Union
from typing import List
from typing import Optional
from typing import Union
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import async_sessionmaker
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import declarative_base
from sqlalchemy import Column, Integer, String
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
import pytest


Expand All @@ -26,6 +32,17 @@ class ItemList:
items: List[Item]


@dataclass
class Location:
lat: float
lon: float


@dataclass
class OptionalLocationRecord:
location: Optional[Location]


def test_serialization_roundtrip():
item = Item(name="foo", value=42)
data = dataclass_to_msgpack(item)
Expand All @@ -49,6 +66,15 @@ def test_json_roundtrip_with_compression():
assert compress_json(json_bytes, enabled=False) == json_bytes


def test_json_roundtrip_with_optional_nested_dataclass():
record = OptionalLocationRecord(location=Location(lat=1.0, lon=2.0))
compressed = compress_json(dataclass_to_json_bytes(record))
obj = dataclass_from_json(OptionalLocationRecord, compressed)
assert isinstance(obj.location, Location)
assert obj.location.lat == pytest.approx(1.0)
assert obj.location.lon == pytest.approx(2.0)


def test_list_of_items_roundtrip():
obj = ItemList(items=[Item(name="a", value=1), Item(name="b", value=2)])
data = dataclass_to_msgpack(obj)
Expand Down