diff --git a/pyproject.toml b/pyproject.toml index 4a67af3c0..cd9e7484a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -220,7 +220,11 @@ include = [ "sqlspec/utils/sync_tools.py", # Synchronous utility functions "sqlspec/utils/type_guards.py", # Type guard utilities "sqlspec/utils/fixtures.py", # File fixture loading - "sqlspec/utils/serializers.py", # Serialization helpers + "sqlspec/utils/config_tools.py", # Configuration utilities + "sqlspec/utils/deprecation.py", # Deprecation helpers + "sqlspec/utils/dispatch.py", # Dispatch helpers + "sqlspec/utils/logging.py", # Logging helpers + "sqlspec/utils/serializers/**/*.py", # Serialization helpers package "sqlspec/utils/type_converters.py", # Adapter type converters "sqlspec/utils/correlation.py", # Correlation context helpers "sqlspec/utils/portal.py", # Thread portal utilities diff --git a/sqlspec/_serialization.py b/sqlspec/_serialization.py deleted file mode 100644 index fe4af788e..000000000 --- a/sqlspec/_serialization.py +++ /dev/null @@ -1,356 +0,0 @@ -"""Enhanced serialization module with byte-aware encoding and class-based architecture. - -Provides a Protocol-based serialization system that users can extend. -Supports msgspec, orjson, and standard library JSON with automatic fallback. - -Features optional numpy array serialization when numpy is installed. -Arrays are automatically converted to lists during JSON encoding. -""" - -import contextlib -import datetime -import enum -import json -import uuid as _uuid_mod -from abc import ABC, abstractmethod -from decimal import Decimal -from typing import Any, Final, Literal, Protocol, overload - -from sqlspec._typing import NUMPY_INSTALLED, UUID_UTILS_INSTALLED -from sqlspec.core.filters import OffsetPagination -from sqlspec.typing import MSGSPEC_INSTALLED, ORJSON_INSTALLED, PYDANTIC_INSTALLED, BaseModel - - -def _get_uuid_utils_type() -> "type[Any] | None": - if not UUID_UTILS_INSTALLED: - return None - try: - import uuid_utils as _uuid_utils_mod # pyright: ignore[reportMissingImports] - except ImportError: - return None - else: - return _uuid_utils_mod.UUID # type: ignore[no-any-return,unused-ignore] - - -_UUID_UTILS_TYPE: "type[Any] | None" = _get_uuid_utils_type() - - -def _type_to_string(value: Any) -> Any: # pragma: no cover - """Convert special types to strings for JSON serialization. - - Handles datetime, date, enums, Decimal, Pydantic models, and numpy arrays. - - Args: - value: Value to convert. - - Returns: - Serializable representation of the value (string, list, dict, etc.). - - Raises: - TypeError: If value cannot be serialized. - """ - if isinstance(value, datetime.datetime): - return convert_datetime_to_gmt_iso(value) - if isinstance(value, datetime.date): - return convert_date_to_iso(value) - if isinstance(value, Decimal): - return float(value) - if isinstance(value, enum.Enum): - return str(value.value) - if PYDANTIC_INSTALLED and isinstance(value, BaseModel): - return value.model_dump_json() - if isinstance(value, _uuid_mod.UUID): - return str(value) - if _UUID_UTILS_TYPE is not None and isinstance(value, _UUID_UTILS_TYPE): - return str(value) - if isinstance(value, OffsetPagination): - return {"items": value.items, "limit": value.limit, "offset": value.offset, "total": value.total} - if NUMPY_INSTALLED: - import numpy as np - - if isinstance(value, np.ndarray): - return value.tolist() - try: - return str(value) - except Exception as exc: - msg = f"Cannot serialize {type(value).__name__}" - raise TypeError(msg) from exc - - -class JSONSerializer(Protocol): - """Protocol for JSON serialization implementations. - - Users can implement this protocol to create custom serializers. - """ - - def encode(self, data: Any, *, as_bytes: bool = False) -> str | bytes: - """Encode data to JSON. - - Args: - data: Data to encode. - as_bytes: Whether to return bytes instead of string. - - Returns: - JSON string or bytes depending on as_bytes parameter. - """ - ... - - def decode(self, data: str | bytes, *, decode_bytes: bool = True) -> Any: - """Decode from JSON. - - Args: - data: JSON string or bytes to decode. - decode_bytes: Whether to decode bytes input. - - Returns: - Decoded Python object. - """ - ... - - -class BaseJSONSerializer(ABC): - """Base class for JSON serializers with common functionality.""" - - __slots__ = () - - @abstractmethod - def encode(self, data: Any, *, as_bytes: bool = False) -> str | bytes: - """Encode data to JSON.""" - ... - - @abstractmethod - def decode(self, data: str | bytes, *, decode_bytes: bool = True) -> Any: - """Decode from JSON.""" - ... - - -# Module-level singleton fallback serializers for performance -# These avoid creating new instances on every fallback call in MsgspecSerializer -_orjson_fallback: "OrjsonSerializer | None" = None -_stdlib_fallback: "StandardLibSerializer | None" = None - - -def _get_orjson_fallback() -> "OrjsonSerializer": - """Get singleton OrjsonSerializer instance for fallback use.""" - global _orjson_fallback - if _orjson_fallback is None: - _orjson_fallback = OrjsonSerializer() - return _orjson_fallback - - -def _get_stdlib_fallback() -> "StandardLibSerializer": - """Get singleton StandardLibSerializer instance for fallback use.""" - global _stdlib_fallback - if _stdlib_fallback is None: - _stdlib_fallback = StandardLibSerializer() - return _stdlib_fallback - - -class MsgspecSerializer(BaseJSONSerializer): - """Msgspec-based JSON serializer.""" - - __slots__ = ("_decoder", "_encoder") - - def __init__(self) -> None: - """Initialize msgspec encoder and decoder.""" - from msgspec.json import Decoder, Encoder - - self._encoder: Final[Encoder] = Encoder(enc_hook=_type_to_string) - self._decoder: Final[Decoder] = Decoder() - - def encode(self, data: Any, *, as_bytes: bool = False) -> str | bytes: - """Encode data using msgspec.""" - try: - if as_bytes: - return self._encoder.encode(data) - return self._encoder.encode(data).decode("utf-8") - except (TypeError, ValueError): - if ORJSON_INSTALLED: - return _get_orjson_fallback().encode(data, as_bytes=as_bytes) - return _get_stdlib_fallback().encode(data, as_bytes=as_bytes) - - def decode(self, data: str | bytes, *, decode_bytes: bool = True) -> Any: - """Decode data using msgspec.""" - if isinstance(data, bytes): - if decode_bytes: - try: - return self._decoder.decode(data) - except (TypeError, ValueError): - if ORJSON_INSTALLED: - return _get_orjson_fallback().decode(data, decode_bytes=decode_bytes) - return _get_stdlib_fallback().decode(data, decode_bytes=decode_bytes) - return data - - try: - return self._decoder.decode(data.encode("utf-8")) - except (TypeError, ValueError): - if ORJSON_INSTALLED: - return _get_orjson_fallback().decode(data, decode_bytes=decode_bytes) - return _get_stdlib_fallback().decode(data, decode_bytes=decode_bytes) - - -class OrjsonSerializer(BaseJSONSerializer): - """Orjson-based JSON serializer with native datetime/UUID support. - - Automatically enables numpy serialization if numpy is installed. - """ - - __slots__ = () - - def encode(self, data: Any, *, as_bytes: bool = False) -> str | bytes: - """Encode data using orjson. - - Args: - data: Data to encode. - as_bytes: Whether to return bytes instead of string. - - Returns: - JSON string or bytes depending on as_bytes parameter. - """ - from orjson import ( - OPT_NAIVE_UTC, # pyright: ignore[reportUnknownVariableType] - OPT_SERIALIZE_UUID, # pyright: ignore[reportUnknownVariableType] - ) - from orjson import dumps as _orjson_dumps # pyright: ignore[reportMissingImports] - - options = OPT_NAIVE_UTC | OPT_SERIALIZE_UUID - - if NUMPY_INSTALLED: - from orjson import OPT_SERIALIZE_NUMPY # pyright: ignore[reportUnknownVariableType] - - options |= OPT_SERIALIZE_NUMPY - - result = _orjson_dumps(data, default=_type_to_string, option=options) - return result if as_bytes else result.decode("utf-8") - - def decode(self, data: str | bytes, *, decode_bytes: bool = True) -> Any: - """Decode data using orjson.""" - from orjson import loads as _orjson_loads # pyright: ignore[reportMissingImports] - - if isinstance(data, bytes): - if decode_bytes: - return _orjson_loads(data) - return data - return _orjson_loads(data) - - -class StandardLibSerializer(BaseJSONSerializer): - """Standard library JSON serializer as fallback.""" - - __slots__ = () - - def encode(self, data: Any, *, as_bytes: bool = False) -> str | bytes: - """Encode data using standard library json.""" - json_str = json.dumps(data, default=_type_to_string) - return json_str.encode("utf-8") if as_bytes else json_str - - def decode(self, data: str | bytes, *, decode_bytes: bool = True) -> Any: - """Decode data using standard library json.""" - if isinstance(data, bytes): - if decode_bytes: - return json.loads(data.decode("utf-8")) - return data - return json.loads(data) - - -_default_serializer: JSONSerializer | None = None - - -def get_default_serializer() -> JSONSerializer: - """Get the default serializer based on available libraries. - - Priority: msgspec > orjson > stdlib - - Returns: - The best available JSON serializer. - """ - global _default_serializer - - if _default_serializer is None: - if MSGSPEC_INSTALLED: - with contextlib.suppress(ImportError): - _default_serializer = MsgspecSerializer() - - if _default_serializer is None and ORJSON_INSTALLED: - with contextlib.suppress(ImportError): - _default_serializer = OrjsonSerializer() - - if _default_serializer is None: - _default_serializer = StandardLibSerializer() - - assert _default_serializer is not None - return _default_serializer - - -@overload -def encode_json(data: Any, *, as_bytes: Literal[False] = ...) -> str: ... # pragma: no cover - - -@overload -def encode_json(data: Any, *, as_bytes: Literal[True]) -> bytes: ... # pragma: no cover - - -def encode_json(data: Any, *, as_bytes: bool = False) -> str | bytes: - """Encode to JSON, optionally returning bytes. - - Args: - data: The data to encode. - as_bytes: Whether to return bytes instead of string. - - Returns: - JSON string or bytes depending on as_bytes parameter. - """ - return get_default_serializer().encode(data, as_bytes=as_bytes) - - -def decode_json(data: str | bytes, *, decode_bytes: bool = True) -> Any: - """Decode from JSON string or bytes efficiently. - - Args: - data: JSON string or bytes to decode. - decode_bytes: Whether to decode bytes input. - - Returns: - Decoded Python object. - """ - return get_default_serializer().decode(data, decode_bytes=decode_bytes) - - -def convert_datetime_to_gmt_iso(dt: datetime.datetime) -> str: # pragma: no cover - """Handle datetime serialization for nested timestamps. - - Args: - dt: The datetime to convert. - - Returns: - The ISO formatted datetime string. - """ - if not dt.tzinfo: - dt = dt.replace(tzinfo=datetime.timezone.utc) - return dt.isoformat().replace("+00:00", "Z") - - -def convert_date_to_iso(dt: datetime.date) -> str: # pragma: no cover - """Handle datetime serialization for nested timestamps. - - Args: - dt: The date to convert. - - Returns: - The ISO formatted date string. - """ - return dt.isoformat() - - -__all__ = ( - "BaseJSONSerializer", - "JSONSerializer", - "MsgspecSerializer", - "OrjsonSerializer", - "StandardLibSerializer", - "convert_date_to_iso", - "convert_datetime_to_gmt_iso", - "decode_json", - "encode_json", - "get_default_serializer", -) diff --git a/sqlspec/extensions/litestar/plugin.py b/sqlspec/extensions/litestar/plugin.py index c92abd1cc..a2371f89a 100644 --- a/sqlspec/extensions/litestar/plugin.py +++ b/sqlspec/extensions/litestar/plugin.py @@ -424,10 +424,10 @@ def store_sqlspec_in_state() -> None: app_config.type_encoders = encoders_dict if app_config.type_decoders is None: - app_config.type_decoders = [(numpy_array_predicate, numpy_array_dec_hook)] # type: ignore[list-item] + app_config.type_decoders = [(numpy_array_predicate, numpy_array_dec_hook)] else: decoders_list = list(app_config.type_decoders) - decoders_list.append((numpy_array_predicate, numpy_array_dec_hook)) # type: ignore[arg-type] + decoders_list.append((numpy_array_predicate, numpy_array_dec_hook)) app_config.type_decoders = decoders_list if self._correlation_headers: diff --git a/sqlspec/utils/deprecation.py b/sqlspec/utils/deprecation.py index 908731192..f93c28e92 100644 --- a/sqlspec/utils/deprecation.py +++ b/sqlspec/utils/deprecation.py @@ -6,7 +6,7 @@ import inspect from collections.abc import Callable -from functools import WRAPPER_ASSIGNMENTS, WRAPPER_UPDATES +from functools import wraps from typing import Generic, Literal, cast from warnings import warn @@ -104,40 +104,13 @@ def deprecated( ) -class _DeprecatedWrapper(Generic[P, T]): - __slots__ = ("__dict__", "_alternative", "_func", "_info", "_kind", "_pending", "_removal_in", "_version") - - def __init__( - self, - func: Callable[P, T], - *, - version: str, - removal_in: str | None, - alternative: str | None, - info: str | None, - pending: bool, - kind: Literal["function", "method", "classmethod", "property"] | None, - ) -> None: - self._func = func - self._version = version - self._removal_in = removal_in - self._alternative = alternative - self._info = info - self._pending = pending - self._kind = kind - - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: - kind = cast("DeprecatedKind", self._kind or ("method" if inspect.ismethod(self._func) else "function")) - warn_deprecation( - version=self._version, - deprecated_name=self._func.__name__, - info=self._info, - alternative=self._alternative, - pending=self._pending, - removal_in=self._removal_in, - kind=kind, - ) - return self._func(*args, **kwargs) +def _infer_deprecated_kind(func: Callable[P, T]) -> DeprecatedKind: + if inspect.ismethod(func): + return "method" + qualname = getattr(func, "__qualname__", "") + if inspect.isfunction(func) and "." in qualname and "" not in qualname: + return "method" + return "function" class _DeprecatedFactory(Generic[P, T]): @@ -161,30 +134,24 @@ def __init__( self._kind: Literal["function", "method", "classmethod", "property"] | None = kind def __call__(self, func: Callable[P, T]) -> Callable[P, T]: - wrapper = _DeprecatedWrapper( - func, - version=self._version, - removal_in=self._removal_in, - alternative=self._alternative, - info=self._info, - pending=self._pending, - kind=self._kind, - ) - return cast("Callable[P, T]", _copy_wrapper_metadata(wrapper, func)) - - -def _copy_wrapper_metadata(wrapper: "_DeprecatedWrapper[P, T]", func: "Callable[P, T]") -> "_DeprecatedWrapper[P, T]": - assignments: tuple[str, ...] = tuple(WRAPPER_ASSIGNMENTS) - updates: tuple[str, ...] = tuple(WRAPPER_UPDATES) - wrapper_dict = wrapper.__dict__ - wrapper_dict.update(getattr(func, "__dict__", {})) - for attr in assignments: - if hasattr(func, attr): - wrapper_dict[attr] = getattr(func, attr) - wrapper_dict["__wrapped__"] = func - for attr in updates: - if attr == "__dict__": - continue - if hasattr(func, attr): - wrapper_dict[attr] = getattr(func, attr) - return wrapper + version = self._version + removal_in = self._removal_in + alternative = self._alternative + info = self._info + pending = self._pending + kind: DeprecatedKind = self._kind or _infer_deprecated_kind(func) + + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + warn_deprecation( + version=version, + deprecated_name=func.__name__, + info=info, + alternative=alternative, + pending=pending, + removal_in=removal_in, + kind=kind, + ) + return func(*args, **kwargs) + + return wrapper diff --git a/sqlspec/utils/logging.py b/sqlspec/utils/logging.py index a8a9754ac..e51566f10 100644 --- a/sqlspec/utils/logging.py +++ b/sqlspec/utils/logging.py @@ -9,9 +9,9 @@ from logging import LogRecord from typing import TYPE_CHECKING, Any, cast -from sqlspec._serialization import encode_json from sqlspec.utils.correlation import CorrelationContext from sqlspec.utils.correlation import correlation_id_var as _correlation_id_var +from sqlspec.utils.serializers import to_json if TYPE_CHECKING: from contextvars import ContextVar @@ -117,17 +117,16 @@ def format(self, record: LogRecord) -> str: Returns: JSON formatted log entry """ + record_dict = record.__dict__ log_entry = { "timestamp": self.formatTime(record, self.datefmt), "level": record.levelname, - "logger": record.name, + "logger": cast("str | None", record_dict.get("name")), "message": record.getMessage(), - "module": record.module, - "function": record.funcName, - "line": record.lineno, + "module": cast("str | None", record_dict.get("module")), + "function": cast("str | None", record_dict.get("funcName")), + "line": cast("int | None", record_dict.get("lineno")), } - - record_dict = record.__dict__ correlation_id = cast("str | None", record_dict.get("correlation_id")) or get_correlation_id() if correlation_id: log_entry["correlation_id"] = correlation_id @@ -155,7 +154,7 @@ def format(self, record: LogRecord) -> str: if record.exc_info: log_entry["exception"] = self.formatException(record.exc_info) - return encode_json(log_entry) + return to_json(log_entry) class CorrelationIDFilter(logging.Filter): diff --git a/sqlspec/utils/serializers/__init__.py b/sqlspec/utils/serializers/__init__.py new file mode 100644 index 000000000..d9f1c47aa --- /dev/null +++ b/sqlspec/utils/serializers/__init__.py @@ -0,0 +1,27 @@ +"""Serialization utilities for SQLSpec.""" + +from sqlspec.utils.serializers._json import decode_json as from_json +from sqlspec.utils.serializers._json import encode_json as to_json +from sqlspec.utils.serializers._numpy import numpy_array_dec_hook, numpy_array_enc_hook, numpy_array_predicate +from sqlspec.utils.serializers._schema import ( + SchemaSerializer, + get_collection_serializer, + get_serializer_metrics, + reset_serializer_cache, + schema_dump, + serialize_collection, +) + +__all__ = ( + "SchemaSerializer", + "from_json", + "get_collection_serializer", + "get_serializer_metrics", + "numpy_array_dec_hook", + "numpy_array_enc_hook", + "numpy_array_predicate", + "reset_serializer_cache", + "schema_dump", + "serialize_collection", + "to_json", +) diff --git a/sqlspec/utils/serializers/_json.py b/sqlspec/utils/serializers/_json.py new file mode 100644 index 000000000..b5bc911b9 --- /dev/null +++ b/sqlspec/utils/serializers/_json.py @@ -0,0 +1,298 @@ +"""Private JSON serialization engine for ``sqlspec.utils.serializers``.""" + +import contextlib +import datetime +import enum +import json +import uuid as uuid_mod +from abc import ABC, abstractmethod +from decimal import Decimal +from typing import Any, Final, Literal, Protocol, overload + +from sqlspec.core.filters import OffsetPagination +from sqlspec.typing import ( + MSGSPEC_INSTALLED, + NUMPY_INSTALLED, + ORJSON_INSTALLED, + PYDANTIC_INSTALLED, + BaseModel, + attrs_asdict, +) +from sqlspec.utils.type_guards import dataclass_to_dict, is_attrs_instance, is_dataclass_instance, is_msgspec_struct +from sqlspec.utils.uuids import UUID_UTILS_INSTALLED, _load_uuid_utils + +__all__ = ( + "BaseJSONSerializer", + "JSONSerializer", + "MsgspecSerializer", + "OrjsonSerializer", + "StandardLibSerializer", + "convert_date_to_iso", + "convert_datetime_to_gmt_iso", + "decode_json", + "encode_json", + "get_default_serializer", +) + + +def _get_uuid_utils_type() -> "type[Any] | None": + if not UUID_UTILS_INSTALLED: + return None + module = _load_uuid_utils() + if module is None: + return None + return module.UUID # type: ignore[no-any-return] + + +_UUID_UTILS_TYPE: "type[Any] | None" = _get_uuid_utils_type() + + +def convert_datetime_to_gmt_iso(value: datetime.datetime) -> str: + """Normalize datetime values to ISO 8601 strings.""" + if value.tzinfo is None: + value = value.replace(tzinfo=datetime.timezone.utc) + return value.isoformat().replace("+00:00", "Z") + + +def convert_date_to_iso(value: datetime.date) -> str: + """Normalize date values to ISO 8601 strings.""" + return value.isoformat() + + +def _dump_pydantic_model(value: Any) -> Any: + if hasattr(value, "model_dump"): + return value.model_dump() + return value.dict() + + +def _dump_msgspec_struct(value: Any) -> "dict[str, Any]": + return {field_name: value.__getattribute__(field_name) for field_name in value.__struct_fields__} + + +def _normalize_numpy_value(value: Any) -> Any: + if not NUMPY_INSTALLED: + return value + + import numpy as np + + if isinstance(value, np.ndarray): + return value.tolist() + if isinstance(value, np.generic): + return value.item() + return value + + +def _normalize_supported_value(value: Any) -> Any: + """Convert supported non-native values into JSON-compatible objects.""" + if isinstance(value, datetime.datetime): + return convert_datetime_to_gmt_iso(value) + if isinstance(value, datetime.date): + return convert_date_to_iso(value) + if isinstance(value, uuid_mod.UUID): + return str(value) + if _UUID_UTILS_TYPE is not None and isinstance(value, _UUID_UTILS_TYPE): + return str(value) + if isinstance(value, Decimal): + return float(value) + if isinstance(value, enum.Enum): + return value.value + if isinstance(value, OffsetPagination): + return {"items": value.items, "limit": value.limit, "offset": value.offset, "total": value.total} + if PYDANTIC_INSTALLED and isinstance(value, BaseModel): + return _dump_pydantic_model(value) + if is_dataclass_instance(value): + return dataclass_to_dict(value) + if is_attrs_instance(value): + return attrs_asdict(value, recurse=True) + if is_msgspec_struct(value): + return _dump_msgspec_struct(value) + numpy_value = _normalize_numpy_value(value) + if numpy_value is not value: + return numpy_value + + msg = f"unsupported JSON value: {type(value).__name__}" + raise TypeError(msg) + + +def _is_explicit_unsupported_error(exc: Exception) -> bool: + return "unsupported json value" in str(exc).lower() + + +class JSONSerializer(Protocol): + """Protocol for JSON serializer implementations.""" + + def encode(self, data: Any, *, as_bytes: bool = False) -> str | bytes: + """Encode Python data into JSON.""" + ... + + def decode(self, data: str | bytes, *, decode_bytes: bool = True) -> Any: + """Decode JSON into Python data.""" + ... + + +class BaseJSONSerializer(ABC): + """Base class shared by JSON serializer implementations.""" + + __slots__ = () + + @abstractmethod + def encode(self, data: Any, *, as_bytes: bool = False) -> str | bytes: + """Encode Python data into JSON.""" + ... + + @abstractmethod + def decode(self, data: str | bytes, *, decode_bytes: bool = True) -> Any: + """Decode JSON into Python data.""" + ... + + +_orjson_fallback: "OrjsonSerializer | None" = None +_stdlib_fallback: "StandardLibSerializer | None" = None +_default_serializer: JSONSerializer | None = None + + +def _get_orjson_fallback() -> "OrjsonSerializer": + global _orjson_fallback + if _orjson_fallback is None: + _orjson_fallback = OrjsonSerializer() + return _orjson_fallback + + +def _get_stdlib_fallback() -> "StandardLibSerializer": + global _stdlib_fallback + if _stdlib_fallback is None: + _stdlib_fallback = StandardLibSerializer() + return _stdlib_fallback + + +class MsgspecSerializer(BaseJSONSerializer): + """Msgspec-based JSON serializer.""" + + __slots__ = ("_decoder", "_encoder") + + def __init__(self) -> None: + from msgspec.json import Decoder, Encoder + + self._encoder: Final[Encoder] = Encoder(enc_hook=_normalize_supported_value) + self._decoder: Final[Decoder] = Decoder() + + def encode(self, data: Any, *, as_bytes: bool = False) -> str | bytes: + try: + encoded = self._encoder.encode(data) + except TypeError as exc: + if _is_explicit_unsupported_error(exc): + raise + if ORJSON_INSTALLED: + return _get_orjson_fallback().encode(data, as_bytes=as_bytes) + return _get_stdlib_fallback().encode(data, as_bytes=as_bytes) + except ValueError: + if ORJSON_INSTALLED: + return _get_orjson_fallback().encode(data, as_bytes=as_bytes) + return _get_stdlib_fallback().encode(data, as_bytes=as_bytes) + return encoded if as_bytes else encoded.decode("utf-8") + + def decode(self, data: str | bytes, *, decode_bytes: bool = True) -> Any: + if isinstance(data, bytes): + if not decode_bytes: + return data + try: + return self._decoder.decode(data) + except (TypeError, ValueError): + if ORJSON_INSTALLED: + return _get_orjson_fallback().decode(data, decode_bytes=decode_bytes) + return _get_stdlib_fallback().decode(data, decode_bytes=decode_bytes) + try: + return self._decoder.decode(data.encode("utf-8")) + except (TypeError, ValueError): + if ORJSON_INSTALLED: + return _get_orjson_fallback().decode(data, decode_bytes=decode_bytes) + return _get_stdlib_fallback().decode(data, decode_bytes=decode_bytes) + + +class OrjsonSerializer(BaseJSONSerializer): + """Orjson-based JSON serializer.""" + + __slots__ = () + + def encode(self, data: Any, *, as_bytes: bool = False) -> str | bytes: + from orjson import OPT_NAIVE_UTC, OPT_SERIALIZE_UUID + from orjson import dumps as orjson_dumps # pyright: ignore[reportMissingImports] + + options = OPT_NAIVE_UTC | OPT_SERIALIZE_UUID + if NUMPY_INSTALLED: + from orjson import OPT_SERIALIZE_NUMPY + + options |= OPT_SERIALIZE_NUMPY + + try: + encoded = orjson_dumps(data, default=_normalize_supported_value, option=options) + except TypeError as exc: + if _is_explicit_unsupported_error(exc): + raise + if "type is not json serializable" in str(exc).lower(): + unsupported_msg = "unsupported JSON value" + raise TypeError(unsupported_msg) from exc + raise + return encoded if as_bytes else encoded.decode("utf-8") + + def decode(self, data: str | bytes, *, decode_bytes: bool = True) -> Any: + from orjson import loads as orjson_loads # pyright: ignore[reportMissingImports] + + if isinstance(data, bytes): + if not decode_bytes: + return data + return orjson_loads(data) + return orjson_loads(data) + + +class StandardLibSerializer(BaseJSONSerializer): + """Standard library JSON serializer fallback.""" + + __slots__ = () + + def encode(self, data: Any, *, as_bytes: bool = False) -> str | bytes: + encoded = json.dumps(data, default=_normalize_supported_value) + return encoded.encode("utf-8") if as_bytes else encoded + + def decode(self, data: str | bytes, *, decode_bytes: bool = True) -> Any: + if isinstance(data, bytes): + if not decode_bytes: + return data + return json.loads(data.decode("utf-8")) + return json.loads(data) + + +def get_default_serializer() -> JSONSerializer: + """Return the best available JSON serializer.""" + global _default_serializer + + if _default_serializer is None: + if MSGSPEC_INSTALLED: + with contextlib.suppress(ImportError): + _default_serializer = MsgspecSerializer() + if _default_serializer is None and ORJSON_INSTALLED: + with contextlib.suppress(ImportError): + _default_serializer = OrjsonSerializer() + if _default_serializer is None: + _default_serializer = StandardLibSerializer() + + assert _default_serializer is not None + return _default_serializer + + +@overload +def encode_json(data: Any, *, as_bytes: Literal[False] = ...) -> str: ... + + +@overload +def encode_json(data: Any, *, as_bytes: Literal[True]) -> bytes: ... + + +def encode_json(data: Any, *, as_bytes: bool = False) -> str | bytes: + """Encode Python data into JSON.""" + return get_default_serializer().encode(data, as_bytes=as_bytes) + + +def decode_json(data: str | bytes, *, decode_bytes: bool = True) -> Any: + """Decode JSON input into Python data.""" + return get_default_serializer().decode(data, decode_bytes=decode_bytes) diff --git a/sqlspec/utils/serializers/_numpy.py b/sqlspec/utils/serializers/_numpy.py new file mode 100644 index 000000000..40d448588 --- /dev/null +++ b/sqlspec/utils/serializers/_numpy.py @@ -0,0 +1,65 @@ +"""NumPy serialization helpers for ``sqlspec.utils.serializers``.""" + +from typing import Any, Final + +from sqlspec.typing import NUMPY_INSTALLED + +__all__ = ("numpy_array_dec_hook", "numpy_array_enc_hook", "numpy_array_predicate") + + +_NUMPY_DECODER_SENTINEL: Final[object] = object() + + +def numpy_array_enc_hook(value: Any) -> Any: + """Encode NumPy arrays and scalars to JSON-compatible values.""" + if not NUMPY_INSTALLED: + return value + + import numpy as np + + if isinstance(value, np.ndarray): + return value.tolist() + if isinstance(value, np.generic): + return value.item() + return value + + +def numpy_array_dec_hook(target_or_value: Any, value: Any = _NUMPY_DECODER_SENTINEL) -> Any: + """Decode JSON list payloads into NumPy arrays. + + Supports both direct one-argument usage and Litestar's + ``(target_type, value)`` decoder contract. + """ + if value is _NUMPY_DECODER_SENTINEL: + raw_value = target_or_value + should_decode = True + else: + raw_value = value + should_decode = numpy_array_predicate(target_or_value) + + if not NUMPY_INSTALLED: + return raw_value + if not should_decode or not isinstance(raw_value, list): + return raw_value + + import numpy as np + + try: + return np.array(raw_value) + except Exception: + return raw_value + + +def numpy_array_predicate(value_or_target: Any) -> bool: + """Check whether a value or target type represents a NumPy array.""" + if not NUMPY_INSTALLED: + return False + + import numpy as np + + if isinstance(value_or_target, type): + try: + return issubclass(value_or_target, np.ndarray) + except TypeError: + return False + return isinstance(value_or_target, np.ndarray) diff --git a/sqlspec/utils/serializers.py b/sqlspec/utils/serializers/_schema.py similarity index 59% rename from sqlspec/utils/serializers.py rename to sqlspec/utils/serializers/_schema.py index 5c1de63b9..4878dc60f 100644 --- a/sqlspec/utils/serializers.py +++ b/sqlspec/utils/serializers/_schema.py @@ -1,16 +1,11 @@ -"""Serialization utilities for SQLSpec. - -Provides JSON helpers, serializer pipelines, optional dependency hooks, -and cache instrumentation aligned with the core pipeline counters. -""" +"""Schema dumping and cache helpers for ``sqlspec.utils.serializers``.""" import os from functools import partial from threading import RLock -from typing import TYPE_CHECKING, Any, Final, Literal, cast, overload +from typing import TYPE_CHECKING, Any, Final, cast -from sqlspec._serialization import decode_json, encode_json -from sqlspec.typing import NUMPY_INSTALLED, UNSET, ArrowReturnFormat, attrs_asdict +from sqlspec.typing import UNSET, ArrowReturnFormat, attrs_asdict from sqlspec.utils.arrow_helpers import convert_dict_to_arrow from sqlspec.utils.type_guards import ( dataclass_to_dict, @@ -27,18 +22,14 @@ __all__ = ( "SchemaSerializer", - "from_json", "get_collection_serializer", "get_serializer_metrics", - "numpy_array_dec_hook", - "numpy_array_enc_hook", - "numpy_array_predicate", "reset_serializer_cache", "schema_dump", "serialize_collection", - "to_json", ) + DEBUG_ENV_FLAG: Final[str] = "SQLSPEC_DEBUG_PIPELINE_CACHE" _PRIMITIVE_TYPES: Final[tuple[type[Any], ...]] = (str, bytes, int, float, bool) @@ -92,157 +83,6 @@ def snapshot(self) -> "dict[str, int]": } -@overload -def to_json(data: Any, *, as_bytes: Literal[False] = ...) -> str: ... - - -@overload -def to_json(data: Any, *, as_bytes: Literal[True]) -> bytes: ... - - -def to_json(data: Any, *, as_bytes: bool = False) -> str | bytes: - """Encode data to JSON string or bytes. - - Args: - data: Data to encode. - as_bytes: Whether to return bytes instead of string for optimal performance. - - Returns: - JSON string or bytes representation based on as_bytes parameter. - """ - if as_bytes: - return encode_json(data, as_bytes=True) - return encode_json(data, as_bytes=False) - - -@overload -def from_json(data: str) -> Any: ... - - -@overload -def from_json(data: bytes, *, decode_bytes: bool = ...) -> Any: ... - - -def from_json(data: str | bytes, *, decode_bytes: bool = True) -> Any: - """Decode JSON string or bytes to Python object. - - Args: - data: JSON string or bytes to decode. - decode_bytes: Whether to decode bytes input (vs passing through). - - Returns: - Decoded Python object. - """ - if isinstance(data, bytes): - return decode_json(data, decode_bytes=decode_bytes) - return decode_json(data) - - -def numpy_array_enc_hook(value: Any) -> Any: - """Encode NumPy array to JSON-compatible list. - - Converts NumPy ndarrays to Python lists for JSON serialization. - Gracefully handles cases where NumPy is not installed by returning - the original value unchanged. - - Args: - value: Value to encode (checked for ndarray type). - - Returns: - List representation if value is ndarray, original value otherwise. - - Example: - >>> import numpy as np - >>> arr = np.array([1.0, 2.0, 3.0]) - >>> numpy_array_enc_hook(arr) - [1.0, 2.0, 3.0] - - >>> # Multi-dimensional arrays work automatically - >>> arr_2d = np.array([[1, 2], [3, 4]]) - >>> numpy_array_enc_hook(arr_2d) - [[1, 2], [3, 4]] - """ - if not NUMPY_INSTALLED: - return value - - import numpy as np - - if isinstance(value, np.ndarray): - return value.tolist() - return value - - -def numpy_array_dec_hook(value: Any) -> Any: - """Decode list to NumPy array. - - Converts Python lists to NumPy arrays when appropriate. - Works best with typed schemas (Pydantic, msgspec) that expect ndarray. - - Args: - value: List to potentially convert to ndarray. - - Returns: - NumPy array if conversion successful, original value otherwise. - - Note: - Dtype is inferred by NumPy and may differ from original array. - For explicit dtype control, construct arrays manually in application code. - - Example: - >>> numpy_array_dec_hook([1.0, 2.0, 3.0]) - array([1., 2., 3.]) - - >>> # Returns original value if NumPy not installed - >>> # (when NUMPY_INSTALLED is False) - >>> numpy_array_dec_hook([1, 2, 3]) - [1, 2, 3] - """ - if not NUMPY_INSTALLED: - return value - - import numpy as np - - if isinstance(value, list): - try: - return np.array(value) - except Exception: - return value - return value - - -def numpy_array_predicate(value: Any) -> bool: - """Check if value is NumPy array instance. - - Type checker for decoder registration in framework plugins. - Returns False when NumPy is not installed. - - Args: - value: Value to type-check. - - Returns: - True if value is ndarray, False otherwise. - - Example: - >>> import numpy as np - >>> numpy_array_predicate(np.array([1, 2, 3])) - True - - >>> numpy_array_predicate([1, 2, 3]) - False - - >>> # Returns False when NumPy not installed - >>> # (when NUMPY_INSTALLED is False) - >>> numpy_array_predicate([1, 2, 3]) - False - """ - if not NUMPY_INSTALLED: - return False - - import numpy as np - - return isinstance(value, np.ndarray) - - class SchemaSerializer: """Serializer pipeline that caches conversions for repeated schema dumps.""" @@ -285,11 +125,15 @@ def _dump_identity_dict(value: Any) -> "dict[str, Any]": def _dump_msgspec_fields(value: Any) -> "dict[str, Any]": - return {f: value.__getattribute__(f) for f in value.__struct_fields__} + return {field_name: value.__getattribute__(field_name) for field_name in value.__struct_fields__} def _dump_msgspec_excluding_unset(value: Any) -> "dict[str, Any]": - return {f: field_value for f in value.__struct_fields__ if (field_value := value.__getattribute__(f)) != UNSET} + return { + field_name: field_value + for field_name in value.__struct_fields__ + if (field_value := value.__getattribute__(field_name)) != UNSET + } def _dump_dataclass(value: Any, *, exclude_unset: bool) -> "dict[str, Any]": @@ -315,7 +159,6 @@ def _dump_mapping(value: Any) -> "dict[str, Any]": def _build_dump_function(sample: Any, exclude_unset: bool) -> "Callable[[Any], dict[str, Any]]": if sample is None or isinstance(sample, dict): return _dump_identity_dict - if is_dataclass_instance(sample): return cast("Callable[[Any], dict[str, Any]]", partial(_dump_dataclass, exclude_unset=exclude_unset)) if is_pydantic_model(sample): @@ -324,19 +167,15 @@ def _build_dump_function(sample: Any, exclude_unset: bool) -> "Callable[[Any], d if exclude_unset: return _dump_msgspec_excluding_unset return _dump_msgspec_fields - if is_attrs_instance(sample): return _dump_attrs - if has_dict_attribute(sample): return _dump_dict_attr - return _dump_mapping def get_collection_serializer(sample: Any, *, exclude_unset: bool = True) -> "SchemaSerializer": """Return cached serializer pipeline for the provided sample object.""" - key = _make_serializer_key(sample, exclude_unset) with _SERIALIZER_LOCK: pipeline = _SCHEMA_SERIALIZERS.get(key) @@ -353,7 +192,6 @@ def get_collection_serializer(sample: Any, *, exclude_unset: bool = True) -> "Sc def serialize_collection(items: "Iterable[Any]", *, exclude_unset: bool = True) -> "list[Any]": """Serialize a collection using cached pipelines keyed by item type.""" - serialized: list[Any] = [] cache: dict[tuple[type[Any] | None, bool], SchemaSerializer] = {} @@ -373,7 +211,6 @@ def serialize_collection(items: "Iterable[Any]", *, exclude_unset: bool = True) def reset_serializer_cache() -> None: """Clear cached serializer pipelines.""" - with _SERIALIZER_LOCK: _SCHEMA_SERIALIZERS.clear() _SERIALIZER_METRICS.reset() @@ -381,7 +218,6 @@ def reset_serializer_cache() -> None: def get_serializer_metrics() -> "dict[str, int]": """Return cache metrics aligned with the core pipeline counters.""" - with _SERIALIZER_LOCK: metrics = _SERIALIZER_METRICS.snapshot() metrics["size"] = len(_SCHEMA_SERIALIZERS) @@ -389,18 +225,9 @@ def get_serializer_metrics() -> "dict[str, int]": def schema_dump(data: Any, *, exclude_unset: bool = True) -> Any: - """Dump a schema model or dict to a plain representation. - - Args: - data: Schema model instance or dictionary to dump. - exclude_unset: Whether to exclude unset fields (for models that support it). - - Returns: - A plain representation of the schema model or value. - """ + """Dump a schema model or dict to a plain representation.""" if is_dict(data): return data - if isinstance(data, _PRIMITIVE_TYPES) or data is None: return data diff --git a/tests/unit/core/test_cache.py b/tests/unit/core/test_cache.py index 635494d98..92661f7ac 100644 --- a/tests/unit/core/test_cache.py +++ b/tests/unit/core/test_cache.py @@ -16,9 +16,10 @@ across the entire SQLSpec system. """ +import logging import threading import time -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest @@ -590,16 +591,21 @@ def test_reset_cache_stats_function() -> None: assert multi_stats.misses == 0 -def test_log_cache_stats_function() -> None: +def test_log_cache_stats_function(caplog: pytest.LogCaptureFixture) -> None: """Test logging cache statistics.""" - with patch("sqlspec.core.cache.get_logger") as mock_get_logger: - mock_logger = MagicMock() - mock_get_logger.return_value = mock_logger + reset_cache_stats() + caplog.set_level(logging.DEBUG, logger="sqlspec.cache") - log_cache_stats() + log_cache_stats() - mock_get_logger.assert_called_once_with("sqlspec.cache") - mock_logger.log.assert_called_once() + records = [record for record in caplog.records if record.name == "sqlspec.cache" and record.msg == "cache.stats"] + assert len(records) == 1 + extra_fields = getattr(records[0], "extra_fields", {}) + assert isinstance(extra_fields, dict) + stats = extra_fields.get("stats") + assert isinstance(stats, dict) + assert "default" in stats + assert "namespaced" in stats def test_namespaced_cache_interface() -> None: diff --git a/tests/unit/core/test_type_conversion.py b/tests/unit/core/test_type_conversion.py index 5b3f7d0c7..b474399db 100644 --- a/tests/unit/core/test_type_conversion.py +++ b/tests/unit/core/test_type_conversion.py @@ -11,7 +11,7 @@ import pytest -import sqlspec._serialization +import sqlspec.utils.serializers._json as json_serialization from sqlspec.core import ( BaseTypeConverter, convert_decimal, @@ -240,7 +240,7 @@ def test_convert_json_avoids_serializer_dispatch(monkeypatch: pytest.MonkeyPatch def fail_get_default_serializer() -> None: raise AssertionError("convert_json should not call serializer selection") - monkeypatch.setattr(sqlspec._serialization, "get_default_serializer", fail_get_default_serializer) + monkeypatch.setattr(json_serialization, "get_default_serializer", fail_get_default_serializer) result = convert_json('{"key": "value"}') diff --git a/tests/unit/utils/test_serialization.py b/tests/unit/utils/test_serialization.py index 07dce221e..c5991039f 100644 --- a/tests/unit/utils/test_serialization.py +++ b/tests/unit/utils/test_serialization.py @@ -1,7 +1,7 @@ -"""Tests for enhanced serialization functionality. +"""Tests for public JSON serialization functionality. -Tests for the byte-aware serialization system, including performance -improvements and compatibility with msgspec/orjson fallback patterns. +Tests for the byte-aware serialization system through the canonical +``sqlspec.utils.serializers`` surface. """ import json @@ -10,7 +10,8 @@ import pytest -from sqlspec._serialization import decode_json, encode_json +from sqlspec.utils.serializers import from_json as decode_json +from sqlspec.utils.serializers import to_json as encode_json def test_encode_json_as_string() -> None: diff --git a/tests/unit/utils/test_serializers.py b/tests/unit/utils/test_serializers.py index e22e351e2..e57ca6763 100644 --- a/tests/unit/utils/test_serializers.py +++ b/tests/unit/utils/test_serializers.py @@ -1,7 +1,6 @@ """Tests for sqlspec.utils.serializers module. -Tests for JSON serialization utilities that are re-exported from sqlspec._serialization. -Covers all serialization scenarios including edge cases and type handling. +Tests for the canonical JSON serialization surface and contract-level regressions. """ import json @@ -455,7 +454,7 @@ def test_parametrized_round_trip(test_input: Any) -> None: def test_imports_work_correctly() -> None: - """Test that the imports from _serialization module work correctly.""" + """Test that the canonical serializer imports round-trip correctly.""" assert callable(to_json) assert callable(from_json) @@ -495,6 +494,30 @@ def test_error_messages_are_helpful() -> None: assert any(word in error_msg for word in ["json", "decode", "parse", "invalid", "expect", "malformed"]) +def test_to_json_embeds_pydantic_models_as_objects() -> None: + """Pydantic models should normalize to plain objects, not JSON strings.""" + + try: + from pydantic import BaseModel + except ImportError: + pytest.skip("pydantic not installed") + + class Payload(BaseModel): + identifier: int + label: str + + result = to_json({"payload": Payload(identifier=1, label="alpha")}) + + assert json.loads(result) == {"payload": {"identifier": 1, "label": "alpha"}} + + +def test_to_json_raises_type_error_for_unsupported_objects() -> None: + """Unsupported objects should fail explicitly instead of stringifying.""" + + with pytest.raises(TypeError, match="unsupported"): + to_json({"payload": object()}) + + numpy_available = pytest.importorskip("numpy", reason="NumPy not installed") @@ -608,6 +631,18 @@ def test_numpy_array_dec_hook_non_list() -> None: assert numpy_array_dec_hook(None) is None +@pytest.mark.skipif(not numpy_available, reason="NumPy not installed") +def test_numpy_array_dec_hook_supports_litestar_decoder_signature() -> None: + """The decoder hook should accept Litestar's ``(target_type, value)`` contract.""" + + import numpy as np + + decoded = numpy_array_dec_hook(np.ndarray, [1.0, 2.0, 3.0]) + + assert isinstance(decoded, np.ndarray) + assert np.array_equal(decoded, np.array([1.0, 2.0, 3.0])) + + @pytest.mark.skipif(not numpy_available, reason="NumPy not installed") def test_numpy_array_predicate_basic() -> None: """Test NumPy array predicate for type checking.""" @@ -615,6 +650,7 @@ def test_numpy_array_predicate_basic() -> None: arr = np.array([1, 2, 3]) assert numpy_array_predicate(arr) is True + assert numpy_array_predicate(np.ndarray) is True assert numpy_array_predicate([1, 2, 3]) is False assert numpy_array_predicate("string") is False