diff --git a/roborock/containers.py b/roborock/containers.py index ab320723..e293984c 100644 --- a/roborock/containers.py +++ b/roborock/containers.py @@ -1,9 +1,9 @@ -from __future__ import annotations - +import dataclasses import datetime import json import logging import re +import types from dataclasses import asdict, dataclass, field from datetime import timezone from enum import Enum @@ -95,105 +95,73 @@ _LOGGER = logging.getLogger(__name__) -def camelize(s: str): +def _camelize(s: str): first, *others = s.split("_") if len(others) == 0: return s return "".join([first.lower(), *map(str.title, others)]) -def decamelize(s: str): +def _decamelize(s: str): return re.sub("([A-Z]+)", "_\\1", s).lower() -def decamelize_obj(d: dict | list, ignore_keys: list[str]): - if isinstance(d, RoborockBase): - d = d.as_dict() - if isinstance(d, list): - return [decamelize_obj(i, ignore_keys) if isinstance(i, dict | list) else i for i in d] - return { - (decamelize(a) if a not in ignore_keys else a): decamelize_obj(b, ignore_keys) - if isinstance(b, dict | list) - else b - for a, b in d.items() - } - - @dataclass class RoborockBase: _ignore_keys = [] # type: ignore - is_cached = False @staticmethod - def convert_to_class_obj(type, value): - try: - class_type = eval(type) - if get_origin(class_type) is list: - return_list = [] - cls_type = get_args(class_type)[0] - for obj in value: - if issubclass(cls_type, RoborockBase): - return_list.append(cls_type.from_dict(obj)) - elif cls_type in {str, int, float}: - return_list.append(cls_type(obj)) - else: - return_list.append(cls_type(**obj)) - return return_list - if issubclass(class_type, RoborockBase): - converted_value = class_type.from_dict(value) - else: - converted_value = class_type(value) - return converted_value - except NameError as err: - _LOGGER.exception(err) - except ValueError as err: - _LOGGER.exception(err) - except Exception as err: - _LOGGER.exception(err) - raise Exception("Fail") + def _convert_to_class_obj(class_type: type, value): + if get_origin(class_type) is list: + sub_type = get_args(class_type)[0] + return [RoborockBase._convert_to_class_obj(sub_type, obj) for obj in value] + if get_origin(class_type) is dict: + _, value_type = get_args(class_type) # assume keys are only basic types + return {k: RoborockBase._convert_to_class_obj(value_type, v) for k, v in value.items()} + if issubclass(class_type, RoborockBase): + return class_type.from_dict(value) + if class_type is Any: + return value + return class_type(value) # type: ignore[call-arg] @classmethod def from_dict(cls, data: dict[str, Any]): - if isinstance(data, dict): - ignore_keys = cls._ignore_keys - data = decamelize_obj(data, ignore_keys) - cls_annotations: dict[str, str] = {} - for base in reversed(cls.__mro__): - cls_annotations.update(getattr(base, "__annotations__", {})) - remove_keys = [] - for key, value in data.items(): - if key not in cls_annotations: - remove_keys.append(key) - continue - if value == "None" or value is None: - data[key] = None - continue - field_type: str = cls_annotations[key] - if "|" in field_type: - # It's a union - types = field_type.split("|") - for type in types: - if "None" in type or "Any" in type: - continue - try: - data[key] = RoborockBase.convert_to_class_obj(type, value) - break - except Exception: - ... - else: + """Create an instance of the class from a dictionary.""" + if not isinstance(data, dict): + return None + field_types = {field.name: field.type for field in dataclasses.fields(cls)} + result: dict[str, Any] = {} + for key, value in data.items(): + key = _decamelize(key) + if (field_type := field_types.get(key)) is None: + continue + if value == "None" or value is None: + result[key] = None + continue + if isinstance(field_type, types.UnionType): + for subtype in get_args(field_type): + if subtype is types.NoneType: + continue try: - data[key] = RoborockBase.convert_to_class_obj(field_type, value) + result[key] = RoborockBase._convert_to_class_obj(subtype, value) + break except Exception: - ... - for key in remove_keys: - del data[key] - return cls(**data) + _LOGGER.exception(f"Failed to convert {key} with value {value} to type {subtype}") + continue + else: + try: + result[key] = RoborockBase._convert_to_class_obj(field_type, value) + except Exception: + _LOGGER.exception(f"Failed to convert {key} with value {value} to type {field_type}") + continue + + return cls(**result) def as_dict(self) -> dict: return asdict( self, dict_factory=lambda _fields: { - camelize(key): value.value if isinstance(value, Enum) else value + _camelize(key): value.value if isinstance(value, Enum) else value for (key, value) in _fields if value is not None }, diff --git a/tests/test_containers.py b/tests/test_containers.py index b3522984..1f0bda70 100644 --- a/tests/test_containers.py +++ b/tests/test_containers.py @@ -1,3 +1,8 @@ +"""Test cases for the containers module.""" + +from dataclasses import dataclass +from typing import Any + from roborock import CleanRecord, CleanSummary, Consumable, DnDTimer, HomeData, S7MaxVStatus, UserData from roborock.code_mappings import ( RoborockCategory, @@ -9,6 +14,7 @@ RoborockMopModeS7, RoborockStateCode, ) +from roborock.containers import RoborockBase from .mock_data import ( CLEAN_RECORD, @@ -23,6 +29,94 @@ ) +@dataclass +class SimpleObject(RoborockBase): + """Simple object for testing serialization.""" + + name: str | None = None + value: int | None = None + + +@dataclass +class ComplexObject(RoborockBase): + """Complex object for testing serialization.""" + + simple: SimpleObject | None = None + items: list[str] | None = None + value: int | None = None + nested_dict: dict[str, SimpleObject] | None = None + nested_list: list[SimpleObject] | None = None + any: Any | None = None + + +def test_simple_object() -> None: + """Test serialization and deserialization of a simple object.""" + + obj = SimpleObject(name="Test", value=42) + serialized = obj.as_dict() + assert serialized == {"name": "Test", "value": 42} + deserialized = SimpleObject.from_dict(serialized) + assert deserialized.name == "Test" + assert deserialized.value == 42 + + +def test_complex_object() -> None: + """Test serialization and deserialization of a complex object.""" + simple = SimpleObject(name="Nested", value=100) + obj = ComplexObject( + simple=simple, + items=["item1", "item2"], + value=200, + nested_dict={ + "nested1": SimpleObject(name="Nested1", value=1), + "nested2": SimpleObject(name="Nested2", value=2), + }, + nested_list=[SimpleObject(name="Nested3", value=3), SimpleObject(name="Nested4", value=4)], + any="This can be anything", + ) + serialized = obj.as_dict() + assert serialized == { + "simple": {"name": "Nested", "value": 100}, + "items": ["item1", "item2"], + "value": 200, + "nestedDict": { + "nested1": {"name": "Nested1", "value": 1}, + "nested2": {"name": "Nested2", "value": 2}, + }, + "nestedList": [ + {"name": "Nested3", "value": 3}, + {"name": "Nested4", "value": 4}, + ], + "any": "This can be anything", + } + deserialized = ComplexObject.from_dict(serialized) + assert deserialized.simple.name == "Nested" + assert deserialized.simple.value == 100 + assert deserialized.items == ["item1", "item2"] + assert deserialized.value == 200 + assert deserialized.nested_dict == { + "nested1": SimpleObject(name="Nested1", value=1), + "nested2": SimpleObject(name="Nested2", value=2), + } + assert deserialized.nested_list == [ + SimpleObject(name="Nested3", value=3), + SimpleObject(name="Nested4", value=4), + ] + assert deserialized.any == "This can be anything" + + +def test_ignore_unknown_keys() -> None: + """Test that we don't fail on unknown keys.""" + data = { + "ignored_key": "This key should be ignored", + "name": "named_object", + "value": 42, + } + deserialized = SimpleObject.from_dict(data) + assert deserialized.name == "named_object" + assert deserialized.value == 42 + + def test_user_data(): ud = UserData.from_dict(USER_DATA) assert ud.uid == 123456 @@ -184,6 +278,7 @@ def test_clean_summary(): assert cs.square_meter_clean_area == 1159.2 assert cs.clean_count == 31 assert cs.dust_collection_count == 25 + assert cs.records assert len(cs.records) == 2 assert cs.records[1] == 1672458041