From 1fd9331e092b4670c3f47c3a26f72a63f3e18784 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 20 Mar 2026 19:15:09 +0000 Subject: [PATCH 1/3] refactor: reorganize serialization utilities and update logging integration --- pyproject.toml | 2 +- sqlspec/_serialization.py | 356 ------------------ sqlspec/utils/logging.py | 4 +- .../__init__.py} | 171 +++------ sqlspec/utils/serializers/_json.py | 286 ++++++++++++++ tests/unit/core/test_type_conversion.py | 4 +- tests/unit/utils/test_serialization.py | 9 +- tests/unit/utils/test_serializers.py | 39 +- 8 files changed, 377 insertions(+), 494 deletions(-) delete mode 100644 sqlspec/_serialization.py rename sqlspec/utils/{serializers.py => serializers/__init__.py} (68%) create mode 100644 sqlspec/utils/serializers/_json.py diff --git a/pyproject.toml b/pyproject.toml index 4a67af3c0..c89bf466f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -220,7 +220,7 @@ include = [ "sqlspec/utils/sync_tools.py", # Synchronous utility functions "sqlspec/utils/type_guards.py", # Type guard utilities "sqlspec/utils/fixtures.py", # File fixture loading - "sqlspec/utils/serializers.py", # Serialization helpers + "sqlspec/utils/serializers/**/*.py", # Serialization helpers package "sqlspec/utils/type_converters.py", # Adapter type converters "sqlspec/utils/correlation.py", # Correlation context helpers "sqlspec/utils/portal.py", # Thread portal utilities diff --git a/sqlspec/_serialization.py b/sqlspec/_serialization.py deleted file mode 100644 index fe4af788e..000000000 --- a/sqlspec/_serialization.py +++ /dev/null @@ -1,356 +0,0 @@ -"""Enhanced serialization module with byte-aware encoding and class-based architecture. - -Provides a Protocol-based serialization system that users can extend. -Supports msgspec, orjson, and standard library JSON with automatic fallback. - -Features optional numpy array serialization when numpy is installed. -Arrays are automatically converted to lists during JSON encoding. -""" - -import contextlib -import datetime -import enum -import json -import uuid as _uuid_mod -from abc import ABC, abstractmethod -from decimal import Decimal -from typing import Any, Final, Literal, Protocol, overload - -from sqlspec._typing import NUMPY_INSTALLED, UUID_UTILS_INSTALLED -from sqlspec.core.filters import OffsetPagination -from sqlspec.typing import MSGSPEC_INSTALLED, ORJSON_INSTALLED, PYDANTIC_INSTALLED, BaseModel - - -def _get_uuid_utils_type() -> "type[Any] | None": - if not UUID_UTILS_INSTALLED: - return None - try: - import uuid_utils as _uuid_utils_mod # pyright: ignore[reportMissingImports] - except ImportError: - return None - else: - return _uuid_utils_mod.UUID # type: ignore[no-any-return,unused-ignore] - - -_UUID_UTILS_TYPE: "type[Any] | None" = _get_uuid_utils_type() - - -def _type_to_string(value: Any) -> Any: # pragma: no cover - """Convert special types to strings for JSON serialization. - - Handles datetime, date, enums, Decimal, Pydantic models, and numpy arrays. - - Args: - value: Value to convert. - - Returns: - Serializable representation of the value (string, list, dict, etc.). - - Raises: - TypeError: If value cannot be serialized. - """ - if isinstance(value, datetime.datetime): - return convert_datetime_to_gmt_iso(value) - if isinstance(value, datetime.date): - return convert_date_to_iso(value) - if isinstance(value, Decimal): - return float(value) - if isinstance(value, enum.Enum): - return str(value.value) - if PYDANTIC_INSTALLED and isinstance(value, BaseModel): - return value.model_dump_json() - if isinstance(value, _uuid_mod.UUID): - return str(value) - if _UUID_UTILS_TYPE is not None and isinstance(value, _UUID_UTILS_TYPE): - return str(value) - if isinstance(value, OffsetPagination): - return {"items": value.items, "limit": value.limit, "offset": value.offset, "total": value.total} - if NUMPY_INSTALLED: - import numpy as np - - if isinstance(value, np.ndarray): - return value.tolist() - try: - return str(value) - except Exception as exc: - msg = f"Cannot serialize {type(value).__name__}" - raise TypeError(msg) from exc - - -class JSONSerializer(Protocol): - """Protocol for JSON serialization implementations. - - Users can implement this protocol to create custom serializers. - """ - - def encode(self, data: Any, *, as_bytes: bool = False) -> str | bytes: - """Encode data to JSON. - - Args: - data: Data to encode. - as_bytes: Whether to return bytes instead of string. - - Returns: - JSON string or bytes depending on as_bytes parameter. - """ - ... - - def decode(self, data: str | bytes, *, decode_bytes: bool = True) -> Any: - """Decode from JSON. - - Args: - data: JSON string or bytes to decode. - decode_bytes: Whether to decode bytes input. - - Returns: - Decoded Python object. - """ - ... - - -class BaseJSONSerializer(ABC): - """Base class for JSON serializers with common functionality.""" - - __slots__ = () - - @abstractmethod - def encode(self, data: Any, *, as_bytes: bool = False) -> str | bytes: - """Encode data to JSON.""" - ... - - @abstractmethod - def decode(self, data: str | bytes, *, decode_bytes: bool = True) -> Any: - """Decode from JSON.""" - ... - - -# Module-level singleton fallback serializers for performance -# These avoid creating new instances on every fallback call in MsgspecSerializer -_orjson_fallback: "OrjsonSerializer | None" = None -_stdlib_fallback: "StandardLibSerializer | None" = None - - -def _get_orjson_fallback() -> "OrjsonSerializer": - """Get singleton OrjsonSerializer instance for fallback use.""" - global _orjson_fallback - if _orjson_fallback is None: - _orjson_fallback = OrjsonSerializer() - return _orjson_fallback - - -def _get_stdlib_fallback() -> "StandardLibSerializer": - """Get singleton StandardLibSerializer instance for fallback use.""" - global _stdlib_fallback - if _stdlib_fallback is None: - _stdlib_fallback = StandardLibSerializer() - return _stdlib_fallback - - -class MsgspecSerializer(BaseJSONSerializer): - """Msgspec-based JSON serializer.""" - - __slots__ = ("_decoder", "_encoder") - - def __init__(self) -> None: - """Initialize msgspec encoder and decoder.""" - from msgspec.json import Decoder, Encoder - - self._encoder: Final[Encoder] = Encoder(enc_hook=_type_to_string) - self._decoder: Final[Decoder] = Decoder() - - def encode(self, data: Any, *, as_bytes: bool = False) -> str | bytes: - """Encode data using msgspec.""" - try: - if as_bytes: - return self._encoder.encode(data) - return self._encoder.encode(data).decode("utf-8") - except (TypeError, ValueError): - if ORJSON_INSTALLED: - return _get_orjson_fallback().encode(data, as_bytes=as_bytes) - return _get_stdlib_fallback().encode(data, as_bytes=as_bytes) - - def decode(self, data: str | bytes, *, decode_bytes: bool = True) -> Any: - """Decode data using msgspec.""" - if isinstance(data, bytes): - if decode_bytes: - try: - return self._decoder.decode(data) - except (TypeError, ValueError): - if ORJSON_INSTALLED: - return _get_orjson_fallback().decode(data, decode_bytes=decode_bytes) - return _get_stdlib_fallback().decode(data, decode_bytes=decode_bytes) - return data - - try: - return self._decoder.decode(data.encode("utf-8")) - except (TypeError, ValueError): - if ORJSON_INSTALLED: - return _get_orjson_fallback().decode(data, decode_bytes=decode_bytes) - return _get_stdlib_fallback().decode(data, decode_bytes=decode_bytes) - - -class OrjsonSerializer(BaseJSONSerializer): - """Orjson-based JSON serializer with native datetime/UUID support. - - Automatically enables numpy serialization if numpy is installed. - """ - - __slots__ = () - - def encode(self, data: Any, *, as_bytes: bool = False) -> str | bytes: - """Encode data using orjson. - - Args: - data: Data to encode. - as_bytes: Whether to return bytes instead of string. - - Returns: - JSON string or bytes depending on as_bytes parameter. - """ - from orjson import ( - OPT_NAIVE_UTC, # pyright: ignore[reportUnknownVariableType] - OPT_SERIALIZE_UUID, # pyright: ignore[reportUnknownVariableType] - ) - from orjson import dumps as _orjson_dumps # pyright: ignore[reportMissingImports] - - options = OPT_NAIVE_UTC | OPT_SERIALIZE_UUID - - if NUMPY_INSTALLED: - from orjson import OPT_SERIALIZE_NUMPY # pyright: ignore[reportUnknownVariableType] - - options |= OPT_SERIALIZE_NUMPY - - result = _orjson_dumps(data, default=_type_to_string, option=options) - return result if as_bytes else result.decode("utf-8") - - def decode(self, data: str | bytes, *, decode_bytes: bool = True) -> Any: - """Decode data using orjson.""" - from orjson import loads as _orjson_loads # pyright: ignore[reportMissingImports] - - if isinstance(data, bytes): - if decode_bytes: - return _orjson_loads(data) - return data - return _orjson_loads(data) - - -class StandardLibSerializer(BaseJSONSerializer): - """Standard library JSON serializer as fallback.""" - - __slots__ = () - - def encode(self, data: Any, *, as_bytes: bool = False) -> str | bytes: - """Encode data using standard library json.""" - json_str = json.dumps(data, default=_type_to_string) - return json_str.encode("utf-8") if as_bytes else json_str - - def decode(self, data: str | bytes, *, decode_bytes: bool = True) -> Any: - """Decode data using standard library json.""" - if isinstance(data, bytes): - if decode_bytes: - return json.loads(data.decode("utf-8")) - return data - return json.loads(data) - - -_default_serializer: JSONSerializer | None = None - - -def get_default_serializer() -> JSONSerializer: - """Get the default serializer based on available libraries. - - Priority: msgspec > orjson > stdlib - - Returns: - The best available JSON serializer. - """ - global _default_serializer - - if _default_serializer is None: - if MSGSPEC_INSTALLED: - with contextlib.suppress(ImportError): - _default_serializer = MsgspecSerializer() - - if _default_serializer is None and ORJSON_INSTALLED: - with contextlib.suppress(ImportError): - _default_serializer = OrjsonSerializer() - - if _default_serializer is None: - _default_serializer = StandardLibSerializer() - - assert _default_serializer is not None - return _default_serializer - - -@overload -def encode_json(data: Any, *, as_bytes: Literal[False] = ...) -> str: ... # pragma: no cover - - -@overload -def encode_json(data: Any, *, as_bytes: Literal[True]) -> bytes: ... # pragma: no cover - - -def encode_json(data: Any, *, as_bytes: bool = False) -> str | bytes: - """Encode to JSON, optionally returning bytes. - - Args: - data: The data to encode. - as_bytes: Whether to return bytes instead of string. - - Returns: - JSON string or bytes depending on as_bytes parameter. - """ - return get_default_serializer().encode(data, as_bytes=as_bytes) - - -def decode_json(data: str | bytes, *, decode_bytes: bool = True) -> Any: - """Decode from JSON string or bytes efficiently. - - Args: - data: JSON string or bytes to decode. - decode_bytes: Whether to decode bytes input. - - Returns: - Decoded Python object. - """ - return get_default_serializer().decode(data, decode_bytes=decode_bytes) - - -def convert_datetime_to_gmt_iso(dt: datetime.datetime) -> str: # pragma: no cover - """Handle datetime serialization for nested timestamps. - - Args: - dt: The datetime to convert. - - Returns: - The ISO formatted datetime string. - """ - if not dt.tzinfo: - dt = dt.replace(tzinfo=datetime.timezone.utc) - return dt.isoformat().replace("+00:00", "Z") - - -def convert_date_to_iso(dt: datetime.date) -> str: # pragma: no cover - """Handle datetime serialization for nested timestamps. - - Args: - dt: The date to convert. - - Returns: - The ISO formatted date string. - """ - return dt.isoformat() - - -__all__ = ( - "BaseJSONSerializer", - "JSONSerializer", - "MsgspecSerializer", - "OrjsonSerializer", - "StandardLibSerializer", - "convert_date_to_iso", - "convert_datetime_to_gmt_iso", - "decode_json", - "encode_json", - "get_default_serializer", -) diff --git a/sqlspec/utils/logging.py b/sqlspec/utils/logging.py index a8a9754ac..3522ad941 100644 --- a/sqlspec/utils/logging.py +++ b/sqlspec/utils/logging.py @@ -9,9 +9,9 @@ from logging import LogRecord from typing import TYPE_CHECKING, Any, cast -from sqlspec._serialization import encode_json from sqlspec.utils.correlation import CorrelationContext from sqlspec.utils.correlation import correlation_id_var as _correlation_id_var +from sqlspec.utils.serializers import to_json if TYPE_CHECKING: from contextvars import ContextVar @@ -155,7 +155,7 @@ def format(self, record: LogRecord) -> str: if record.exc_info: log_entry["exception"] = self.formatException(record.exc_info) - return encode_json(log_entry) + return to_json(log_entry) class CorrelationIDFilter(logging.Filter): diff --git a/sqlspec/utils/serializers.py b/sqlspec/utils/serializers/__init__.py similarity index 68% rename from sqlspec/utils/serializers.py rename to sqlspec/utils/serializers/__init__.py index 5c1de63b9..b671d48f9 100644 --- a/sqlspec/utils/serializers.py +++ b/sqlspec/utils/serializers/__init__.py @@ -1,7 +1,7 @@ """Serialization utilities for SQLSpec. -Provides JSON helpers, serializer pipelines, optional dependency hooks, -and cache instrumentation aligned with the core pipeline counters. +Provides the canonical public serialization surface, schema dump helpers, +optional NumPy hooks, and serializer-cache instrumentation. """ import os @@ -9,9 +9,10 @@ from threading import RLock from typing import TYPE_CHECKING, Any, Final, Literal, cast, overload -from sqlspec._serialization import decode_json, encode_json from sqlspec.typing import NUMPY_INSTALLED, UNSET, ArrowReturnFormat, attrs_asdict from sqlspec.utils.arrow_helpers import convert_dict_to_arrow +from sqlspec.utils.serializers._json import decode_json as _decode_json +from sqlspec.utils.serializers._json import encode_json as _encode_json from sqlspec.utils.type_guards import ( dataclass_to_dict, has_dict_attribute, @@ -41,6 +42,7 @@ DEBUG_ENV_FLAG: Final[str] = "SQLSPEC_DEBUG_PIPELINE_CACHE" _PRIMITIVE_TYPES: Final[tuple[type[Any], ...]] = (str, bytes, int, float, bool) +_NUMPY_DECODER_SENTINEL: Final[object] = object() def _is_truthy(value: "str | None") -> bool: @@ -101,18 +103,8 @@ def to_json(data: Any, *, as_bytes: Literal[True]) -> bytes: ... def to_json(data: Any, *, as_bytes: bool = False) -> str | bytes: - """Encode data to JSON string or bytes. - - Args: - data: Data to encode. - as_bytes: Whether to return bytes instead of string for optimal performance. - - Returns: - JSON string or bytes representation based on as_bytes parameter. - """ - if as_bytes: - return encode_json(data, as_bytes=True) - return encode_json(data, as_bytes=False) + """Encode data to JSON string or bytes.""" + return _encode_json(data, as_bytes=as_bytes) @overload @@ -124,44 +116,12 @@ def from_json(data: bytes, *, decode_bytes: bool = ...) -> Any: ... def from_json(data: str | bytes, *, decode_bytes: bool = True) -> Any: - """Decode JSON string or bytes to Python object. - - Args: - data: JSON string or bytes to decode. - decode_bytes: Whether to decode bytes input (vs passing through). - - Returns: - Decoded Python object. - """ - if isinstance(data, bytes): - return decode_json(data, decode_bytes=decode_bytes) - return decode_json(data) + """Decode JSON string or bytes to Python objects.""" + return _decode_json(data, decode_bytes=decode_bytes) def numpy_array_enc_hook(value: Any) -> Any: - """Encode NumPy array to JSON-compatible list. - - Converts NumPy ndarrays to Python lists for JSON serialization. - Gracefully handles cases where NumPy is not installed by returning - the original value unchanged. - - Args: - value: Value to encode (checked for ndarray type). - - Returns: - List representation if value is ndarray, original value otherwise. - - Example: - >>> import numpy as np - >>> arr = np.array([1.0, 2.0, 3.0]) - >>> numpy_array_enc_hook(arr) - [1.0, 2.0, 3.0] - - >>> # Multi-dimensional arrays work automatically - >>> arr_2d = np.array([[1, 2], [3, 4]]) - >>> numpy_array_enc_hook(arr_2d) - [[1, 2], [3, 4]] - """ + """Encode NumPy arrays and scalars to JSON-compatible values.""" if not NUMPY_INSTALLED: return value @@ -169,78 +129,50 @@ def numpy_array_enc_hook(value: Any) -> Any: if isinstance(value, np.ndarray): return value.tolist() + if isinstance(value, np.generic): + return value.item() return value -def numpy_array_dec_hook(value: Any) -> Any: - """Decode list to NumPy array. - - Converts Python lists to NumPy arrays when appropriate. - Works best with typed schemas (Pydantic, msgspec) that expect ndarray. - - Args: - value: List to potentially convert to ndarray. - - Returns: - NumPy array if conversion successful, original value otherwise. - - Note: - Dtype is inferred by NumPy and may differ from original array. - For explicit dtype control, construct arrays manually in application code. +def numpy_array_dec_hook(target_or_value: Any, value: Any = _NUMPY_DECODER_SENTINEL) -> Any: + """Decode JSON list payloads into NumPy arrays. - Example: - >>> numpy_array_dec_hook([1.0, 2.0, 3.0]) - array([1., 2., 3.]) - - >>> # Returns original value if NumPy not installed - >>> # (when NUMPY_INSTALLED is False) - >>> numpy_array_dec_hook([1, 2, 3]) - [1, 2, 3] + Supports both direct one-argument usage and Litestar's + ``(target_type, value)`` decoder contract. """ + if value is _NUMPY_DECODER_SENTINEL: + raw_value = target_or_value + should_decode = True + else: + raw_value = value + should_decode = numpy_array_predicate(target_or_value) + if not NUMPY_INSTALLED: - return value + return raw_value + if not should_decode or not isinstance(raw_value, list): + return raw_value import numpy as np - if isinstance(value, list): - try: - return np.array(value) - except Exception: - return value - return value + try: + return np.array(raw_value) + except Exception: + return raw_value -def numpy_array_predicate(value: Any) -> bool: - """Check if value is NumPy array instance. - - Type checker for decoder registration in framework plugins. - Returns False when NumPy is not installed. - - Args: - value: Value to type-check. - - Returns: - True if value is ndarray, False otherwise. - - Example: - >>> import numpy as np - >>> numpy_array_predicate(np.array([1, 2, 3])) - True - - >>> numpy_array_predicate([1, 2, 3]) - False - - >>> # Returns False when NumPy not installed - >>> # (when NUMPY_INSTALLED is False) - >>> numpy_array_predicate([1, 2, 3]) - False - """ +def numpy_array_predicate(value_or_target: Any) -> bool: + """Check whether a value or target type represents a NumPy array.""" if not NUMPY_INSTALLED: return False import numpy as np - return isinstance(value, np.ndarray) + if isinstance(value_or_target, type): + try: + return issubclass(value_or_target, np.ndarray) + except TypeError: + return False + return isinstance(value_or_target, np.ndarray) class SchemaSerializer: @@ -285,11 +217,15 @@ def _dump_identity_dict(value: Any) -> "dict[str, Any]": def _dump_msgspec_fields(value: Any) -> "dict[str, Any]": - return {f: value.__getattribute__(f) for f in value.__struct_fields__} + return {field_name: value.__getattribute__(field_name) for field_name in value.__struct_fields__} def _dump_msgspec_excluding_unset(value: Any) -> "dict[str, Any]": - return {f: field_value for f in value.__struct_fields__ if (field_value := value.__getattribute__(f)) != UNSET} + return { + field_name: field_value + for field_name in value.__struct_fields__ + if (field_value := value.__getattribute__(field_name)) != UNSET + } def _dump_dataclass(value: Any, *, exclude_unset: bool) -> "dict[str, Any]": @@ -315,7 +251,6 @@ def _dump_mapping(value: Any) -> "dict[str, Any]": def _build_dump_function(sample: Any, exclude_unset: bool) -> "Callable[[Any], dict[str, Any]]": if sample is None or isinstance(sample, dict): return _dump_identity_dict - if is_dataclass_instance(sample): return cast("Callable[[Any], dict[str, Any]]", partial(_dump_dataclass, exclude_unset=exclude_unset)) if is_pydantic_model(sample): @@ -324,19 +259,15 @@ def _build_dump_function(sample: Any, exclude_unset: bool) -> "Callable[[Any], d if exclude_unset: return _dump_msgspec_excluding_unset return _dump_msgspec_fields - if is_attrs_instance(sample): return _dump_attrs - if has_dict_attribute(sample): return _dump_dict_attr - return _dump_mapping def get_collection_serializer(sample: Any, *, exclude_unset: bool = True) -> "SchemaSerializer": """Return cached serializer pipeline for the provided sample object.""" - key = _make_serializer_key(sample, exclude_unset) with _SERIALIZER_LOCK: pipeline = _SCHEMA_SERIALIZERS.get(key) @@ -353,7 +284,6 @@ def get_collection_serializer(sample: Any, *, exclude_unset: bool = True) -> "Sc def serialize_collection(items: "Iterable[Any]", *, exclude_unset: bool = True) -> "list[Any]": """Serialize a collection using cached pipelines keyed by item type.""" - serialized: list[Any] = [] cache: dict[tuple[type[Any] | None, bool], SchemaSerializer] = {} @@ -373,7 +303,6 @@ def serialize_collection(items: "Iterable[Any]", *, exclude_unset: bool = True) def reset_serializer_cache() -> None: """Clear cached serializer pipelines.""" - with _SERIALIZER_LOCK: _SCHEMA_SERIALIZERS.clear() _SERIALIZER_METRICS.reset() @@ -381,7 +310,6 @@ def reset_serializer_cache() -> None: def get_serializer_metrics() -> "dict[str, int]": """Return cache metrics aligned with the core pipeline counters.""" - with _SERIALIZER_LOCK: metrics = _SERIALIZER_METRICS.snapshot() metrics["size"] = len(_SCHEMA_SERIALIZERS) @@ -389,18 +317,9 @@ def get_serializer_metrics() -> "dict[str, int]": def schema_dump(data: Any, *, exclude_unset: bool = True) -> Any: - """Dump a schema model or dict to a plain representation. - - Args: - data: Schema model instance or dictionary to dump. - exclude_unset: Whether to exclude unset fields (for models that support it). - - Returns: - A plain representation of the schema model or value. - """ + """Dump a schema model or dict to a plain representation.""" if is_dict(data): return data - if isinstance(data, _PRIMITIVE_TYPES) or data is None: return data diff --git a/sqlspec/utils/serializers/_json.py b/sqlspec/utils/serializers/_json.py new file mode 100644 index 000000000..81b75f19f --- /dev/null +++ b/sqlspec/utils/serializers/_json.py @@ -0,0 +1,286 @@ +"""Private JSON serialization engine for ``sqlspec.utils.serializers``.""" + +# ruff: noqa: PLC2801 +import contextlib +import datetime +import enum +import json +import uuid as uuid_mod +from abc import ABC, abstractmethod +from decimal import Decimal +from typing import Any, Final, Literal, Protocol, overload + +from sqlspec.core.filters import OffsetPagination +from sqlspec.typing import ( + MSGSPEC_INSTALLED, + NUMPY_INSTALLED, + ORJSON_INSTALLED, + PYDANTIC_INSTALLED, + BaseModel, + attrs_asdict, +) +from sqlspec.utils.type_guards import dataclass_to_dict, is_attrs_instance, is_dataclass_instance, is_msgspec_struct +from sqlspec.utils.uuids import UUID_UTILS_INSTALLED, _load_uuid_utils + + +def _get_uuid_utils_type() -> "type[Any] | None": + if not UUID_UTILS_INSTALLED: + return None + module = _load_uuid_utils() + if module is None: + return None + return module.UUID # type: ignore[no-any-return] + + +_UUID_UTILS_TYPE: "type[Any] | None" = _get_uuid_utils_type() + + +def convert_datetime_to_gmt_iso(value: datetime.datetime) -> str: + """Normalize datetime values to ISO 8601 strings.""" + if value.tzinfo is None: + value = value.replace(tzinfo=datetime.timezone.utc) + return value.isoformat().replace("+00:00", "Z") + + +def convert_date_to_iso(value: datetime.date) -> str: + """Normalize date values to ISO 8601 strings.""" + return value.isoformat() + + +def _dump_pydantic_model(value: Any) -> Any: + if hasattr(value, "model_dump"): + return value.model_dump() + return value.dict() + + +def _dump_msgspec_struct(value: Any) -> "dict[str, Any]": + return {field_name: value.__getattribute__(field_name) for field_name in value.__struct_fields__} + + +def _normalize_numpy_value(value: Any) -> Any: + if not NUMPY_INSTALLED: + return value + + import numpy as np + + if isinstance(value, np.ndarray): + return value.tolist() + if isinstance(value, np.generic): + return value.item() + return value + + +def _normalize_supported_value(value: Any) -> Any: + """Convert supported non-native values into JSON-compatible objects.""" + if isinstance(value, datetime.datetime): + return convert_datetime_to_gmt_iso(value) + if isinstance(value, datetime.date): + return convert_date_to_iso(value) + if isinstance(value, uuid_mod.UUID): + return str(value) + if _UUID_UTILS_TYPE is not None and isinstance(value, _UUID_UTILS_TYPE): + return str(value) + if isinstance(value, Decimal): + return float(value) + if isinstance(value, enum.Enum): + return value.value + if isinstance(value, OffsetPagination): + return {"items": value.items, "limit": value.limit, "offset": value.offset, "total": value.total} + if PYDANTIC_INSTALLED and isinstance(value, BaseModel): + return _dump_pydantic_model(value) + if is_dataclass_instance(value): + return dataclass_to_dict(value) + if is_attrs_instance(value): + return attrs_asdict(value, recurse=True) + if is_msgspec_struct(value): + return _dump_msgspec_struct(value) + numpy_value = _normalize_numpy_value(value) + if numpy_value is not value: + return numpy_value + + msg = f"unsupported JSON value: {type(value).__name__}" + raise TypeError(msg) + + +def _is_explicit_unsupported_error(exc: Exception) -> bool: + return "unsupported json value" in str(exc).lower() + + +class JSONSerializer(Protocol): + """Protocol for JSON serializer implementations.""" + + def encode(self, data: Any, *, as_bytes: bool = False) -> str | bytes: + """Encode Python data into JSON.""" + ... + + def decode(self, data: str | bytes, *, decode_bytes: bool = True) -> Any: + """Decode JSON into Python data.""" + ... + + +class BaseJSONSerializer(ABC): + """Base class shared by JSON serializer implementations.""" + + __slots__ = () + + @abstractmethod + def encode(self, data: Any, *, as_bytes: bool = False) -> str | bytes: + """Encode Python data into JSON.""" + ... + + @abstractmethod + def decode(self, data: str | bytes, *, decode_bytes: bool = True) -> Any: + """Decode JSON into Python data.""" + ... + + +_orjson_fallback: "OrjsonSerializer | None" = None +_stdlib_fallback: "StandardLibSerializer | None" = None +_default_serializer: JSONSerializer | None = None + + +def _get_orjson_fallback() -> "OrjsonSerializer": + global _orjson_fallback + if _orjson_fallback is None: + _orjson_fallback = OrjsonSerializer() + return _orjson_fallback + + +def _get_stdlib_fallback() -> "StandardLibSerializer": + global _stdlib_fallback + if _stdlib_fallback is None: + _stdlib_fallback = StandardLibSerializer() + return _stdlib_fallback + + +class MsgspecSerializer(BaseJSONSerializer): + """Msgspec-based JSON serializer.""" + + __slots__ = ("_decoder", "_encoder") + + def __init__(self) -> None: + from msgspec.json import Decoder, Encoder + + self._encoder: Final[Encoder] = Encoder(enc_hook=_normalize_supported_value) + self._decoder: Final[Decoder] = Decoder() + + def encode(self, data: Any, *, as_bytes: bool = False) -> str | bytes: + try: + encoded = self._encoder.encode(data) + except TypeError as exc: + if _is_explicit_unsupported_error(exc): + raise + if ORJSON_INSTALLED: + return _get_orjson_fallback().encode(data, as_bytes=as_bytes) + return _get_stdlib_fallback().encode(data, as_bytes=as_bytes) + except ValueError: + if ORJSON_INSTALLED: + return _get_orjson_fallback().encode(data, as_bytes=as_bytes) + return _get_stdlib_fallback().encode(data, as_bytes=as_bytes) + return encoded if as_bytes else encoded.decode("utf-8") + + def decode(self, data: str | bytes, *, decode_bytes: bool = True) -> Any: + if isinstance(data, bytes): + if not decode_bytes: + return data + try: + return self._decoder.decode(data) + except (TypeError, ValueError): + if ORJSON_INSTALLED: + return _get_orjson_fallback().decode(data, decode_bytes=decode_bytes) + return _get_stdlib_fallback().decode(data, decode_bytes=decode_bytes) + try: + return self._decoder.decode(data.encode("utf-8")) + except (TypeError, ValueError): + if ORJSON_INSTALLED: + return _get_orjson_fallback().decode(data, decode_bytes=decode_bytes) + return _get_stdlib_fallback().decode(data, decode_bytes=decode_bytes) + + +class OrjsonSerializer(BaseJSONSerializer): + """Orjson-based JSON serializer.""" + + __slots__ = () + + def encode(self, data: Any, *, as_bytes: bool = False) -> str | bytes: + from orjson import OPT_NAIVE_UTC, OPT_SERIALIZE_UUID + from orjson import dumps as orjson_dumps # pyright: ignore[reportMissingImports] + + options = OPT_NAIVE_UTC | OPT_SERIALIZE_UUID + if NUMPY_INSTALLED: + from orjson import OPT_SERIALIZE_NUMPY + + options |= OPT_SERIALIZE_NUMPY + + try: + encoded = orjson_dumps(data, default=_normalize_supported_value, option=options) + except TypeError as exc: + if _is_explicit_unsupported_error(exc): + raise + if "type is not json serializable" in str(exc).lower(): + unsupported_msg = "unsupported JSON value" + raise TypeError(unsupported_msg) from exc + raise + return encoded if as_bytes else encoded.decode("utf-8") + + def decode(self, data: str | bytes, *, decode_bytes: bool = True) -> Any: + from orjson import loads as orjson_loads # pyright: ignore[reportMissingImports] + + if isinstance(data, bytes): + if not decode_bytes: + return data + return orjson_loads(data) + return orjson_loads(data) + + +class StandardLibSerializer(BaseJSONSerializer): + """Standard library JSON serializer fallback.""" + + __slots__ = () + + def encode(self, data: Any, *, as_bytes: bool = False) -> str | bytes: + encoded = json.dumps(data, default=_normalize_supported_value) + return encoded.encode("utf-8") if as_bytes else encoded + + def decode(self, data: str | bytes, *, decode_bytes: bool = True) -> Any: + if isinstance(data, bytes): + if not decode_bytes: + return data + return json.loads(data.decode("utf-8")) + return json.loads(data) + + +def get_default_serializer() -> JSONSerializer: + """Return the best available JSON serializer.""" + global _default_serializer + + if _default_serializer is None: + if MSGSPEC_INSTALLED: + with contextlib.suppress(ImportError): + _default_serializer = MsgspecSerializer() + if _default_serializer is None and ORJSON_INSTALLED: + with contextlib.suppress(ImportError): + _default_serializer = OrjsonSerializer() + if _default_serializer is None: + _default_serializer = StandardLibSerializer() + + assert _default_serializer is not None + return _default_serializer + + +@overload +def encode_json(data: Any, *, as_bytes: Literal[False] = ...) -> str: ... + + +@overload +def encode_json(data: Any, *, as_bytes: Literal[True]) -> bytes: ... + + +def encode_json(data: Any, *, as_bytes: bool = False) -> str | bytes: + """Encode Python data into JSON.""" + return get_default_serializer().encode(data, as_bytes=as_bytes) + + +def decode_json(data: str | bytes, *, decode_bytes: bool = True) -> Any: + """Decode JSON input into Python data.""" + return get_default_serializer().decode(data, decode_bytes=decode_bytes) diff --git a/tests/unit/core/test_type_conversion.py b/tests/unit/core/test_type_conversion.py index 5b3f7d0c7..b474399db 100644 --- a/tests/unit/core/test_type_conversion.py +++ b/tests/unit/core/test_type_conversion.py @@ -11,7 +11,7 @@ import pytest -import sqlspec._serialization +import sqlspec.utils.serializers._json as json_serialization from sqlspec.core import ( BaseTypeConverter, convert_decimal, @@ -240,7 +240,7 @@ def test_convert_json_avoids_serializer_dispatch(monkeypatch: pytest.MonkeyPatch def fail_get_default_serializer() -> None: raise AssertionError("convert_json should not call serializer selection") - monkeypatch.setattr(sqlspec._serialization, "get_default_serializer", fail_get_default_serializer) + monkeypatch.setattr(json_serialization, "get_default_serializer", fail_get_default_serializer) result = convert_json('{"key": "value"}') diff --git a/tests/unit/utils/test_serialization.py b/tests/unit/utils/test_serialization.py index 07dce221e..c5991039f 100644 --- a/tests/unit/utils/test_serialization.py +++ b/tests/unit/utils/test_serialization.py @@ -1,7 +1,7 @@ -"""Tests for enhanced serialization functionality. +"""Tests for public JSON serialization functionality. -Tests for the byte-aware serialization system, including performance -improvements and compatibility with msgspec/orjson fallback patterns. +Tests for the byte-aware serialization system through the canonical +``sqlspec.utils.serializers`` surface. """ import json @@ -10,7 +10,8 @@ import pytest -from sqlspec._serialization import decode_json, encode_json +from sqlspec.utils.serializers import from_json as decode_json +from sqlspec.utils.serializers import to_json as encode_json def test_encode_json_as_string() -> None: diff --git a/tests/unit/utils/test_serializers.py b/tests/unit/utils/test_serializers.py index e22e351e2..cdf6b6648 100644 --- a/tests/unit/utils/test_serializers.py +++ b/tests/unit/utils/test_serializers.py @@ -1,7 +1,6 @@ """Tests for sqlspec.utils.serializers module. -Tests for JSON serialization utilities that are re-exported from sqlspec._serialization. -Covers all serialization scenarios including edge cases and type handling. +Tests for the canonical JSON serialization surface and contract-level regressions. """ import json @@ -455,7 +454,7 @@ def test_parametrized_round_trip(test_input: Any) -> None: def test_imports_work_correctly() -> None: - """Test that the imports from _serialization module work correctly.""" + """Test that the canonical serializer imports round-trip correctly.""" assert callable(to_json) assert callable(from_json) @@ -495,6 +494,27 @@ def test_error_messages_are_helpful() -> None: assert any(word in error_msg for word in ["json", "decode", "parse", "invalid", "expect", "malformed"]) +def test_to_json_embeds_pydantic_models_as_objects() -> None: + """Pydantic models should normalize to plain objects, not JSON strings.""" + + pydantic = pytest.importorskip("pydantic") + + class Payload(pydantic.BaseModel): + identifier: int + label: str + + result = to_json({"payload": Payload(identifier=1, label="alpha")}) + + assert json.loads(result) == {"payload": {"identifier": 1, "label": "alpha"}} + + +def test_to_json_raises_type_error_for_unsupported_objects() -> None: + """Unsupported objects should fail explicitly instead of stringifying.""" + + with pytest.raises(TypeError, match="unsupported"): + to_json({"payload": object()}) + + numpy_available = pytest.importorskip("numpy", reason="NumPy not installed") @@ -608,6 +628,18 @@ def test_numpy_array_dec_hook_non_list() -> None: assert numpy_array_dec_hook(None) is None +@pytest.mark.skipif(not numpy_available, reason="NumPy not installed") +def test_numpy_array_dec_hook_supports_litestar_decoder_signature() -> None: + """The decoder hook should accept Litestar's ``(target_type, value)`` contract.""" + + import numpy as np + + decoded = numpy_array_dec_hook(np.ndarray, [1.0, 2.0, 3.0]) + + assert isinstance(decoded, np.ndarray) + assert np.array_equal(decoded, np.array([1.0, 2.0, 3.0])) + + @pytest.mark.skipif(not numpy_available, reason="NumPy not installed") def test_numpy_array_predicate_basic() -> None: """Test NumPy array predicate for type checking.""" @@ -615,6 +647,7 @@ def test_numpy_array_predicate_basic() -> None: arr = np.array([1, 2, 3]) assert numpy_array_predicate(arr) is True + assert numpy_array_predicate(np.ndarray) is True assert numpy_array_predicate([1, 2, 3]) is False assert numpy_array_predicate("string") is False From 3b432d86223c9f64b5b737c6842d6d1551eca45b Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 20 Mar 2026 19:38:43 +0000 Subject: [PATCH 2/3] feat: add serialization and schema dumping utilities for NumPy and enhance logging capabilities --- pyproject.toml | 4 + sqlspec/utils/deprecation.py | 91 +++----- sqlspec/utils/logging.py | 11 +- sqlspec/utils/serializers/__init__.py | 324 +------------------------- sqlspec/utils/serializers/_numpy.py | 62 +++++ sqlspec/utils/serializers/_schema.py | 226 ++++++++++++++++++ tests/unit/core/test_cache.py | 22 +- 7 files changed, 352 insertions(+), 388 deletions(-) create mode 100644 sqlspec/utils/serializers/_numpy.py create mode 100644 sqlspec/utils/serializers/_schema.py diff --git a/pyproject.toml b/pyproject.toml index c89bf466f..cd9e7484a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -220,6 +220,10 @@ include = [ "sqlspec/utils/sync_tools.py", # Synchronous utility functions "sqlspec/utils/type_guards.py", # Type guard utilities "sqlspec/utils/fixtures.py", # File fixture loading + "sqlspec/utils/config_tools.py", # Configuration utilities + "sqlspec/utils/deprecation.py", # Deprecation helpers + "sqlspec/utils/dispatch.py", # Dispatch helpers + "sqlspec/utils/logging.py", # Logging helpers "sqlspec/utils/serializers/**/*.py", # Serialization helpers package "sqlspec/utils/type_converters.py", # Adapter type converters "sqlspec/utils/correlation.py", # Correlation context helpers diff --git a/sqlspec/utils/deprecation.py b/sqlspec/utils/deprecation.py index 908731192..132785228 100644 --- a/sqlspec/utils/deprecation.py +++ b/sqlspec/utils/deprecation.py @@ -6,7 +6,7 @@ import inspect from collections.abc import Callable -from functools import WRAPPER_ASSIGNMENTS, WRAPPER_UPDATES +from functools import wraps from typing import Generic, Literal, cast from warnings import warn @@ -104,40 +104,13 @@ def deprecated( ) -class _DeprecatedWrapper(Generic[P, T]): - __slots__ = ("__dict__", "_alternative", "_func", "_info", "_kind", "_pending", "_removal_in", "_version") - - def __init__( - self, - func: Callable[P, T], - *, - version: str, - removal_in: str | None, - alternative: str | None, - info: str | None, - pending: bool, - kind: Literal["function", "method", "classmethod", "property"] | None, - ) -> None: - self._func = func - self._version = version - self._removal_in = removal_in - self._alternative = alternative - self._info = info - self._pending = pending - self._kind = kind - - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: - kind = cast("DeprecatedKind", self._kind or ("method" if inspect.ismethod(self._func) else "function")) - warn_deprecation( - version=self._version, - deprecated_name=self._func.__name__, - info=self._info, - alternative=self._alternative, - pending=self._pending, - removal_in=self._removal_in, - kind=kind, - ) - return self._func(*args, **kwargs) +def _infer_deprecated_kind(func: Callable[P, T]) -> DeprecatedKind: + if inspect.ismethod(func): + return "method" + qualname = getattr(func, "__qualname__", "") + if inspect.isfunction(func) and "." in qualname and "" not in qualname: + return "method" + return "function" class _DeprecatedFactory(Generic[P, T]): @@ -161,30 +134,24 @@ def __init__( self._kind: Literal["function", "method", "classmethod", "property"] | None = kind def __call__(self, func: Callable[P, T]) -> Callable[P, T]: - wrapper = _DeprecatedWrapper( - func, - version=self._version, - removal_in=self._removal_in, - alternative=self._alternative, - info=self._info, - pending=self._pending, - kind=self._kind, - ) - return cast("Callable[P, T]", _copy_wrapper_metadata(wrapper, func)) - - -def _copy_wrapper_metadata(wrapper: "_DeprecatedWrapper[P, T]", func: "Callable[P, T]") -> "_DeprecatedWrapper[P, T]": - assignments: tuple[str, ...] = tuple(WRAPPER_ASSIGNMENTS) - updates: tuple[str, ...] = tuple(WRAPPER_UPDATES) - wrapper_dict = wrapper.__dict__ - wrapper_dict.update(getattr(func, "__dict__", {})) - for attr in assignments: - if hasattr(func, attr): - wrapper_dict[attr] = getattr(func, attr) - wrapper_dict["__wrapped__"] = func - for attr in updates: - if attr == "__dict__": - continue - if hasattr(func, attr): - wrapper_dict[attr] = getattr(func, attr) - return wrapper + version = self._version + removal_in = self._removal_in + alternative = self._alternative + info = self._info + pending = self._pending + kind = cast("DeprecatedKind", self._kind or _infer_deprecated_kind(func)) + + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + warn_deprecation( + version=version, + deprecated_name=func.__name__, + info=info, + alternative=alternative, + pending=pending, + removal_in=removal_in, + kind=kind, + ) + return func(*args, **kwargs) + + return wrapper diff --git a/sqlspec/utils/logging.py b/sqlspec/utils/logging.py index 3522ad941..e51566f10 100644 --- a/sqlspec/utils/logging.py +++ b/sqlspec/utils/logging.py @@ -117,17 +117,16 @@ def format(self, record: LogRecord) -> str: Returns: JSON formatted log entry """ + record_dict = record.__dict__ log_entry = { "timestamp": self.formatTime(record, self.datefmt), "level": record.levelname, - "logger": record.name, + "logger": cast("str | None", record_dict.get("name")), "message": record.getMessage(), - "module": record.module, - "function": record.funcName, - "line": record.lineno, + "module": cast("str | None", record_dict.get("module")), + "function": cast("str | None", record_dict.get("funcName")), + "line": cast("int | None", record_dict.get("lineno")), } - - record_dict = record.__dict__ correlation_id = cast("str | None", record_dict.get("correlation_id")) or get_correlation_id() if correlation_id: log_entry["correlation_id"] = correlation_id diff --git a/sqlspec/utils/serializers/__init__.py b/sqlspec/utils/serializers/__init__.py index b671d48f9..d9f1c47aa 100644 --- a/sqlspec/utils/serializers/__init__.py +++ b/sqlspec/utils/serializers/__init__.py @@ -1,31 +1,17 @@ -"""Serialization utilities for SQLSpec. - -Provides the canonical public serialization surface, schema dump helpers, -optional NumPy hooks, and serializer-cache instrumentation. -""" - -import os -from functools import partial -from threading import RLock -from typing import TYPE_CHECKING, Any, Final, Literal, cast, overload - -from sqlspec.typing import NUMPY_INSTALLED, UNSET, ArrowReturnFormat, attrs_asdict -from sqlspec.utils.arrow_helpers import convert_dict_to_arrow -from sqlspec.utils.serializers._json import decode_json as _decode_json -from sqlspec.utils.serializers._json import encode_json as _encode_json -from sqlspec.utils.type_guards import ( - dataclass_to_dict, - has_dict_attribute, - is_attrs_instance, - is_dataclass_instance, - is_dict, - is_msgspec_struct, - is_pydantic_model, +"""Serialization utilities for SQLSpec.""" + +from sqlspec.utils.serializers._json import decode_json as from_json +from sqlspec.utils.serializers._json import encode_json as to_json +from sqlspec.utils.serializers._numpy import numpy_array_dec_hook, numpy_array_enc_hook, numpy_array_predicate +from sqlspec.utils.serializers._schema import ( + SchemaSerializer, + get_collection_serializer, + get_serializer_metrics, + reset_serializer_cache, + schema_dump, + serialize_collection, ) -if TYPE_CHECKING: - from collections.abc import Callable, Iterable - __all__ = ( "SchemaSerializer", "from_json", @@ -39,289 +25,3 @@ "serialize_collection", "to_json", ) - -DEBUG_ENV_FLAG: Final[str] = "SQLSPEC_DEBUG_PIPELINE_CACHE" -_PRIMITIVE_TYPES: Final[tuple[type[Any], ...]] = (str, bytes, int, float, bool) -_NUMPY_DECODER_SENTINEL: Final[object] = object() - - -def _is_truthy(value: "str | None") -> bool: - if value is None: - return False - normalized = value.strip().lower() - return normalized in {"1", "true", "yes", "on"} - - -def _metrics_enabled() -> bool: - return _is_truthy(os.getenv(DEBUG_ENV_FLAG)) - - -class _SerializerCacheMetrics: - __slots__ = ("hits", "max_size", "misses", "size") - - def __init__(self) -> None: - self.hits = 0 - self.misses = 0 - self.size = 0 - self.max_size = 0 - - def record_hit(self, cache_size: int) -> None: - if not _metrics_enabled(): - return - self.hits += 1 - self.size = cache_size - self.max_size = max(self.max_size, cache_size) - - def record_miss(self, cache_size: int) -> None: - if not _metrics_enabled(): - return - self.misses += 1 - self.size = cache_size - self.max_size = max(self.max_size, cache_size) - - def reset(self) -> None: - self.hits = 0 - self.misses = 0 - self.size = 0 - self.max_size = 0 - - def snapshot(self) -> "dict[str, int]": - return { - "hits": self.hits if _metrics_enabled() else 0, - "misses": self.misses if _metrics_enabled() else 0, - "max_size": self.max_size if _metrics_enabled() else 0, - "size": self.size if _metrics_enabled() else 0, - } - - -@overload -def to_json(data: Any, *, as_bytes: Literal[False] = ...) -> str: ... - - -@overload -def to_json(data: Any, *, as_bytes: Literal[True]) -> bytes: ... - - -def to_json(data: Any, *, as_bytes: bool = False) -> str | bytes: - """Encode data to JSON string or bytes.""" - return _encode_json(data, as_bytes=as_bytes) - - -@overload -def from_json(data: str) -> Any: ... - - -@overload -def from_json(data: bytes, *, decode_bytes: bool = ...) -> Any: ... - - -def from_json(data: str | bytes, *, decode_bytes: bool = True) -> Any: - """Decode JSON string or bytes to Python objects.""" - return _decode_json(data, decode_bytes=decode_bytes) - - -def numpy_array_enc_hook(value: Any) -> Any: - """Encode NumPy arrays and scalars to JSON-compatible values.""" - if not NUMPY_INSTALLED: - return value - - import numpy as np - - if isinstance(value, np.ndarray): - return value.tolist() - if isinstance(value, np.generic): - return value.item() - return value - - -def numpy_array_dec_hook(target_or_value: Any, value: Any = _NUMPY_DECODER_SENTINEL) -> Any: - """Decode JSON list payloads into NumPy arrays. - - Supports both direct one-argument usage and Litestar's - ``(target_type, value)`` decoder contract. - """ - if value is _NUMPY_DECODER_SENTINEL: - raw_value = target_or_value - should_decode = True - else: - raw_value = value - should_decode = numpy_array_predicate(target_or_value) - - if not NUMPY_INSTALLED: - return raw_value - if not should_decode or not isinstance(raw_value, list): - return raw_value - - import numpy as np - - try: - return np.array(raw_value) - except Exception: - return raw_value - - -def numpy_array_predicate(value_or_target: Any) -> bool: - """Check whether a value or target type represents a NumPy array.""" - if not NUMPY_INSTALLED: - return False - - import numpy as np - - if isinstance(value_or_target, type): - try: - return issubclass(value_or_target, np.ndarray) - except TypeError: - return False - return isinstance(value_or_target, np.ndarray) - - -class SchemaSerializer: - """Serializer pipeline that caches conversions for repeated schema dumps.""" - - __slots__ = ("_dump", "_key") - - def __init__(self, key: "tuple[type[Any] | None, bool]", dump: "Callable[[Any], dict[str, Any]]") -> None: - self._key = key - self._dump = dump - - @property - def key(self) -> "tuple[type[Any] | None, bool]": - return self._key - - def dump_one(self, item: Any) -> "dict[str, Any]": - return self._dump(item) - - def dump_many(self, items: "Iterable[Any]") -> "list[dict[str, Any]]": - return [self._dump(item) for item in items] - - def to_arrow( - self, items: "Iterable[Any]", *, return_format: "ArrowReturnFormat" = "table", batch_size: int | None = None - ) -> Any: - payload = self.dump_many(items) - return convert_dict_to_arrow(payload, return_format=return_format, batch_size=batch_size) - - -_SERIALIZER_LOCK: RLock = RLock() -_SCHEMA_SERIALIZERS: dict[tuple[type[Any] | None, bool], SchemaSerializer] = {} -_SERIALIZER_METRICS = _SerializerCacheMetrics() - - -def _make_serializer_key(sample: Any, exclude_unset: bool) -> "tuple[type[Any] | None, bool]": - if sample is None or isinstance(sample, dict): - return (None, exclude_unset) - return (type(sample), exclude_unset) - - -def _dump_identity_dict(value: Any) -> "dict[str, Any]": - return cast("dict[str, Any]", value) - - -def _dump_msgspec_fields(value: Any) -> "dict[str, Any]": - return {field_name: value.__getattribute__(field_name) for field_name in value.__struct_fields__} - - -def _dump_msgspec_excluding_unset(value: Any) -> "dict[str, Any]": - return { - field_name: field_value - for field_name in value.__struct_fields__ - if (field_value := value.__getattribute__(field_name)) != UNSET - } - - -def _dump_dataclass(value: Any, *, exclude_unset: bool) -> "dict[str, Any]": - return dataclass_to_dict(value, exclude_empty=exclude_unset) - - -def _dump_pydantic(value: Any, *, exclude_unset: bool) -> "dict[str, Any]": - return cast("dict[str, Any]", value.model_dump(exclude_unset=exclude_unset)) - - -def _dump_attrs(value: Any) -> "dict[str, Any]": - return attrs_asdict(value, recurse=True) - - -def _dump_dict_attr(value: Any) -> "dict[str, Any]": - return dict(value.__dict__) - - -def _dump_mapping(value: Any) -> "dict[str, Any]": - return dict(value) - - -def _build_dump_function(sample: Any, exclude_unset: bool) -> "Callable[[Any], dict[str, Any]]": - if sample is None or isinstance(sample, dict): - return _dump_identity_dict - if is_dataclass_instance(sample): - return cast("Callable[[Any], dict[str, Any]]", partial(_dump_dataclass, exclude_unset=exclude_unset)) - if is_pydantic_model(sample): - return cast("Callable[[Any], dict[str, Any]]", partial(_dump_pydantic, exclude_unset=exclude_unset)) - if is_msgspec_struct(sample): - if exclude_unset: - return _dump_msgspec_excluding_unset - return _dump_msgspec_fields - if is_attrs_instance(sample): - return _dump_attrs - if has_dict_attribute(sample): - return _dump_dict_attr - return _dump_mapping - - -def get_collection_serializer(sample: Any, *, exclude_unset: bool = True) -> "SchemaSerializer": - """Return cached serializer pipeline for the provided sample object.""" - key = _make_serializer_key(sample, exclude_unset) - with _SERIALIZER_LOCK: - pipeline = _SCHEMA_SERIALIZERS.get(key) - if pipeline is not None: - _SERIALIZER_METRICS.record_hit(len(_SCHEMA_SERIALIZERS)) - return pipeline - - dump = _build_dump_function(sample, exclude_unset) - pipeline = SchemaSerializer(key, dump) - _SCHEMA_SERIALIZERS[key] = pipeline - _SERIALIZER_METRICS.record_miss(len(_SCHEMA_SERIALIZERS)) - return pipeline - - -def serialize_collection(items: "Iterable[Any]", *, exclude_unset: bool = True) -> "list[Any]": - """Serialize a collection using cached pipelines keyed by item type.""" - serialized: list[Any] = [] - cache: dict[tuple[type[Any] | None, bool], SchemaSerializer] = {} - - for item in items: - if isinstance(item, _PRIMITIVE_TYPES) or item is None or isinstance(item, dict): - serialized.append(item) - continue - - key = _make_serializer_key(item, exclude_unset) - pipeline = cache.get(key) - if pipeline is None: - pipeline = get_collection_serializer(item, exclude_unset=exclude_unset) - cache[key] = pipeline - serialized.append(pipeline.dump_one(item)) - return serialized - - -def reset_serializer_cache() -> None: - """Clear cached serializer pipelines.""" - with _SERIALIZER_LOCK: - _SCHEMA_SERIALIZERS.clear() - _SERIALIZER_METRICS.reset() - - -def get_serializer_metrics() -> "dict[str, int]": - """Return cache metrics aligned with the core pipeline counters.""" - with _SERIALIZER_LOCK: - metrics = _SERIALIZER_METRICS.snapshot() - metrics["size"] = len(_SCHEMA_SERIALIZERS) - return metrics - - -def schema_dump(data: Any, *, exclude_unset: bool = True) -> Any: - """Dump a schema model or dict to a plain representation.""" - if is_dict(data): - return data - if isinstance(data, _PRIMITIVE_TYPES) or data is None: - return data - - serializer = get_collection_serializer(data, exclude_unset=exclude_unset) - return serializer.dump_one(data) diff --git a/sqlspec/utils/serializers/_numpy.py b/sqlspec/utils/serializers/_numpy.py new file mode 100644 index 000000000..9ab026c45 --- /dev/null +++ b/sqlspec/utils/serializers/_numpy.py @@ -0,0 +1,62 @@ +"""NumPy serialization helpers for ``sqlspec.utils.serializers``.""" + +from typing import Any, Final + +from sqlspec.typing import NUMPY_INSTALLED + +_NUMPY_DECODER_SENTINEL: Final[object] = object() + + +def numpy_array_enc_hook(value: Any) -> Any: + """Encode NumPy arrays and scalars to JSON-compatible values.""" + if not NUMPY_INSTALLED: + return value + + import numpy as np + + if isinstance(value, np.ndarray): + return value.tolist() + if isinstance(value, np.generic): + return value.item() + return value + + +def numpy_array_dec_hook(target_or_value: Any, value: Any = _NUMPY_DECODER_SENTINEL) -> Any: + """Decode JSON list payloads into NumPy arrays. + + Supports both direct one-argument usage and Litestar's + ``(target_type, value)`` decoder contract. + """ + if value is _NUMPY_DECODER_SENTINEL: + raw_value = target_or_value + should_decode = True + else: + raw_value = value + should_decode = numpy_array_predicate(target_or_value) + + if not NUMPY_INSTALLED: + return raw_value + if not should_decode or not isinstance(raw_value, list): + return raw_value + + import numpy as np + + try: + return np.array(raw_value) + except Exception: + return raw_value + + +def numpy_array_predicate(value_or_target: Any) -> bool: + """Check whether a value or target type represents a NumPy array.""" + if not NUMPY_INSTALLED: + return False + + import numpy as np + + if isinstance(value_or_target, type): + try: + return issubclass(value_or_target, np.ndarray) + except TypeError: + return False + return isinstance(value_or_target, np.ndarray) diff --git a/sqlspec/utils/serializers/_schema.py b/sqlspec/utils/serializers/_schema.py new file mode 100644 index 000000000..543ef37c7 --- /dev/null +++ b/sqlspec/utils/serializers/_schema.py @@ -0,0 +1,226 @@ +"""Schema dumping and cache helpers for ``sqlspec.utils.serializers``.""" + +import os +from functools import partial +from threading import RLock +from typing import TYPE_CHECKING, Any, Final, cast + +from sqlspec.typing import UNSET, ArrowReturnFormat, attrs_asdict +from sqlspec.utils.arrow_helpers import convert_dict_to_arrow +from sqlspec.utils.type_guards import ( + dataclass_to_dict, + has_dict_attribute, + is_attrs_instance, + is_dataclass_instance, + is_dict, + is_msgspec_struct, + is_pydantic_model, +) + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + + +DEBUG_ENV_FLAG: Final[str] = "SQLSPEC_DEBUG_PIPELINE_CACHE" +_PRIMITIVE_TYPES: Final[tuple[type[Any], ...]] = (str, bytes, int, float, bool) + + +def _is_truthy(value: "str | None") -> bool: + if value is None: + return False + normalized = value.strip().lower() + return normalized in {"1", "true", "yes", "on"} + + +def _metrics_enabled() -> bool: + return _is_truthy(os.getenv(DEBUG_ENV_FLAG)) + + +class _SerializerCacheMetrics: + __slots__ = ("hits", "max_size", "misses", "size") + + def __init__(self) -> None: + self.hits = 0 + self.misses = 0 + self.size = 0 + self.max_size = 0 + + def record_hit(self, cache_size: int) -> None: + if not _metrics_enabled(): + return + self.hits += 1 + self.size = cache_size + self.max_size = max(self.max_size, cache_size) + + def record_miss(self, cache_size: int) -> None: + if not _metrics_enabled(): + return + self.misses += 1 + self.size = cache_size + self.max_size = max(self.max_size, cache_size) + + def reset(self) -> None: + self.hits = 0 + self.misses = 0 + self.size = 0 + self.max_size = 0 + + def snapshot(self) -> "dict[str, int]": + return { + "hits": self.hits if _metrics_enabled() else 0, + "misses": self.misses if _metrics_enabled() else 0, + "max_size": self.max_size if _metrics_enabled() else 0, + "size": self.size if _metrics_enabled() else 0, + } + + +class SchemaSerializer: + """Serializer pipeline that caches conversions for repeated schema dumps.""" + + __slots__ = ("_dump", "_key") + + def __init__(self, key: "tuple[type[Any] | None, bool]", dump: "Callable[[Any], dict[str, Any]]") -> None: + self._key = key + self._dump = dump + + @property + def key(self) -> "tuple[type[Any] | None, bool]": + return self._key + + def dump_one(self, item: Any) -> "dict[str, Any]": + return self._dump(item) + + def dump_many(self, items: "Iterable[Any]") -> "list[dict[str, Any]]": + return [self._dump(item) for item in items] + + def to_arrow( + self, items: "Iterable[Any]", *, return_format: "ArrowReturnFormat" = "table", batch_size: int | None = None + ) -> Any: + payload = self.dump_many(items) + return convert_dict_to_arrow(payload, return_format=return_format, batch_size=batch_size) + + +_SERIALIZER_LOCK: RLock = RLock() +_SCHEMA_SERIALIZERS: dict[tuple[type[Any] | None, bool], SchemaSerializer] = {} +_SERIALIZER_METRICS = _SerializerCacheMetrics() + + +def _make_serializer_key(sample: Any, exclude_unset: bool) -> "tuple[type[Any] | None, bool]": + if sample is None or isinstance(sample, dict): + return (None, exclude_unset) + return (type(sample), exclude_unset) + + +def _dump_identity_dict(value: Any) -> "dict[str, Any]": + return cast("dict[str, Any]", value) + + +def _dump_msgspec_fields(value: Any) -> "dict[str, Any]": + return {field_name: value.__getattribute__(field_name) for field_name in value.__struct_fields__} + + +def _dump_msgspec_excluding_unset(value: Any) -> "dict[str, Any]": + return { + field_name: field_value + for field_name in value.__struct_fields__ + if (field_value := value.__getattribute__(field_name)) != UNSET + } + + +def _dump_dataclass(value: Any, *, exclude_unset: bool) -> "dict[str, Any]": + return dataclass_to_dict(value, exclude_empty=exclude_unset) + + +def _dump_pydantic(value: Any, *, exclude_unset: bool) -> "dict[str, Any]": + return cast("dict[str, Any]", value.model_dump(exclude_unset=exclude_unset)) + + +def _dump_attrs(value: Any) -> "dict[str, Any]": + return attrs_asdict(value, recurse=True) + + +def _dump_dict_attr(value: Any) -> "dict[str, Any]": + return dict(value.__dict__) + + +def _dump_mapping(value: Any) -> "dict[str, Any]": + return dict(value) + + +def _build_dump_function(sample: Any, exclude_unset: bool) -> "Callable[[Any], dict[str, Any]]": + if sample is None or isinstance(sample, dict): + return _dump_identity_dict + if is_dataclass_instance(sample): + return cast("Callable[[Any], dict[str, Any]]", partial(_dump_dataclass, exclude_unset=exclude_unset)) + if is_pydantic_model(sample): + return cast("Callable[[Any], dict[str, Any]]", partial(_dump_pydantic, exclude_unset=exclude_unset)) + if is_msgspec_struct(sample): + if exclude_unset: + return _dump_msgspec_excluding_unset + return _dump_msgspec_fields + if is_attrs_instance(sample): + return _dump_attrs + if has_dict_attribute(sample): + return _dump_dict_attr + return _dump_mapping + + +def get_collection_serializer(sample: Any, *, exclude_unset: bool = True) -> "SchemaSerializer": + """Return cached serializer pipeline for the provided sample object.""" + key = _make_serializer_key(sample, exclude_unset) + with _SERIALIZER_LOCK: + pipeline = _SCHEMA_SERIALIZERS.get(key) + if pipeline is not None: + _SERIALIZER_METRICS.record_hit(len(_SCHEMA_SERIALIZERS)) + return pipeline + + dump = _build_dump_function(sample, exclude_unset) + pipeline = SchemaSerializer(key, dump) + _SCHEMA_SERIALIZERS[key] = pipeline + _SERIALIZER_METRICS.record_miss(len(_SCHEMA_SERIALIZERS)) + return pipeline + + +def serialize_collection(items: "Iterable[Any]", *, exclude_unset: bool = True) -> "list[Any]": + """Serialize a collection using cached pipelines keyed by item type.""" + serialized: list[Any] = [] + cache: dict[tuple[type[Any] | None, bool], SchemaSerializer] = {} + + for item in items: + if isinstance(item, _PRIMITIVE_TYPES) or item is None or isinstance(item, dict): + serialized.append(item) + continue + + key = _make_serializer_key(item, exclude_unset) + pipeline = cache.get(key) + if pipeline is None: + pipeline = get_collection_serializer(item, exclude_unset=exclude_unset) + cache[key] = pipeline + serialized.append(pipeline.dump_one(item)) + return serialized + + +def reset_serializer_cache() -> None: + """Clear cached serializer pipelines.""" + with _SERIALIZER_LOCK: + _SCHEMA_SERIALIZERS.clear() + _SERIALIZER_METRICS.reset() + + +def get_serializer_metrics() -> "dict[str, int]": + """Return cache metrics aligned with the core pipeline counters.""" + with _SERIALIZER_LOCK: + metrics = _SERIALIZER_METRICS.snapshot() + metrics["size"] = len(_SCHEMA_SERIALIZERS) + return metrics + + +def schema_dump(data: Any, *, exclude_unset: bool = True) -> Any: + """Dump a schema model or dict to a plain representation.""" + if is_dict(data): + return data + if isinstance(data, _PRIMITIVE_TYPES) or data is None: + return data + + serializer = get_collection_serializer(data, exclude_unset=exclude_unset) + return serializer.dump_one(data) diff --git a/tests/unit/core/test_cache.py b/tests/unit/core/test_cache.py index 635494d98..92661f7ac 100644 --- a/tests/unit/core/test_cache.py +++ b/tests/unit/core/test_cache.py @@ -16,9 +16,10 @@ across the entire SQLSpec system. """ +import logging import threading import time -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest @@ -590,16 +591,21 @@ def test_reset_cache_stats_function() -> None: assert multi_stats.misses == 0 -def test_log_cache_stats_function() -> None: +def test_log_cache_stats_function(caplog: pytest.LogCaptureFixture) -> None: """Test logging cache statistics.""" - with patch("sqlspec.core.cache.get_logger") as mock_get_logger: - mock_logger = MagicMock() - mock_get_logger.return_value = mock_logger + reset_cache_stats() + caplog.set_level(logging.DEBUG, logger="sqlspec.cache") - log_cache_stats() + log_cache_stats() - mock_get_logger.assert_called_once_with("sqlspec.cache") - mock_logger.log.assert_called_once() + records = [record for record in caplog.records if record.name == "sqlspec.cache" and record.msg == "cache.stats"] + assert len(records) == 1 + extra_fields = getattr(records[0], "extra_fields", {}) + assert isinstance(extra_fields, dict) + stats = extra_fields.get("stats") + assert isinstance(stats, dict) + assert "default" in stats + assert "namespaced" in stats def test_namespaced_cache_interface() -> None: From 6235f222c871838b9a7568b5ad8229454c26b281 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 20 Mar 2026 21:20:25 +0000 Subject: [PATCH 3/3] fix: address serializer mypyc follow-ups --- sqlspec/extensions/litestar/plugin.py | 4 ++-- sqlspec/utils/deprecation.py | 2 +- sqlspec/utils/serializers/_json.py | 14 +++++++++++++- sqlspec/utils/serializers/_numpy.py | 3 +++ sqlspec/utils/serializers/_schema.py | 9 +++++++++ tests/unit/utils/test_serializers.py | 7 +++++-- 6 files changed, 33 insertions(+), 6 deletions(-) diff --git a/sqlspec/extensions/litestar/plugin.py b/sqlspec/extensions/litestar/plugin.py index c92abd1cc..a2371f89a 100644 --- a/sqlspec/extensions/litestar/plugin.py +++ b/sqlspec/extensions/litestar/plugin.py @@ -424,10 +424,10 @@ def store_sqlspec_in_state() -> None: app_config.type_encoders = encoders_dict if app_config.type_decoders is None: - app_config.type_decoders = [(numpy_array_predicate, numpy_array_dec_hook)] # type: ignore[list-item] + app_config.type_decoders = [(numpy_array_predicate, numpy_array_dec_hook)] else: decoders_list = list(app_config.type_decoders) - decoders_list.append((numpy_array_predicate, numpy_array_dec_hook)) # type: ignore[arg-type] + decoders_list.append((numpy_array_predicate, numpy_array_dec_hook)) app_config.type_decoders = decoders_list if self._correlation_headers: diff --git a/sqlspec/utils/deprecation.py b/sqlspec/utils/deprecation.py index 132785228..f93c28e92 100644 --- a/sqlspec/utils/deprecation.py +++ b/sqlspec/utils/deprecation.py @@ -139,7 +139,7 @@ def __call__(self, func: Callable[P, T]) -> Callable[P, T]: alternative = self._alternative info = self._info pending = self._pending - kind = cast("DeprecatedKind", self._kind or _infer_deprecated_kind(func)) + kind: DeprecatedKind = self._kind or _infer_deprecated_kind(func) @wraps(func) def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: diff --git a/sqlspec/utils/serializers/_json.py b/sqlspec/utils/serializers/_json.py index 81b75f19f..b5bc911b9 100644 --- a/sqlspec/utils/serializers/_json.py +++ b/sqlspec/utils/serializers/_json.py @@ -1,6 +1,5 @@ """Private JSON serialization engine for ``sqlspec.utils.serializers``.""" -# ruff: noqa: PLC2801 import contextlib import datetime import enum @@ -22,6 +21,19 @@ from sqlspec.utils.type_guards import dataclass_to_dict, is_attrs_instance, is_dataclass_instance, is_msgspec_struct from sqlspec.utils.uuids import UUID_UTILS_INSTALLED, _load_uuid_utils +__all__ = ( + "BaseJSONSerializer", + "JSONSerializer", + "MsgspecSerializer", + "OrjsonSerializer", + "StandardLibSerializer", + "convert_date_to_iso", + "convert_datetime_to_gmt_iso", + "decode_json", + "encode_json", + "get_default_serializer", +) + def _get_uuid_utils_type() -> "type[Any] | None": if not UUID_UTILS_INSTALLED: diff --git a/sqlspec/utils/serializers/_numpy.py b/sqlspec/utils/serializers/_numpy.py index 9ab026c45..40d448588 100644 --- a/sqlspec/utils/serializers/_numpy.py +++ b/sqlspec/utils/serializers/_numpy.py @@ -4,6 +4,9 @@ from sqlspec.typing import NUMPY_INSTALLED +__all__ = ("numpy_array_dec_hook", "numpy_array_enc_hook", "numpy_array_predicate") + + _NUMPY_DECODER_SENTINEL: Final[object] = object() diff --git a/sqlspec/utils/serializers/_schema.py b/sqlspec/utils/serializers/_schema.py index 543ef37c7..4878dc60f 100644 --- a/sqlspec/utils/serializers/_schema.py +++ b/sqlspec/utils/serializers/_schema.py @@ -20,6 +20,15 @@ if TYPE_CHECKING: from collections.abc import Callable, Iterable +__all__ = ( + "SchemaSerializer", + "get_collection_serializer", + "get_serializer_metrics", + "reset_serializer_cache", + "schema_dump", + "serialize_collection", +) + DEBUG_ENV_FLAG: Final[str] = "SQLSPEC_DEBUG_PIPELINE_CACHE" _PRIMITIVE_TYPES: Final[tuple[type[Any], ...]] = (str, bytes, int, float, bool) diff --git a/tests/unit/utils/test_serializers.py b/tests/unit/utils/test_serializers.py index cdf6b6648..e57ca6763 100644 --- a/tests/unit/utils/test_serializers.py +++ b/tests/unit/utils/test_serializers.py @@ -497,9 +497,12 @@ def test_error_messages_are_helpful() -> None: def test_to_json_embeds_pydantic_models_as_objects() -> None: """Pydantic models should normalize to plain objects, not JSON strings.""" - pydantic = pytest.importorskip("pydantic") + try: + from pydantic import BaseModel + except ImportError: + pytest.skip("pydantic not installed") - class Payload(pydantic.BaseModel): + class Payload(BaseModel): identifier: int label: str