From d0f5f3567c7eba90b2c1d9c70a6be1a55c0b9111 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 20 Mar 2026 21:57:43 +0000 Subject: [PATCH 01/23] feat(adk): core contract reset - full-event JSON storage, state persistence fix --- sqlspec/extensions/adk/_types.py | 22 +-- sqlspec/extensions/adk/converters.py | 176 +++++++++++--------- sqlspec/extensions/adk/memory/converters.py | 93 ++++++++++- sqlspec/extensions/adk/memory/service.py | 113 ++++++++++++- sqlspec/extensions/adk/service.py | 16 +- sqlspec/extensions/adk/store.py | 36 ++++ 6 files changed, 351 insertions(+), 105 deletions(-) diff --git a/sqlspec/extensions/adk/_types.py b/sqlspec/extensions/adk/_types.py index 651431651..838cc14a3 100644 --- a/sqlspec/extensions/adk/_types.py +++ b/sqlspec/extensions/adk/_types.py @@ -27,25 +27,15 @@ class SessionRecord(TypedDict): class EventRecord(TypedDict): """Database record for an event. - Represents the schema for events stored in the database. - Follows the ADK Event model plus session metadata. + Stores the full ADK Event as a single JSON blob (``event_json``) alongside + a small number of indexed scalar columns used for query filtering. + + This design eliminates column drift with upstream ADK: new Event fields are + automatically captured in ``event_json`` without schema changes. """ - id: str - app_name: str - user_id: str session_id: str invocation_id: str author: str - branch: "str | None" - actions: bytes - long_running_tool_ids_json: "str | None" timestamp: datetime - content: "dict[str, Any] | None" - grounding_metadata: "dict[str, Any] | None" - custom_metadata: "dict[str, Any] | None" - partial: "bool | None" - turn_complete: "bool | None" - interrupted: "bool | None" - error_code: "str | None" - error_message: "str | None" + event_json: str diff --git a/sqlspec/extensions/adk/converters.py b/sqlspec/extensions/adk/converters.py index d68cbc141..35304f644 100644 --- a/sqlspec/extensions/adk/converters.py +++ b/sqlspec/extensions/adk/converters.py @@ -1,20 +1,37 @@ -"""Conversion functions between ADK models and database records.""" +"""Conversion functions between ADK models and database records. + +Implements full-event JSON storage: the entire Event is serialized via +``Event.model_dump_json(exclude_none=True)`` into a single ``event_json`` +column, with a small set of indexed scalar columns extracted alongside for +query performance. Reconstruction uses ``Event.model_validate_json()``. + +Also provides scoped-state helpers that normalise ADK state prefixes +(``app:``, ``user:``, ``temp:``) so the shared service layer can split, +filter, and merge state before handing it to backend stores. +""" -import json -import pickle from datetime import datetime, timezone from typing import Any from google.adk.events.event import Event from google.adk.sessions import Session -from google.genai import types from sqlspec.extensions.adk._types import EventRecord, SessionRecord -from sqlspec.utils.logging import get_logger -logger = get_logger("sqlspec.extensions.adk.converters") +__all__ = ( + "event_to_record", + "filter_temp_state", + "merge_scoped_state", + "record_to_event", + "record_to_session", + "session_to_record", + "split_scoped_state", +) + -__all__ = ("event_to_record", "record_to_event", "record_to_session", "session_to_record") +# --------------------------------------------------------------------------- +# Session converters +# --------------------------------------------------------------------------- def session_to_record(session: "Session") -> SessionRecord: @@ -58,115 +75,122 @@ def record_to_session(record: SessionRecord, events: "list[EventRecord]") -> "Se ) -def event_to_record(event: "Event", session_id: str, app_name: str, user_id: str) -> EventRecord: - """Convert ADK Event to database record. +# --------------------------------------------------------------------------- +# Event converters (full-event JSON storage) +# --------------------------------------------------------------------------- + + +def event_to_record(event: "Event", session_id: str) -> EventRecord: + """Convert ADK Event to database record using full-event JSON storage. + + The entire Event is serialized into ``event_json`` via Pydantic's + ``model_dump_json(exclude_none=True)``. A small number of indexed scalar + columns are extracted alongside for query performance. Args: event: ADK Event object. session_id: ID of the parent session. - app_name: Name of the application. - user_id: ID of the user. Returns: EventRecord for database storage. """ - actions_bytes = pickle.dumps(event.actions) - - long_running_tool_ids_json = None - if event.long_running_tool_ids: - long_running_tool_ids_json = json.dumps(list(event.long_running_tool_ids)) - - content_dict = None - if event.content: - content_dict = event.content.model_dump(exclude_none=True, mode="json") - - grounding_metadata_dict = None - if event.grounding_metadata: - grounding_metadata_dict = event.grounding_metadata.model_dump(exclude_none=True, mode="json") - - custom_metadata_dict = event.custom_metadata - return EventRecord( - id=event.id, - app_name=app_name, - user_id=user_id, session_id=session_id, invocation_id=event.invocation_id, author=event.author, - branch=event.branch, - actions=actions_bytes, - long_running_tool_ids_json=long_running_tool_ids_json, timestamp=datetime.fromtimestamp(event.timestamp, tz=timezone.utc), - content=content_dict, - grounding_metadata=grounding_metadata_dict, - custom_metadata=custom_metadata_dict, - partial=event.partial, - turn_complete=event.turn_complete, - interrupted=event.interrupted, - error_code=event.error_code, - error_message=event.error_message, + event_json=event.model_dump_json(exclude_none=True), ) def record_to_event(record: "EventRecord") -> "Event": """Convert database record to ADK Event. + Reconstruction is lossless: the full Event is restored from + ``event_json`` via ``Event.model_validate_json()``. + Args: record: Event database record. Returns: ADK Event object. """ - actions = pickle.loads(record["actions"]) # noqa: S301 + return Event.model_validate_json(record["event_json"]) - long_running_tool_ids = None - if record["long_running_tool_ids_json"]: - long_running_tool_ids = set(json.loads(record["long_running_tool_ids_json"])) - return Event( - id=record["id"], - invocation_id=record["invocation_id"], - author=record["author"], - branch=record["branch"], - actions=actions, - timestamp=record["timestamp"].timestamp(), - content=_decode_content(record["content"]), - long_running_tool_ids=long_running_tool_ids, - partial=record["partial"], - turn_complete=record["turn_complete"], - error_code=record["error_code"], - error_message=record["error_message"], - interrupted=record["interrupted"], - grounding_metadata=_decode_grounding_metadata(record["grounding_metadata"]), - custom_metadata=record["custom_metadata"], - ) +# --------------------------------------------------------------------------- +# Scoped-state helpers +# --------------------------------------------------------------------------- + +def filter_temp_state(state: "dict[str, Any]") -> "dict[str, Any]": + """Return a copy of *state* with all ``temp:`` keys removed. -def _decode_content(content_dict: "dict[str, Any] | None") -> Any: - """Decode content dictionary from database to ADK Content object. + ``temp:`` keys are process-local/session-runtime state and must never be + written to persistent storage. Args: - content_dict: Content dictionary from database. + state: ADK state dictionary (may contain ``temp:`` prefixed keys). Returns: - ADK Content object or None. + A new dict without any ``temp:``-prefixed keys. """ - if not content_dict: - return None + return {k: v for k, v in state.items() if not k.startswith("temp:")} - return types.Content.model_validate(content_dict) +def split_scoped_state( + state: "dict[str, Any]", +) -> "tuple[dict[str, Any], dict[str, Any], dict[str, Any]]": + """Split ADK state into ``(session_local, app_scoped, user_scoped)`` dicts. -def _decode_grounding_metadata(grounding_dict: "dict[str, Any] | None") -> Any: - """Decode grounding metadata dictionary from database to ADK object. + Keys without a recognised scope prefix are session-local. ``temp:`` keys + are silently dropped (they must not be persisted). Args: - grounding_dict: Grounding metadata dictionary from database. + state: ADK state dictionary. Returns: - ADK GroundingMetadata object or None. + A 3-tuple of ``(session_local, app_scoped, user_scoped)`` dicts. + Scoped dicts retain their prefix in the key (e.g. ``"app:foo"``). """ - if not grounding_dict: - return None + session_local: dict[str, Any] = {} + app_scoped: dict[str, Any] = {} + user_scoped: dict[str, Any] = {} + + for k, v in state.items(): + if k.startswith("temp:"): + continue + elif k.startswith("app:"): + app_scoped[k] = v + elif k.startswith("user:"): + user_scoped[k] = v + else: + session_local[k] = v + + return session_local, app_scoped, user_scoped + + +def merge_scoped_state( + session_local: "dict[str, Any]", + app_scoped: "dict[str, Any]", + user_scoped: "dict[str, Any]", +) -> "dict[str, Any]": + """Merge scoped state dicts back into a single ADK-compatible state dict. + + The merge order is ``session_local | app_scoped | user_scoped`` so that + broader scopes can shadow narrower ones if keys collide (which they + normally should not, since prefixes differ). - return types.GroundingMetadata.model_validate(grounding_dict) + Args: + session_local: Session-local state (no prefix). + app_scoped: App-scoped state (``app:`` prefix). + user_scoped: User-scoped state (``user:`` prefix). + + Returns: + A single merged state dict suitable for ``Session.state``. + """ + merged: dict[str, Any] = {} + merged.update(session_local) + merged.update(app_scoped) + merged.update(user_scoped) + return merged diff --git a/sqlspec/extensions/adk/memory/converters.py b/sqlspec/extensions/adk/memory/converters.py index 1dc0dbf83..1eb8d4e1f 100644 --- a/sqlspec/extensions/adk/memory/converters.py +++ b/sqlspec/extensions/adk/memory/converters.py @@ -19,15 +19,22 @@ logger = get_logger("sqlspec.extensions.adk.memory.converters") -__all__ = ("event_to_memory_record", "extract_content_text", "record_to_memory_entry", "session_to_memory_records") +__all__ = ( + "event_to_memory_record", + "extract_content_text", + "memory_entry_to_record", + "record_to_memory_entry", + "records_to_memory_entries", + "session_to_memory_records", +) def extract_content_text(content: "types.Content") -> str: """Extract plain text from ADK Content for search indexing. Handles multi-modal Content.parts including text, function calls, - and other part types. Non-text parts are indexed by their type - for discoverability. + function responses, and other part types. Non-text parts are indexed + by their type for discoverability. Args: content: ADK Content object with parts list. @@ -44,9 +51,9 @@ def extract_content_text(content: "types.Content") -> str: if part.text: parts_text.append(part.text) elif part.function_call is not None: - parts_text.append(f"function:{part.function_call.name}") + parts_text.append(f"{part.function_call.name}: {part.function_call.args}") elif part.function_response is not None: - parts_text.append(f"response:{part.function_response.name}") + parts_text.append(f"{part.function_response.name}: {part.function_response.response}") return " ".join(parts_text) @@ -91,6 +98,69 @@ def event_to_memory_record(event: "Event", session_id: str, app_name: str, user_ ) +def memory_entry_to_record( + entry: "MemoryEntry", + app_name: str, + user_id: str, + extra_metadata: "dict[str, Any] | None" = None, +) -> "MemoryRecord | None": + """Convert an ADK MemoryEntry to a database record. + + Serializes the entry's ``content`` to ``content_json``, extracts text + from ``content.parts`` for ``content_text``, and merges entry-level + ``custom_metadata`` with the optional ``extra_metadata`` parameter. + + Args: + entry: ADK MemoryEntry object. + app_name: Name of the application. + user_id: ID of the user. + extra_metadata: Optional call-level metadata to merge with the + entry's own ``custom_metadata``. + + Returns: + MemoryRecord for database storage, or None if entry has no + indexable content. + """ + content_text = extract_content_text(entry.content) + if not content_text.strip(): + return None + + content_dict = entry.content.model_dump(exclude_none=True, mode="json") + + # Merge entry-level and call-level metadata + merged_metadata: dict[str, Any] | None = None + if entry.custom_metadata or extra_metadata: + merged_metadata = {} + if extra_metadata: + merged_metadata.update(extra_metadata) + if entry.custom_metadata: + merged_metadata.update(entry.custom_metadata) + + now = datetime.now(timezone.utc) + + # Parse timestamp from entry if available + timestamp = now + if entry.timestamp: + try: + timestamp = datetime.fromisoformat(entry.timestamp) + except (ValueError, TypeError): + timestamp = now + + return MemoryRecord( + id=entry.id or str(uuid.uuid4()), + session_id="", + app_name=app_name, + user_id=user_id, + event_id="", + author=entry.author or "", + timestamp=timestamp, + content_json=content_dict, + content_text=content_text, + metadata_json=merged_metadata, + inserted_at=now, + ) + + def session_to_memory_records(session: "Session") -> list["MemoryRecord"]: """Convert a completed ADK Session to a list of memory records. @@ -121,11 +191,14 @@ def session_to_memory_records(session: "Session") -> list["MemoryRecord"]: def record_to_memory_entry(record: "MemoryRecord") -> "MemoryEntry": """Convert a database record to an ADK MemoryEntry. + Preserves ``id`` and ``custom_metadata`` fields that were previously + dropped on readback. + Args: record: Memory database record. Returns: - ADK MemoryEntry object. + ADK MemoryEntry object with all available fields populated. """ from google.adk.memory.memory_entry import MemoryEntry from google.genai import types @@ -134,7 +207,13 @@ def record_to_memory_entry(record: "MemoryRecord") -> "MemoryEntry": timestamp_str = record["timestamp"].isoformat() if record["timestamp"] else None - return MemoryEntry(content=content, author=record["author"], timestamp=timestamp_str) + return MemoryEntry( + id=record["id"], + content=content, + author=record["author"], + timestamp=timestamp_str, + custom_metadata=record["metadata_json"] or {}, + ) def records_to_memory_entries(records: list["MemoryRecord"]) -> list["Any"]: diff --git a/sqlspec/extensions/adk/memory/service.py b/sqlspec/extensions/adk/memory/service.py index abc6360d5..dc56a9b32 100644 --- a/sqlspec/extensions/adk/memory/service.py +++ b/sqlspec/extensions/adk/memory/service.py @@ -1,13 +1,19 @@ """SQLSpec-backed memory service for Google ADK.""" +from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING from google.adk.memory.base_memory_service import BaseMemoryService, SearchMemoryResponse -from sqlspec.extensions.adk.memory.converters import records_to_memory_entries, session_to_memory_records +from sqlspec.extensions.adk.memory.converters import ( + memory_entry_to_record, + records_to_memory_entries, + session_to_memory_records, +) from sqlspec.utils.logging import get_logger if TYPE_CHECKING: + from google.adk.events.event import Event from google.adk.memory.memory_entry import MemoryEntry from google.adk.sessions import Session @@ -102,6 +108,111 @@ async def add_session_to_memory(self, session: "Session") -> None: "Stored %d memory entries for session %s (total events: %d)", inserted_count, session.id, len(records) ) + async def add_events_to_memory( + self, + *, + app_name: str, + user_id: str, + events: "Sequence[Event]", + session_id: "str | None" = None, + custom_metadata: "Mapping[str, object] | None" = None, + ) -> None: + """Add an explicit list of events to the memory service. + + Same Event-to-MemoryRecord extraction logic as + ``add_session_to_memory``, but operates on a sequence of Events + directly (no Session wrapper needed). + + Args: + app_name: The application name for memory scope. + user_id: The user ID for memory scope. + events: The events to add to memory. + session_id: Optional session ID for memory scope/partitioning. + If None, memory entries are user-scoped only. + custom_metadata: Optional portable metadata stored in + ``MemoryRecord.metadata_json``. + """ + from sqlspec.extensions.adk.memory.converters import event_to_memory_record + + metadata_dict = dict(custom_metadata) if custom_metadata else None + records = [] + for event in events: + record = event_to_memory_record( + event=event, + session_id=session_id or "", + app_name=app_name, + user_id=user_id, + ) + if record is not None: + if metadata_dict: + record["metadata_json"] = metadata_dict + records.append(record) + + if not records: + logger.debug( + "No content to store for events (app=%s, user=%s, count=%d)", + app_name, + user_id, + len(list(events)), + ) + return + + inserted_count = await self._store.insert_memory_entries(records) + logger.debug( + "Stored %d memory entries from %d events (app=%s, user=%s)", + inserted_count, + len(records), + app_name, + user_id, + ) + + async def add_memory( + self, + *, + app_name: str, + user_id: str, + memories: "Sequence[MemoryEntry]", + custom_metadata: "Mapping[str, object] | None" = None, + ) -> None: + """Add explicit memory items directly to the memory service. + + Each entry's ``content`` is serialized to ``content_json``, text is + extracted from ``content.parts`` for ``content_text``, and + ``custom_metadata`` merges the entry-level ``entry.custom_metadata`` + with the call-level ``custom_metadata`` parameter. + + Args: + app_name: The application name for memory scope. + user_id: The user ID for memory scope. + memories: Explicit memory items to add. + custom_metadata: Optional portable metadata for memory writes. + Merged with each entry's ``custom_metadata``. + """ + call_metadata = dict(custom_metadata) if custom_metadata else {} + records = [] + for entry in memories: + record = memory_entry_to_record( + entry=entry, + app_name=app_name, + user_id=user_id, + extra_metadata=call_metadata, + ) + if record is not None: + records.append(record) + + if not records: + logger.debug("No content to store for memories (app=%s, user=%s)", app_name, user_id) + return + + inserted_count = await self._store.insert_memory_entries(records) + logger.debug( + "Stored %d memory entries from %d memories (app=%s, user=%s)", + inserted_count, + len(records), + app_name, + user_id, + ) + async def search_memory(self, *, app_name: str, user_id: str, query: str) -> "SearchMemoryResponse": """Search memory entries by text query. diff --git a/sqlspec/extensions/adk/service.py b/sqlspec/extensions/adk/service.py index 8656f4beb..4b2c4886a 100644 --- a/sqlspec/extensions/adk/service.py +++ b/sqlspec/extensions/adk/service.py @@ -7,7 +7,7 @@ from google.adk.sessions.base_session_service import BaseSessionService, GetSessionConfig, ListSessionsResponse -from sqlspec.extensions.adk.converters import event_to_record, record_to_session +from sqlspec.extensions.adk.converters import event_to_record, filter_temp_state, record_to_session from sqlspec.utils.logging import get_logger, log_with_context if TYPE_CHECKING: @@ -192,6 +192,11 @@ async def delete_session(self, *, app_name: str, user_id: str, session_id: str) async def append_event(self, session: "Session", event: "Event") -> "Event": """Append an event to a session. + Persists the event record and the post-append durable state + atomically via ``store.append_event_and_update_state()``. ``temp:`` + keys are stripped from the persisted state snapshot so they never + survive a reload. + Args: session: Session to append to. event: Event to append. @@ -204,11 +209,12 @@ async def append_event(self, session: "Session", event: "Event") -> "Event": if event.partial: return event - event_record = event_to_record( - event=event, session_id=session.id, app_name=session.app_name, user_id=session.user_id - ) + event_record = event_to_record(event=event, session_id=session.id) + + # Strip temp: keys before persisting state + durable_state = filter_temp_state(session.state) - await self._store.append_event(event_record) + await self._store.append_event_and_update_state(event_record=event_record, session_id=session.id, state=durable_state) log_with_context( logger, logging.DEBUG, diff --git a/sqlspec/extensions/adk/store.py b/sqlspec/extensions/adk/store.py index 9903ee7b8..9f2ea8d1f 100644 --- a/sqlspec/extensions/adk/store.py +++ b/sqlspec/extensions/adk/store.py @@ -250,6 +250,24 @@ async def append_event(self, event_record: "EventRecord") -> None: """ raise NotImplementedError + @abstractmethod + async def append_event_and_update_state( + self, event_record: "EventRecord", session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state. + + This is the authoritative durable write boundary for post-creation + session mutations. The event insert and state update must succeed + together or fail together. + + Args: + event_record: Event record to store. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot (``temp:`` keys already + stripped by the service layer). + """ + raise NotImplementedError + @abstractmethod async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None @@ -505,6 +523,24 @@ def create_event( """ raise NotImplementedError + @abstractmethod + def create_event_and_update_state( + self, event_record: "EventRecord", session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically create an event and update the session's durable state. + + This is the authoritative durable write boundary for post-creation + session mutations. The event insert and state update must succeed + together or fail together. + + Args: + event_record: Event record to store. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot (``temp:`` keys already + stripped by the service layer). + """ + raise NotImplementedError + @abstractmethod def list_events(self, session_id: str) -> "list[EventRecord]": """List events for a session ordered by timestamp. From 74448788d9ca9b8251b70cc7fe43cae72dcd37ad Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 20 Mar 2026 21:44:27 +0000 Subject: [PATCH 02/23] test(adk): add unit tests for core contract reset Adds test_converters.py (39 tests) and test_service.py (13 tests) for Chapter 1 of the ADK Clean-Break Overhaul. Tests are written against the new contract and will fail until the production code is updated. --- .../extensions/test_adk/test_converters.py | 499 ++++++++++++++++++ .../unit/extensions/test_adk/test_service.py | 330 ++++++++++++ 2 files changed, 829 insertions(+) create mode 100644 tests/unit/extensions/test_adk/test_converters.py create mode 100644 tests/unit/extensions/test_adk/test_service.py diff --git a/tests/unit/extensions/test_adk/test_converters.py b/tests/unit/extensions/test_adk/test_converters.py new file mode 100644 index 000000000..c858c88f9 --- /dev/null +++ b/tests/unit/extensions/test_adk/test_converters.py @@ -0,0 +1,499 @@ +"""Unit tests for ADK session/event converters and scoped state helpers. + +Tests the NEW contract specified in Chapter 1 of the ADK Clean-Break Overhaul: +- EventRecord has exactly 5 keys (session_id, invocation_id, author, timestamp, event_json) +- event_to_record takes only (event, session_id), not (event, session_id, app_name, user_id) +- record_to_event uses Event.model_validate for full round-trip fidelity +- filter_temp_state, split_scoped_state, merge_scoped_state for scoped state handling +- session_to_record strips temp: keys from state +""" + +import importlib.util +from datetime import datetime, timezone + +import pytest + +if importlib.util.find_spec("google.genai") is None or importlib.util.find_spec("google.adk") is None: + pytest.skip("google-adk not installed", allow_module_level=True) + +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.sessions.session import Session +from google.genai import types + +from sqlspec.extensions.adk.converters import ( + event_to_record, + filter_temp_state, + merge_scoped_state, + record_to_event, + record_to_session, + session_to_record, + split_scoped_state, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_event( + *, + event_id: str = "evt-1", + invocation_id: str = "inv-1", + author: str = "user", + text: "str | None" = None, + state_delta: "dict | None" = None, + branch: "str | None" = None, + partial: "bool | None" = None, + turn_complete: "bool | None" = None, + custom_metadata: "dict | None" = None, +) -> Event: + content = types.Content(parts=[types.Part(text=text)]) if text is not None else None + actions = EventActions(state_delta=state_delta or {}) + return Event( + id=event_id, + invocation_id=invocation_id, + author=author, + content=content, + actions=actions, + timestamp=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc).timestamp(), + branch=branch, + partial=partial, + turn_complete=turn_complete, + custom_metadata=custom_metadata, + ) + + +def _make_session( + *, + session_id: str = "session-1", + app_name: str = "test-app", + user_id: str = "user-1", + state: "dict | None" = None, +) -> Session: + return Session( + id=session_id, + app_name=app_name, + user_id=user_id, + state=state or {}, + last_update_time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc).timestamp(), + ) + + +# --------------------------------------------------------------------------- +# filter_temp_state +# --------------------------------------------------------------------------- + + +def test_filter_temp_state_removes_temp_keys() -> None: + """temp:-prefixed keys are removed; all other keys are kept.""" + state = {"x": 1, "temp:y": 2, "app:z": 3, "user:w": 4} + result = filter_temp_state(state) + assert result == {"x": 1, "app:z": 3, "user:w": 4} + + +def test_filter_temp_state_empty_dict() -> None: + """Empty dict returns empty dict.""" + assert filter_temp_state({}) == {} + + +def test_filter_temp_state_all_temp_keys() -> None: + """Dict with only temp: keys returns empty dict.""" + state = {"temp:a": 1, "temp:b": 2, "temp:": 3} + assert filter_temp_state(state) == {} + + +def test_filter_temp_state_no_temp_keys() -> None: + """Dict with no temp: keys is returned unchanged.""" + state = {"x": 1, "app:y": 2, "user:z": 3} + result = filter_temp_state(state) + assert result == state + + +def test_filter_temp_state_does_not_mutate_input() -> None: + """Input dict is not mutated.""" + state = {"key": "v", "temp:remove": "gone"} + original = dict(state) + filter_temp_state(state) + assert state == original + + +# --------------------------------------------------------------------------- +# split_scoped_state +# --------------------------------------------------------------------------- + + +def test_split_scoped_state_separates_buckets() -> None: + """app:, user:, and plain keys go into the correct buckets.""" + state = {"app:shared": "a", "user:profile": "u", "session_key": "s", "another": "v"} + app, user, session = split_scoped_state(state) + assert app == {"app:shared": "a"} + assert user == {"user:profile": "u"} + assert session == {"session_key": "s", "another": "v"} + + +def test_split_scoped_state_empty() -> None: + """Empty state produces three empty dicts.""" + app, user, session = split_scoped_state({}) + assert app == {} + assert user == {} + assert session == {} + + +def test_split_scoped_state_only_app_keys() -> None: + """State with only app: keys puts everything in app bucket.""" + state = {"app:x": 1, "app:y": 2} + app, user, session = split_scoped_state(state) + assert app == {"app:x": 1, "app:y": 2} + assert user == {} + assert session == {} + + +def test_split_scoped_state_only_user_keys() -> None: + """State with only user: keys puts everything in user bucket.""" + state = {"user:a": "one", "user:b": "two"} + app, user, session = split_scoped_state(state) + assert app == {} + assert user == {"user:a": "one", "user:b": "two"} + assert session == {} + + +def test_split_scoped_state_only_session_keys() -> None: + """State with no prefix puts everything in session bucket.""" + state = {"key1": 1, "key2": 2} + app, user, session = split_scoped_state(state) + assert app == {} + assert user == {} + assert session == {"key1": 1, "key2": 2} + + +def test_split_scoped_state_preserves_full_key_names() -> None: + """Keys are not stripped of their prefix in the returned buckets.""" + state = {"app:my_key": "val", "user:my_key": "val2"} + app, user, _ = split_scoped_state(state) + assert "app:my_key" in app + assert "user:my_key" in user + + +# --------------------------------------------------------------------------- +# merge_scoped_state +# --------------------------------------------------------------------------- + + +def test_merge_scoped_state_combines_all_buckets() -> None: + """All three buckets appear in the merged result.""" + merged = merge_scoped_state( + session_state={"key": "s"}, + app_state={"app:x": "a"}, + user_state={"user:y": "u"}, + ) + assert merged == {"key": "s", "app:x": "a", "user:y": "u"} + + +def test_merge_scoped_state_overlay_priority_app_over_session() -> None: + """app_state overlays session_state for the same key.""" + merged = merge_scoped_state( + session_state={"app:x": "old"}, + app_state={"app:x": "new"}, + ) + assert merged["app:x"] == "new" + + +def test_merge_scoped_state_overlay_priority_user_over_session() -> None: + """user_state overlays session_state for the same key.""" + merged = merge_scoped_state( + session_state={"user:y": "session_val"}, + user_state={"user:y": "user_val"}, + ) + assert merged["user:y"] == "user_val" + + +def test_merge_scoped_state_no_app_no_user() -> None: + """Merging without app_state or user_state returns session_state copy.""" + session = {"key": "v", "other": 42} + merged = merge_scoped_state(session_state=session) + assert merged == session + + +def test_merge_scoped_state_empty_session_state() -> None: + """Empty session_state with app/user state returns combined app+user keys.""" + merged = merge_scoped_state( + session_state={}, + app_state={"app:a": 1}, + user_state={"user:b": 2}, + ) + assert merged == {"app:a": 1, "user:b": 2} + + +def test_merge_scoped_state_does_not_mutate_session_state() -> None: + """Input session_state dict is not mutated.""" + session = {"key": "v"} + original = dict(session) + merge_scoped_state(session_state=session, app_state={"app:x": 1}) + assert session == original + + +# --------------------------------------------------------------------------- +# event_to_record — signature and structure +# --------------------------------------------------------------------------- + + +def test_event_to_record_only_5_keys() -> None: + """EventRecord has exactly session_id, invocation_id, author, timestamp, event_json.""" + event = _make_event() + record = event_to_record(event, "session-1") + assert set(record.keys()) == {"session_id", "invocation_id", "author", "timestamp", "event_json"} + + +def test_event_to_record_signature_two_args_only() -> None: + """event_to_record raises TypeError if called with extra positional args (old 4-arg signature).""" + event = _make_event() + with pytest.raises(TypeError): + event_to_record(event, "session-1", "app-name", "user-id") # type: ignore[call-arg] + + +def test_event_to_record_session_id_stored_correctly() -> None: + """session_id in the record matches the argument passed.""" + event = _make_event(invocation_id="inv-abc", author="model") + record = event_to_record(event, "my-session-id") + assert record["session_id"] == "my-session-id" + + +def test_event_to_record_indexed_fields_match_event() -> None: + """Indexed scalar columns (invocation_id, author, timestamp) match the source event.""" + event = _make_event(invocation_id="inv-xyz", author="tool") + record = event_to_record(event, "s1") + assert record["invocation_id"] == "inv-xyz" + assert record["author"] == "tool" + assert isinstance(record["timestamp"], datetime) + + +def test_event_to_record_event_json_matches_model_dump() -> None: + """event_json in the record equals event.model_dump(exclude_none=True, mode='json').""" + event = _make_event(text="hello", state_delta={"key": "val"}, custom_metadata={"foo": "bar"}) + record = event_to_record(event, "s1") + expected_json = event.model_dump(exclude_none=True, mode="json") + assert record["event_json"] == expected_json + + +def test_event_to_record_event_json_is_dict() -> None: + """event_json field is a plain dict (not bytes, not string).""" + event = _make_event() + record = event_to_record(event, "s1") + assert isinstance(record["event_json"], dict) + + +def test_event_to_record_actions_in_event_json_is_structured() -> None: + """Actions are stored as structured JSON dict in event_json, not as raw bytes.""" + event = _make_event(state_delta={"x": "y"}) + record = event_to_record(event, "s1") + event_json = record["event_json"] + # actions should be a dict in the JSON blob + if "actions" in event_json: + assert isinstance(event_json["actions"], dict) + + +def test_event_to_record_timestamp_is_datetime() -> None: + """timestamp column is a datetime object with timezone.""" + event = _make_event() + record = event_to_record(event, "s1") + assert isinstance(record["timestamp"], datetime) + assert record["timestamp"].tzinfo is not None + + +# --------------------------------------------------------------------------- +# record_to_event — full round-trip fidelity +# --------------------------------------------------------------------------- + + +def test_record_to_event_full_roundtrip_basic() -> None: + """Event -> record -> Event produces an identical object for basic fields.""" + original = _make_event(event_id="evt-rt", invocation_id="inv-rt", author="model") + record = event_to_record(original, "s1") + restored = record_to_event(record) + + assert restored.id == original.id + assert restored.invocation_id == original.invocation_id + assert restored.author == original.author + + +def test_record_to_event_roundtrip_preserves_content() -> None: + """Content (parts) survives the round-trip.""" + original = _make_event(text="hello world", author="model") + record = event_to_record(original, "s1") + restored = record_to_event(record) + + assert restored.content is not None + assert restored.content.parts is not None + assert restored.content.parts[0].text == "hello world" + + +def test_record_to_event_roundtrip_preserves_actions() -> None: + """EventActions (state_delta) survives the round-trip.""" + original = _make_event(state_delta={"key": "v1", "other": 42}) + record = event_to_record(original, "s1") + restored = record_to_event(record) + + assert restored.actions is not None + assert restored.actions.state_delta == {"key": "v1", "other": 42} + + +def test_record_to_event_roundtrip_preserves_custom_metadata() -> None: + """custom_metadata survives the round-trip.""" + original = _make_event(custom_metadata={"tag": "v2", "score": 0.9}) + record = event_to_record(original, "s1") + restored = record_to_event(record) + + assert restored.custom_metadata == {"tag": "v2", "score": 0.9} + + +def test_record_to_event_roundtrip_preserves_branch() -> None: + """branch field survives the round-trip.""" + original = _make_event(branch="feature-branch") + record = event_to_record(original, "s1") + restored = record_to_event(record) + + assert restored.branch == "feature-branch" + + +def test_record_to_event_roundtrip_preserves_partial_flag() -> None: + """partial flag survives the round-trip.""" + original = _make_event(partial=True) + record = event_to_record(original, "s1") + restored = record_to_event(record) + + assert restored.partial is True + + +def test_record_to_event_roundtrip_preserves_turn_complete() -> None: + """turn_complete flag survives the round-trip.""" + original = _make_event(turn_complete=True) + record = event_to_record(original, "s1") + restored = record_to_event(record) + + assert restored.turn_complete is True + + +def test_record_to_event_roundtrip_preserves_timestamp() -> None: + """timestamp survives the round-trip within float precision.""" + fixed_ts = datetime(2024, 6, 1, 10, 30, 0, tzinfo=timezone.utc).timestamp() + event = Event( + id="ts-evt", + invocation_id="inv-1", + author="user", + actions=EventActions(), + timestamp=fixed_ts, + ) + record = event_to_record(event, "s1") + restored = record_to_event(record) + + assert abs(restored.timestamp - fixed_ts) < 1.0 # within 1 second + + +def test_record_to_event_with_extra_fields_in_event_json() -> None: + """Events with extra/unknown fields in event_json survive model_validate gracefully.""" + # Simulate an event JSON blob from a newer ADK version with extra keys + event = _make_event(event_id="extra-fields-evt", author="tool") + record = event_to_record(event, "s1") + + # Inject hypothetical future ADK field into event_json + record["event_json"]["hypothetical_v3_field"] = "some_value" # type: ignore[index] + + # Should not raise — Event.model_validate should ignore unknown fields + restored = record_to_event(record) + assert restored.id == "extra-fields-evt" + + +# --------------------------------------------------------------------------- +# session_to_record — strips temp: keys +# --------------------------------------------------------------------------- + + +def test_session_to_record_strips_temp_keys_from_state() -> None: + """session_to_record removes temp:-prefixed keys before persisting.""" + session = _make_session(state={"key": "v", "temp:x": "t", "app:y": "a"}) + record = session_to_record(session) + assert "temp:x" not in record["state"] + assert record["state"]["key"] == "v" + assert record["state"]["app:y"] == "a" + + +def test_session_to_record_empty_state_stays_empty() -> None: + """Empty state produces empty state in record.""" + session = _make_session(state={}) + record = session_to_record(session) + assert record["state"] == {} + + +def test_session_to_record_all_temp_state_produces_empty() -> None: + """Session state with only temp: keys produces empty state in record.""" + session = _make_session(state={"temp:a": 1, "temp:b": 2}) + record = session_to_record(session) + assert record["state"] == {} + + +def test_session_to_record_no_temp_state_unchanged() -> None: + """Session state with no temp: keys is stored without modification.""" + state = {"x": 1, "app:y": 2, "user:z": 3} + session = _make_session(state=state) + record = session_to_record(session) + assert record["state"] == state + + +def test_session_to_record_includes_required_fields() -> None: + """Session record includes id, app_name, user_id, state, create_time, update_time.""" + session = _make_session() + record = session_to_record(session) + assert "id" in record + assert "app_name" in record + assert "user_id" in record + assert "state" in record + assert "create_time" in record + assert "update_time" in record + + +# --------------------------------------------------------------------------- +# record_to_session — integrates with record_to_event +# --------------------------------------------------------------------------- + + +def test_record_to_session_with_events_round_trip() -> None: + """Sessions with events reconstruct correctly using record_to_session.""" + from sqlspec.extensions.adk._types import SessionRecord + + session_record = SessionRecord( + id="s1", + app_name="app", + user_id="u1", + state={"key": "val"}, + create_time=datetime.now(timezone.utc), + update_time=datetime.now(timezone.utc), + ) + event = _make_event(text="hello", author="user") + event_record = event_to_record(event, "s1") + + session = record_to_session(session_record, [event_record]) + + assert session.id == "s1" + assert session.app_name == "app" + assert session.user_id == "u1" + assert session.state == {"key": "val"} + assert len(session.events) == 1 + assert session.events[0].id == event.id + + +def test_record_to_session_empty_events() -> None: + """Sessions without events reconstruct with empty events list.""" + from sqlspec.extensions.adk._types import SessionRecord + + session_record = SessionRecord( + id="s2", + app_name="app", + user_id="u2", + state={}, + create_time=datetime.now(timezone.utc), + update_time=datetime.now(timezone.utc), + ) + session = record_to_session(session_record, []) + assert session.events == [] diff --git a/tests/unit/extensions/test_adk/test_service.py b/tests/unit/extensions/test_adk/test_service.py new file mode 100644 index 000000000..407166172 --- /dev/null +++ b/tests/unit/extensions/test_adk/test_service.py @@ -0,0 +1,330 @@ +"""Unit tests for SQLSpecSessionService — state persistence fix. + +Tests the NEW contract specified in Chapter 1 of the ADK Clean-Break Overhaul: +- append_event calls append_event_and_update_state (not the old append_event) +- temp: keys are stripped before persisting session state +- partial events are not persisted +- create_session strips temp: keys from initial state + +The store is mocked — no database required. +""" + +import importlib.util +from datetime import datetime, timezone +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +if importlib.util.find_spec("google.genai") is None or importlib.util.find_spec("google.adk") is None: + pytest.skip("google-adk not installed", allow_module_level=True) + +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.sessions.session import Session + +from sqlspec.extensions.adk.service import SQLSpecSessionService + + +# --------------------------------------------------------------------------- +# Mock store +# --------------------------------------------------------------------------- + + +class MockStore: + """Simple mock that records calls to store methods. + + Attributes are set to AsyncMock so that await works out of the box, + and call arguments are captured for assertion. + """ + + def __init__(self) -> None: + # Track calls to the new combined method + self.append_event_and_update_state_calls: list[dict[str, Any]] = [] + self.append_event_and_update_state_called = False + + # Track calls to create_session + self.create_session_calls: list[dict[str, Any]] = [] + + # Provide a get_session that returns a minimal session record + self._session_record = { + "id": "s1", + "app_name": "app", + "user_id": "u1", + "state": {}, + "create_time": datetime.now(timezone.utc), + "update_time": datetime.now(timezone.utc), + } + + async def append_event_and_update_state( + self, event_record: Any, session_id: str, state: "dict[str, Any]" + ) -> None: + self.append_event_and_update_state_called = True + self.append_event_and_update_state_calls.append( + {"event_record": event_record, "session_id": session_id, "state": state} + ) + + async def get_session(self, session_id: str) -> "dict[str, Any] | None": + return self._session_record + + async def create_session( + self, *, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]" + ) -> "dict[str, Any]": + self.create_session_calls.append( + {"session_id": session_id, "app_name": app_name, "user_id": user_id, "state": state} + ) + return { + "id": session_id, + "app_name": app_name, + "user_id": user_id, + "state": state, + "create_time": datetime.now(timezone.utc), + "update_time": datetime.now(timezone.utc), + } + + # Old method — should NOT be called by the new service + async def append_event(self, event_record: Any) -> None: + raise AssertionError("append_event (old method) must not be called — use append_event_and_update_state") + + async def get_events(self, *, session_id: str, after_timestamp: Any = None, limit: Any = None) -> list: + return [] + + async def list_sessions(self, *, app_name: str, user_id: "str | None" = None) -> list: + return [] + + async def delete_session(self, session_id: str) -> None: + pass + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_session( + *, + session_id: str = "s1", + app_name: str = "app", + user_id: str = "u1", + state: "dict | None" = None, +) -> Session: + return Session( + id=session_id, + app_name=app_name, + user_id=user_id, + state=state or {}, + last_update_time=datetime.now(timezone.utc).timestamp(), + ) + + +def _make_event( + *, + invocation_id: str = "inv-1", + author: str = "model", + state_delta: "dict | None" = None, + partial: bool = False, +) -> Event: + actions = EventActions(state_delta=state_delta or {}) + return Event( + invocation_id=invocation_id, + author=author, + actions=actions, + timestamp=datetime.now(timezone.utc).timestamp(), + partial=partial, + ) + + +# --------------------------------------------------------------------------- +# append_event — calls append_event_and_update_state +# --------------------------------------------------------------------------- + + +@pytest.mark.anyio +async def test_append_event_calls_append_event_and_update_state() -> None: + """append_event must call append_event_and_update_state, not the old append_event.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + session = _make_session(state={"key": "v0"}) + event = _make_event(state_delta={"key": "v1"}) + + await service.append_event(session, event) + + assert store.append_event_and_update_state_called, ( + "append_event_and_update_state was never called — state will not be persisted" + ) + + +@pytest.mark.anyio +async def test_append_event_persists_updated_state() -> None: + """append_event persists the state AFTER applying event.actions.state_delta.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + session = _make_session(state={"key": "v0"}) + event = _make_event(state_delta={"key": "v1"}) + + await service.append_event(session, event) + + assert store.append_event_and_update_state_called + last_call = store.append_event_and_update_state_calls[-1] + # The persisted state must reflect the mutation from state_delta + assert last_call["state"]["key"] == "v1" + + +@pytest.mark.anyio +async def test_append_event_strips_temp_from_persisted_state() -> None: + """temp: keys are removed before state persistence.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + session = _make_session(state={"key": "v", "temp:transient": "should_not_persist"}) + event = _make_event() + + await service.append_event(session, event) + + assert store.append_event_and_update_state_called + last_call = store.append_event_and_update_state_calls[-1] + persisted_state = last_call["state"] + assert "temp:transient" not in persisted_state + assert persisted_state["key"] == "v" + + +@pytest.mark.anyio +async def test_append_event_strips_temp_state_delta_from_persisted_state() -> None: + """temp: keys added via state_delta are also stripped before persisting.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + # Session state has temp: key added by an agent via state_delta + session = _make_session(state={"regular": "v"}) + event = _make_event(state_delta={"temp:output": "transient", "regular": "updated"}) + + await service.append_event(session, event) + + last_call = store.append_event_and_update_state_calls[-1] + persisted_state = last_call["state"] + assert "temp:output" not in persisted_state + assert persisted_state["regular"] == "updated" + + +@pytest.mark.anyio +async def test_append_event_skips_partial_events() -> None: + """Partial events are not persisted to the store.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + session = _make_session() + partial_event = _make_event(partial=True) + + result = await service.append_event(session, partial_event) + + assert not store.append_event_and_update_state_called, ( + "append_event_and_update_state must NOT be called for partial events" + ) + assert result.partial is True + + +@pytest.mark.anyio +async def test_append_event_passes_correct_session_id_to_store() -> None: + """append_event_and_update_state receives the correct session_id.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + session = _make_session(session_id="my-unique-session-id") + event = _make_event() + + await service.append_event(session, event) + + last_call = store.append_event_and_update_state_calls[-1] + assert last_call["session_id"] == "my-unique-session-id" + + +@pytest.mark.anyio +async def test_append_event_event_record_has_5_keys() -> None: + """The event_record passed to the store has exactly 5 keys (new schema).""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + session = _make_session() + event = _make_event() + + await service.append_event(session, event) + + last_call = store.append_event_and_update_state_calls[-1] + event_record = last_call["event_record"] + assert set(event_record.keys()) == {"session_id", "invocation_id", "author", "timestamp", "event_json"} + + +@pytest.mark.anyio +async def test_append_event_returns_the_event() -> None: + """append_event returns the event after persisting.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + session = _make_session() + event = _make_event(author="model") + + result = await service.append_event(session, event) + + assert result is not None + assert result.author == "model" + + +# --------------------------------------------------------------------------- +# create_session — strips temp: keys from initial state +# --------------------------------------------------------------------------- + + +@pytest.mark.anyio +async def test_create_session_strips_temp_keys_from_initial_state() -> None: + """create_session filters temp: keys before passing state to the store.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + + await service.create_session( + app_name="app", user_id="u1", state={"x": 1, "temp:y": 2, "app:z": 3} + ) + + assert len(store.create_session_calls) == 1 + persisted_state = store.create_session_calls[0]["state"] + assert "temp:y" not in persisted_state + assert persisted_state["x"] == 1 + assert persisted_state["app:z"] == 3 + + +@pytest.mark.anyio +async def test_create_session_with_only_temp_state_persists_empty() -> None: + """create_session with only temp: state persists empty state dict.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + + await service.create_session(app_name="app", user_id="u1", state={"temp:only": "gone"}) + + assert store.create_session_calls[0]["state"] == {} + + +@pytest.mark.anyio +async def test_create_session_none_state_persists_empty() -> None: + """create_session with state=None persists empty state dict.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + + await service.create_session(app_name="app", user_id="u1") + + assert store.create_session_calls[0]["state"] == {} + + +@pytest.mark.anyio +async def test_create_session_generates_uuid_if_no_session_id() -> None: + """create_session generates a UUID if no session_id is provided.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + + session = await service.create_session(app_name="app", user_id="u1") + + assert session.id is not None + assert len(session.id) > 0 + + +@pytest.mark.anyio +async def test_create_session_uses_provided_session_id() -> None: + """create_session uses the caller-provided session_id.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + + session = await service.create_session(app_name="app", user_id="u1", session_id="my-id") + + assert session.id == "my-id" From 6f46431d6b9f1aa88e66f819946f51b5ee78dceb Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 20 Mar 2026 21:36:53 +0000 Subject: [PATCH 03/23] test(adk): add store instantiation smoke tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Parametrized smoke test covering all 28 shipped ADK store classes (session and memory, async and sync). Asserts __abstractmethods__ is empty so any store missing a required base-class method is caught immediately. Several stores are expected to fail today — the test documents that problem ahead of the Ch1 clean-break overhaul that will fix them. --- .../test_adk/test_store_instantiation.py | 94 +++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 tests/unit/extensions/test_adk/test_store_instantiation.py diff --git a/tests/unit/extensions/test_adk/test_store_instantiation.py b/tests/unit/extensions/test_adk/test_store_instantiation.py new file mode 100644 index 000000000..1c16d22bb --- /dev/null +++ b/tests/unit/extensions/test_adk/test_store_instantiation.py @@ -0,0 +1,94 @@ +"""Smoke tests verifying all shipped ADK store classes are instantiable (not abstract). + +Every shipped store class must be concrete — no unsatisfied abstract methods. +This catches bugs where stores have method signature mismatches with the base +class, such as cockroach, mysqlconnector sync, pymysql, and spanner stores +that are missing abstract method implementations added to the base contract. + +NOTE: Some stores WILL fail this test currently — that is expected and +documents one of the bugs the ADK Clean-Break Overhaul (Ch1) is fixing. +""" + +import importlib + +import pytest + +# Session stores (async) +ASYNC_SESSION_STORES = [ + "sqlspec.adapters.asyncpg.adk.store.AsyncpgADKStore", + "sqlspec.adapters.aiosqlite.adk.store.AiosqliteADKStore", + "sqlspec.adapters.asyncmy.adk.store.AsyncmyADKStore", + "sqlspec.adapters.bigquery.adk.store.BigQueryADKStore", + "sqlspec.adapters.cockroach_asyncpg.adk.store.CockroachAsyncpgADKStore", + "sqlspec.adapters.cockroach_psycopg.adk.store.CockroachPsycopgAsyncADKStore", + "sqlspec.adapters.mysqlconnector.adk.store.MysqlConnectorAsyncADKStore", + "sqlspec.adapters.oracledb.adk.store.OracleAsyncADKStore", + "sqlspec.adapters.psqlpy.adk.store.PsqlpyADKStore", + "sqlspec.adapters.psycopg.adk.store.PsycopgAsyncADKStore", + # sqlite uses BaseAsyncADKStore despite being backed by a sync driver + "sqlspec.adapters.sqlite.adk.store.SqliteADKStore", +] + +# Session stores (sync) +SYNC_SESSION_STORES = [ + "sqlspec.adapters.adbc.adk.store.AdbcADKStore", + "sqlspec.adapters.cockroach_psycopg.adk.store.CockroachPsycopgSyncADKStore", + "sqlspec.adapters.duckdb.adk.store.DuckdbADKStore", + "sqlspec.adapters.mysqlconnector.adk.store.MysqlConnectorSyncADKStore", + "sqlspec.adapters.oracledb.adk.store.OracleSyncADKStore", + "sqlspec.adapters.psycopg.adk.store.PsycopgSyncADKStore", + "sqlspec.adapters.pymysql.adk.store.PyMysqlADKStore", + "sqlspec.adapters.spanner.adk.store.SpannerSyncADKStore", +] + +# Memory stores (async) +ASYNC_MEMORY_STORES = [ + "sqlspec.adapters.asyncpg.adk.store.AsyncpgADKMemoryStore", + "sqlspec.adapters.aiosqlite.adk.store.AiosqliteADKMemoryStore", + "sqlspec.adapters.asyncmy.adk.store.AsyncmyADKMemoryStore", + "sqlspec.adapters.bigquery.adk.store.BigQueryADKMemoryStore", + "sqlspec.adapters.cockroach_asyncpg.adk.store.CockroachAsyncpgADKMemoryStore", + "sqlspec.adapters.cockroach_psycopg.adk.store.CockroachPsycopgAsyncADKMemoryStore", + "sqlspec.adapters.mysqlconnector.adk.store.MysqlConnectorAsyncADKMemoryStore", + "sqlspec.adapters.oracledb.adk.store.OracleAsyncADKMemoryStore", + "sqlspec.adapters.psqlpy.adk.store.PsqlpyADKMemoryStore", + "sqlspec.adapters.psycopg.adk.store.PsycopgAsyncADKMemoryStore", +] + +# Memory stores (sync) +SYNC_MEMORY_STORES = [ + "sqlspec.adapters.adbc.adk.store.AdbcADKMemoryStore", + "sqlspec.adapters.cockroach_psycopg.adk.store.CockroachPsycopgSyncADKMemoryStore", + "sqlspec.adapters.duckdb.adk.store.DuckdbADKMemoryStore", + "sqlspec.adapters.mysqlconnector.adk.store.MysqlConnectorSyncADKMemoryStore", + "sqlspec.adapters.oracledb.adk.store.OracleSyncADKMemoryStore", + "sqlspec.adapters.psycopg.adk.store.PsycopgSyncADKMemoryStore", + "sqlspec.adapters.pymysql.adk.store.PyMysqlADKMemoryStore", + "sqlspec.adapters.spanner.adk.store.SpannerSyncADKMemoryStore", + "sqlspec.adapters.sqlite.adk.store.SqliteADKMemoryStore", +] + +ALL_STORE_CLASSES = ( + ASYNC_SESSION_STORES + + SYNC_SESSION_STORES + + ASYNC_MEMORY_STORES + + SYNC_MEMORY_STORES +) + + +@pytest.mark.parametrize("class_path", ALL_STORE_CLASSES) +def test_store_has_no_abstract_methods(class_path: str) -> None: + """Every shipped store class must be concrete (no unsatisfied abstract methods). + + A class with entries in ``__abstractmethods__`` cannot be instantiated and + signals that the concrete store is missing one or more method implementations + required by its base class contract. + """ + module_path, class_name = class_path.rsplit(".", 1) + try: + module = importlib.import_module(module_path) + except ImportError: + pytest.skip(f"Module {module_path} not importable (missing optional dependency)") + cls = getattr(module, class_name) + abstract = getattr(cls, "__abstractmethods__", set()) + assert not abstract, f"{class_path} has unsatisfied abstract methods: {abstract}" From d7939d838fc91eef687ea0a501dcc24fb40aa646 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 20 Mar 2026 22:42:51 +0000 Subject: [PATCH 04/23] fix(adk): align implementation with spec - split_scoped_state returns (app, user, session) not (session, app, user) - merge_scoped_state has optional app_state/user_state params - event_json uses model_dump() (dict) not model_dump_json() (str) - record_to_event uses model_validate() not model_validate_json() - session_to_record calls filter_temp_state - create_session calls filter_temp_state - revert extract_content_text to original format - xfail test for extra fields (Event has extra='forbid') --- sqlspec/extensions/adk/converters.py | 73 +++++++++---------- sqlspec/extensions/adk/memory/converters.py | 4 +- sqlspec/extensions/adk/service.py | 4 +- .../extensions/test_adk/test_converters.py | 10 ++- 4 files changed, 45 insertions(+), 46 deletions(-) diff --git a/sqlspec/extensions/adk/converters.py b/sqlspec/extensions/adk/converters.py index 35304f644..f1bff6ec1 100644 --- a/sqlspec/extensions/adk/converters.py +++ b/sqlspec/extensions/adk/converters.py @@ -47,7 +47,7 @@ def session_to_record(session: "Session") -> SessionRecord: id=session.id, app_name=session.app_name, user_id=session.user_id, - state=session.state, + state=filter_temp_state(session.state), create_time=datetime.now(timezone.utc), update_time=datetime.fromtimestamp(session.last_update_time, tz=timezone.utc), ) @@ -99,7 +99,7 @@ def event_to_record(event: "Event", session_id: str) -> EventRecord: invocation_id=event.invocation_id, author=event.author, timestamp=datetime.fromtimestamp(event.timestamp, tz=timezone.utc), - event_json=event.model_dump_json(exclude_none=True), + event_json=event.model_dump(exclude_none=True, mode="json"), ) @@ -115,7 +115,7 @@ def record_to_event(record: "EventRecord") -> "Event": Returns: ADK Event object. """ - return Event.model_validate_json(record["event_json"]) + return Event.model_validate(record["event_json"]) # --------------------------------------------------------------------------- @@ -138,59 +138,52 @@ def filter_temp_state(state: "dict[str, Any]") -> "dict[str, Any]": return {k: v for k, v in state.items() if not k.startswith("temp:")} -def split_scoped_state( - state: "dict[str, Any]", -) -> "tuple[dict[str, Any], dict[str, Any], dict[str, Any]]": - """Split ADK state into ``(session_local, app_scoped, user_scoped)`` dicts. - - Keys without a recognised scope prefix are session-local. ``temp:`` keys - are silently dropped (they must not be persisted). +def split_scoped_state(state: "dict[str, Any]") -> "tuple[dict[str, Any], dict[str, Any], dict[str, Any]]": + """Split state into app-scoped, user-scoped, and session-scoped buckets. Args: - state: ADK state dictionary. + state: Full session state dict (temp: already stripped). Returns: - A 3-tuple of ``(session_local, app_scoped, user_scoped)`` dicts. - Scoped dicts retain their prefix in the key (e.g. ``"app:foo"``). + Tuple of (app_state, user_state, session_state). + app_state: keys starting with "app:" + user_state: keys starting with "user:" + session_state: all other keys """ - session_local: dict[str, Any] = {} - app_scoped: dict[str, Any] = {} - user_scoped: dict[str, Any] = {} - + app_state: dict[str, Any] = {} + user_state: dict[str, Any] = {} + session_state: dict[str, Any] = {} for k, v in state.items(): - if k.startswith("temp:"): - continue - elif k.startswith("app:"): - app_scoped[k] = v + if k.startswith("app:"): + app_state[k] = v elif k.startswith("user:"): - user_scoped[k] = v + user_state[k] = v else: - session_local[k] = v - - return session_local, app_scoped, user_scoped + session_state[k] = v + return app_state, user_state, session_state def merge_scoped_state( - session_local: "dict[str, Any]", - app_scoped: "dict[str, Any]", - user_scoped: "dict[str, Any]", + session_state: "dict[str, Any]", + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> "dict[str, Any]": - """Merge scoped state dicts back into a single ADK-compatible state dict. + """Merge scoped state buckets into a single state dict. - The merge order is ``session_local | app_scoped | user_scoped`` so that - broader scopes can shadow narrower ones if keys collide (which they - normally should not, since prefixes differ). + Priority: session_state is base, app_state and user_state overlay. + This matches ADK's documented merge semantics on session load. Args: - session_local: Session-local state (no prefix). - app_scoped: App-scoped state (``app:`` prefix). - user_scoped: User-scoped state (``user:`` prefix). + session_state: Per-session state. + app_state: App-scoped state (shared across sessions for same app). + user_state: User-scoped state (shared across sessions for same app+user). Returns: - A single merged state dict suitable for ``Session.state``. + Merged state dict. """ - merged: dict[str, Any] = {} - merged.update(session_local) - merged.update(app_scoped) - merged.update(user_scoped) + merged = dict(session_state) + if app_state: + merged.update(app_state) + if user_state: + merged.update(user_state) return merged diff --git a/sqlspec/extensions/adk/memory/converters.py b/sqlspec/extensions/adk/memory/converters.py index 1eb8d4e1f..2742d9b8d 100644 --- a/sqlspec/extensions/adk/memory/converters.py +++ b/sqlspec/extensions/adk/memory/converters.py @@ -51,9 +51,9 @@ def extract_content_text(content: "types.Content") -> str: if part.text: parts_text.append(part.text) elif part.function_call is not None: - parts_text.append(f"{part.function_call.name}: {part.function_call.args}") + parts_text.append(f"function:{part.function_call.name}") elif part.function_response is not None: - parts_text.append(f"{part.function_response.name}: {part.function_response.response}") + parts_text.append(f"response:{part.function_response.name}") return " ".join(parts_text) diff --git a/sqlspec/extensions/adk/service.py b/sqlspec/extensions/adk/service.py index 4b2c4886a..200b3c948 100644 --- a/sqlspec/extensions/adk/service.py +++ b/sqlspec/extensions/adk/service.py @@ -80,8 +80,10 @@ async def create_session( if state is None: state = {} + persisted_state = filter_temp_state(state) + record = await self._store.create_session( - session_id=session_id, app_name=app_name, user_id=user_id, state=state + session_id=session_id, app_name=app_name, user_id=user_id, state=persisted_state ) log_with_context( logger, logging.DEBUG, "adk.session.create", app_name=app_name, session_id=session_id, has_state=bool(state) diff --git a/tests/unit/extensions/test_adk/test_converters.py b/tests/unit/extensions/test_adk/test_converters.py index c858c88f9..d634e4dd6 100644 --- a/tests/unit/extensions/test_adk/test_converters.py +++ b/tests/unit/extensions/test_adk/test_converters.py @@ -391,16 +391,20 @@ def test_record_to_event_roundtrip_preserves_timestamp() -> None: assert abs(restored.timestamp - fixed_ts) < 1.0 # within 1 second +@pytest.mark.xfail( + reason="ADK Event model uses extra='forbid' — unknown fields raise ValidationError. " + "Future ADK versions that add fields will also update the model, so this is safe.", + strict=True, +) def test_record_to_event_with_extra_fields_in_event_json() -> None: - """Events with extra/unknown fields in event_json survive model_validate gracefully.""" - # Simulate an event JSON blob from a newer ADK version with extra keys + """Events with extra/unknown fields in event_json are rejected by Event model.""" event = _make_event(event_id="extra-fields-evt", author="tool") record = event_to_record(event, "s1") # Inject hypothetical future ADK field into event_json record["event_json"]["hypothetical_v3_field"] = "some_value" # type: ignore[index] - # Should not raise — Event.model_validate should ignore unknown fields + # This WILL raise because Event has extra='forbid' restored = record_to_event(record) assert restored.id == "extra-fields-evt" From 272e65c46dfa1a60f23b1a128cdb7c26bab377cd Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 20 Mar 2026 23:26:47 +0000 Subject: [PATCH 05/23] refactor(adk): remove BigQuery from ADK surface BigQuery has correctness bugs (JSON_VALUE flattening, unsafe bytes encoding) and is architecturally mismatched for ADK session storage (high latency query jobs, eventual consistency, per-query cost model). - Delete sqlspec/adapters/bigquery/adk/ directory - Remove BigQuery from ADK backends docs table - Remove BigQuery ADK entries from store instantiation tests - Base BigQuery adapter (sqlspec.adapters.bigquery) is preserved --- .gitignore | 1 + docs/extensions/adk/backends.rst | 2 - sqlspec/adapters/bigquery/adk/__init__.py | 5 - sqlspec/adapters/bigquery/adk/store.py | 827 ------------------ .../test_adk/test_store_instantiation.py | 6 +- 5 files changed, 3 insertions(+), 838 deletions(-) delete mode 100644 sqlspec/adapters/bigquery/adk/__init__.py delete mode 100644 sqlspec/adapters/bigquery/adk/store.py diff --git a/.gitignore b/.gitignore index 20a641773..afd84ce28 100644 --- a/.gitignore +++ b/.gitignore @@ -68,3 +68,4 @@ uv.toml .geminiignore .beads/ tools/scripts/profiles/*.prof +.agents/ \ No newline at end of file diff --git a/docs/extensions/adk/backends.rst b/docs/extensions/adk/backends.rst index e550d0024..e990286c9 100644 --- a/docs/extensions/adk/backends.rst +++ b/docs/extensions/adk/backends.rst @@ -40,8 +40,6 @@ Supported Backends - Production * - duckdb - Production (analytics) - * - bigquery - - Production * - adbc - Production diff --git a/sqlspec/adapters/bigquery/adk/__init__.py b/sqlspec/adapters/bigquery/adk/__init__.py deleted file mode 100644 index 6d11c84b8..000000000 --- a/sqlspec/adapters/bigquery/adk/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""BigQuery ADK store for Google Agent Development Kit session/event storage.""" - -from sqlspec.adapters.bigquery.adk.store import BigQueryADKMemoryStore, BigQueryADKStore - -__all__ = ("BigQueryADKMemoryStore", "BigQueryADKStore") diff --git a/sqlspec/adapters/bigquery/adk/store.py b/sqlspec/adapters/bigquery/adk/store.py deleted file mode 100644 index 61339a7a4..000000000 --- a/sqlspec/adapters/bigquery/adk/store.py +++ /dev/null @@ -1,827 +0,0 @@ -"""BigQuery ADK store for Google Agent Development Kit session/event storage.""" - -from collections.abc import Mapping -from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, cast - -from google.api_core.exceptions import NotFound -from google.cloud.bigquery import QueryJobConfig, ScalarQueryParameter - -from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore -from sqlspec.utils.serializers import from_json, to_json -from sqlspec.utils.sync_tools import async_, run_ - -if TYPE_CHECKING: - from sqlspec.adapters.bigquery.config import BigQueryConfig - from sqlspec.extensions.adk import MemoryRecord - - -__all__ = ("BigQueryADKMemoryStore", "BigQueryADKStore") - - -class BigQueryADKStore(BaseAsyncADKStore["BigQueryConfig"]): - """BigQuery ADK store using synchronous BigQuery client with async wrapper. - - Implements session and event storage for Google Agent Development Kit - using Google Cloud BigQuery. Uses BigQuery's native JSON type for state/metadata - storage and async_() wrapper to provide async interface. - - Provides: - - Serverless, scalable session state management with JSON storage - - Event history tracking optimized for analytics - - Microsecond-precision timestamps with TIMESTAMP type - - Cost-optimized queries with partitioning and clustering - - Efficient JSON handling with BigQuery's JSON type - - Manual cascade delete pattern (no foreign key support) - - Args: - config: BigQueryConfig with extension_config["adk"] settings. - - Example: - from sqlspec.adapters.bigquery import BigQueryConfig - from sqlspec.adapters.bigquery.adk import BigQueryADKStore - - config = BigQueryConfig( - connection_config={ - "project": "my-project", - "dataset_id": "my_dataset", - }, - extension_config={ - "adk": { - "session_table": "my_sessions", - "events_table": "my_events", - "owner_id_column": "tenant_id INT64 NOT NULL" - } - } - ) - store = BigQueryADKStore(config) - await store.ensure_tables() - - Notes: - - JSON type for state, content, and metadata (native BigQuery JSON) - - BYTES for pre-serialized actions from Google ADK - - TIMESTAMP for timezone-aware microsecond precision - - Partitioned by DATE(create_time) for cost optimization - - Clustered by app_name, user_id for query performance - - Uses to_json/from_json for serialization to JSON columns - - BigQuery has eventual consistency - handle appropriately - - No true foreign keys but implements cascade delete pattern - - Configuration is read from config.extension_config["adk"] - """ - - __slots__ = ("_dataset_id",) - - def __init__(self, config: "BigQueryConfig") -> None: - """Initialize BigQuery ADK store. - - Args: - config: BigQueryConfig instance. - - Notes: - Configuration is read from config.extension_config["adk"]: - - session_table: Sessions table name (default: "adk_sessions") - - events_table: Events table name (default: "adk_events") - - owner_id_column: Optional owner FK column DDL (default: None) - """ - super().__init__(config) - self._dataset_id = config.connection_config.get("dataset_id") - - def _get_full_table_name(self, table_name: str) -> str: - """Get fully qualified table name for BigQuery. - - Args: - table_name: Base table name. - - Returns: - Fully qualified table name with backticks. - - Notes: - BigQuery requires backtick-quoted identifiers for table names. - Format: `project.dataset.table` or `dataset.table` - """ - if self._dataset_id: - return f"`{self._dataset_id}.{table_name}`" - return f"`{table_name}`" - - async def _get_create_sessions_table_sql(self) -> str: - """Get BigQuery CREATE TABLE SQL for sessions. - - Returns: - SQL statement to create adk_sessions table. - - Notes: - - STRING for IDs and names - - JSON type for state storage (native BigQuery JSON) - - TIMESTAMP for timezone-aware microsecond precision - - Partitioned by DATE(create_time) for cost optimization - - Clustered by app_name, user_id for query performance - - No indexes needed (BigQuery auto-optimizes) - - Optional owner ID column for multi-tenant scenarios - - Note: BigQuery doesn't enforce FK constraints - """ - owner_id_line = "" - if self._owner_id_column_ddl: - owner_id_line = f",\n {self._owner_id_column_ddl}" - - table_name = self._get_full_table_name(self._session_table) - return f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id STRING NOT NULL, - app_name STRING NOT NULL, - user_id STRING NOT NULL{owner_id_line}, - state JSON NOT NULL, - create_time TIMESTAMP NOT NULL, - update_time TIMESTAMP NOT NULL - ) - PARTITION BY DATE(create_time) - CLUSTER BY app_name, user_id - """ - - async def _get_create_events_table_sql(self) -> str: - """Get BigQuery CREATE TABLE SQL for events. - - Returns: - SQL statement to create adk_events table. - - Notes: - - STRING for IDs and text fields - - BYTES for pickled actions - - JSON for content, grounding_metadata, custom_metadata, long_running_tool_ids_json - - BOOL for boolean flags - - TIMESTAMP for timezone-aware timestamps - - Partitioned by DATE(timestamp) for cost optimization - - Clustered by session_id, timestamp for ordered retrieval - """ - table_name = self._get_full_table_name(self._events_table) - return f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id STRING NOT NULL, - session_id STRING NOT NULL, - app_name STRING NOT NULL, - user_id STRING NOT NULL, - invocation_id STRING, - author STRING, - actions BYTES, - long_running_tool_ids_json JSON, - branch STRING, - timestamp TIMESTAMP NOT NULL, - content JSON, - grounding_metadata JSON, - custom_metadata JSON, - partial BOOL, - turn_complete BOOL, - interrupted BOOL, - error_code STRING, - error_message STRING - ) - PARTITION BY DATE(timestamp) - CLUSTER BY session_id, timestamp - """ - - def _get_drop_tables_sql(self) -> "list[str]": - """Get BigQuery DROP TABLE SQL statements. - - Returns: - List of SQL statements to drop tables. - - Notes: - Order matters: drop events table before sessions table. - BigQuery uses IF EXISTS for idempotent drops. - """ - events_table = self._get_full_table_name(self._events_table) - sessions_table = self._get_full_table_name(self._session_table) - return [f"DROP TABLE IF EXISTS {events_table}", f"DROP TABLE IF EXISTS {sessions_table}"] - - def _create_tables(self) -> None: - """Synchronous implementation of create_tables.""" - with self._config.provide_session() as driver: - driver.execute_script(run_(self._get_create_sessions_table_sql)()) - driver.execute_script(run_(self._get_create_events_table_sql)()) - - async def create_tables(self) -> None: - """Create both sessions and events tables if they don't exist.""" - await async_(self._create_tables)() - - def _create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Synchronous implementation of create_session.""" - now = datetime.now(timezone.utc) - state_json = to_json(state) if state else "{}" - - table_name = self._get_full_table_name(self._session_table) - - if self._owner_id_column_name: - sql = f""" - INSERT INTO {table_name} (id, app_name, user_id, {self._owner_id_column_name}, state, create_time, update_time) - VALUES (@id, @app_name, @user_id, @owner_id, JSON(@state), @create_time, @update_time) - """ - - params = [ - ScalarQueryParameter("id", "STRING", session_id), - ScalarQueryParameter("app_name", "STRING", app_name), - ScalarQueryParameter("user_id", "STRING", user_id), - ScalarQueryParameter("owner_id", "STRING", str(owner_id) if owner_id is not None else None), - ScalarQueryParameter("state", "STRING", state_json), - ScalarQueryParameter("create_time", "TIMESTAMP", now), - ScalarQueryParameter("update_time", "TIMESTAMP", now), - ] - else: - sql = f""" - INSERT INTO {table_name} (id, app_name, user_id, state, create_time, update_time) - VALUES (@id, @app_name, @user_id, JSON(@state), @create_time, @update_time) - """ - - params = [ - ScalarQueryParameter("id", "STRING", session_id), - ScalarQueryParameter("app_name", "STRING", app_name), - ScalarQueryParameter("user_id", "STRING", user_id), - ScalarQueryParameter("state", "STRING", state_json), - ScalarQueryParameter("create_time", "TIMESTAMP", now), - ScalarQueryParameter("update_time", "TIMESTAMP", now), - ] - - with self._config.provide_connection() as conn: - job_config = QueryJobConfig(query_parameters=params) - conn.query(sql, job_config=job_config).result() - - return SessionRecord( - id=session_id, app_name=app_name, user_id=user_id, state=state, create_time=now, update_time=now - ) - - async def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Create a new session. - - Args: - session_id: Unique session identifier. - app_name: Application name. - user_id: User identifier. - state: Initial session state. - owner_id: Optional owner ID value for owner_id_column (if configured). - - Returns: - Created session record. - - Notes: - Uses CURRENT_TIMESTAMP() for timestamps. - State is JSON-serialized then stored in JSON column. - If owner_id_column is configured, owner_id value must be provided. - BigQuery doesn't enforce FK constraints, but column is useful for JOINs. - """ - return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - - def _get_session(self, session_id: str) -> "SessionRecord | None": - """Synchronous implementation of get_session.""" - table_name = self._get_full_table_name(self._session_table) - sql = f""" - SELECT id, app_name, user_id, JSON_VALUE(state) as state, create_time, update_time - FROM {table_name} - WHERE id = @session_id - """ - - params = [ScalarQueryParameter("session_id", "STRING", session_id)] - - with self._config.provide_connection() as conn: - job_config = QueryJobConfig(query_parameters=params) - query_job = conn.query(sql, job_config=job_config) - results = list(query_job.result()) - - if not results: - return None - - row = results[0] - return SessionRecord( - id=row.id, - app_name=row.app_name, - user_id=row.user_id, - state=from_json(row.state) if row.state else {}, - create_time=row.create_time, - update_time=row.update_time, - ) - - async def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID. - - Args: - session_id: Session identifier. - - Returns: - Session record or None if not found. - - Notes: - BigQuery returns datetime objects for TIMESTAMP columns. - JSON_VALUE extracts string representation for parsing. - """ - return await async_(self._get_session)(session_id) - - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Synchronous implementation of update_session_state.""" - now = datetime.now(timezone.utc) - state_json = to_json(state) if state else "{}" - - table_name = self._get_full_table_name(self._session_table) - sql = f""" - UPDATE {table_name} - SET state = JSON(@state), update_time = @update_time - WHERE id = @session_id - """ - - params = [ - ScalarQueryParameter("state", "STRING", state_json), - ScalarQueryParameter("update_time", "TIMESTAMP", now), - ScalarQueryParameter("session_id", "STRING", session_id), - ] - - with self._config.provide_connection() as conn: - job_config = QueryJobConfig(query_parameters=params) - conn.query(sql, job_config=job_config).result() - - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. - - Args: - session_id: Session identifier. - state: New state dictionary (replaces existing state). - - Notes: - Replaces entire state dictionary. - Updates update_time to CURRENT_TIMESTAMP(). - """ - await async_(self._update_session_state)(session_id, state) - - def _list_sessions(self, app_name: str, user_id: "str | None") -> "list[SessionRecord]": - """Synchronous implementation of list_sessions.""" - table_name = self._get_full_table_name(self._session_table) - - if user_id is None: - sql = f""" - SELECT id, app_name, user_id, JSON_VALUE(state) as state, create_time, update_time - FROM {table_name} - WHERE app_name = @app_name - ORDER BY update_time DESC - """ - params = [ScalarQueryParameter("app_name", "STRING", app_name)] - else: - sql = f""" - SELECT id, app_name, user_id, JSON_VALUE(state) as state, create_time, update_time - FROM {table_name} - WHERE app_name = @app_name AND user_id = @user_id - ORDER BY update_time DESC - """ - params = [ - ScalarQueryParameter("app_name", "STRING", app_name), - ScalarQueryParameter("user_id", "STRING", user_id), - ] - - with self._config.provide_connection() as conn: - job_config = QueryJobConfig(query_parameters=params) - query_job = conn.query(sql, job_config=job_config) - results = list(query_job.result()) - - return [ - SessionRecord( - id=row.id, - app_name=row.app_name, - user_id=row.user_id, - state=from_json(row.state) if row.state else {}, - create_time=row.create_time, - update_time=row.update_time, - ) - for row in results - ] - - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app, optionally filtered by user. - - Args: - app_name: Application name. - user_id: User identifier. If None, lists all sessions for the app. - - Returns: - List of session records ordered by update_time DESC. - - Notes: - Uses clustering on (app_name, user_id) when user_id is provided for efficiency. - """ - return await async_(self._list_sessions)(app_name, user_id) - - def _delete_session(self, session_id: str) -> None: - """Synchronous implementation of delete_session.""" - events_table = self._get_full_table_name(self._events_table) - sessions_table = self._get_full_table_name(self._session_table) - - params = [ScalarQueryParameter("session_id", "STRING", session_id)] - - with self._config.provide_connection() as conn: - job_config = QueryJobConfig(query_parameters=params) - conn.query(f"DELETE FROM {events_table} WHERE session_id = @session_id", job_config=job_config).result() - conn.query(f"DELETE FROM {sessions_table} WHERE id = @session_id", job_config=job_config).result() - - async def delete_session(self, session_id: str) -> None: - """Delete session and all associated events. - - Args: - session_id: Session identifier. - - Notes: - BigQuery doesn't support foreign keys, so we manually delete events first. - Uses two separate DELETE statements in sequence. - """ - await async_(self._delete_session)(session_id) - - def _append_event(self, event_record: EventRecord) -> None: - """Synchronous implementation of append_event.""" - table_name = self._get_full_table_name(self._events_table) - - content_json = to_json(event_record.get("content")) if event_record.get("content") else None - grounding_metadata_json = ( - to_json(event_record.get("grounding_metadata")) if event_record.get("grounding_metadata") else None - ) - custom_metadata_json = ( - to_json(event_record.get("custom_metadata")) if event_record.get("custom_metadata") else None - ) - - sql = f""" - INSERT INTO {table_name} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - @id, @session_id, @app_name, @user_id, @invocation_id, @author, @actions, - @long_running_tool_ids_json, @branch, @timestamp, - {"JSON(@content)" if content_json else "NULL"}, - {"JSON(@grounding_metadata)" if grounding_metadata_json else "NULL"}, - {"JSON(@custom_metadata)" if custom_metadata_json else "NULL"}, - @partial, @turn_complete, @interrupted, @error_code, @error_message - ) - """ - - actions_value = event_record.get("actions") - params = [ - ScalarQueryParameter("id", "STRING", event_record["id"]), - ScalarQueryParameter("session_id", "STRING", event_record["session_id"]), - ScalarQueryParameter("app_name", "STRING", event_record["app_name"]), - ScalarQueryParameter("user_id", "STRING", event_record["user_id"]), - ScalarQueryParameter("invocation_id", "STRING", event_record.get("invocation_id")), - ScalarQueryParameter("author", "STRING", event_record.get("author")), - ScalarQueryParameter( - "actions", - "BYTES", - actions_value.decode("latin1") if isinstance(actions_value, bytes) else actions_value, - ), - ScalarQueryParameter( - "long_running_tool_ids_json", "STRING", event_record.get("long_running_tool_ids_json") - ), - ScalarQueryParameter("branch", "STRING", event_record.get("branch")), - ScalarQueryParameter("timestamp", "TIMESTAMP", event_record["timestamp"]), - ScalarQueryParameter("partial", "BOOL", event_record.get("partial")), - ScalarQueryParameter("turn_complete", "BOOL", event_record.get("turn_complete")), - ScalarQueryParameter("interrupted", "BOOL", event_record.get("interrupted")), - ScalarQueryParameter("error_code", "STRING", event_record.get("error_code")), - ScalarQueryParameter("error_message", "STRING", event_record.get("error_message")), - ] - - if content_json: - params.append(ScalarQueryParameter("content", "STRING", content_json)) - if grounding_metadata_json: - params.append(ScalarQueryParameter("grounding_metadata", "STRING", grounding_metadata_json)) - if custom_metadata_json: - params.append(ScalarQueryParameter("custom_metadata", "STRING", custom_metadata_json)) - - with self._config.provide_connection() as conn: - job_config = QueryJobConfig(query_parameters=params) - conn.query(sql, job_config=job_config).result() - - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session. - - Args: - event_record: Event record to store. - - Notes: - Uses BigQuery TIMESTAMP for timezone-aware timestamps. - JSON fields are serialized to STRING then cast to JSON. - Boolean fields stored natively as BOOL. - """ - await async_(self._append_event)(event_record) - - def _get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Synchronous implementation of get_events.""" - table_name = self._get_full_table_name(self._events_table) - - where_clauses = ["session_id = @session_id"] - params: list[ScalarQueryParameter] = [ScalarQueryParameter("session_id", "STRING", session_id)] - - if after_timestamp is not None: - where_clauses.append("timestamp > @after_timestamp") - params.append(ScalarQueryParameter("after_timestamp", "TIMESTAMP", after_timestamp)) - - where_clause = " AND ".join(where_clauses) - limit_clause = f" LIMIT {limit}" if limit else "" - - sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, - JSON_VALUE(content) as content, - JSON_VALUE(grounding_metadata) as grounding_metadata, - JSON_VALUE(custom_metadata) as custom_metadata, - partial, turn_complete, interrupted, error_code, error_message - FROM {table_name} - WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} - """ - - with self._config.provide_connection() as conn: - job_config = QueryJobConfig(query_parameters=params) - query_job = conn.query(sql, job_config=job_config) - results = list(query_job.result()) - - return [ - EventRecord( - id=row.id, - session_id=row.session_id, - app_name=row.app_name, - user_id=row.user_id, - invocation_id=row.invocation_id, - author=row.author, - actions=bytes(row.actions) if row.actions else b"", - long_running_tool_ids_json=row.long_running_tool_ids_json, - branch=row.branch, - timestamp=row.timestamp, - content=from_json(row.content) if row.content else None, - grounding_metadata=from_json(row.grounding_metadata) if row.grounding_metadata else None, - custom_metadata=from_json(row.custom_metadata) if row.custom_metadata else None, - partial=row.partial, - turn_complete=row.turn_complete, - interrupted=row.interrupted, - error_code=row.error_code, - error_message=row.error_message, - ) - for row in results - ] - - async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Get events for a session. - - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. - - Returns: - List of event records ordered by timestamp ASC. - - Notes: - Uses clustering on (session_id, timestamp) for efficient retrieval. - Parses JSON fields and converts BYTES actions to bytes. - """ - return await async_(self._get_events)(session_id, after_timestamp, limit) - - -class BigQueryADKMemoryStore(BaseAsyncADKMemoryStore["BigQueryConfig"]): - """BigQuery ADK memory store using synchronous BigQuery client with async wrapper.""" - - __slots__ = ("_dataset_id",) - - def __init__(self, config: "BigQueryConfig") -> None: - """Initialize BigQuery ADK memory store.""" - super().__init__(config) - self._dataset_id = config.connection_config.get("dataset_id") - - def _get_full_table_name(self, table_name: str) -> str: - """Get fully qualified table name for BigQuery.""" - if self._dataset_id: - return f"`{self._dataset_id}.{table_name}`" - return f"`{table_name}`" - - async def _get_create_memory_table_sql(self) -> str: - """Get BigQuery CREATE TABLE SQL for memory entries.""" - owner_id_line = "" - if self._owner_id_column_ddl: - owner_id_line = f",\n {self._owner_id_column_ddl}" - - table_name = self._get_full_table_name(self._memory_table) - fts_index = "" - if self._use_fts: - fts_index = f""" - CREATE SEARCH INDEX idx_{self._memory_table}_fts - ON {table_name}(content_text) - """ - - return f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id STRING NOT NULL, - session_id STRING NOT NULL, - app_name STRING NOT NULL, - user_id STRING NOT NULL, - event_id STRING NOT NULL, - author STRING{owner_id_line}, - timestamp TIMESTAMP NOT NULL, - content_json JSON NOT NULL, - content_text STRING NOT NULL, - metadata_json JSON, - inserted_at TIMESTAMP NOT NULL - ) - PARTITION BY DATE(timestamp) - CLUSTER BY app_name, user_id; - {fts_index} - """ - - def _get_drop_memory_table_sql(self) -> "list[str]": - """Get BigQuery DROP TABLE SQL statements.""" - table_name = self._get_full_table_name(self._memory_table) - return [f"DROP TABLE IF EXISTS {table_name}"] - - def _create_tables(self) -> None: - """Synchronous implementation of create_tables.""" - with self._config.provide_session() as driver: - driver.execute_script(run_(self._get_create_memory_table_sql)()) - - async def create_tables(self) -> None: - """Create the memory table if it doesn't exist.""" - if not self._enabled: - return - await async_(self._create_tables)() - - def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Synchronous implementation of insert_memory_entries.""" - table_name = self._get_full_table_name(self._memory_table) - inserted_count = 0 - - with self._config.provide_connection() as conn: - for entry in entries: - content_json = to_json(entry["content_json"]) - metadata_json = to_json(entry["metadata_json"]) if entry["metadata_json"] is not None else None - metadata_expr = "JSON(@metadata_json)" if metadata_json is not None else "NULL" - - owner_column = f", {self._owner_id_column_name}" if self._owner_id_column_name else "" - owner_value = ", @owner_id" if self._owner_id_column_name else "" - - sql = f""" - MERGE {table_name} T - USING (SELECT @event_id AS event_id) S - ON T.event_id = S.event_id - WHEN NOT MATCHED THEN - INSERT (id, session_id, app_name, user_id, event_id, author{owner_column}, - timestamp, content_json, content_text, metadata_json, inserted_at) - VALUES (@id, @session_id, @app_name, @user_id, @event_id, @author{owner_value}, - @timestamp, JSON(@content_json), @content_text, {metadata_expr}, @inserted_at) - """ - - params = [ - ScalarQueryParameter("id", "STRING", entry["id"]), - ScalarQueryParameter("session_id", "STRING", entry["session_id"]), - ScalarQueryParameter("app_name", "STRING", entry["app_name"]), - ScalarQueryParameter("user_id", "STRING", entry["user_id"]), - ScalarQueryParameter("event_id", "STRING", entry["event_id"]), - ScalarQueryParameter("author", "STRING", entry["author"]), - ScalarQueryParameter("timestamp", "TIMESTAMP", entry["timestamp"]), - ScalarQueryParameter("content_json", "STRING", content_json), - ScalarQueryParameter("content_text", "STRING", entry["content_text"]), - ScalarQueryParameter("inserted_at", "TIMESTAMP", entry["inserted_at"]), - ] - - if self._owner_id_column_name: - params.append(ScalarQueryParameter("owner_id", "STRING", str(owner_id) if owner_id else None)) - if metadata_json is not None: - params.append(ScalarQueryParameter("metadata_json", "STRING", metadata_json)) - - job_config = QueryJobConfig(query_parameters=params) - job = conn.query(sql, job_config=job_config) - job.result() - if job.num_dml_affected_rows: - inserted_count += int(job.num_dml_affected_rows) - - return inserted_count - - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication.""" - if not self._enabled: - msg = "Memory store is disabled" - raise RuntimeError(msg) - - if not entries: - return 0 - - return await async_(self._insert_memory_entries)(entries, owner_id) - - def _search_entries(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": - """Synchronous implementation of search_entries.""" - table_name = self._get_full_table_name(self._memory_table) - base_params = [ - ScalarQueryParameter("app_name", "STRING", app_name), - ScalarQueryParameter("user_id", "STRING", user_id), - ScalarQueryParameter("limit", "INT64", limit), - ] - - if self._use_fts: - sql = f""" - SELECT id, session_id, app_name, user_id, event_id, author, - timestamp, content_json, content_text, metadata_json, inserted_at - FROM {table_name} - WHERE app_name = @app_name - AND user_id = @user_id - AND SEARCH(content_text, @query) - ORDER BY timestamp DESC - LIMIT @limit - """ - params = [*base_params, ScalarQueryParameter("query", "STRING", query)] - else: - sql = f""" - SELECT id, session_id, app_name, user_id, event_id, author, - timestamp, content_json, content_text, metadata_json, inserted_at - FROM {table_name} - WHERE app_name = @app_name - AND user_id = @user_id - AND LOWER(content_text) LIKE LOWER(@pattern) - ORDER BY timestamp DESC - LIMIT @limit - """ - pattern = f"%{query}%" - params = [*base_params, ScalarQueryParameter("pattern", "STRING", pattern)] - - with self._config.provide_connection() as conn: - job_config = QueryJobConfig(query_parameters=params) - rows = conn.query(sql, job_config=job_config).result() - return _rows_to_records(rows) - - async def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - """Search memory entries by text query.""" - if not self._enabled: - msg = "Memory store is disabled" - raise RuntimeError(msg) - - effective_limit = limit if limit is not None else self._max_results - - try: - return await async_(self._search_entries)(query, app_name, user_id, effective_limit) - except NotFound: - return [] - - def _delete_entries_by_session(self, session_id: str) -> int: - table_name = self._get_full_table_name(self._memory_table) - sql = f"DELETE FROM {table_name} WHERE session_id = @session_id" - params = [ScalarQueryParameter("session_id", "STRING", session_id)] - with self._config.provide_connection() as conn: - job_config = QueryJobConfig(query_parameters=params) - job = conn.query(sql, job_config=job_config) - job.result() - return int(job.num_dml_affected_rows or 0) - - async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" - return await async_(self._delete_entries_by_session)(session_id) - - def _delete_entries_older_than(self, days: int) -> int: - table_name = self._get_full_table_name(self._memory_table) - sql = f""" - DELETE FROM {table_name} - WHERE inserted_at < TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL {days} DAY) - """ - with self._config.provide_connection() as conn: - job = conn.query(sql) - job.result() - return int(job.num_dml_affected_rows or 0) - - async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" - return await async_(self._delete_entries_older_than)(days) - - -def _decode_json_field(value: Any) -> "dict[str, Any] | None": - if value is None: - return None - if isinstance(value, str): - return cast("dict[str, Any]", from_json(value)) - if isinstance(value, Mapping): - return dict(value) - return None - - -def _rows_to_records(rows: Any) -> "list[MemoryRecord]": - return [ - { - "id": row["id"], - "session_id": row["session_id"], - "app_name": row["app_name"], - "user_id": row["user_id"], - "event_id": row["event_id"], - "author": row["author"], - "timestamp": row["timestamp"], - "content_json": _decode_json_field(row["content_json"]) or {}, - "content_text": row["content_text"], - "metadata_json": _decode_json_field(row["metadata_json"]), - "inserted_at": row["inserted_at"], - } - for row in rows - ] diff --git a/tests/unit/extensions/test_adk/test_store_instantiation.py b/tests/unit/extensions/test_adk/test_store_instantiation.py index 1c16d22bb..1bd442a15 100644 --- a/tests/unit/extensions/test_adk/test_store_instantiation.py +++ b/tests/unit/extensions/test_adk/test_store_instantiation.py @@ -18,8 +18,7 @@ "sqlspec.adapters.asyncpg.adk.store.AsyncpgADKStore", "sqlspec.adapters.aiosqlite.adk.store.AiosqliteADKStore", "sqlspec.adapters.asyncmy.adk.store.AsyncmyADKStore", - "sqlspec.adapters.bigquery.adk.store.BigQueryADKStore", - "sqlspec.adapters.cockroach_asyncpg.adk.store.CockroachAsyncpgADKStore", +"sqlspec.adapters.cockroach_asyncpg.adk.store.CockroachAsyncpgADKStore", "sqlspec.adapters.cockroach_psycopg.adk.store.CockroachPsycopgAsyncADKStore", "sqlspec.adapters.mysqlconnector.adk.store.MysqlConnectorAsyncADKStore", "sqlspec.adapters.oracledb.adk.store.OracleAsyncADKStore", @@ -46,8 +45,7 @@ "sqlspec.adapters.asyncpg.adk.store.AsyncpgADKMemoryStore", "sqlspec.adapters.aiosqlite.adk.store.AiosqliteADKMemoryStore", "sqlspec.adapters.asyncmy.adk.store.AsyncmyADKMemoryStore", - "sqlspec.adapters.bigquery.adk.store.BigQueryADKMemoryStore", - "sqlspec.adapters.cockroach_asyncpg.adk.store.CockroachAsyncpgADKMemoryStore", +"sqlspec.adapters.cockroach_asyncpg.adk.store.CockroachAsyncpgADKMemoryStore", "sqlspec.adapters.cockroach_psycopg.adk.store.CockroachPsycopgAsyncADKMemoryStore", "sqlspec.adapters.mysqlconnector.adk.store.MysqlConnectorAsyncADKMemoryStore", "sqlspec.adapters.oracledb.adk.store.OracleAsyncADKMemoryStore", From 92fa14bf411891041dbbdb8323e97335c7bb8078 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 20 Mar 2026 23:29:57 +0000 Subject: [PATCH 06/23] feat(adk): rebuild duckdb store for new contract Rebuild DuckdbADKStore for the clean-break 5-column EventRecord: - Events table now uses (session_id, invocation_id, author, timestamp, event_json) instead of 17+ decomposed columns - Implement create_event_and_update_state() for atomic event+state writes - Update create_event() to build event_json blob from legacy parameters - Update list_events() to return new 5-key EventRecord shape - Use TIMESTAMPTZ instead of TIMESTAMP for timezone-aware storage Fix DuckdbADKMemoryStore FTS bugs: - Replace wrong @@ operator with match_bm25() for BM25-ranked search - Move FTS index refresh from search to post-insert/delete only - Add stemmer, stopwords, strip_accents config to FTS index creation --- sqlspec/adapters/duckdb/adk/store.py | 322 +++++++++++++-------------- 1 file changed, 151 insertions(+), 171 deletions(-) diff --git a/sqlspec/adapters/duckdb/adk/store.py b/sqlspec/adapters/duckdb/adk/store.py index 4db753b58..0255ca758 100644 --- a/sqlspec/adapters/duckdb/adk/store.py +++ b/sqlspec/adapters/duckdb/adk/store.py @@ -39,9 +39,9 @@ class DuckdbADKStore(BaseSyncADKStore["DuckDBConfig"]): Implements session and event storage for Google Agent Development Kit using DuckDB's synchronous driver. Provides: - Session state management with native JSON type - - Event history tracking with BLOB-serialized actions - - Native TIMESTAMP type support - - Foreign key constraints (manual cascade in delete_session) + - Event history with single JSON blob (event_json) plus indexed scalars + - Native TIMESTAMPTZ type support + - Manual cascade delete (DuckDB has no FK CASCADE) - Columnar storage for analytical queries Args: @@ -64,18 +64,10 @@ class DuckdbADKStore(BaseSyncADKStore["DuckDBConfig"]): store = DuckdbADKStore(config) store.ensure_tables() - session = store.create_session( - session_id="session-123", - app_name="my-app", - user_id="user-456", - state={"context": "conversation"} - ) - Notes: - - Uses DuckDB native JSON type (not JSONB) - - TIMESTAMP for date/time storage with microsecond precision - - BLOB for binary actions data - - BOOLEAN native type support + - Uses DuckDB native JSON type for event_json and state + - TIMESTAMPTZ for date/time storage with microsecond precision + - event_json stores the full ADK Event as a single JSON blob - Columnar storage provides excellent analytical query performance - DuckDB doesn't support CASCADE in foreign keys (manual cascade required) - Optimized for OLAP workloads; for high-concurrency writes use PostgreSQL @@ -107,7 +99,7 @@ def _get_create_sessions_table_sql(self) -> str: Notes: - VARCHAR for IDs and names - JSON type for state storage (DuckDB native) - - TIMESTAMP for create_time and update_time + - TIMESTAMPTZ for create_time and update_time - CURRENT_TIMESTAMP for defaults - Optional owner ID column for multi-tenant scenarios - Composite index on (app_name, user_id) for listing @@ -123,8 +115,8 @@ def _get_create_sessions_table_sql(self) -> str: app_name VARCHAR NOT NULL, user_id VARCHAR NOT NULL{owner_id_line}, state JSON NOT NULL, - create_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP ); CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user ON {self._session_table}(app_name, user_id); CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time ON {self._session_table}(update_time DESC); @@ -137,34 +129,20 @@ def _get_create_events_table_sql(self) -> str: SQL statement to create adk_events table with indexes. Notes: - - VARCHAR for string fields - - BLOB for pickled actions - - JSON for content, grounding_metadata, custom_metadata, long_running_tool_ids_json - - BOOLEAN for flags + - 5-column schema: session_id, invocation_id, author, timestamp, event_json + - event_json stores the full ADK Event as a single JSON blob + - No decomposed columns -- eliminates column drift with upstream ADK - Foreign key constraint (DuckDB doesn't support CASCADE) - Index on (session_id, timestamp ASC) for ordered event retrieval - Manual cascade delete required in delete_session method """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR PRIMARY KEY, session_id VARCHAR NOT NULL, - app_name VARCHAR NOT NULL, - user_id VARCHAR NOT NULL, - invocation_id VARCHAR, - author VARCHAR, - actions BLOB, - long_running_tool_ids_json JSON, - branch VARCHAR, - timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - content JSON, - grounding_metadata JSON, - custom_metadata JSON, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR, - error_message VARCHAR, + invocation_id VARCHAR NOT NULL, + author VARCHAR NOT NULL, + timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + event_json JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); @@ -243,7 +221,7 @@ def get_session(self, session_id: str) -> "SessionRecord | None": Session record or None if not found. Notes: - DuckDB returns datetime objects for TIMESTAMP columns. + DuckDB returns datetime objects for TIMESTAMPTZ columns. JSON is parsed from database storage. """ sql = f""" @@ -380,138 +358,125 @@ def create_event( content: "dict[str, Any] | None" = None, **kwargs: Any, ) -> EventRecord: - """Create a new event. + """Create a new event using the legacy decomposed-parameter signature. + + This method satisfies the abstract base class contract. It builds an + ``EventRecord`` from the provided arguments and delegates to the new + 5-column schema. Args: - event_id: Unique event identifier. + event_id: Unique event identifier (unused in new schema, kept for API compat). session_id: Session identifier. - app_name: Application name. - user_id: User identifier. + app_name: Application name (stored inside event_json). + user_id: User identifier (stored inside event_json). author: Event author (user/assistant/system). - actions: Pickled actions object. - content: Event content (JSON). - **kwargs: Additional optional fields. + actions: Legacy actions bytes (ignored in new schema). + content: Event content dict (stored inside event_json). + **kwargs: Additional optional fields folded into event_json. Returns: - Created event record. - - Notes: - Uses current UTC timestamp if not provided in kwargs. - JSON fields are serialized using SQLSpec serializers. + Created event record with the new 5-key shape. """ timestamp = kwargs.get("timestamp", datetime.now(timezone.utc)) - content_json = to_json(content) if content else None - grounding_metadata = kwargs.get("grounding_metadata") - grounding_metadata_json = to_json(grounding_metadata) if grounding_metadata else None - custom_metadata = kwargs.get("custom_metadata") - custom_metadata_json = to_json(custom_metadata) if custom_metadata else None + + # Build the event_json blob from all provided fields + event_data: dict[str, Any] = { + "id": event_id, + "app_name": app_name, + "user_id": user_id, + } + if content is not None: + event_data["content"] = content + for key in ( + "invocation_id", + "branch", + "grounding_metadata", + "custom_metadata", + "long_running_tool_ids_json", + "partial", + "turn_complete", + "interrupted", + "error_code", + "error_message", + ): + val = kwargs.get(key) + if val is not None: + event_data[key] = val + + event_json_str = to_json(event_data) sql = f""" - INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO {self._events_table} + (session_id, invocation_id, author, timestamp, event_json) + VALUES (?, ?, ?, ?, ?) """ with self._config.provide_connection() as conn: conn.execute( sql, ( - event_id, session_id, - app_name, - user_id, - kwargs.get("invocation_id"), - author, - actions, - kwargs.get("long_running_tool_ids_json"), - kwargs.get("branch"), + kwargs.get("invocation_id", ""), + author or "", timestamp, - content_json, - grounding_metadata_json, - custom_metadata_json, - kwargs.get("partial"), - kwargs.get("turn_complete"), - kwargs.get("interrupted"), - kwargs.get("error_code"), - kwargs.get("error_message"), + event_json_str, ), ) conn.commit() return EventRecord( - id=event_id, session_id=session_id, - app_name=app_name, - user_id=user_id, invocation_id=kwargs.get("invocation_id", ""), author=author or "", - actions=actions or b"", - long_running_tool_ids_json=kwargs.get("long_running_tool_ids_json"), - branch=kwargs.get("branch"), timestamp=timestamp, - content=content, - grounding_metadata=grounding_metadata, - custom_metadata=custom_metadata, - partial=kwargs.get("partial"), - turn_complete=kwargs.get("turn_complete"), - interrupted=kwargs.get("interrupted"), - error_code=kwargs.get("error_code"), - error_message=kwargs.get("error_message"), + event_json=event_json_str, ) - def get_event(self, event_id: str) -> "EventRecord | None": - """Get event by ID. + def create_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically create an event and update the session's durable state. - Args: - event_id: Event identifier. + The event insert and state update succeed together or fail together + within a single DuckDB transaction. - Returns: - Event record or None if not found. + Args: + event_record: Event record to store (5-key shape). + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot (``temp:`` keys already + stripped by the service layer). """ - sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - FROM {self._events_table} - WHERE id = ? + now = datetime.now(timezone.utc) + state_json = to_json(state) + event_json_value = event_record["event_json"] + if not isinstance(event_json_value, str): + event_json_value = to_json(event_json_value) + + insert_sql = f""" + INSERT INTO {self._events_table} + (session_id, invocation_id, author, timestamp, event_json) + VALUES (?, ?, ?, ?, ?) """ - try: - with self._config.provide_connection() as conn: - cursor = conn.execute(sql, (event_id,)) - row = cursor.fetchone() - - if row is None: - return None + update_sql = f""" + UPDATE {self._session_table} + SET state = ?, update_time = ? + WHERE id = ? + """ - return EventRecord( - id=row[0], - session_id=row[1], - app_name=row[2], - user_id=row[3], - invocation_id=row[4], - author=row[5], - actions=bytes(row[6]) if row[6] else b"", - long_running_tool_ids_json=row[7], - branch=row[8], - timestamp=row[9], - content=from_json(row[10]) if row[10] else None, - grounding_metadata=from_json(row[11]) if row[11] else None, - custom_metadata=from_json(row[12]) if row[12] else None, - partial=row[13], - turn_complete=row[14], - interrupted=row[15], - error_code=row[16], - error_message=row[17], - ) - except Exception as e: - if DUCKDB_TABLE_NOT_FOUND_ERROR in str(e): - return None - raise + with self._config.provide_connection() as conn: + conn.execute( + insert_sql, + ( + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_json_value, + ), + ) + conn.execute(update_sql, (state_json, now, session_id)) + conn.commit() def list_events(self, session_id: str) -> "list[EventRecord]": """List events for a session ordered by timestamp. @@ -523,10 +488,7 @@ def list_events(self, session_id: str) -> "list[EventRecord]": List of event records ordered by timestamp ASC. """ sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} WHERE session_id = ? ORDER BY timestamp ASC @@ -539,24 +501,11 @@ def list_events(self, session_id: str) -> "list[EventRecord]": return [ EventRecord( - id=row[0], - session_id=row[1], - app_name=row[2], - user_id=row[3], - invocation_id=row[4], - author=row[5], - actions=bytes(row[6]) if row[6] else b"", - long_running_tool_ids_json=row[7], - branch=row[8], - timestamp=row[9], - content=from_json(row[10]) if row[10] else None, - grounding_metadata=from_json(row[11]) if row[11] else None, - custom_metadata=from_json(row[12]) if row[12] else None, - partial=row[13], - turn_complete=row[14], - interrupted=row[15], - error_code=row[16], - error_message=row[17], + session_id=row[0], + invocation_id=row[1], + author=row[2], + timestamp=row[3], + event_json=row[4] if isinstance(row[4], str) else to_json(row[4]), ) for row in rows ] @@ -572,7 +521,7 @@ class DuckdbADKMemoryStore(BaseSyncADKMemoryStore["DuckDBConfig"]): Implements memory entry storage for Google Agent Development Kit using DuckDB's synchronous driver. Provides: - Session memory storage with native JSON type - - Simple ILIKE search + - Simple ILIKE search or BM25 full-text search via FTS extension - Native TIMESTAMP type support - Deduplication via event_id unique constraint - Efficient upserts using INSERT OR IGNORE @@ -602,6 +551,8 @@ class DuckdbADKMemoryStore(BaseSyncADKMemoryStore["DuckDBConfig"]): - TIMESTAMP for date/time storage with microsecond precision - event_id UNIQUE constraint for deduplication - Composite index on (app_name, user_id, timestamp DESC) + - FTS uses match_bm25() for BM25-ranked results (not @@ operator) + - FTS index is refreshed after inserts, not on every search - Columnar storage provides excellent analytical query performance - Optimized for OLAP workloads; for high-concurrency writes use PostgreSQL - Configuration is read from config.extension_config["adk"] @@ -644,12 +595,19 @@ def _create_fts_index(self, conn: Any) -> None: return try: - conn.execute(f"PRAGMA create_fts_index('{self._memory_table}', 'id', 'content_text')") + conn.execute( + f"PRAGMA create_fts_index('{self._memory_table}', 'id', 'content_text', " + f"stemmer='porter', stopwords='english', strip_accents=1, lower=1)" + ) except Exception as exc: logger.debug("Failed to create DuckDB FTS index: %s", exc) def _refresh_fts_index(self, conn: Any) -> None: - """Rebuild the FTS index to reflect recent changes.""" + """Rebuild the FTS index to reflect recent inserts. + + DuckDB FTS indexes do not auto-update. This must be called after + insert/update/delete operations, NOT on every search. + """ if not self._ensure_fts_extension(conn): return @@ -657,7 +615,10 @@ def _refresh_fts_index(self, conn: Any) -> None: conn.execute(f"PRAGMA drop_fts_index('{self._memory_table}')") try: - conn.execute(f"PRAGMA create_fts_index('{self._memory_table}', 'id', 'content_text')") + conn.execute( + f"PRAGMA create_fts_index('{self._memory_table}', 'id', 'content_text', " + f"overwrite=1, stemmer='porter', stopwords='english', strip_accents=1, lower=1)" + ) except Exception as exc: logger.debug("Failed to refresh DuckDB FTS index: %s", exc) @@ -708,7 +669,10 @@ def create_tables(self) -> None: self._create_fts_index(conn) def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication.""" + """Bulk insert memory entries with deduplication. + + After successful inserts, refreshes the FTS index if FTS is enabled. + """ if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -770,12 +734,21 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object result = conn.execute(sql, params) inserted_count += len(result.fetchall()) conn.commit() + + # Refresh FTS index after inserts, not on search + if self._use_fts and inserted_count > 0: + self._refresh_fts_index(conn) + return inserted_count def search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": - """Search memory entries by text query.""" + """Search memory entries by text query. + + When FTS is enabled, uses ``match_bm25()`` for BM25-ranked results. + Falls back to ILIKE for simple substring matching. + """ if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -785,13 +758,19 @@ def search_entries( limit_value = limit or self._max_results if self._use_fts: + # Use match_bm25() -- the correct DuckDB FTS syntax sql = f""" - SELECT * FROM {self._memory_table} - WHERE app_name = ? AND user_id = ? AND content_text @@ ? - ORDER BY timestamp DESC + SELECT m.* + FROM {self._memory_table} m + JOIN ( + SELECT id, fts_main_{self._memory_table}.match_bm25(id, ?, fields := 'content_text') AS score + FROM {self._memory_table} + ) fts ON m.id = fts.id + WHERE m.app_name = ? AND m.user_id = ? AND fts.score IS NOT NULL + ORDER BY fts.score DESC LIMIT ? """ - params = (app_name, user_id, query, limit_value) + params = (query, app_name, user_id, limit_value) else: sql = f""" SELECT * FROM {self._memory_table} @@ -814,9 +793,6 @@ def search_entries( if isinstance(metadata_value, (str, bytes)): record["metadata_json"] = from_json(metadata_value) records.append(record) - if self._use_fts: - with self._config.provide_connection() as conn: - self._refresh_fts_index(conn) return records def delete_entries_by_session(self, session_id: str) -> int: @@ -830,6 +806,8 @@ def delete_entries_by_session(self, session_id: str) -> int: result = conn.execute(sql, (session_id,)) deleted_count = len(result.fetchall()) conn.commit() + if self._use_fts and deleted_count > 0: + self._refresh_fts_index(conn) return deleted_count def delete_entries_older_than(self, days: int) -> int: @@ -847,4 +825,6 @@ def delete_entries_older_than(self, days: int) -> int: result = conn.execute(sql) deleted_count = len(result.fetchall()) conn.commit() + if self._use_fts and deleted_count > 0: + self._refresh_fts_index(conn) return deleted_count From d1727161df768eb58156c1bc6b2bd36b82cce078 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 20 Mar 2026 23:31:00 +0000 Subject: [PATCH 07/23] feat(adk): rebuild spanner stores for new contract Rebuild the Spanner session/event stores for the clean-break EventRecord (5 keys: session_id, invocation_id, author, timestamp, event_json). Changes: - Events table DDL reduced to 5 columns with event_json as JSON type - Implement create_event_and_update_state() for atomic event+state writes - Update list_events/create_event to use new EventRecord shape - Add ORDER BY update_time DESC to list_sessions - Add _get_drop_memory_table_sql to SpannerSyncADKMemoryStore - Remove bytes_to_spanner/spanner_to_bytes imports (no more actions column) - Keep FARM_FINGERPRINT sharding and commit timestamp support --- sqlspec/adapters/spanner/adk/store.py | 242 ++++++++++---------------- 1 file changed, 91 insertions(+), 151 deletions(-) diff --git a/sqlspec/adapters/spanner/adk/store.py b/sqlspec/adapters/spanner/adk/store.py index 22ebebc7e..dc9090a86 100644 --- a/sqlspec/adapters/spanner/adk/store.py +++ b/sqlspec/adapters/spanner/adk/store.py @@ -7,7 +7,6 @@ from google.cloud.spanner_v1 import param_types from sqlspec.adapters.spanner.config import SpannerSyncConfig -from sqlspec.adapters.spanner.type_converter import bytes_to_spanner, spanner_to_bytes from sqlspec.extensions.adk import BaseSyncADKStore, EventRecord, SessionRecord from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore from sqlspec.protocols import SpannerParamTypesProtocol @@ -80,30 +79,15 @@ def _session_param_types(self, include_owner: bool) -> "dict[str, Any]": types["owner_id"] = SPANNER_PARAM_TYPES.STRING return types - def _event_param_types(self, has_branch: bool) -> "dict[str, Any]": + def _event_param_types(self) -> "dict[str, Any]": json_type = _json_param_type() - types: dict[str, Any] = { - "id": SPANNER_PARAM_TYPES.STRING, + return { "session_id": SPANNER_PARAM_TYPES.STRING, - "app_name": SPANNER_PARAM_TYPES.STRING, - "user_id": SPANNER_PARAM_TYPES.STRING, - "author": SPANNER_PARAM_TYPES.STRING, - "actions": SPANNER_PARAM_TYPES.BYTES, - "long_running_tool_ids_json": json_type, "invocation_id": SPANNER_PARAM_TYPES.STRING, + "author": SPANNER_PARAM_TYPES.STRING, "timestamp": SPANNER_PARAM_TYPES.TIMESTAMP, - "content": json_type, - "grounding_metadata": json_type, - "custom_metadata": json_type, - "partial": SPANNER_PARAM_TYPES.BOOL, - "turn_complete": SPANNER_PARAM_TYPES.BOOL, - "interrupted": SPANNER_PARAM_TYPES.BOOL, - "error_code": SPANNER_PARAM_TYPES.STRING, - "error_message": SPANNER_PARAM_TYPES.STRING, + "event_json": json_type, } - if has_branch: - types["branch"] = SPANNER_PARAM_TYPES.STRING - return types def _decode_state(self, raw: Any) -> Any: if isinstance(raw, str): @@ -198,6 +182,7 @@ def list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[Se types["user_id"] = SPANNER_PARAM_TYPES.STRING if self._shard_count > 1: sql = f"{sql} AND shard_id = MOD(FARM_FINGERPRINT(id), {self._shard_count})" + sql = f"{sql} ORDER BY update_time DESC" rows = self._run_read(sql, params, types) records: list[SessionRecord] = [] @@ -235,116 +220,81 @@ def create_event( content: "dict[str, Any] | None" = None, **kwargs: Any, ) -> EventRecord: - branch = kwargs.get("branch") - long_running_serialized = ( - to_json(kwargs.get("long_running_tool_ids_json")) - if kwargs.get("long_running_tool_ids_json") is not None - else None - ) - content_serialized = to_json(content) if content is not None else None - grounding_serialized = ( - to_json(kwargs.get("grounding_metadata")) if kwargs.get("grounding_metadata") is not None else None - ) - custom_serialized = ( - to_json(kwargs.get("custom_metadata")) if kwargs.get("custom_metadata") is not None else None - ) - params: dict[str, Any] = { + invocation_id = kwargs.get("invocation_id", "") + event_json = to_json({ "id": event_id, - "session_id": session_id, "app_name": app_name, "user_id": user_id, "author": author, - "actions": bytes_to_spanner(actions), - "long_running_tool_ids_json": long_running_serialized, - "timestamp": datetime.now(timezone.utc), - "content": content_serialized, - "grounding_metadata": grounding_serialized, - "custom_metadata": custom_serialized, - "invocation_id": kwargs.get("invocation_id"), - "partial": kwargs.get("partial"), - "turn_complete": kwargs.get("turn_complete"), - "interrupted": kwargs.get("interrupted"), - "error_code": kwargs.get("error_code"), - "error_message": kwargs.get("error_message"), + "content": content, + **{k: v for k, v in kwargs.items() if v is not None}, + }) + now = datetime.now(timezone.utc) + params: dict[str, Any] = { + "session_id": session_id, + "invocation_id": invocation_id, + "author": author or "", + "timestamp": now, + "event_json": event_json, } - branch = kwargs.get("branch") - columns = [ - "id", - "session_id", - "app_name", - "user_id", - "author", - "actions", - "long_running_tool_ids_json", - "timestamp", - "content", - "grounding_metadata", - "custom_metadata", - "invocation_id", - "partial", - "turn_complete", - "interrupted", - "error_code", - "error_message", - ] - values = [ - "@id", - "@session_id", - "@app_name", - "@user_id", - "@author", - "@actions", - "@long_running_tool_ids_json", - "PENDING_COMMIT_TIMESTAMP()", - "@content", - "@grounding_metadata", - "@custom_metadata", - "@invocation_id", - "@partial", - "@turn_complete", - "@interrupted", - "@error_code", - "@error_message", - ] - has_branch = branch is not None - if has_branch: - params["branch"] = branch - columns.append("branch") - values.append("@branch") sql = f""" - INSERT INTO {self._events_table} ({", ".join(columns)}) - VALUES ({", ".join(values)}) + INSERT INTO {self._events_table} (session_id, invocation_id, author, timestamp, event_json) + VALUES (@session_id, @invocation_id, @author, PENDING_COMMIT_TIMESTAMP(), @event_json) """ - self._run_write([(sql, params, self._event_param_types(has_branch))]) + self._run_write([(sql, params, self._event_param_types())]) - record: EventRecord = { - "id": event_id, + return { "session_id": session_id, - "app_name": app_name, - "user_id": user_id, + "invocation_id": invocation_id, "author": author or "", - "actions": actions or b"", - "long_running_tool_ids_json": long_running_serialized, - "branch": branch, - "timestamp": params["timestamp"], - "content": from_json(content_serialized) if content_serialized else None, - "grounding_metadata": from_json(grounding_serialized) if grounding_serialized else None, - "custom_metadata": from_json(custom_serialized) if custom_serialized else None, - "invocation_id": kwargs.get("invocation_id", ""), - "partial": kwargs.get("partial"), - "turn_complete": kwargs.get("turn_complete"), - "interrupted": kwargs.get("interrupted"), - "error_code": kwargs.get("error_code"), - "error_message": kwargs.get("error_message"), + "timestamp": now, + "event_json": event_json, } - return record + + def create_event_and_update_state( + self, event_record: "EventRecord", session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically insert an event and update session state in one transaction. + + Both the event INSERT and the session state UPDATE execute within a single + Spanner transaction so they succeed or fail together. + + Args: + event_record: Event record to store. + session_id: Session whose state should be updated. + state: Post-append durable state snapshot. + """ + event_params: dict[str, Any] = { + "session_id": event_record["session_id"], + "invocation_id": event_record["invocation_id"], + "author": event_record["author"], + "timestamp": event_record["timestamp"], + "event_json": event_record["event_json"], + } + insert_sql = f""" + INSERT INTO {self._events_table} (session_id, invocation_id, author, timestamp, event_json) + VALUES (@session_id, @invocation_id, @author, PENDING_COMMIT_TIMESTAMP(), @event_json) + """ + + json_type = _json_param_type() + state_params: dict[str, Any] = {"id": session_id, "state": to_json(state)} + update_sql = f""" + UPDATE {self._session_table} + SET state = @state, update_time = PENDING_COMMIT_TIMESTAMP() + WHERE id = @id + """ + if self._shard_count > 1: + update_sql = f"{update_sql} AND shard_id = MOD(FARM_FINGERPRINT(@id), {self._shard_count})" + + self._run_write([ + (insert_sql, event_params, self._event_param_types()), + (update_sql, state_params, {"id": SPANNER_PARAM_TYPES.STRING, "state": json_type}), + ]) def list_events(self, session_id: str) -> "list[EventRecord]": sql = f""" - SELECT id, session_id, app_name, user_id, author, actions, long_running_tool_ids_json, branch, - timestamp, content, grounding_metadata, custom_metadata, invocation_id, partial, - turn_complete, interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} WHERE session_id = @session_id """ @@ -356,24 +306,11 @@ def list_events(self, session_id: str) -> "list[EventRecord]": rows = self._run_read(sql, params, types) return [ { - "id": row[0], - "session_id": row[1], - "app_name": row[2], - "user_id": row[3], - "invocation_id": row[12] or "", - "author": row[4] or "", - "actions": spanner_to_bytes(row[5]) or b"", - "long_running_tool_ids_json": row[6], - "branch": row[7], - "timestamp": row[8], - "content": self._decode_json(row[9]), - "grounding_metadata": self._decode_json(row[10]), - "custom_metadata": self._decode_json(row[11]), - "partial": row[13], - "turn_complete": row[14], - "interrupted": row[15], - "error_code": row[16], - "error_message": row[17], + "session_id": row[0], + "invocation_id": row[1] or "", + "author": row[2] or "", + "timestamp": row[3], + "event_json": row[4], } for row in rows ] @@ -416,33 +353,20 @@ def _get_create_sessions_table_sql(self) -> str: def _get_create_events_table_sql(self) -> str: shard_column = "" - pk = "PRIMARY KEY (session_id, timestamp, id)" + pk = "PRIMARY KEY (session_id, timestamp)" if self._shard_count > 1: shard_column = f",\n shard_id INT64 AS (MOD(FARM_FINGERPRINT(session_id), {self._shard_count})) STORED" - pk = "PRIMARY KEY (shard_id, session_id, timestamp, id)" + pk = "PRIMARY KEY (shard_id, session_id, timestamp)" options = "" if self._events_table_options: options = f"\nOPTIONS ({self._events_table_options})" return f""" CREATE TABLE {self._events_table} ( - id STRING(128) NOT NULL, session_id STRING(128) NOT NULL, - app_name STRING(128) NOT NULL, - user_id STRING(128) NOT NULL, - invocation_id STRING(128), - author STRING(64), - actions BYTES(MAX), - long_running_tool_ids_json JSON, - branch STRING(64), + invocation_id STRING(256) NOT NULL, + author STRING(128) NOT NULL, timestamp TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true), - content JSON, - grounding_metadata JSON, - custom_metadata JSON, - partial BOOL, - turn_complete BOOL, - interrupted BOOL, - error_code STRING(64), - error_message STRING(255){shard_column} + event_json JSON NOT NULL{shard_column} ) {pk}{options} """ @@ -590,6 +514,22 @@ def _get_create_memory_table_sql(self) -> "list[str]": statements.append(fts_index) return statements + def _get_drop_memory_table_sql(self) -> "list[str]": + """Get SQL to drop the memory table and its indexes. + + Returns: + List of SQL statements to drop the memory table and associated indexes. + """ + statements: list[str] = [] + if self._use_fts: + statements.append(f"DROP SEARCH INDEX idx_{self._memory_table}_fts") + statements.extend([ + f"DROP INDEX idx_{self._memory_table}_session", + f"DROP INDEX idx_{self._memory_table}_app_user_time", + f"DROP TABLE {self._memory_table}", + ]) + return statements + def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: if not self._enabled: msg = "Memory store is disabled" From 2860edad0309dad99905768ea873ef6b60473399 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 20 Mar 2026 23:32:23 +0000 Subject: [PATCH 08/23] feat(adk): rebuild adbc store for new contract Update ADBC ADK store to use the new 5-column EventRecord contract (session_id, invocation_id, author, timestamp, event_json). All dialect DDL branches use appropriate JSON types: JSONB for PostgreSQL, JSON for DuckDB, VARIANT for Snowflake, TEXT for SQLite and generic fallback. Adds create_event_and_update_state for atomic event+state persistence. --- sqlspec/adapters/adbc/adk/store.py | 341 +++++++++++++---------------- 1 file changed, 151 insertions(+), 190 deletions(-) diff --git a/sqlspec/adapters/adbc/adk/store.py b/sqlspec/adapters/adbc/adk/store.py index 65e5c4975..3d2963d8e 100644 --- a/sqlspec/adapters/adbc/adk/store.py +++ b/sqlspec/adapters/adbc/adk/store.py @@ -1,5 +1,6 @@ """ADBC ADK store for Google Agent Development Kit session/event storage.""" +import contextlib from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Final @@ -32,9 +33,16 @@ class AdbcADKStore(BaseSyncADKStore["AdbcConfig"]): using ADBC. ADBC provides a vendor-neutral API with Arrow-native data transfer across multiple databases (PostgreSQL, SQLite, DuckDB, etc.). + Events use the new 5-column contract: session_id, invocation_id, author, + timestamp, and event_json. The full ADK Event payload is stored as a + single JSON blob in event_json using a dialect-appropriate column type + (JSONB for PostgreSQL, JSON for DuckDB, VARIANT for Snowflake, TEXT for + SQLite and generic fallback). + Provides: - - Session state management with JSON serialization (TEXT storage) - - Event history tracking with BLOB-serialized actions + - Session state management with JSON serialization + - Event history tracking via single event_json blob + - Atomic event insert + session state update - Timezone-aware timestamps - Foreign key constraints with cascade delete - Database-agnostic SQL (supports multiple backends) @@ -60,12 +68,9 @@ class AdbcADKStore(BaseSyncADKStore["AdbcConfig"]): store.ensure_tables() Notes: - - TEXT for JSON storage (compatible across all ADBC backends) - - BLOB for pre-serialized actions from Google ADK + - Dialect-appropriate JSON type for event_json storage - TIMESTAMP for timezone-aware timestamps (driver-dependent precision) - - INTEGER for booleans (0/1/NULL) - - Parameter style varies by backend (?, $1, :name, etc.) - - Uses dialect-agnostic SQL for maximum compatibility + - Parameter style: ``?`` universally across ADBC backends - State and JSON fields use to_json/from_json for serialization - ADBC drivers handle parameter binding automatically - Configuration is read from config.extension_config["adk"] @@ -298,27 +303,17 @@ def _get_events_ddl_postgresql(self) -> str: Returns: SQL to create events table optimized for PostgreSQL. + + Notes: + Uses JSONB for event_json to enable indexing and query support. """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256), - author VARCHAR(256), - actions BYTEA, - long_running_tool_ids_json TEXT, - branch VARCHAR(256), + invocation_id VARCHAR(256) NOT NULL, + author VARCHAR(256) NOT NULL, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - content JSONB, - grounding_metadata JSONB, - custom_metadata JSONB, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ) """ @@ -328,27 +323,17 @@ def _get_events_ddl_sqlite(self) -> str: Returns: SQL to create events table optimized for SQLite. + + Notes: + Uses TEXT for event_json (SQLite has no native JSON column type). """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id TEXT PRIMARY KEY, session_id TEXT NOT NULL, - app_name TEXT NOT NULL, - user_id TEXT NOT NULL, - invocation_id TEXT, - author TEXT, - actions BLOB, - long_running_tool_ids_json TEXT, - branch TEXT, + invocation_id TEXT NOT NULL, + author TEXT NOT NULL, timestamp REAL NOT NULL, - content TEXT, - grounding_metadata TEXT, - custom_metadata TEXT, - partial INTEGER, - turn_complete INTEGER, - interrupted INTEGER, - error_code TEXT, - error_message TEXT, + event_json TEXT NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ) """ @@ -358,27 +343,17 @@ def _get_events_ddl_duckdb(self) -> str: Returns: SQL to create events table optimized for DuckDB. + + Notes: + Uses JSON for event_json (DuckDB native JSON type). """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256), - author VARCHAR(256), - actions BLOB, - long_running_tool_ids_json VARCHAR, - branch VARCHAR(256), + invocation_id VARCHAR(256) NOT NULL, + author VARCHAR(256) NOT NULL, timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - content JSON, - grounding_metadata JSON, - custom_metadata JSON, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ) """ @@ -388,27 +363,17 @@ def _get_events_ddl_snowflake(self) -> str: Returns: SQL to create events table optimized for Snowflake. + + Notes: + Uses VARIANT for event_json (Snowflake semi-structured type). """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR PRIMARY KEY, session_id VARCHAR NOT NULL, - app_name VARCHAR NOT NULL, - user_id VARCHAR NOT NULL, - invocation_id VARCHAR, - author VARCHAR, - actions BINARY, - long_running_tool_ids_json VARCHAR, - branch VARCHAR, + invocation_id VARCHAR NOT NULL, + author VARCHAR NOT NULL, timestamp TIMESTAMP_TZ NOT NULL DEFAULT CURRENT_TIMESTAMP(), - content VARIANT, - grounding_metadata VARIANT, - custom_metadata VARIANT, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR, - error_message VARCHAR, + event_json VARIANT NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ) """ @@ -418,27 +383,17 @@ def _get_events_ddl_generic(self) -> str: Returns: SQL to create events table using generic types. + + Notes: + Uses TEXT for event_json (maximum portability). """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256), - author VARCHAR(256), - actions BLOB, - long_running_tool_ids_json TEXT, - branch VARCHAR(256), + invocation_id VARCHAR(256) NOT NULL, + author VARCHAR(256) NOT NULL, timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - content TEXT, - grounding_metadata TEXT, - custom_metadata TEXT, - partial INTEGER, - turn_complete INTEGER, - interrupted INTEGER, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json TEXT NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ) """ @@ -707,86 +662,136 @@ def create_event( content: "dict[str, Any] | None" = None, **kwargs: Any, ) -> "EventRecord": - """Create a new event. + """Create a new event using the new 5-column EventRecord contract. Args: - event_id: Unique event identifier. + event_id: Unique event identifier (unused in new schema, kept for API compat). session_id: Session identifier. - app_name: Application name. - user_id: User identifier. + app_name: Application name (stored inside event_json). + user_id: User identifier (stored inside event_json). author: Event author (user/assistant/system). - actions: Pickled actions object. - content: Event content (JSON). - **kwargs: Additional optional fields. + actions: Pickled actions object (stored inside event_json if provided). + content: Event content (stored inside event_json). + **kwargs: Additional optional fields (stored inside event_json). Returns: Created event record. Notes: - Uses CURRENT_TIMESTAMP for timestamp if not provided. - JSON fields are serialized to JSON strings. - Boolean fields are converted to INTEGER (0/1). + Builds an event_json blob from all provided fields and stores it + alongside the indexed scalar columns (session_id, invocation_id, + author, timestamp). """ - content_json = self._serialize_json_field(content) - grounding_metadata_json = self._serialize_json_field(kwargs.get("grounding_metadata")) - custom_metadata_json = self._serialize_json_field(kwargs.get("custom_metadata")) + timestamp = kwargs.pop("timestamp", None) + if timestamp is None: + timestamp = datetime.now(timezone.utc) + + invocation_id = kwargs.pop("invocation_id", "") or "" + + # Build event_json from all provided data + event_data: dict[str, Any] = { + "id": event_id, + "app_name": app_name, + "user_id": user_id, + } + if content is not None: + event_data["content"] = content + if actions is not None: + event_data["actions"] = actions.hex() + if author is not None: + event_data["author"] = author + # Include remaining kwargs in event_json + event_data.update({k: v for k, v in kwargs.items() if v is not None}) + + event_json_str = to_json(event_data) + + event_record = EventRecord( + session_id=session_id, + invocation_id=invocation_id, + author=author or "", + timestamp=timestamp, + event_json=event_json_str, + ) + self._insert_event(event_record) + return event_record - partial_int = self._to_int_bool(kwargs.get("partial")) - turn_complete_int = self._to_int_bool(kwargs.get("turn_complete")) - interrupted_int = self._to_int_bool(kwargs.get("interrupted")) + def _insert_event(self, event_record: "EventRecord") -> None: + """Insert an event record into the events table. + Args: + event_record: Event record to store. + """ sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? - ) + session_id, invocation_id, author, timestamp, event_json + ) VALUES (?, ?, ?, ?, ?) """ - timestamp = kwargs.get("timestamp") - if timestamp is None: - timestamp = datetime.now(timezone.utc) - with self._config.provide_connection() as conn: cursor = conn.cursor() try: cursor.execute( sql, ( - event_id, - session_id, - app_name, - user_id, - kwargs.get("invocation_id"), - author, - actions, - kwargs.get("long_running_tool_ids_json"), - kwargs.get("branch"), - timestamp, - content_json, - grounding_metadata_json, - custom_metadata_json, - partial_int, - turn_complete_int, - interrupted_int, - kwargs.get("error_code"), - kwargs.get("error_message"), + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_record["event_json"], ), ) conn.commit() finally: cursor.close() # type: ignore[no-untyped-call] - events = self.list_events(session_id) - for event in events: - if event["id"] == event_id: - return event + def create_event_and_update_state( + self, event_record: "EventRecord", session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically insert an event and update the session's durable state. - msg = f"Failed to retrieve created event {event_id}" - raise RuntimeError(msg) + The event insert and state update are executed within a single + connection and committed together. If either statement fails the + transaction is rolled back so the two writes remain consistent. + + Args: + event_record: Event record to store. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot (``temp:`` keys already + stripped by the service layer). + """ + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES (?, ?, ?, ?, ?) + """ + update_sql = f""" + UPDATE {self._session_table} + SET state = ?, update_time = CURRENT_TIMESTAMP + WHERE id = ? + """ + state_json = self._serialize_state(state) + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute( + insert_sql, + ( + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_record["event_json"], + ), + ) + cursor.execute(update_sql, (state_json, session_id)) + conn.commit() + except Exception: + with contextlib.suppress(Exception): + conn.rollback() + raise + finally: + cursor.close() # type: ignore[no-untyped-call] def list_events(self, session_id: str) -> "list[EventRecord]": """List events for a session ordered by timestamp. @@ -799,14 +804,11 @@ def list_events(self, session_id: str) -> "list[EventRecord]": Notes: Uses index on (session_id, timestamp ASC). - JSON fields deserialized from JSON strings. - Converts INTEGER booleans to Python bool. + Returns the 5-column EventRecord (session_id, invocation_id, + author, timestamp, event_json). """ sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} WHERE session_id = ? ORDER BY timestamp ASC @@ -821,24 +823,11 @@ def list_events(self, session_id: str) -> "list[EventRecord]": return [ EventRecord( - id=row[0], - session_id=row[1], - app_name=row[2], - user_id=row[3], - invocation_id=row[4], - author=row[5], - actions=bytes(row[6]) if row[6] is not None else b"", - long_running_tool_ids_json=row[7], - branch=row[8], - timestamp=row[9], - content=self._deserialize_json_field(row[10]), - grounding_metadata=self._deserialize_json_field(row[11]), - custom_metadata=self._deserialize_json_field(row[12]), - partial=self._from_int_bool(row[13]), - turn_complete=self._from_int_bool(row[14]), - interrupted=self._from_int_bool(row[15]), - error_code=row[16], - error_message=row[17], + session_id=row[0], + invocation_id=row[1], + author=row[2], + timestamp=row[3], + event_json=str(row[4]) if row[4] is not None else "{}", ) for row in rows ] @@ -850,34 +839,6 @@ def list_events(self, session_id: str) -> "list[EventRecord]": return [] raise - @staticmethod - def _to_int_bool(value: "bool | None") -> "int | None": - """Convert Python boolean to INTEGER (0/1). - - Args: - value: Python boolean value or None. - - Returns: - 1 for True, 0 for False, None for None. - """ - if value is None: - return None - return 1 if value else 0 - - @staticmethod - def _from_int_bool(value: "int | None") -> "bool | None": - """Convert INTEGER to Python boolean. - - Args: - value: INTEGER value (0, 1, or None). - - Returns: - Python boolean or None. - """ - if value is None: - return None - return bool(value) - class AdbcADKMemoryStore(BaseSyncADKMemoryStore["AdbcConfig"]): """ADBC synchronous ADK memory store for Arrow Database Connectivity.""" From 3e239fec86a3790c64d7878e152e682b6159e0d3 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 20 Mar 2026 23:33:00 +0000 Subject: [PATCH 09/23] feat(adk): rebuild sqlite family stores for new contract --- sqlspec/adapters/aiosqlite/adk/store.py | 235 ++++++++++------------- sqlspec/adapters/sqlite/adk/store.py | 242 +++++++++++------------- 2 files changed, 215 insertions(+), 262 deletions(-) diff --git a/sqlspec/adapters/aiosqlite/adk/store.py b/sqlspec/adapters/aiosqlite/adk/store.py index 7a63095f1..f68805297 100644 --- a/sqlspec/adapters/aiosqlite/adk/store.py +++ b/sqlspec/adapters/aiosqlite/adk/store.py @@ -52,34 +52,6 @@ def _julian_to_datetime(julian: float) -> datetime: return datetime.fromtimestamp(timestamp, tz=timezone.utc) -def _to_sqlite_bool(value: "bool | None") -> "int | None": - """Convert Python bool to SQLite INTEGER. - - Args: - value: Boolean value or None. - - Returns: - 1 for True, 0 for False, None for None. - """ - if value is None: - return None - return 1 if value else 0 - - -def _from_sqlite_bool(value: "int | None") -> "bool | None": - """Convert SQLite INTEGER to Python bool. - - Args: - value: Integer value (0/1) or None. - - Returns: - True for 1, False for 0, None for None. - """ - if value is None: - return None - return bool(value) - - class AiosqliteADKStore(BaseAsyncADKStore["AiosqliteConfig"]): """Aiosqlite ADK store using asynchronous SQLite driver. @@ -88,10 +60,11 @@ class AiosqliteADKStore(BaseAsyncADKStore["AiosqliteConfig"]): Provides: - Session state management with JSON storage (as TEXT) - - Event history tracking with BLOB-serialized actions + - Event history tracking with full-event JSON storage - Julian Day timestamps (REAL) for efficient date operations - Foreign key constraints with cascade delete - - Efficient upserts using INSERT OR REPLACE + - Atomic event+state writes via append_event_and_update_state + - PRAGMA optimization profile for file-based databases Args: config: AiosqliteConfig with extension_config["adk"] settings. @@ -114,9 +87,8 @@ class AiosqliteADKStore(BaseAsyncADKStore["AiosqliteConfig"]): Notes: - JSON stored as TEXT with SQLSpec serializers (msgspec/orjson/stdlib) - - BOOLEAN as INTEGER (0/1, with None for NULL) - Timestamps as REAL (Julian day: julianday('now')) - - BLOB for pre-serialized actions from Google ADK + - Full event stored as JSON TEXT in event_data column - PRAGMA foreign_keys = ON (enable per connection) - Configuration is read from config.extension_config["adk"] """ @@ -136,6 +108,22 @@ def __init__(self, config: "AiosqliteConfig") -> None: """ super().__init__(config) + async def _apply_pragmas(self, connection: Any) -> None: + """Apply PRAGMA optimization profile for this connection. + + Args: + connection: Aiosqlite connection. + + Notes: + Enables foreign keys and applies performance PRAGMAs. + For file-based databases, adds cache_size, mmap_size, + and journal_size_limit optimizations. + """ + await connection.execute("PRAGMA foreign_keys = ON") + await connection.execute("PRAGMA cache_size = -64000") + await connection.execute("PRAGMA mmap_size = 30000000") + await connection.execute("PRAGMA journal_size_limit = 67108864") + async def _get_create_sessions_table_sql(self) -> str: """Get SQLite CREATE TABLE SQL for sessions. @@ -170,9 +158,8 @@ async def _get_create_events_table_sql(self) -> str: SQL statement to create adk_events table with indexes. Notes: - - TEXT for IDs, strings, and JSON content - - BLOB for pickled actions - - INTEGER for booleans (0/1/NULL) + - TEXT for IDs and indexed scalars + - TEXT for full event JSON (event_data) - REAL for Julian Day timestamps - Foreign key to sessions with CASCADE delete - Index on (session_id, timestamp ASC) @@ -181,22 +168,10 @@ async def _get_create_events_table_sql(self) -> str: CREATE TABLE IF NOT EXISTS {self._events_table} ( id TEXT PRIMARY KEY, session_id TEXT NOT NULL, - app_name TEXT NOT NULL, - user_id TEXT NOT NULL, - invocation_id TEXT NOT NULL, - author TEXT NOT NULL, - actions BLOB NOT NULL, - long_running_tool_ids_json TEXT, - branch TEXT, + invocation_id TEXT, + author TEXT, timestamp REAL NOT NULL, - content TEXT, - grounding_metadata TEXT, - custom_metadata TEXT, - partial INTEGER, - turn_complete INTEGER, - interrupted INTEGER, - error_code TEXT, - error_message TEXT, + event_data TEXT NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session @@ -215,21 +190,10 @@ def _get_drop_tables_sql(self) -> "list[str]": """ return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - async def _enable_foreign_keys(self, connection: Any) -> None: - """Enable foreign key constraints for this connection. - - Args: - connection: Aiosqlite connection. - - Notes: - SQLite requires PRAGMA foreign_keys = ON per connection. - """ - await connection.execute("PRAGMA foreign_keys = ON") - async def create_tables(self) -> None: """Create both sessions and events tables if they don't exist.""" async with self._config.provide_session() as driver: - await self._enable_foreign_keys(driver.connection) + await self._apply_pragmas(driver.connection) await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) @@ -250,11 +214,11 @@ async def create_session( Notes: Uses Julian Day for create_time and update_time. - State is JSON-serialized before insertion. + State is always JSON-serialized (empty dict becomes '{}', never NULL). """ now = datetime.now(timezone.utc) now_julian = _datetime_to_julian(now) - state_json = to_json(state) if state else None + state_json = to_json(state) params: tuple[Any, ...] if self._owner_id_column_name: @@ -272,7 +236,7 @@ async def create_session( params = (session_id, app_name, user_id, state_json, now_julian, now_julian) async with self._config.provide_connection() as conn: - await self._enable_foreign_keys(conn) + await self._apply_pragmas(conn) await conn.execute(sql, params) await conn.commit() @@ -300,7 +264,7 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": """ async with self._config.provide_connection() as conn: - await self._enable_foreign_keys(conn) + await self._apply_pragmas(conn) cursor = await conn.execute(sql, (session_id,)) row = await cursor.fetchone() @@ -326,9 +290,10 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") - Notes: This replaces the entire state dictionary. Updates update_time to current Julian Day. + Empty dict is serialized as '{}', never NULL. """ now_julian = _datetime_to_julian(datetime.now(timezone.utc)) - state_json = to_json(state) if state else None + state_json = to_json(state) sql = f""" UPDATE {self._session_table} @@ -337,7 +302,7 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") - """ async with self._config.provide_connection() as conn: - await self._enable_foreign_keys(conn) + await self._apply_pragmas(conn) await conn.execute(sql, (state_json, now_julian, session_id)) await conn.commit() @@ -372,7 +337,7 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis params = (app_name, user_id) async with self._config.provide_connection() as conn: - await self._enable_foreign_keys(conn) + await self._apply_pragmas(conn) cursor = await conn.execute(sql, params) rows = await cursor.fetchall() @@ -400,7 +365,7 @@ async def delete_session(self, session_id: str) -> None: sql = f"DELETE FROM {self._session_table} WHERE id = ?" async with self._config.provide_connection() as conn: - await self._enable_foreign_keys(conn) + await self._apply_pragmas(conn) await conn.execute(sql, (session_id,)) await conn.commit() @@ -408,63 +373,88 @@ async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session. Args: - event_record: Event record to store. + event_record: Event record with 5 keys: session_id, invocation_id, + author, timestamp, event_json. Notes: Uses Julian Day for timestamp. - JSON fields are serialized to TEXT. - Boolean fields converted to INTEGER (0/1/NULL). + event_json dict is serialized to TEXT as event_data column. """ - timestamp_julian = _datetime_to_julian(event_record["timestamp"]) + import uuid - content_json = to_json(event_record.get("content")) if event_record.get("content") else None - grounding_metadata_json = ( - to_json(event_record.get("grounding_metadata")) if event_record.get("grounding_metadata") else None - ) - custom_metadata_json = ( - to_json(event_record.get("custom_metadata")) if event_record.get("custom_metadata") else None - ) - - partial_int = _to_sqlite_bool(event_record.get("partial")) - turn_complete_int = _to_sqlite_bool(event_record.get("turn_complete")) - interrupted_int = _to_sqlite_bool(event_record.get("interrupted")) + timestamp_julian = _datetime_to_julian(event_record["timestamp"]) + event_data_json = to_json(event_record["event_json"]) + event_id = str(uuid.uuid4()) sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? - ) + id, session_id, invocation_id, author, timestamp, event_data + ) VALUES (?, ?, ?, ?, ?, ?) """ async with self._config.provide_connection() as conn: - await self._enable_foreign_keys(conn) + await self._apply_pragmas(conn) await conn.execute( sql, ( - event_record["id"], + event_id, + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + timestamp_julian, + event_data_json, + ), + ) + await conn.commit() + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state. + + Inserts the event and updates the session state + update_time in a + single transaction. Both operations succeed or fail together. + + Args: + event_record: Event record to store. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot (temp: keys already + stripped by the service layer). + """ + import uuid + + timestamp_julian = _datetime_to_julian(event_record["timestamp"]) + event_data_json = to_json(event_record["event_json"]) + now_julian = _datetime_to_julian(datetime.now(timezone.utc)) + state_json = to_json(state) + event_id = str(uuid.uuid4()) + + insert_sql = f""" + INSERT INTO {self._events_table} ( + id, session_id, invocation_id, author, timestamp, event_data + ) VALUES (?, ?, ?, ?, ?, ?) + """ + + update_sql = f""" + UPDATE {self._session_table} + SET state = ?, update_time = ? + WHERE id = ? + """ + + async with self._config.provide_connection() as conn: + await self._apply_pragmas(conn) + await conn.execute( + insert_sql, + ( + event_id, event_record["session_id"], - event_record["app_name"], - event_record["user_id"], event_record["invocation_id"], event_record["author"], - event_record["actions"], - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), timestamp_julian, - content_json, - grounding_metadata_json, - custom_metadata_json, - partial_int, - turn_complete_int, - interrupted_int, - event_record.get("error_code"), - event_record.get("error_message"), + event_data_json, ), ) + await conn.execute(update_sql, (state_json, now_julian, session_id)) await conn.commit() async def get_events( @@ -482,8 +472,7 @@ async def get_events( Notes: Uses index on (session_id, timestamp ASC). - Parses JSON fields and converts BLOB actions to bytes. - Converts INTEGER booleans back to bool/None. + Parses event_data TEXT back to dict for event_json field. """ where_clauses = ["session_id = ?"] params: list[Any] = [session_id] @@ -496,40 +485,24 @@ async def get_events( limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT id, session_id, invocation_id, author, timestamp, event_data FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} """ async with self._config.provide_connection() as conn: - await self._enable_foreign_keys(conn) + await self._apply_pragmas(conn) cursor = await conn.execute(sql, params) rows = await cursor.fetchall() return [ EventRecord( - id=row[0], session_id=row[1], - app_name=row[2], - user_id=row[3], - invocation_id=row[4], - author=row[5], - actions=bytes(row[6]), - long_running_tool_ids_json=row[7], - branch=row[8], - timestamp=_julian_to_datetime(row[9]), - content=from_json(row[10]) if row[10] else None, - grounding_metadata=from_json(row[11]) if row[11] else None, - custom_metadata=from_json(row[12]) if row[12] else None, - partial=_from_sqlite_bool(row[13]), - turn_complete=_from_sqlite_bool(row[14]), - interrupted=_from_sqlite_bool(row[15]), - error_code=row[16], - error_message=row[17], + invocation_id=row[2], + author=row[3], + timestamp=_julian_to_datetime(row[4]), + event_json=from_json(row[5]) if row[5] else {}, ) for row in rows ] diff --git a/sqlspec/adapters/sqlite/adk/store.py b/sqlspec/adapters/sqlite/adk/store.py index ee3376d9e..6cfc14399 100644 --- a/sqlspec/adapters/sqlite/adk/store.py +++ b/sqlspec/adapters/sqlite/adk/store.py @@ -58,34 +58,6 @@ def _julian_to_datetime(julian: float) -> datetime: return datetime.fromtimestamp(timestamp, tz=timezone.utc) -def _to_sqlite_bool(value: "bool | None") -> "int | None": - """Convert Python bool to SQLite INTEGER. - - Args: - value: Boolean value or None. - - Returns: - 1 for True, 0 for False, None for None. - """ - if value is None: - return None - return 1 if value else 0 - - -def _from_sqlite_bool(value: "int | None") -> "bool | None": - """Convert SQLite INTEGER to Python bool. - - Args: - value: Integer value (0/1) or None. - - Returns: - True for 1, False for 0, None for None. - """ - if value is None: - return None - return bool(value) - - class SqliteADKStore(BaseAsyncADKStore["SqliteConfig"]): """SQLite ADK store using synchronous SQLite driver. @@ -95,10 +67,11 @@ class SqliteADKStore(BaseAsyncADKStore["SqliteConfig"]): Provides: - Session state management with JSON storage (as TEXT) - - Event history tracking with BLOB-serialized actions + - Event history tracking with full-event JSON storage - Julian Day timestamps (REAL) for efficient date operations - Foreign key constraints with cascade delete - - Efficient upserts using INSERT OR REPLACE + - Atomic event+state writes via append_event_and_update_state + - PRAGMA optimization profile for file-based databases Args: config: SqliteConfig instance with extension_config["adk"] settings. @@ -122,9 +95,8 @@ class SqliteADKStore(BaseAsyncADKStore["SqliteConfig"]): Notes: - JSON stored as TEXT with SQLSpec serializers (msgspec/orjson/stdlib) - - BOOLEAN as INTEGER (0/1, with None for NULL) - Timestamps as REAL (Julian day: julianday('now')) - - BLOB for pre-serialized actions from Google ADK + - Full event stored as JSON TEXT in event_data column - PRAGMA foreign_keys = ON (enable per connection) - Configuration is read from config.extension_config["adk"] """ @@ -145,6 +117,22 @@ def __init__(self, config: "SqliteConfig") -> None: """ super().__init__(config) + def _apply_pragmas(self, connection: Any) -> None: + """Apply PRAGMA optimization profile for this connection. + + Args: + connection: SQLite connection. + + Notes: + Enables foreign keys and applies performance PRAGMAs. + For file-based databases, adds cache_size, mmap_size, + and journal_size_limit optimizations. + """ + connection.execute("PRAGMA foreign_keys = ON") + connection.execute("PRAGMA cache_size = -64000") + connection.execute("PRAGMA mmap_size = 30000000") + connection.execute("PRAGMA journal_size_limit = 67108864") + async def _get_create_sessions_table_sql(self) -> str: """Get SQLite CREATE TABLE SQL for sessions. @@ -184,9 +172,8 @@ async def _get_create_events_table_sql(self) -> str: SQL statement to create adk_events table with indexes. Notes: - - TEXT for IDs, strings, and JSON content - - BLOB for pickled actions - - INTEGER for booleans (0/1/NULL) + - TEXT for IDs and indexed scalars + - TEXT for full event JSON (event_data) - REAL for Julian Day timestamps - Foreign key to sessions with CASCADE delete - Index on (session_id, timestamp ASC) @@ -195,22 +182,10 @@ async def _get_create_events_table_sql(self) -> str: CREATE TABLE IF NOT EXISTS {self._events_table} ( id TEXT PRIMARY KEY, session_id TEXT NOT NULL, - app_name TEXT NOT NULL, - user_id TEXT NOT NULL, - invocation_id TEXT NOT NULL, - author TEXT NOT NULL, - actions BLOB NOT NULL, - long_running_tool_ids_json TEXT, - branch TEXT, + invocation_id TEXT, + author TEXT, timestamp REAL NOT NULL, - content TEXT, - grounding_metadata TEXT, - custom_metadata TEXT, - partial INTEGER, - turn_complete INTEGER, - interrupted INTEGER, - error_code TEXT, - error_message TEXT, + event_data TEXT NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session @@ -229,21 +204,10 @@ def _get_drop_tables_sql(self) -> "list[str]": """ return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - def _enable_foreign_keys(self, connection: Any) -> None: - """Enable foreign key constraints for this connection. - - Args: - connection: SQLite connection. - - Notes: - SQLite requires PRAGMA foreign_keys = ON per connection. - """ - connection.execute("PRAGMA foreign_keys = ON") - def _create_tables(self) -> None: """Synchronous implementation of create_tables.""" with self._config.provide_session() as driver: - driver.connection.execute("PRAGMA foreign_keys = ON") + self._apply_pragmas(driver.connection) driver.execute_script(run_(self._get_create_sessions_table_sql)()) driver.execute_script(run_(self._get_create_events_table_sql)()) @@ -257,7 +221,7 @@ def _create_session( """Synchronous implementation of create_session.""" now = datetime.now(timezone.utc) now_julian = _datetime_to_julian(now) - state_json = to_json(state) if state else None + state_json = to_json(state) params: tuple[Any, ...] if self._owner_id_column_name: @@ -275,7 +239,7 @@ def _create_session( params = (session_id, app_name, user_id, state_json, now_julian, now_julian) with self._config.provide_connection() as conn: - self._enable_foreign_keys(conn) + self._apply_pragmas(conn) conn.execute(sql, params) conn.commit() @@ -300,7 +264,7 @@ async def create_session( Notes: Uses Julian Day for create_time and update_time. - State is JSON-serialized before insertion. + State is always JSON-serialized (empty dict becomes '{}', never NULL). If owner_id_column is configured, owner_id is inserted into that column. """ return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) @@ -314,7 +278,7 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": """ with self._config.provide_connection() as conn: - self._enable_foreign_keys(conn) + self._apply_pragmas(conn) cursor = conn.execute(sql, (session_id,)) row = cursor.fetchone() @@ -348,7 +312,7 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: """Synchronous implementation of update_session_state.""" now_julian = _datetime_to_julian(datetime.now(timezone.utc)) - state_json = to_json(state) if state else None + state_json = to_json(state) sql = f""" UPDATE {self._session_table} @@ -357,7 +321,7 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non """ with self._config.provide_connection() as conn: - self._enable_foreign_keys(conn) + self._apply_pragmas(conn) conn.execute(sql, (state_json, now_julian, session_id)) conn.commit() @@ -371,6 +335,7 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") - Notes: This replaces the entire state dictionary. Updates update_time to current Julian Day. + Empty dict is serialized as '{}', never NULL. """ await async_(self._update_session_state)(session_id, state) @@ -394,7 +359,7 @@ def _list_sessions(self, app_name: str, user_id: "str | None") -> "list[SessionR params = (app_name, user_id) with self._config.provide_connection() as conn: - self._enable_foreign_keys(conn) + self._apply_pragmas(conn) cursor = conn.execute(sql, params) rows = cursor.fetchall() @@ -430,7 +395,7 @@ def _delete_session(self, session_id: str) -> None: sql = f"DELETE FROM {self._session_table} WHERE id = ?" with self._config.provide_connection() as conn: - self._enable_foreign_keys(conn) + self._apply_pragmas(conn) conn.execute(sql, (session_id,)) conn.commit() @@ -448,53 +413,29 @@ async def delete_session(self, session_id: str) -> None: def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" timestamp_julian = _datetime_to_julian(event_record["timestamp"]) - - content_json = to_json(event_record.get("content")) if event_record.get("content") else None - grounding_metadata_json = ( - to_json(event_record.get("grounding_metadata")) if event_record.get("grounding_metadata") else None - ) - custom_metadata_json = ( - to_json(event_record.get("custom_metadata")) if event_record.get("custom_metadata") else None - ) - - partial_int = _to_sqlite_bool(event_record.get("partial")) - turn_complete_int = _to_sqlite_bool(event_record.get("turn_complete")) - interrupted_int = _to_sqlite_bool(event_record.get("interrupted")) + event_data_json = to_json(event_record["event_json"]) sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? - ) + id, session_id, invocation_id, author, timestamp, event_data + ) VALUES (?, ?, ?, ?, ?, ?) """ + import uuid + + event_id = str(uuid.uuid4()) + with self._config.provide_connection() as conn: - self._enable_foreign_keys(conn) + self._apply_pragmas(conn) conn.execute( sql, ( - event_record["id"], + event_id, event_record["session_id"], - event_record["app_name"], - event_record["user_id"], event_record["invocation_id"], event_record["author"], - event_record["actions"], - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), timestamp_julian, - content_json, - grounding_metadata_json, - custom_metadata_json, - partial_int, - turn_complete_int, - interrupted_int, - event_record.get("error_code"), - event_record.get("error_message"), + event_data_json, ), ) conn.commit() @@ -503,15 +444,71 @@ async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session. Args: - event_record: Event record to store. + event_record: Event record with 5 keys: session_id, invocation_id, + author, timestamp, event_json. Notes: Uses Julian Day for timestamp. - JSON fields are serialized to TEXT. - Boolean fields converted to INTEGER (0/1/NULL). + event_json dict is serialized to TEXT as event_data column. """ await async_(self._append_event)(event_record) + def _append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Synchronous implementation of append_event_and_update_state.""" + import uuid + + timestamp_julian = _datetime_to_julian(event_record["timestamp"]) + event_data_json = to_json(event_record["event_json"]) + now_julian = _datetime_to_julian(datetime.now(timezone.utc)) + state_json = to_json(state) + event_id = str(uuid.uuid4()) + + insert_sql = f""" + INSERT INTO {self._events_table} ( + id, session_id, invocation_id, author, timestamp, event_data + ) VALUES (?, ?, ?, ?, ?, ?) + """ + + update_sql = f""" + UPDATE {self._session_table} + SET state = ?, update_time = ? + WHERE id = ? + """ + + with self._config.provide_connection() as conn: + self._apply_pragmas(conn) + conn.execute( + insert_sql, + ( + event_id, + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + timestamp_julian, + event_data_json, + ), + ) + conn.execute(update_sql, (state_json, now_julian, session_id)) + conn.commit() + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state. + + Inserts the event and updates the session state + update_time in a + single transaction. Both operations succeed or fail together. + + Args: + event_record: Event record to store. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot (temp: keys already + stripped by the service layer). + """ + await async_(self._append_event_and_update_state)(event_record, session_id, state) + def _get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": @@ -527,40 +524,24 @@ def _get_events( limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT id, session_id, invocation_id, author, timestamp, event_data FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} """ with self._config.provide_connection() as conn: - self._enable_foreign_keys(conn) + self._apply_pragmas(conn) cursor = conn.execute(sql, params) rows = cursor.fetchall() return [ EventRecord( - id=row[0], session_id=row[1], - app_name=row[2], - user_id=row[3], - invocation_id=row[4], - author=row[5], - actions=bytes(row[6]), - long_running_tool_ids_json=row[7], - branch=row[8], - timestamp=_julian_to_datetime(row[9]), - content=from_json(row[10]) if row[10] else None, - grounding_metadata=from_json(row[11]) if row[11] else None, - custom_metadata=from_json(row[12]) if row[12] else None, - partial=_from_sqlite_bool(row[13]), - turn_complete=_from_sqlite_bool(row[14]), - interrupted=_from_sqlite_bool(row[15]), - error_code=row[16], - error_message=row[17], + invocation_id=row[2], + author=row[3], + timestamp=_julian_to_datetime(row[4]), + event_json=from_json(row[5]) if row[5] else {}, ) for row in rows ] @@ -580,8 +561,7 @@ async def get_events( Notes: Uses index on (session_id, timestamp ASC). - Parses JSON fields and converts BLOB actions to bytes. - Converts INTEGER booleans back to bool/None. + Parses event_data TEXT back to dict for event_json field. """ return await async_(self._get_events)(session_id, after_timestamp, limit) From 59b548d004dd639fd8e800ab599db44c6f19f56a Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 20 Mar 2026 23:33:21 +0000 Subject: [PATCH 10/23] feat(adk): rebuild mysql family stores for new contract Rebuild asyncmy, pymysql, and mysqlconnector ADK session/event stores for the clean-break EventRecord contract (5 columns: session_id, invocation_id, author, timestamp, event_json as native JSON). - Events table DDL reduced from 17 columns + BLOB to 5 columns + JSON - Async stores implement append_event_and_update_state() for atomic event+state persistence in a single transaction - Sync stores implement create_event(), create_event_and_update_state(), and list_events() to satisfy BaseSyncADKStore abstract methods - mysqlconnector extracts shared DDL helpers (_mysql_sessions_ddl, _mysql_events_ddl) to reduce duplication between async and sync stores - All adapters use %s parameter style consistently - Memory stores left unchanged --- sqlspec/adapters/asyncmy/adk/store.py | 210 +++------ sqlspec/adapters/mysqlconnector/adk/store.py | 466 ++++++++++--------- sqlspec/adapters/pymysql/adk/store.py | 207 ++++---- 3 files changed, 433 insertions(+), 450 deletions(-) diff --git a/sqlspec/adapters/asyncmy/adk/store.py b/sqlspec/adapters/asyncmy/adk/store.py index ff74e6851..446defc78 100644 --- a/sqlspec/adapters/asyncmy/adk/store.py +++ b/sqlspec/adapters/asyncmy/adk/store.py @@ -27,36 +27,16 @@ class AsyncmyADKStore(BaseAsyncADKStore["AsyncmyConfig"]): Implements session and event storage for Google Agent Development Kit using MySQL/MariaDB via the AsyncMy driver. Provides: - Session state management with JSON storage - - Event history tracking with BLOB-serialized actions + - Full-event JSON storage (single ``event_json`` column) + - Atomic event-append + state-update in one transaction - Microsecond-precision timestamps - Foreign key constraints with cascade delete - Efficient upserts using ON DUPLICATE KEY UPDATE - Args: - config: AsyncmyConfig with extension_config["adk"] settings. - - Example: - from sqlspec.adapters.asyncmy import AsyncmyConfig - from sqlspec.adapters.asyncmy.adk import AsyncmyADKStore - - config = AsyncmyConfig( - connection_config={"host": "localhost", ...}, - extension_config={ - "adk": { - "session_table": "my_sessions", - "events_table": "my_events", - "owner_id_column": "tenant_id BIGINT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE" - } - } - ) - store = AsyncmyADKStore(config) - await store.ensure_tables() - Notes: - MySQL JSON type used (not JSONB) - requires MySQL 5.7.8+ - TIMESTAMP(6) provides microsecond precision - InnoDB engine required for foreign key support - - State merging handled at application level - Configuration is read from config.extension_config["adk"] """ @@ -67,12 +47,6 @@ def __init__(self, config: "AsyncmyConfig") -> None: Args: config: AsyncmyConfig instance. - - Notes: - Configuration is read from config.extension_config["adk"]: - - session_table: Sessions table name (default: "adk_sessions") - - events_table: Events table name (default: "adk_events") - - owner_id_column: Optional owner FK column DDL (default: None) """ super().__init__(config) @@ -88,10 +62,6 @@ def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]" Returns: Tuple of (column_definition, foreign_key_constraint) - - Example: - Input: "tenant_id BIGINT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE" - Output: ("tenant_id BIGINT NOT NULL", "FOREIGN KEY (tenant_id) REFERENCES tenants(id) ON DELETE CASCADE") """ references_match = re.search(r"\s+REFERENCES\s+(.+)", column_ddl, re.IGNORECASE) @@ -110,16 +80,6 @@ async def _get_create_sessions_table_sql(self) -> str: Returns: SQL statement to create adk_sessions table with indexes. - - Notes: - - VARCHAR(128) for IDs and names (sufficient for UUIDs and app names) - - JSON type for state storage (MySQL 5.7.8+) - - TIMESTAMP(6) with microsecond precision - - AUTO-UPDATE on update_time - - Composite index on (app_name, user_id) for listing - - Index on update_time DESC for recent session queries - - Optional owner ID column for multi-tenancy - - MySQL requires explicit FOREIGN KEY syntax (inline REFERENCES is ignored) """ owner_id_col = "" fk_constraint = "" @@ -151,34 +111,18 @@ async def _get_create_events_table_sql(self) -> str: SQL statement to create adk_events table with indexes. Notes: - - VARCHAR sizes: id(128), session_id(128), invocation_id(256), author(256), - branch(256), error_code(256), error_message(1024) - - BLOB for pickled actions (up to 64KB) - - JSON for content, grounding_metadata, custom_metadata, long_running_tool_ids_json - - BOOLEAN for partial, turn_complete, interrupted - - Foreign key to sessions with CASCADE delete - - Index on (session_id, timestamp ASC) for ordered event retrieval + Post clean-break schema: 5 columns only. + - session_id, invocation_id, author: indexed scalars + - timestamp: microsecond-precision TIMESTAMP + - event_json: full Event as native JSON """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, - actions BLOB NOT NULL, - long_running_tool_ids_json JSON, - branch VARCHAR(256), + author VARCHAR(128) NOT NULL, timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - content JSON, - grounding_metadata JSON, - custom_metadata JSON, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE, INDEX idx_{self._events_table}_session (session_id, timestamp ASC) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci @@ -189,10 +133,6 @@ def _get_drop_tables_sql(self) -> "list[str]": Returns: List of SQL statements to drop tables and indexes. - - Notes: - Order matters: drop events table (child) before sessions (parent). - MySQL automatically drops indexes when dropping tables. """ return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] @@ -216,11 +156,6 @@ async def create_session( Returns: Created session record. - - Notes: - Uses INSERT with UTC_TIMESTAMP(6) for create_time and update_time. - State is JSON-serialized before insertion. - If owner_id_column is configured, owner_id must be provided. """ state_json = to_json(state) @@ -252,10 +187,6 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": Returns: Session record or None if not found. - - Notes: - MySQL returns datetime objects for TIMESTAMP columns. - JSON is parsed from database storage. """ sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -292,10 +223,6 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") - Args: session_id: Session identifier. state: New state dictionary (replaces existing state). - - Notes: - This replaces the entire state dictionary. - Uses update_time auto-update trigger. """ state_json = to_json(state) @@ -314,9 +241,6 @@ async def delete_session(self, session_id: str) -> None: Args: session_id: Session identifier. - - Notes: - Foreign key constraint ensures events are cascade-deleted. """ sql = f"DELETE FROM {self._session_table} WHERE id = %s" @@ -333,9 +257,6 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis Returns: List of session records ordered by update_time DESC. - - Notes: - Uses composite index on (app_name, user_id) when user_id is provided. """ if user_id is None: sql = f""" @@ -379,55 +300,72 @@ async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session. Args: - event_record: Event record to store. - - Notes: - Uses UTC_TIMESTAMP(6) for timestamp if not provided. - JSON fields are serialized before insertion. + event_record: Event record with 5 keys (session_id, invocation_id, + author, timestamp, event_json). """ - content_json = to_json(event_record.get("content")) if event_record.get("content") else None - grounding_metadata_json = ( - to_json(event_record.get("grounding_metadata")) if event_record.get("grounding_metadata") else None - ) - custom_metadata_json = ( - to_json(event_record.get("custom_metadata")) if event_record.get("custom_metadata") else None - ) + event_json = event_record["event_json"] + event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s - ) + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) """ async with self._config.provide_connection() as conn, conn.cursor() as cursor: await cursor.execute( sql, ( - event_record["id"], event_record["session_id"], - event_record["app_name"], - event_record["user_id"], event_record["invocation_id"], event_record["author"], - event_record["actions"], - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), event_record["timestamp"], - content_json, - grounding_metadata_json, - custom_metadata_json, - event_record.get("partial"), - event_record.get("turn_complete"), - event_record.get("interrupted"), - event_record.get("error_code"), - event_record.get("error_message"), + event_json_str, + ), + ) + await conn.commit() + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state. + + The event insert and state update succeed together or fail together + within a single transaction. + + Args: + event_record: Event record to store. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot. + """ + event_json = event_record["event_json"] + event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + state_json = to_json(state) + + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) + """ + + update_sql = f""" + UPDATE {self._session_table} + SET state = %s + WHERE id = %s + """ + + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute( + insert_sql, + ( + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_json_str, ), ) + await cursor.execute(update_sql, (state_json, session_id)) await conn.commit() async def get_events( @@ -442,10 +380,6 @@ async def get_events( Returns: List of event records ordered by timestamp ASC. - - Notes: - Uses index on (session_id, timestamp ASC). - Parses JSON fields and converts BLOB actions to bytes. """ where_clauses = ["session_id = %s"] params: list[Any] = [session_id] @@ -458,10 +392,7 @@ async def get_events( limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -474,24 +405,11 @@ async def get_events( return [ EventRecord( - id=row[0], - session_id=row[1], - app_name=row[2], - user_id=row[3], - invocation_id=row[4], - author=row[5], - actions=bytes(row[6]), - long_running_tool_ids_json=row[7], - branch=row[8], - timestamp=row[9], - content=from_json(row[10]) if row[10] and isinstance(row[10], str) else row[10], - grounding_metadata=from_json(row[11]) if row[11] and isinstance(row[11], str) else row[11], - custom_metadata=from_json(row[12]) if row[12] and isinstance(row[12], str) else row[12], - partial=row[13], - turn_complete=row[14], - interrupted=row[15], - error_code=row[16], - error_message=row[17], + session_id=row[0], + invocation_id=row[1], + author=row[2], + timestamp=row[3], + event_json=from_json(row[4]) if isinstance(row[4], str) else row[4], ) for row in rows ] diff --git a/sqlspec/adapters/mysqlconnector/adk/store.py b/sqlspec/adapters/mysqlconnector/adk/store.py index 3aed6258e..2c3079688 100644 --- a/sqlspec/adapters/mysqlconnector/adk/store.py +++ b/sqlspec/adapters/mysqlconnector/adk/store.py @@ -38,8 +38,61 @@ def _parse_owner_id_column_for_mysql(column_ddl: str) -> "tuple[str, str]": return (col_def, fk_constraint) +def _mysql_sessions_ddl(session_table: str, owner_id_column_ddl: "str | None") -> str: + """Generate shared MySQL sessions CREATE TABLE DDL.""" + owner_id_col = "" + fk_constraint = "" + + if owner_id_column_ddl: + col_def, fk_def = _parse_owner_id_column_for_mysql(owner_id_column_ddl) + owner_id_col = f"{col_def}," + if fk_def: + fk_constraint = f",\n {fk_def}" + + return f""" + CREATE TABLE IF NOT EXISTS {session_table} ( + id VARCHAR(128) PRIMARY KEY, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + {owner_id_col} + state JSON NOT NULL, + create_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), + INDEX idx_{session_table}_app_user (app_name, user_id), + INDEX idx_{session_table}_update_time (update_time DESC){fk_constraint} + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_events_ddl(events_table: str, session_table: str) -> str: + """Generate shared MySQL events CREATE TABLE DDL (post clean-break, 5 columns).""" + return f""" + CREATE TABLE IF NOT EXISTS {events_table} ( + session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256) NOT NULL, + author VARCHAR(128) NOT NULL, + timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + event_json JSON NOT NULL, + FOREIGN KEY (session_id) REFERENCES {session_table}(id) ON DELETE CASCADE, + INDEX idx_{events_table}_session (session_id, timestamp ASC) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + class MysqlConnectorAsyncADKStore(BaseAsyncADKStore["MysqlConnectorAsyncConfig"]): - """MySQL/MariaDB ADK store using mysql-connector async driver.""" + """MySQL/MariaDB ADK store using mysql-connector async driver. + + Provides: + - Session state management with JSON storage + - Full-event JSON storage (single ``event_json`` column) + - Atomic event-append + state-update in one transaction + - Microsecond-precision timestamps + - Foreign key constraints with cascade delete + + Notes: + - Uses ``cast()`` extensively because mysql-connector returns ``Any`` types + - Configuration is read from config.extension_config["adk"] + """ __slots__ = () @@ -50,54 +103,10 @@ def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]" return _parse_owner_id_column_for_mysql(column_ddl) async def _get_create_sessions_table_sql(self) -> str: - owner_id_col = "" - fk_constraint = "" - - if self._owner_id_column_ddl: - col_def, fk_def = self._parse_owner_id_column_for_mysql(self._owner_id_column_ddl) - owner_id_col = f"{col_def}," - if fk_def: - fk_constraint = f",\n {fk_def}" - - return f""" - CREATE TABLE IF NOT EXISTS {self._session_table} ( - id VARCHAR(128) PRIMARY KEY, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - {owner_id_col} - state JSON NOT NULL, - create_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), - INDEX idx_{self._session_table}_app_user (app_name, user_id), - INDEX idx_{self._session_table}_update_time (update_time DESC){fk_constraint} - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ + return _mysql_sessions_ddl(self._session_table, self._owner_id_column_ddl) async def _get_create_events_table_sql(self) -> str: - return f""" - CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, - session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, - actions BLOB NOT NULL, - long_running_tool_ids_json JSON, - branch VARCHAR(256), - timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - content JSON, - grounding_metadata JSON, - custom_metadata JSON, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), - FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE, - INDEX idx_{self._events_table}_session (session_id, timestamp ASC) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ + return _mysql_events_ddl(self._events_table, self._session_table) def _get_drop_tables_sql(self) -> "list[str]": return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] @@ -242,23 +251,19 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis raise async def append_event(self, event_record: EventRecord) -> None: - content_json = to_json(event_record.get("content")) if event_record.get("content") else None - grounding_metadata_json = ( - to_json(event_record.get("grounding_metadata")) if event_record.get("grounding_metadata") else None - ) - custom_metadata_json = ( - to_json(event_record.get("custom_metadata")) if event_record.get("custom_metadata") else None - ) + """Append an event to a session. + + Args: + event_record: Event record with 5 keys (session_id, invocation_id, + author, timestamp, event_json). + """ + event_json = event_record["event_json"] + event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s - ) + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) """ async with self._config.provide_connection() as conn: @@ -267,26 +272,60 @@ async def append_event(self, event_record: EventRecord) -> None: await cursor.execute( sql, ( - event_record["id"], event_record["session_id"], - event_record["app_name"], - event_record["user_id"], event_record["invocation_id"], event_record["author"], - event_record["actions"], - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), event_record["timestamp"], - content_json, - grounding_metadata_json, - custom_metadata_json, - event_record.get("partial"), - event_record.get("turn_complete"), - event_record.get("interrupted"), - event_record.get("error_code"), - event_record.get("error_message"), + event_json_str, + ), + ) + finally: + await cursor.close() + await conn.commit() + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state. + + The event insert and state update succeed together or fail together + within a single transaction. + + Args: + event_record: Event record to store. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot. + """ + event_json = event_record["event_json"] + event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + state_json = to_json(state) + + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) + """ + + update_sql = f""" + UPDATE {self._session_table} + SET state = %s + WHERE id = %s + """ + + async with self._config.provide_connection() as conn: + cursor = await conn.cursor() + try: + await cursor.execute( + insert_sql, + ( + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_json_str, ), ) + await cursor.execute(update_sql, (state_json, session_id)) finally: await cursor.close() await conn.commit() @@ -294,6 +333,16 @@ async def append_event(self, event_record: EventRecord) -> None: async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": + """Get events for a session. + + Args: + session_id: Session identifier. + after_timestamp: Only return events after this time. + limit: Maximum number of events to return. + + Returns: + List of event records ordered by timestamp ASC. + """ where_clauses = ["session_id = %s"] params: list[Any] = [session_id] @@ -305,10 +354,7 @@ async def get_events( limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -325,30 +371,11 @@ async def get_events( return [ EventRecord( - id=cast("str", row[0]), - session_id=cast("str", row[1]), - app_name=cast("str", row[2]), - user_id=cast("str", row[3]), - invocation_id=cast("str", row[4]), - author=cast("str", row[5]), - actions=bytes(cast("bytes", row[6])), - long_running_tool_ids_json=cast("str | None", row[7]), - branch=cast("str | None", row[8]), - timestamp=cast("datetime", row[9]), - content=from_json(row[10]) - if row[10] and isinstance(row[10], str) - else cast("dict[str, Any] | None", row[10]), - grounding_metadata=from_json(row[11]) - if row[11] and isinstance(row[11], str) - else cast("dict[str, Any] | None", row[11]), - custom_metadata=from_json(row[12]) - if row[12] and isinstance(row[12], str) - else cast("dict[str, Any] | None", row[12]), - partial=cast("bool | None", row[13]), - turn_complete=cast("bool | None", row[14]), - interrupted=cast("bool | None", row[15]), - error_code=cast("str | None", row[16]), - error_message=cast("str | None", row[17]), + session_id=cast("str", row[0]), + invocation_id=cast("str", row[1]), + author=cast("str", row[2]), + timestamp=cast("datetime", row[3]), + event_json=from_json(row[4]) if isinstance(row[4], str) else cast("dict[str, Any]", row[4]), ) for row in rows ] @@ -359,7 +386,19 @@ async def get_events( class MysqlConnectorSyncADKStore(BaseSyncADKStore["MysqlConnectorSyncConfig"]): - """MySQL/MariaDB ADK store using mysql-connector sync driver.""" + """MySQL/MariaDB ADK store using mysql-connector sync driver. + + Provides: + - Session state management with JSON storage + - Full-event JSON storage (single ``event_json`` column) + - Atomic event-create + state-update in one transaction + - Microsecond-precision timestamps + - Foreign key constraints with cascade delete + + Notes: + - Uses ``cast()`` extensively because mysql-connector returns ``Any`` types + - Configuration is read from config.extension_config["adk"] + """ __slots__ = () @@ -370,54 +409,10 @@ def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]" return _parse_owner_id_column_for_mysql(column_ddl) def _get_create_sessions_table_sql(self) -> str: - owner_id_col = "" - fk_constraint = "" - - if self._owner_id_column_ddl: - col_def, fk_def = self._parse_owner_id_column_for_mysql(self._owner_id_column_ddl) - owner_id_col = f"{col_def}," - if fk_def: - fk_constraint = f",\n {fk_def}" - - return f""" - CREATE TABLE IF NOT EXISTS {self._session_table} ( - id VARCHAR(128) PRIMARY KEY, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - {owner_id_col} - state JSON NOT NULL, - create_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), - INDEX idx_{self._session_table}_app_user (app_name, user_id), - INDEX idx_{self._session_table}_update_time (update_time DESC){fk_constraint} - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ + return _mysql_sessions_ddl(self._session_table, self._owner_id_column_ddl) def _get_create_events_table_sql(self) -> str: - return f""" - CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, - session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, - actions BLOB NOT NULL, - long_running_tool_ids_json JSON, - branch VARCHAR(256), - timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - content JSON, - grounding_metadata JSON, - custom_metadata JSON, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), - FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE, - INDEX idx_{self._events_table}_session (session_id, timestamp ASC) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ + return _mysql_events_ddl(self._events_table, self._session_table) def _get_drop_tables_sql(self) -> "list[str]": return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] @@ -565,24 +560,43 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess return [] raise - def append_event(self, event_record: EventRecord) -> None: - content_json = to_json(event_record.get("content")) if event_record.get("content") else None - grounding_metadata_json = ( - to_json(event_record.get("grounding_metadata")) if event_record.get("grounding_metadata") else None - ) - custom_metadata_json = ( - to_json(event_record.get("custom_metadata")) if event_record.get("custom_metadata") else None - ) + def create_event( + self, + event_id: str, + session_id: str, + app_name: str, + user_id: str, + author: "str | None" = None, + actions: "bytes | None" = None, + content: "dict[str, Any] | None" = None, + **kwargs: Any, + ) -> EventRecord: + """Create a new event. + + Args: + event_id: Unique event identifier (unused in new schema, kept for contract). + session_id: Session identifier. + app_name: Application name (unused in new schema, kept for contract). + user_id: User identifier (unused in new schema, kept for contract). + author: Event author. + actions: Unused in new contract (kept for interface compatibility). + content: Event content dictionary. + **kwargs: Additional fields including invocation_id, timestamp, event_json. + + Returns: + Created event record. + """ + from datetime import datetime, timezone + + invocation_id = kwargs.get("invocation_id", "") + timestamp = kwargs.get("timestamp", datetime.now(tz=timezone.utc)) + event_json = kwargs.get("event_json", content or {}) + event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s - ) + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) """ with self._config.provide_connection() as conn: @@ -590,89 +604,99 @@ def append_event(self, event_record: EventRecord) -> None: try: cursor.execute( sql, + (session_id, invocation_id, author or "", timestamp, event_json_str), + ) + finally: + cursor.close() + conn.commit() + + return EventRecord( + session_id=session_id, + invocation_id=invocation_id, + author=author or "", + timestamp=timestamp, + event_json=event_json, + ) + + def create_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically create an event and update the session's durable state. + + The event insert and state update succeed together or fail together + within a single transaction. + + Args: + event_record: Event record to store. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot. + """ + event_json = event_record["event_json"] + event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + state_json = to_json(state) + + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) + """ + + update_sql = f""" + UPDATE {self._session_table} + SET state = %s + WHERE id = %s + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute( + insert_sql, ( - event_record["id"], event_record["session_id"], - event_record["app_name"], - event_record["user_id"], event_record["invocation_id"], event_record["author"], - event_record["actions"], - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), event_record["timestamp"], - content_json, - grounding_metadata_json, - custom_metadata_json, - event_record.get("partial"), - event_record.get("turn_complete"), - event_record.get("interrupted"), - event_record.get("error_code"), - event_record.get("error_message"), + event_json_str, ), ) + cursor.execute(update_sql, (state_json, session_id)) finally: cursor.close() conn.commit() - def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - where_clauses = ["session_id = %s"] - params: list[Any] = [session_id] - - if after_timestamp is not None: - where_clauses.append("timestamp > %s") - params.append(after_timestamp) + def list_events(self, session_id: str) -> "list[EventRecord]": + """List events for a session ordered by timestamp. - where_clause = " AND ".join(where_clauses) - limit_clause = f" LIMIT {limit}" if limit else "" + Args: + session_id: Session identifier. + Returns: + List of event records ordered by timestamp ASC. + """ sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} - WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + WHERE session_id = %s + ORDER BY timestamp ASC """ try: with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute(sql, params) + cursor.execute(sql, (session_id,)) rows = cursor.fetchall() finally: cursor.close() return [ EventRecord( - id=cast("str", row[0]), - session_id=cast("str", row[1]), - app_name=cast("str", row[2]), - user_id=cast("str", row[3]), - invocation_id=cast("str", row[4]), - author=cast("str", row[5]), - actions=bytes(cast("bytes", row[6])), - long_running_tool_ids_json=cast("str | None", row[7]), - branch=cast("str | None", row[8]), - timestamp=cast("datetime", row[9]), - content=from_json(row[10]) - if row[10] and isinstance(row[10], str) - else cast("dict[str, Any] | None", row[10]), - grounding_metadata=from_json(row[11]) - if row[11] and isinstance(row[11], str) - else cast("dict[str, Any] | None", row[11]), - custom_metadata=from_json(row[12]) - if row[12] and isinstance(row[12], str) - else cast("dict[str, Any] | None", row[12]), - partial=cast("bool | None", row[13]), - turn_complete=cast("bool | None", row[14]), - interrupted=cast("bool | None", row[15]), - error_code=cast("str | None", row[16]), - error_message=cast("str | None", row[17]), + session_id=cast("str", row[0]), + invocation_id=cast("str", row[1]), + author=cast("str", row[2]), + timestamp=cast("datetime", row[3]), + event_json=from_json(row[4]) if isinstance(row[4], str) else cast("dict[str, Any]", row[4]), ) for row in rows ] diff --git a/sqlspec/adapters/pymysql/adk/store.py b/sqlspec/adapters/pymysql/adk/store.py index e30765c7f..a73d3f638 100644 --- a/sqlspec/adapters/pymysql/adk/store.py +++ b/sqlspec/adapters/pymysql/adk/store.py @@ -10,8 +10,6 @@ from sqlspec.utils.serializers import from_json, to_json if TYPE_CHECKING: - from datetime import datetime - from sqlspec.adapters.pymysql.config import PyMysqlConfig from sqlspec.extensions.adk import MemoryRecord @@ -34,7 +32,22 @@ def _parse_owner_id_column_for_mysql(column_ddl: str) -> "tuple[str, str]": class PyMysqlADKStore(BaseSyncADKStore["PyMysqlConfig"]): - """MySQL/MariaDB ADK store using PyMySQL.""" + """MySQL/MariaDB ADK store using PyMySQL. + + Implements session and event storage for Google Agent Development Kit + using MySQL/MariaDB via the PyMySQL sync driver. Provides: + - Session state management with JSON storage + - Full-event JSON storage (single ``event_json`` column) + - Atomic event-create + state-update in one transaction + - Microsecond-precision timestamps + - Foreign key constraints with cascade delete + + Notes: + - MySQL JSON type used - requires MySQL 5.7.8+ + - TIMESTAMP(6) provides microsecond precision + - InnoDB engine required for foreign key support + - Configuration is read from config.extension_config["adk"] + """ __slots__ = () @@ -69,26 +82,17 @@ def _get_create_sessions_table_sql(self) -> str: """ def _get_create_events_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for events. + + Post clean-break schema: 5 columns only. + """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, - actions BLOB NOT NULL, - long_running_tool_ids_json JSON, - branch VARCHAR(256), + author VARCHAR(128) NOT NULL, timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - content JSON, - grounding_metadata JSON, - custom_metadata JSON, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE, INDEX idx_{self._events_table}_session (session_id, timestamp ASC) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci @@ -240,24 +244,45 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess return [] raise - def append_event(self, event_record: EventRecord) -> None: - content_json = to_json(event_record.get("content")) if event_record.get("content") else None - grounding_metadata_json = ( - to_json(event_record.get("grounding_metadata")) if event_record.get("grounding_metadata") else None - ) - custom_metadata_json = ( - to_json(event_record.get("custom_metadata")) if event_record.get("custom_metadata") else None - ) + def create_event( + self, + event_id: str, + session_id: str, + app_name: str, + user_id: str, + author: "str | None" = None, + actions: "bytes | None" = None, + content: "dict[str, Any] | None" = None, + **kwargs: Any, + ) -> EventRecord: + """Create a new event. + + Constructs an EventRecord from the provided fields and inserts it. + + Args: + event_id: Unique event identifier (unused in new schema, kept for contract). + session_id: Session identifier. + app_name: Application name (unused in new schema, kept for contract). + user_id: User identifier (unused in new schema, kept for contract). + author: Event author. + actions: Unused in new contract (kept for interface compatibility). + content: Event content dictionary. + **kwargs: Additional fields including invocation_id, timestamp, event_json. + + Returns: + Created event record. + """ + from datetime import datetime, timezone + + invocation_id = kwargs.get("invocation_id", "") + timestamp = kwargs.get("timestamp", datetime.now(tz=timezone.utc)) + event_json = kwargs.get("event_json", content or {}) + event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s - ) + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) """ with self._config.provide_connection() as conn: @@ -265,83 +290,99 @@ def append_event(self, event_record: EventRecord) -> None: try: cursor.execute( sql, + (session_id, invocation_id, author or "", timestamp, event_json_str), + ) + finally: + cursor.close() + conn.commit() + + return EventRecord( + session_id=session_id, + invocation_id=invocation_id, + author=author or "", + timestamp=timestamp, + event_json=event_json, + ) + + def create_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically create an event and update the session's durable state. + + The event insert and state update succeed together or fail together + within a single transaction. + + Args: + event_record: Event record to store. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot. + """ + event_json = event_record["event_json"] + event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + state_json = to_json(state) + + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) + """ + + update_sql = f""" + UPDATE {self._session_table} + SET state = %s + WHERE id = %s + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute( + insert_sql, ( - event_record["id"], event_record["session_id"], - event_record["app_name"], - event_record["user_id"], event_record["invocation_id"], event_record["author"], - event_record["actions"], - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), event_record["timestamp"], - content_json, - grounding_metadata_json, - custom_metadata_json, - event_record.get("partial"), - event_record.get("turn_complete"), - event_record.get("interrupted"), - event_record.get("error_code"), - event_record.get("error_message"), + event_json_str, ), ) + cursor.execute(update_sql, (state_json, session_id)) finally: cursor.close() conn.commit() - def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - where_clauses = ["session_id = %s"] - params: list[Any] = [session_id] + def list_events(self, session_id: str) -> "list[EventRecord]": + """List events for a session ordered by timestamp. - if after_timestamp is not None: - where_clauses.append("timestamp > %s") - params.append(after_timestamp) - - where_clause = " AND ".join(where_clauses) - limit_clause = f" LIMIT {limit}" if limit else "" + Args: + session_id: Session identifier. + Returns: + List of event records ordered by timestamp ASC. + """ sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} - WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + WHERE session_id = %s + ORDER BY timestamp ASC """ try: with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute(sql, params) + cursor.execute(sql, (session_id,)) rows = cursor.fetchall() finally: cursor.close() return [ EventRecord( - id=row[0], - session_id=row[1], - app_name=row[2], - user_id=row[3], - invocation_id=row[4], - author=row[5], - actions=bytes(row[6]), - long_running_tool_ids_json=row[7], - branch=row[8], - timestamp=row[9], - content=from_json(row[10]) if row[10] and isinstance(row[10], str) else row[10], - grounding_metadata=from_json(row[11]) if row[11] and isinstance(row[11], str) else row[11], - custom_metadata=from_json(row[12]) if row[12] and isinstance(row[12], str) else row[12], - partial=row[13], - turn_complete=row[14], - interrupted=row[15], - error_code=row[16], - error_message=row[17], + session_id=row[0], + invocation_id=row[1], + author=row[2], + timestamp=row[3], + event_json=from_json(row[4]) if isinstance(row[4], str) else row[4], ) for row in rows ] From 777348eddf0a680385f637d8b8038d00a595d42f Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 20 Mar 2026 23:33:57 +0000 Subject: [PATCH 11/23] feat(adk): rebuild oracledb stores for new contract --- sqlspec/adapters/oracledb/adk/store.py | 552 ++++++++----------------- 1 file changed, 182 insertions(+), 370 deletions(-) diff --git a/sqlspec/adapters/oracledb/adk/store.py b/sqlspec/adapters/oracledb/adk/store.py index 4aa9b2811..2e7d994b7 100644 --- a/sqlspec/adapters/oracledb/adk/store.py +++ b/sqlspec/adapters/oracledb/adk/store.py @@ -93,43 +93,18 @@ def storage_type_from_version(version_info: "OracleVersionInfo | None") -> JSONS return _storage_type_from_version(version_info) -def _to_oracle_bool(value: "bool | None") -> "int | None": - """Convert Python boolean to Oracle NUMBER(1). +def _event_json_column_ddl(storage_type: JSONStorageType) -> str: + """Return the DDL fragment for the event_json column. - Args: - value: Python boolean value or None. - - Returns: - 1 for True, 0 for False, None for None. + For JSON_NATIVE (Oracle 21c+) we use the native JSON type. + For older versions we use CLOB since event_json is a JSON text string. + BLOB_JSON gets a CHECK constraint; BLOB_PLAIN does not. """ - if value is None: - return None - return 1 if value else 0 - - -def _from_oracle_bool(value: "int | None") -> "bool | None": - """Convert Oracle NUMBER(1) to Python boolean. - - Args: - value: Oracle NUMBER value (0, 1, or None). - - Returns: - Python boolean or None. - """ - if value is None: - return None - return bool(value) - - -def _coerce_bytes_payload(value: Any) -> bytes: - """Coerce a LOB payload into bytes.""" - if value is None: - return b"" - if isinstance(value, bytes): - return value - if isinstance(value, str): - return value.encode("utf-8") - return str(value).encode("utf-8") + if storage_type == JSONStorageType.JSON_NATIVE: + return "event_json JSON NOT NULL" + if storage_type == JSONStorageType.BLOB_JSON: + return "event_json CLOB CHECK (event_json IS JSON) NOT NULL" + return "event_json CLOB NOT NULL" class OracleAsyncADKStore(BaseAsyncADKStore["OracleAsyncConfig"]): @@ -138,7 +113,8 @@ class OracleAsyncADKStore(BaseAsyncADKStore["OracleAsyncConfig"]): Implements session and event storage for Google Agent Development Kit using Oracle Database via the python-oracledb async driver. Provides: - Session state management with version-specific JSON storage - - Event history tracking with BLOB-serialized actions + - Full-fidelity event storage via ``event_json`` column + - Atomic ``append_event_and_update_state`` for durable session mutations - TIMESTAMP WITH TIME ZONE for timezone-aware timestamps - Foreign key constraints with cascade delete - Efficient upserts using MERGE statement @@ -146,28 +122,10 @@ class OracleAsyncADKStore(BaseAsyncADKStore["OracleAsyncConfig"]): Args: config: OracleAsyncConfig with extension_config["adk"] settings. - Example: - from sqlspec.adapters.oracledb import OracleAsyncConfig - from sqlspec.adapters.oracledb.adk import OracleAsyncADKStore - - config = OracleAsyncConfig( - connection_config={"dsn": "oracle://..."}, - extension_config={ - "adk": { - "session_table": "my_sessions", - "events_table": "my_events", - "owner_id_column": "tenant_id NUMBER(10) REFERENCES tenants(id)" - } - } - ) - store = OracleAsyncADKStore(config) - await store.ensure_tables() - Notes: - JSON storage type detected based on Oracle version (21c+, 12c+, legacy) - - BLOB for pre-serialized actions from Google ADK + - event_json stored as JSON (21c+) or CLOB (older versions) - TIMESTAMP WITH TIME ZONE for timezone-aware timestamps - - NUMBER(1) for booleans (0/1/NULL) - Named parameters using :param_name - State merging handled at application level - owner_id_column supports NUMBER, VARCHAR2, RAW for Oracle FK types @@ -223,10 +181,9 @@ async def _detect_json_storage_type(self) -> JSONStorageType: Notes: Queries product_component_version to determine Oracle version. - Oracle 21c+ with compatible >= 20: Native JSON type - - Oracle 12c+: BLOB with IS JSON constraint (preferred) - - Oracle 11g and earlier: BLOB without constraint + - Oracle 12c+: CLOB with IS JSON constraint + - Oracle 11g and earlier: CLOB without constraint - BLOB is preferred over CLOB for 12c+ as per Oracle recommendations. Result is cached in self._json_storage_type. """ if self._json_storage_type is not None: @@ -296,55 +253,27 @@ async def _deserialize_state(self, data: Any) -> "dict[str, Any]": return from_json(str(data)) # type: ignore[no-any-return] - async def _serialize_json_field(self, value: Any) -> "str | bytes | None": - """Serialize optional JSON field for event storage. - - Args: - value: Value to serialize (dict or None). - - Returns: - Serialized JSON or None. - """ - if value is None: - return None - - storage_type = await self._detect_json_storage_type() - - if storage_type == JSONStorageType.JSON_NATIVE: - return to_json(value) - - return to_json(value, as_bytes=True) - - async def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": - """Deserialize optional JSON field from database. + async def _read_event_json(self, data: Any) -> str: + """Read event_json from database, handling LOB types. Args: - data: Data from database (may be LOB, str, bytes, dict, or None). + data: Data from database (may be LOB, str, or dict). Returns: - Deserialized dictionary or None. - - Notes: - Oracle JSON type may return dict directly. + JSON string. """ - if data is None: - return None - if is_async_readable(data): data = await data.read() elif is_readable(data): data = data.read() if isinstance(data, dict): - return cast("dict[str, Any]", _coerce_decimal_values(data)) + return to_json(data) if isinstance(data, bytes): - return from_json(data) # type: ignore[no-any-return] + return data.decode("utf-8") - if isinstance(data, str): - return from_json(data) # type: ignore[no-any-return] - - return from_json(str(data)) # type: ignore[no-any-return] + return str(data) def _get_create_sessions_table_sql_for_type(self, storage_type: JSONStorageType) -> str: """Get Oracle CREATE TABLE SQL for sessions with specified storage type. @@ -406,54 +335,27 @@ def _get_create_sessions_table_sql_for_type(self, storage_type: JSONStorageType) def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) -> str: """Get Oracle CREATE TABLE SQL for events with specified storage type. + The events table uses the new 5-column contract: session_id, invocation_id, + author, timestamp, and event_json. The event_json column stores the full + ADK Event as JSON (21c+) or CLOB (older versions). + Args: storage_type: JSON storage type to use. Returns: SQL statement to create adk_events table. """ - if storage_type == JSONStorageType.JSON_NATIVE: - json_columns = """ - content JSON, - grounding_metadata JSON, - custom_metadata JSON, - long_running_tool_ids_json JSON - """ - elif storage_type == JSONStorageType.BLOB_JSON: - json_columns = """ - content BLOB CHECK (content IS JSON), - grounding_metadata BLOB CHECK (grounding_metadata IS JSON), - custom_metadata BLOB CHECK (custom_metadata IS JSON), - long_running_tool_ids_json BLOB CHECK (long_running_tool_ids_json IS JSON) - """ - else: - json_columns = """ - content BLOB, - grounding_metadata BLOB, - custom_metadata BLOB, - long_running_tool_ids_json BLOB - """ - + event_json_col = _event_json_column_ddl(storage_type) inmemory_clause = " INMEMORY PRIORITY HIGH" if self._in_memory else "" return f""" BEGIN EXECUTE IMMEDIATE 'CREATE TABLE {self._events_table} ( - id VARCHAR2(128) PRIMARY KEY, session_id VARCHAR2(128) NOT NULL, - app_name VARCHAR2(128) NOT NULL, - user_id VARCHAR2(128) NOT NULL, - invocation_id VARCHAR2(256), - author VARCHAR2(256), - actions BLOB, - branch VARCHAR2(256), + invocation_id VARCHAR2(256) NOT NULL, + author VARCHAR2(256) NOT NULL, timestamp TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL, - {json_columns}, - partial NUMBER(1), - turn_complete NUMBER(1), - interrupted NUMBER(1), - error_code VARCHAR2(256), - error_message VARCHAR2(1024), + {event_json_col}, CONSTRAINT fk_{self._events_table}_session FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ){inmemory_clause}'; @@ -753,28 +655,14 @@ async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session. Args: - event_record: Event record to store. - - Notes: - Uses SYSTIMESTAMP for timestamp if not provided. - JSON fields are serialized using version-appropriate format. - Boolean fields are converted to NUMBER(1). + event_record: Event record with 5 keys: session_id, invocation_id, + author, timestamp, event_json. """ - content_data = await self._serialize_json_field(event_record.get("content")) - grounding_metadata_data = await self._serialize_json_field(event_record.get("grounding_metadata")) - custom_metadata_data = await self._serialize_json_field(event_record.get("custom_metadata")) - sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + session_id, invocation_id, author, timestamp, event_json ) VALUES ( - :id, :session_id, :app_name, :user_id, :invocation_id, :author, :actions, - :long_running_tool_ids_json, :branch, :timestamp, :content, - :grounding_metadata, :custom_metadata, :partial, :turn_complete, - :interrupted, :error_code, :error_message + :session_id, :invocation_id, :author, :timestamp, :event_json ) """ @@ -783,26 +671,58 @@ async def append_event(self, event_record: EventRecord) -> None: await cursor.execute( sql, { - "id": event_record["id"], "session_id": event_record["session_id"], - "app_name": event_record["app_name"], - "user_id": event_record["user_id"], "invocation_id": event_record["invocation_id"], "author": event_record["author"], - "actions": event_record["actions"], - "long_running_tool_ids_json": event_record.get("long_running_tool_ids_json"), - "branch": event_record.get("branch"), "timestamp": event_record["timestamp"], - "content": content_data, - "grounding_metadata": grounding_metadata_data, - "custom_metadata": custom_metadata_data, - "partial": _to_oracle_bool(event_record.get("partial")), - "turn_complete": _to_oracle_bool(event_record.get("turn_complete")), - "interrupted": _to_oracle_bool(event_record.get("interrupted")), - "error_code": event_record.get("error_code"), - "error_message": event_record.get("error_message"), + "event_json": event_record["event_json"], + }, + ) + await conn.commit() + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state. + + Both the event insert and session state update are executed within a + single transaction so they succeed or fail together. + + Args: + event_record: Event record with 5 keys: session_id, invocation_id, + author, timestamp, event_json. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot (``temp:`` keys already + stripped by the service layer). + """ + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES ( + :session_id, :invocation_id, :author, :timestamp, :event_json + ) + """ + + state_data = await self._serialize_state(state) + update_sql = f""" + UPDATE {self._session_table} + SET state = :state, update_time = SYSTIMESTAMP + WHERE id = :id + """ + + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute( + insert_sql, + { + "session_id": event_record["session_id"], + "invocation_id": event_record["invocation_id"], + "author": event_record["author"], + "timestamp": event_record["timestamp"], + "event_json": event_record["event_json"], }, ) + await cursor.execute(update_sql, {"state": state_data, "id": session_id}) await conn.commit() async def get_events( @@ -817,11 +737,6 @@ async def get_events( Returns: List of event records ordered by timestamp ASC. - - Notes: - Uses index on (session_id, timestamp ASC). - JSON fields deserialized using version-appropriate format. - Converts BLOB actions to bytes and NUMBER(1) booleans to Python bool. """ where_clauses = ["session_id = :session_id"] @@ -837,10 +752,7 @@ async def get_events( limit_clause = f" FETCH FIRST {limit} ROWS ONLY" sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -854,38 +766,15 @@ async def get_events( results = [] for row in rows: - actions_blob = row[6] - if is_async_readable(actions_blob): - actions_data = await actions_blob.read() - elif is_readable(actions_blob): - actions_data = actions_blob.read() - else: - actions_data = actions_blob - - content = await self._deserialize_json_field(row[10]) - grounding_metadata = await self._deserialize_json_field(row[11]) - custom_metadata = await self._deserialize_json_field(row[12]) + event_json_str = await self._read_event_json(row[4]) results.append( EventRecord( - id=row[0], - session_id=row[1], - app_name=row[2], - user_id=row[3], - invocation_id=row[4], - author=row[5], - actions=_coerce_bytes_payload(actions_data), - long_running_tool_ids_json=row[7], - branch=row[8], - timestamp=row[9], - content=content, - grounding_metadata=grounding_metadata, - custom_metadata=custom_metadata, - partial=_from_oracle_bool(row[13]), - turn_complete=_from_oracle_bool(row[14]), - interrupted=_from_oracle_bool(row[15]), - error_code=row[16], - error_message=row[17], + session_id=row[0], + invocation_id=row[1], + author=row[2], + timestamp=row[3], + event_json=event_json_str, ) ) return results @@ -902,7 +791,8 @@ class OracleSyncADKStore(BaseSyncADKStore["OracleSyncConfig"]): Implements session and event storage for Google Agent Development Kit using Oracle Database via the python-oracledb synchronous driver. Provides: - Session state management with version-specific JSON storage - - Event history tracking with BLOB-serialized actions + - Full-fidelity event storage via ``event_json`` column + - Atomic ``create_event_and_update_state`` for durable session mutations - TIMESTAMP WITH TIME ZONE for timezone-aware timestamps - Foreign key constraints with cascade delete - Efficient upserts using MERGE statement @@ -910,28 +800,10 @@ class OracleSyncADKStore(BaseSyncADKStore["OracleSyncConfig"]): Args: config: OracleSyncConfig with extension_config["adk"] settings. - Example: - from sqlspec.adapters.oracledb import OracleSyncConfig - from sqlspec.adapters.oracledb.adk import OracleSyncADKStore - - config = OracleSyncConfig( - connection_config={"dsn": "oracle://..."}, - extension_config={ - "adk": { - "session_table": "my_sessions", - "events_table": "my_events", - "owner_id_column": "account_id NUMBER(19) REFERENCES accounts(id)" - } - } - ) - store = OracleSyncADKStore(config) - store.ensure_tables() - Notes: - JSON storage type detected based on Oracle version (21c+, 12c+, legacy) - - BLOB for pre-serialized actions from Google ADK + - event_json stored as JSON (21c+) or CLOB (older versions) - TIMESTAMP WITH TIME ZONE for timezone-aware timestamps - - NUMBER(1) for booleans (0/1/NULL) - Named parameters using :param_name - State merging handled at application level - owner_id_column supports NUMBER, VARCHAR2, RAW for Oracle FK types @@ -987,10 +859,9 @@ def _detect_json_storage_type(self) -> JSONStorageType: Notes: Queries product_component_version to determine Oracle version. - Oracle 21c+ with compatible >= 20: Native JSON type - - Oracle 12c+: BLOB with IS JSON constraint (preferred) - - Oracle 11g and earlier: BLOB without constraint + - Oracle 12c+: CLOB with IS JSON constraint + - Oracle 11g and earlier: CLOB without constraint - BLOB is preferred over CLOB for 12c+ as per Oracle recommendations. Result is cached in self._json_storage_type. """ if self._json_storage_type is not None: @@ -1058,53 +929,25 @@ def _deserialize_state(self, data: Any) -> "dict[str, Any]": return from_json(str(data)) # type: ignore[no-any-return] - def _serialize_json_field(self, value: Any) -> "str | bytes | None": - """Serialize optional JSON field for event storage. - - Args: - value: Value to serialize (dict or None). - - Returns: - Serialized JSON or None. - """ - if value is None: - return None - - storage_type = self._detect_json_storage_type() - - if storage_type == JSONStorageType.JSON_NATIVE: - return to_json(value) - - return to_json(value, as_bytes=True) - - def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": - """Deserialize optional JSON field from database. + def _read_event_json(self, data: Any) -> str: + """Read event_json from database, handling LOB types. Args: - data: Data from database (may be LOB, str, bytes, dict, or None). + data: Data from database (may be LOB, str, or dict). Returns: - Deserialized dictionary or None. - - Notes: - Oracle JSON type may return dict directly. + JSON string. """ - if data is None: - return None - if is_readable(data): data = data.read() if isinstance(data, dict): - return cast("dict[str, Any]", _coerce_decimal_values(data)) + return to_json(data) if isinstance(data, bytes): - return from_json(data) # type: ignore[no-any-return] + return data.decode("utf-8") - if isinstance(data, str): - return from_json(data) # type: ignore[no-any-return] - - return from_json(str(data)) # type: ignore[no-any-return] + return str(data) def _get_create_sessions_table_sql_for_type(self, storage_type: JSONStorageType) -> str: """Get Oracle CREATE TABLE SQL for sessions with specified storage type. @@ -1166,54 +1009,27 @@ def _get_create_sessions_table_sql_for_type(self, storage_type: JSONStorageType) def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) -> str: """Get Oracle CREATE TABLE SQL for events with specified storage type. + The events table uses the new 5-column contract: session_id, invocation_id, + author, timestamp, and event_json. The event_json column stores the full + ADK Event as JSON (21c+) or CLOB (older versions). + Args: storage_type: JSON storage type to use. Returns: SQL statement to create adk_events table. """ - if storage_type == JSONStorageType.JSON_NATIVE: - json_columns = """ - content JSON, - grounding_metadata JSON, - custom_metadata JSON, - long_running_tool_ids_json JSON - """ - elif storage_type == JSONStorageType.BLOB_JSON: - json_columns = """ - content BLOB CHECK (content IS JSON), - grounding_metadata BLOB CHECK (grounding_metadata IS JSON), - custom_metadata BLOB CHECK (custom_metadata IS JSON), - long_running_tool_ids_json BLOB CHECK (long_running_tool_ids_json IS JSON) - """ - else: - json_columns = """ - content BLOB, - grounding_metadata BLOB, - custom_metadata BLOB, - long_running_tool_ids_json BLOB - """ - + event_json_col = _event_json_column_ddl(storage_type) inmemory_clause = " INMEMORY PRIORITY HIGH" if self._in_memory else "" return f""" BEGIN EXECUTE IMMEDIATE 'CREATE TABLE {self._events_table} ( - id VARCHAR2(128) PRIMARY KEY, session_id VARCHAR2(128) NOT NULL, - app_name VARCHAR2(128) NOT NULL, - user_id VARCHAR2(128) NOT NULL, - invocation_id VARCHAR2(256), - author VARCHAR2(256), - actions BLOB, - branch VARCHAR2(256), + invocation_id VARCHAR2(256) NOT NULL, + author VARCHAR2(256) NOT NULL, timestamp TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL, - {json_columns}, - partial NUMBER(1), - turn_complete NUMBER(1), - interrupted NUMBER(1), - error_code VARCHAR2(256), - error_message VARCHAR2(1024), + {event_json_col}, CONSTRAINT fk_{self._events_table}_session FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ){inmemory_clause}'; @@ -1524,38 +1340,28 @@ def create_event( """Create a new event. Args: - event_id: Unique event identifier. + event_id: Unused (kept for base class compatibility). session_id: Session identifier. - app_name: Application name. - user_id: User identifier. - author: Event author (user/assistant/system). - actions: Pickled actions object. - content: Event content (JSONB/JSON). - **kwargs: Additional optional fields. + app_name: Unused (kept for base class compatibility). + user_id: Unused (kept for base class compatibility). + author: Event author. + actions: Unused (no longer stored). + content: Unused (no longer stored separately). + **kwargs: Must include ``invocation_id``, ``timestamp``, and + ``event_json``. Returns: Created event record. - - Notes: - Uses SYSTIMESTAMP for timestamp if not provided. - JSON fields are serialized using version-appropriate format. - Boolean fields are converted to NUMBER(1). """ - content_data = self._serialize_json_field(content) - grounding_metadata_data = self._serialize_json_field(kwargs.get("grounding_metadata")) - custom_metadata_data = self._serialize_json_field(kwargs.get("custom_metadata")) + event_json: str = kwargs["event_json"] + invocation_id: str = kwargs.get("invocation_id", "") + timestamp = kwargs.get("timestamp") sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + session_id, invocation_id, author, timestamp, event_json ) VALUES ( - :id, :session_id, :app_name, :user_id, :invocation_id, :author, :actions, - :long_running_tool_ids_json, :branch, :timestamp, :content, - :grounding_metadata, :custom_metadata, :partial, :turn_complete, - :interrupted, :error_code, :error_message + :session_id, :invocation_id, :author, :timestamp, :event_json ) """ @@ -1564,35 +1370,67 @@ def create_event( cursor.execute( sql, { - "id": event_id, "session_id": session_id, - "app_name": app_name, - "user_id": user_id, - "invocation_id": kwargs.get("invocation_id"), - "author": author, - "actions": actions, - "long_running_tool_ids_json": kwargs.get("long_running_tool_ids_json"), - "branch": kwargs.get("branch"), - "timestamp": kwargs.get("timestamp"), - "content": content_data, - "grounding_metadata": grounding_metadata_data, - "custom_metadata": custom_metadata_data, - "partial": _to_oracle_bool(kwargs.get("partial")), - "turn_complete": _to_oracle_bool(kwargs.get("turn_complete")), - "interrupted": _to_oracle_bool(kwargs.get("interrupted")), - "error_code": kwargs.get("error_code"), - "error_message": kwargs.get("error_message"), + "invocation_id": invocation_id, + "author": author or "", + "timestamp": timestamp, + "event_json": event_json, }, ) conn.commit() - events = self.list_events(session_id) - for event in events: - if event["id"] == event_id: - return event + return EventRecord( + session_id=session_id, + invocation_id=invocation_id, + author=author or "", + timestamp=timestamp, + event_json=event_json, + ) + + def create_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically create an event and update the session's durable state. - msg = f"Failed to retrieve created event {event_id}" - raise RuntimeError(msg) + Both the event insert and session state update are executed within a + single transaction so they succeed or fail together. + + Args: + event_record: Event record with 5 keys: session_id, invocation_id, + author, timestamp, event_json. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot (``temp:`` keys already + stripped by the service layer). + """ + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES ( + :session_id, :invocation_id, :author, :timestamp, :event_json + ) + """ + + state_data = self._serialize_state(state) + update_sql = f""" + UPDATE {self._session_table} + SET state = :state, update_time = SYSTIMESTAMP + WHERE id = :id + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute( + insert_sql, + { + "session_id": event_record["session_id"], + "invocation_id": event_record["invocation_id"], + "author": event_record["author"], + "timestamp": event_record["timestamp"], + "event_json": event_record["event_json"], + }, + ) + cursor.execute(update_sql, {"state": state_data, "id": session_id}) + conn.commit() def list_events(self, session_id: str) -> "list[EventRecord]": """List events for a session ordered by timestamp. @@ -1602,18 +1440,10 @@ def list_events(self, session_id: str) -> "list[EventRecord]": Returns: List of event records ordered by timestamp ASC. - - Notes: - Uses index on (session_id, timestamp ASC). - JSON fields deserialized using version-appropriate format. - Converts BLOB actions to bytes and NUMBER(1) booleans to Python bool. """ sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} WHERE session_id = :session_id ORDER BY timestamp ASC @@ -1627,33 +1457,15 @@ def list_events(self, session_id: str) -> "list[EventRecord]": results = [] for row in rows: - actions_blob = row[6] - actions_data = actions_blob.read() if is_readable(actions_blob) else actions_blob - - content = self._deserialize_json_field(row[10]) - grounding_metadata = self._deserialize_json_field(row[11]) - custom_metadata = self._deserialize_json_field(row[12]) + event_json_str = self._read_event_json(row[4]) results.append( EventRecord( - id=row[0], - session_id=row[1], - app_name=row[2], - user_id=row[3], - invocation_id=row[4], - author=row[5], - actions=_coerce_bytes_payload(actions_data), - long_running_tool_ids_json=row[7], - branch=row[8], - timestamp=row[9], - content=content, - grounding_metadata=grounding_metadata, - custom_metadata=custom_metadata, - partial=_from_oracle_bool(row[13]), - turn_complete=_from_oracle_bool(row[14]), - interrupted=_from_oracle_bool(row[15]), - error_code=row[16], - error_message=row[17], + session_id=row[0], + invocation_id=row[1], + author=row[2], + timestamp=row[3], + event_json=event_json_str, ) ) return results From 3fa43225e443e9a37fd652e2372b8cda1799e9fe Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 20 Mar 2026 23:34:46 +0000 Subject: [PATCH 12/23] feat(adk): rebuild postgresql family stores for new contract --- sqlspec/adapters/asyncpg/adk/store.py | 371 ++--------- .../adapters/cockroach_asyncpg/adk/store.py | 168 +++-- .../adapters/cockroach_psycopg/adk/store.py | 370 +++++++---- sqlspec/adapters/psqlpy/adk/store.py | 275 ++------ sqlspec/adapters/psycopg/adk/store.py | 599 ++++-------------- 5 files changed, 540 insertions(+), 1243 deletions(-) diff --git a/sqlspec/adapters/asyncpg/adk/store.py b/sqlspec/adapters/asyncpg/adk/store.py index 4c4624a87..999395d51 100644 --- a/sqlspec/adapters/asyncpg/adk/store.py +++ b/sqlspec/adapters/asyncpg/adk/store.py @@ -1,6 +1,6 @@ """AsyncPG ADK store for Google Agent Development Kit session/event storage.""" -from typing import TYPE_CHECKING, Any, Final, cast +from typing import TYPE_CHECKING, Any, Final import asyncpg @@ -21,87 +21,32 @@ class AsyncpgADKStore(BaseAsyncADKStore[AsyncConfigT]): - """PostgreSQL ADK store base class for all PostgreSQL drivers. + """PostgreSQL ADK store using asyncpg driver. Implements session and event storage for Google Agent Development Kit - using PostgreSQL via any PostgreSQL driver (AsyncPG, Psycopg, Psqlpy). - All drivers share the same SQL dialect and parameter style ($1, $2, etc). + using PostgreSQL via asyncpg. Events are stored as a single JSONB blob + (``event_json``) alongside indexed scalar columns for efficient querying. Provides: - - Session state management with JSONB storage and merge operations - - Event history tracking with BYTEA-serialized actions + - Session state management with JSONB storage + - Full-fidelity event storage via ``event_json`` JSONB column + - Atomic ``append_event_and_update_state`` for durable session mutations - Microsecond-precision timestamps with TIMESTAMPTZ - Foreign key constraints with cascade delete - - Efficient upserts using ON CONFLICT - GIN indexes for JSONB queries - HOT updates with FILLFACTOR 80 - - Optional user FK column for multi-tenancy + - Optional owner ID column for multi-tenancy Args: config: PostgreSQL database config with extension_config["adk"] settings. - - Example: - from sqlspec.adapters.asyncpg import AsyncpgConfig - from sqlspec.adapters.asyncpg.adk import AsyncpgADKStore - - config = AsyncpgConfig( - connection_config={"dsn": "postgresql://..."}, - extension_config={ - "adk": { - "session_table": "my_sessions", - "events_table": "my_events", - "owner_id_column": "tenant_id INTEGER NOT NULL REFERENCES tenants(id) ON DELETE CASCADE" - } - } - ) - store = AsyncpgADKStore(config) - await store.ensure_tables() - - Notes: - - PostgreSQL JSONB type used for state (more efficient than JSON) - - AsyncPG automatically converts Python dicts to/from JSONB (no manual serialization) - - TIMESTAMPTZ provides timezone-aware microsecond precision - - State merging uses `state || $1::jsonb` operator for efficiency - - BYTEA for pre-serialized actions from Google ADK (not pickled here) - - GIN index on state for JSONB queries (partial index) - - FILLFACTOR 80 leaves space for HOT updates - - Generic over PostgresConfigT to support all PostgreSQL drivers - - Owner ID column enables multi-tenant isolation with referential integrity - - Configuration is read from config.extension_config["adk"] """ __slots__ = () def __init__(self, config: AsyncConfigT) -> None: - """Initialize AsyncPG ADK store. - - Args: - config: PostgreSQL database config. - - Notes: - Configuration is read from config.extension_config["adk"]: - - session_table: Sessions table name (default: "adk_sessions") - - events_table: Events table name (default: "adk_events") - - owner_id_column: Optional owner FK column DDL (default: None) - """ super().__init__(config) async def _get_create_sessions_table_sql(self) -> str: - """Get PostgreSQL CREATE TABLE SQL for sessions. - - Returns: - SQL statement to create adk_sessions table with indexes. - - Notes: - - VARCHAR(128) for IDs and names (sufficient for UUIDs and app names) - - JSONB type for state storage with default empty object - - TIMESTAMPTZ with microsecond precision - - FILLFACTOR 80 for HOT updates (reduces table bloat) - - Composite index on (app_name, user_id) for listing - - Index on update_time DESC for recent session queries - - Partial GIN index on state for JSONB queries (only non-empty) - - Optional owner ID column for multi-tenancy or owner references - """ owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" @@ -128,61 +73,24 @@ async def _get_create_sessions_table_sql(self) -> str: """ async def _get_create_events_table_sql(self) -> str: - """Get PostgreSQL CREATE TABLE SQL for events. - - Returns: - SQL statement to create adk_events table with indexes. - - Notes: - - VARCHAR sizes: id(128), session_id(128), invocation_id(256), author(256), - branch(256), error_code(256), error_message(1024) - - BYTEA for pickled actions (no size limit) - - JSONB for content, grounding_metadata, custom_metadata, long_running_tool_ids_json - - BOOLEAN for partial, turn_complete, interrupted - - Foreign key to sessions with CASCADE delete - - Index on (session_id, timestamp ASC) for ordered event retrieval - """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256), - author VARCHAR(256), - actions BYTEA, - long_running_tool_ids_json JSONB, - branch VARCHAR(256), + invocation_id VARCHAR(256) NOT NULL, + author VARCHAR(256) NOT NULL, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - content JSONB, - grounding_metadata JSONB, - custom_metadata JSONB, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE - ); + ) WITH (fillfactor = 80); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); """ def _get_drop_tables_sql(self) -> "list[str]": - """Get PostgreSQL DROP TABLE SQL statements. - - Returns: - List of SQL statements to drop tables and indexes. - - Notes: - Order matters: drop events table (child) before sessions (parent). - PostgreSQL automatically drops indexes when dropping tables. - """ return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] async def create_tables(self) -> None: - """Create both sessions and events tables if they don't exist.""" async with self.config.provide_session() as driver: await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) @@ -190,23 +98,6 @@ async def create_tables(self) -> None: async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: - """Create a new session. - - Args: - session_id: Unique session identifier. - app_name: Application name. - user_id: User identifier. - state: Initial session state. - owner_id: Optional owner ID value for owner_id_column (if configured). - - Returns: - Created session record. - - Notes: - Uses CURRENT_TIMESTAMP for create_time and update_time. - State is passed as dict and asyncpg converts to JSONB automatically. - If owner_id_column is configured, owner_id value must be provided. - """ async with self.config.provide_connection() as conn: if self._owner_id_column_name: sql = f""" @@ -225,18 +116,6 @@ async def create_session( return await self.get_session(session_id) # type: ignore[return-value] async def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID. - - Args: - session_id: Session identifier. - - Returns: - Session record or None if not found. - - Notes: - PostgreSQL returns datetime objects for TIMESTAMPTZ columns. - JSONB is automatically parsed by asyncpg. - """ sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} @@ -262,16 +141,6 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": return None async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. - - Args: - session_id: Session identifier. - state: New state dictionary (replaces existing state). - - Notes: - This replaces the entire state dictionary. - Uses CURRENT_TIMESTAMP for update_time. - """ sql = f""" UPDATE {self._session_table} SET state = $1, update_time = CURRENT_TIMESTAMP @@ -282,32 +151,12 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") - await conn.execute(sql, state, session_id) async def delete_session(self, session_id: str) -> None: - """Delete session and all associated events (cascade). - - Args: - session_id: Session identifier. - - Notes: - Foreign key constraint ensures events are cascade-deleted. - """ sql = f"DELETE FROM {self._session_table} WHERE id = $1" async with self.config.provide_connection() as conn: await conn.execute(sql, session_id) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app, optionally filtered by user. - - Args: - app_name: Application name. - user_id: User identifier. If None, lists all sessions for the app. - - Returns: - List of session records ordered by update_time DESC. - - Notes: - Uses composite index on (app_name, user_id) when user_id is provided. - """ if user_id is None: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -344,70 +193,50 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis return [] async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session. - - Args: - event_record: Event record to store. - - Notes: - Uses CURRENT_TIMESTAMP for timestamp if not provided. - JSONB fields are passed as dicts and asyncpg converts automatically. - """ - content_json = event_record.get("content") - grounding_metadata_json = event_record.get("grounding_metadata") - custom_metadata_json = event_record.get("custom_metadata") - sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18 - ) + session_id, invocation_id, author, timestamp, event_json + ) VALUES ($1, $2, $3, $4, $5) """ async with self.config.provide_connection() as conn: await conn.execute( sql, - event_record["id"], event_record["session_id"], - event_record["app_name"], - event_record["user_id"], - event_record.get("invocation_id"), - event_record.get("author"), - event_record.get("actions"), - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), + event_record["invocation_id"], + event_record["author"], event_record["timestamp"], - content_json, - grounding_metadata_json, - custom_metadata_json, - event_record.get("partial"), - event_record.get("turn_complete"), - event_record.get("interrupted"), - event_record.get("error_code"), - event_record.get("error_message"), + event_record["event_json"], ) + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES ($1, $2, $3, $4, $5) + """ + update_sql = f""" + UPDATE {self._session_table} + SET state = $1, update_time = CURRENT_TIMESTAMP + WHERE id = $2 + """ + + async with self.config.provide_connection() as conn, conn.transaction(): + await conn.execute( + insert_sql, + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_record["event_json"], + ) + await conn.execute(update_sql, state, session_id) + async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": - """Get events for a session. - - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. - - Returns: - List of event records ordered by timestamp ASC. - - Notes: - Uses index on (session_id, timestamp ASC). - Parses JSONB fields and converts BYTEA actions to bytes. - """ where_clauses = ["session_id = $1"] params: list[Any] = [session_id] @@ -421,10 +250,7 @@ async def get_events( params.append(limit) sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -436,24 +262,11 @@ async def get_events( return [ EventRecord( - id=row["id"], session_id=row["session_id"], - app_name=row["app_name"], - user_id=row["user_id"], invocation_id=row["invocation_id"], author=row["author"], - actions=bytes(row["actions"]) if row["actions"] else b"", - long_running_tool_ids_json=row["long_running_tool_ids_json"], - branch=row["branch"], timestamp=row["timestamp"], - content=row["content"], - grounding_metadata=row["grounding_metadata"], - custom_metadata=row["custom_metadata"], - partial=row["partial"], - turn_complete=row["turn_complete"], - interrupted=row["interrupted"], - error_code=row["error_code"], - error_message=row["error_message"], + event_json=row["event_json"], ) for row in rows ] @@ -475,66 +288,14 @@ class AsyncpgADKMemoryStore(BaseAsyncADKMemoryStore["AsyncpgConfig"]): Args: config: AsyncpgConfig with extension_config["adk"] settings. - - Example: - from sqlspec.adapters.asyncpg import AsyncpgConfig - from sqlspec.adapters.asyncpg.adk.store import AsyncpgADKMemoryStore - - config = AsyncpgConfig( - connection_config={"dsn": "postgresql://..."}, - extension_config={ - "adk": { - "memory_table": "adk_memory_entries", - "memory_use_fts": True, - "memory_max_results": 20, - } - } - ) - store = AsyncpgADKMemoryStore(config) - await store.ensure_tables() - - Notes: - - JSONB type for content_json and metadata_json - - TIMESTAMPTZ with microsecond precision - - GIN index on content_text tsvector for FTS queries - - Composite index on (app_name, user_id) for filtering - - event_id UNIQUE constraint for deduplication - - Configuration is read from config.extension_config["adk"] """ __slots__ = () def __init__(self, config: "AsyncpgConfig") -> None: - """Initialize AsyncPG ADK memory store. - - Args: - config: AsyncpgConfig instance. - - Notes: - Configuration is read from config.extension_config["adk"]: - - memory_table: Memory table name (default: "adk_memory_entries") - - memory_use_fts: Enable full-text search when supported (default: False) - - memory_max_results: Max search results (default: 20) - - owner_id_column: Optional owner FK column DDL (default: None) - - enable_memory: Whether memory is enabled (default: True) - """ super().__init__(config) async def _get_create_memory_table_sql(self) -> str: - """Get PostgreSQL CREATE TABLE SQL for memory entries. - - Returns: - SQL statement to create memory table with indexes. - - Notes: - - VARCHAR(128) for IDs and names - - JSONB for content and metadata storage - - TIMESTAMPTZ with microsecond precision - - UNIQUE constraint on event_id for deduplication - - Composite index on (app_name, user_id, timestamp DESC) - - GIN index on content_text tsvector for FTS - - Optional owner ID column for multi-tenancy - """ owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" @@ -570,21 +331,9 @@ async def _get_create_memory_table_sql(self) -> str: """ def _get_drop_memory_table_sql(self) -> "list[str]": - """Get PostgreSQL DROP TABLE SQL statements. - - Returns: - List of SQL statements to drop the memory table. - - Notes: - PostgreSQL automatically drops indexes when dropping tables. - """ return [f"DROP TABLE IF EXISTS {self._memory_table}"] async def create_tables(self) -> None: - """Create the memory table and indexes if they don't exist. - - Skips table creation if memory store is disabled. - """ if not self._enabled: return @@ -592,21 +341,6 @@ async def create_tables(self) -> None: await driver.execute_script(await self._get_create_memory_table_sql()) async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication. - - Uses UPSERT pattern (ON CONFLICT DO NOTHING) to skip duplicates - based on event_id unique constraint. - - Args: - entries: List of memory records to insert. - owner_id: Optional owner ID value for owner_id_column (if configured). - - Returns: - Number of entries actually inserted (excludes duplicates). - - Raises: - RuntimeError: If memory store is disabled. - """ if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -673,19 +407,6 @@ async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: " async def search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": - """Search memory entries by text query. - - Uses the configured search strategy (simple ILIKE or FTS). - - Args: - query: Text query to search for. - app_name: Application name to filter by. - user_id: User ID to filter by. - limit: Maximum number of results (defaults to max_results config). - - Returns: - List of memory records. - """ if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -693,6 +414,8 @@ async def search_entries( if not query: return [] + from typing import cast + limit_value = limit or self._max_results if self._use_fts: sql = f""" @@ -717,7 +440,6 @@ async def search_entries( return [cast("MemoryRecord", dict(row)) for row in rows] async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -731,7 +453,6 @@ async def delete_entries_by_session(self, session_id: str) -> int: return 0 async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) diff --git a/sqlspec/adapters/cockroach_asyncpg/adk/store.py b/sqlspec/adapters/cockroach_asyncpg/adk/store.py index c81e5f6cc..36979547e 100644 --- a/sqlspec/adapters/cockroach_asyncpg/adk/store.py +++ b/sqlspec/adapters/cockroach_asyncpg/adk/store.py @@ -7,6 +7,8 @@ from sqlspec.utils.logging import get_logger if TYPE_CHECKING: + from datetime import datetime + from sqlspec.adapters.cockroach_asyncpg.config import CockroachAsyncpgConfig from sqlspec.extensions.adk import MemoryRecord @@ -17,7 +19,19 @@ class CockroachAsyncpgADKStore(BaseAsyncADKStore["CockroachAsyncpgConfig"]): - """CockroachDB ADK store using asyncpg driver.""" + """CockroachDB ADK store using asyncpg driver. + + Implements session and event storage for Google Agent Development Kit + using CockroachDB via asyncpg in PostgreSQL compatibility mode. + Events are stored as a single JSONB blob (``event_json``) alongside + indexed scalar columns for efficient querying. + + CockroachDB-specific differences from native PostgreSQL: + - No FILLFACTOR (CockroachDB uses different storage engine) + - No BRIN indexes (different physical storage layout) + - GIN/Inverted indexes on JSONB are fully supported (v23.1+) + - Native tsvector/tsquery FTS with GIN is supported (v23.1+) + """ __slots__ = () @@ -44,34 +58,28 @@ async def _get_create_sessions_table_sql(self) -> str: CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time ON {self._session_table}(update_time DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state + ON {self._session_table} USING GIN (state) + WHERE state != '{{}}'::jsonb; """ async def _get_create_events_table_sql(self) -> str: return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, invocation_id VARCHAR(256) NOT NULL, author VARCHAR(256) NOT NULL, - actions BYTEA NOT NULL, - long_running_tool_ids_json JSONB, - branch VARCHAR(256), timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - content JSONB, - grounding_metadata JSONB, - custom_metadata JSONB, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); + + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_json + ON {self._events_table} USING GIN (event_json); """ def _get_drop_tables_sql(self) -> "list[str]": @@ -181,72 +189,77 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis async def append_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18 - ) + session_id, invocation_id, author, timestamp, event_json + ) VALUES ($1, $2, $3, $4, $5) """ async with self._config.provide_connection() as conn: await conn.execute( sql, - event_record["id"], event_record["session_id"], - event_record["app_name"], - event_record["user_id"], event_record["invocation_id"], event_record["author"], - event_record["actions"], - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), event_record["timestamp"], - event_record.get("content"), - event_record.get("grounding_metadata"), - event_record.get("custom_metadata"), - event_record.get("partial"), - event_record.get("turn_complete"), - event_record.get("interrupted"), - event_record.get("error_code"), - event_record.get("error_message"), + event_record["event_json"], ) - async def list_events(self, session_id: str) -> "list[EventRecord]": + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES ($1, $2, $3, $4, $5) + """ + update_sql = f""" + UPDATE {self._session_table} + SET state = $1, update_time = CURRENT_TIMESTAMP + WHERE id = $2 + """ + + async with self._config.provide_connection() as conn, conn.transaction(): + await conn.execute( + insert_sql, + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_record["event_json"], + ) + await conn.execute(update_sql, state, session_id) + + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + where_clauses = ["session_id = $1"] + params: list[Any] = [session_id] + + if after_timestamp is not None: + where_clauses.append(f"timestamp > ${len(params) + 1}") + params.append(after_timestamp) + + where_clause = " AND ".join(where_clauses) + limit_clause = f" LIMIT ${len(params) + 1}" if limit else "" + if limit: + params.append(limit) + sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} - WHERE session_id = $1 - ORDER BY timestamp ASC + WHERE {where_clause} + ORDER BY timestamp ASC{limit_clause} """ async with self._config.provide_connection() as conn: - rows = await conn.fetch(sql, session_id) + rows = await conn.fetch(sql, *params) return [ EventRecord( - id=row["id"], session_id=row["session_id"], - app_name=row["app_name"], - user_id=row["user_id"], invocation_id=row["invocation_id"], author=row["author"], - actions=bytes(row["actions"]), - long_running_tool_ids_json=row["long_running_tool_ids_json"], - branch=row["branch"], timestamp=row["timestamp"], - content=row["content"], - grounding_metadata=row["grounding_metadata"], - custom_metadata=row["custom_metadata"], - partial=row["partial"], - turn_complete=row["turn_complete"], - interrupted=row["interrupted"], - error_code=row["error_code"], - error_message=row["error_message"], + event_json=row["event_json"], ) for row in rows ] @@ -265,6 +278,13 @@ async def _get_create_memory_table_sql(self) -> str: if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" + fts_index = "" + if self._use_fts: + fts_index = f""" + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts + ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); + """ + return f""" CREATE TABLE IF NOT EXISTS {self._memory_table} ( id VARCHAR(128) PRIMARY KEY, @@ -285,6 +305,7 @@ async def _get_create_memory_table_sql(self) -> str: CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session ON {self._memory_table}(session_id); + {fts_index} """ def _get_drop_memory_table_sql(self) -> "list[str]": @@ -371,18 +392,27 @@ async def search_entries( return [] effective_limit = limit if limit is not None else self._max_results - if self._use_fts: - logger.debug("CockroachDB full-text search not supported; using simple search") - sql = f""" - SELECT * FROM {self._memory_table} - WHERE app_name = $1 AND user_id = $2 AND content_text ILIKE $3 - ORDER BY timestamp DESC - LIMIT $4 - """ + if self._use_fts: + sql = f""" + SELECT * FROM {self._memory_table} + WHERE app_name = $1 AND user_id = $2 + AND to_tsvector('english', content_text) @@ plainto_tsquery('english', $3) + ORDER BY timestamp DESC + LIMIT $4 + """ + params: tuple[Any, ...] = (app_name, user_id, query, effective_limit) + else: + sql = f""" + SELECT * FROM {self._memory_table} + WHERE app_name = $1 AND user_id = $2 AND content_text ILIKE $3 + ORDER BY timestamp DESC + LIMIT $4 + """ + params = (app_name, user_id, f"%{query}%", effective_limit) async with self._config.provide_connection() as conn: - rows = await conn.fetch(sql, app_name, user_id, f"%{query}%", effective_limit) + rows = await conn.fetch(sql, *params) return [cast("MemoryRecord", dict(row)) for row in rows] @@ -403,8 +433,8 @@ async def delete_entries_older_than(self, days: int) -> int: sql = f""" DELETE FROM {self._memory_table} - WHERE inserted_at < (CURRENT_TIMESTAMP - INTERVAL $1 DAY) + WHERE inserted_at < (CURRENT_TIMESTAMP - INTERVAL '{days} days') """ async with self._config.provide_connection() as conn: - result = await conn.execute(sql, days) + result = await conn.execute(sql) return int(result.split()[-1]) if result else 0 diff --git a/sqlspec/adapters/cockroach_psycopg/adk/store.py b/sqlspec/adapters/cockroach_psycopg/adk/store.py index 97ee76bc6..fa589168b 100644 --- a/sqlspec/adapters/cockroach_psycopg/adk/store.py +++ b/sqlspec/adapters/cockroach_psycopg/adk/store.py @@ -11,6 +11,8 @@ from sqlspec.utils.logging import get_logger if TYPE_CHECKING: + from datetime import datetime + from sqlspec.adapters.cockroach_psycopg.config import CockroachPsycopgAsyncConfig, CockroachPsycopgSyncConfig from sqlspec.extensions.adk import MemoryRecord @@ -59,7 +61,19 @@ def _build_insert_params_with_owner(entry: "MemoryRecord", owner_id: "object | N class CockroachPsycopgAsyncADKStore(BaseAsyncADKStore["CockroachPsycopgAsyncConfig"]): - """CockroachDB ADK store using psycopg async driver.""" + """CockroachDB ADK store using psycopg async driver. + + Implements session and event storage for Google Agent Development Kit + using CockroachDB via psycopg in PostgreSQL compatibility mode. + Events are stored as a single JSONB blob (``event_json``) alongside + indexed scalar columns for efficient querying. + + CockroachDB-specific differences from native PostgreSQL: + - No FILLFACTOR (CockroachDB uses different storage engine) + - SQL strings require ``.encode()`` for cockroach-psycopg driver + - GIN/Inverted indexes on JSONB are fully supported (v23.1+) + - Native tsvector/tsquery FTS with GIN is supported (v23.1+) + """ __slots__ = () @@ -67,7 +81,6 @@ def __init__(self, config: "CockroachPsycopgAsyncConfig") -> None: super().__init__(config) async def _get_create_sessions_table_sql(self) -> str: - """Get CockroachDB CREATE TABLE SQL for sessions.""" owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" @@ -87,35 +100,28 @@ async def _get_create_sessions_table_sql(self) -> str: CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time ON {self._session_table}(update_time DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state + ON {self._session_table} USING GIN (state) + WHERE state != '{{}}'::jsonb; """ async def _get_create_events_table_sql(self) -> str: - """Get CockroachDB CREATE TABLE SQL for events.""" return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, invocation_id VARCHAR(256) NOT NULL, author VARCHAR(256) NOT NULL, - actions BYTEA NOT NULL, - long_running_tool_ids_json JSONB, - branch VARCHAR(256), timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - content JSONB, - grounding_metadata JSONB, - custom_metadata JSONB, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); + + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_json + ON {self._events_table} USING GIN (event_json); """ def _get_drop_tables_sql(self) -> "list[str]": @@ -239,77 +245,91 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis async def append_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s - ) + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) """ + event_json_value = event_record["event_json"] + jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value + async with self._config.provide_connection() as conn, conn.cursor() as cur: await cur.execute( sql.encode(), ( - event_record["id"], event_record["session_id"], - event_record["app_name"], - event_record["user_id"], event_record["invocation_id"], event_record["author"], - event_record["actions"], - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), event_record["timestamp"], - event_record.get("content"), - event_record.get("grounding_metadata"), - event_record.get("custom_metadata"), - event_record.get("partial"), - event_record.get("turn_complete"), - event_record.get("interrupted"), - event_record.get("error_code"), - event_record.get("error_message"), + jsonb_value, ), ) await conn.commit() - async def list_events(self, session_id: str) -> "list[EventRecord]": + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) + """ + update_sql = f""" + UPDATE {self._session_table} + SET state = %s, update_time = CURRENT_TIMESTAMP + WHERE id = %s + """ + + event_json_value = event_record["event_json"] + jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value + + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute( + insert_sql.encode(), + ( + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + jsonb_value, + ), + ) + await cur.execute(update_sql.encode(), (Jsonb(state), session_id)) + await conn.commit() + + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + where_clauses = ["session_id = %s"] + params: list[Any] = [session_id] + + if after_timestamp is not None: + where_clauses.append("timestamp > %s") + params.append(after_timestamp) + + where_clause = " AND ".join(where_clauses) + limit_clause = " LIMIT %s" if limit else "" + if limit: + params.append(limit) + sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} - WHERE session_id = %s - ORDER BY timestamp ASC + WHERE {where_clause} + ORDER BY timestamp ASC{limit_clause} """ try: async with self._config.provide_connection() as conn, conn.cursor() as cur: - await cur.execute(sql.encode(), (session_id,)) + await cur.execute(sql.encode(), tuple(params)) rows = await cur.fetchall() return [ EventRecord( - id=row["id"], session_id=row["session_id"], - app_name=row["app_name"], - user_id=row["user_id"], invocation_id=row["invocation_id"], author=row["author"], - actions=bytes(row["actions"]), - long_running_tool_ids_json=row["long_running_tool_ids_json"], - branch=row["branch"], timestamp=row["timestamp"], - content=row["content"], - grounding_metadata=row["grounding_metadata"], - custom_metadata=row["custom_metadata"], - partial=row["partial"], - turn_complete=row["turn_complete"], - interrupted=row["interrupted"], - error_code=row["error_code"], - error_message=row["error_message"], + event_json=row["event_json"], ) for row in rows ] @@ -318,7 +338,18 @@ async def list_events(self, session_id: str) -> "list[EventRecord]": class CockroachPsycopgSyncADKStore(BaseSyncADKStore["CockroachPsycopgSyncConfig"]): - """CockroachDB ADK store using psycopg sync driver.""" + """CockroachDB ADK store using psycopg sync driver. + + Implements session and event storage for Google Agent Development Kit + using CockroachDB via psycopg in PostgreSQL compatibility mode (sync). + Events are stored as a single JSONB blob (``event_json``) alongside + indexed scalar columns for efficient querying. + + CockroachDB-specific differences from native PostgreSQL: + - No FILLFACTOR (CockroachDB uses different storage engine) + - SQL strings require ``.encode()`` for cockroach-psycopg driver + - GIN/Inverted indexes on JSONB are fully supported (v23.1+) + """ __slots__ = () @@ -345,34 +376,28 @@ def _get_create_sessions_table_sql(self) -> str: CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time ON {self._session_table}(update_time DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state + ON {self._session_table} USING GIN (state) + WHERE state != '{{}}'::jsonb; """ def _get_create_events_table_sql(self) -> str: return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, invocation_id VARCHAR(256) NOT NULL, author VARCHAR(256) NOT NULL, - actions BYTEA NOT NULL, - long_running_tool_ids_json JSONB, - branch VARCHAR(256), timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - content JSONB, - grounding_metadata JSONB, - custom_metadata JSONB, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); + + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_json + ON {self._events_table} USING GIN (event_json); """ def _get_drop_tables_sql(self) -> "list[str]": @@ -493,50 +518,104 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess except errors.UndefinedTable: return [] - def append_event(self, event_record: EventRecord) -> None: + def create_event( + self, + event_id: str, + session_id: str, + app_name: str, + user_id: str, + author: "str | None" = None, + actions: "bytes | None" = None, + content: "dict[str, Any] | None" = None, + **kwargs: Any, + ) -> EventRecord: + """Create a new event using the legacy positional API. + + This method is required by the BaseSyncADKStore contract. For new code, + prefer ``create_event_and_update_state`` which atomically persists the + event and updates session state. + """ + from datetime import datetime, timezone + + event_json: dict[str, Any] = {} + if author is not None: + event_json["author"] = author + if actions is not None: + event_json["actions"] = actions.hex() + if content is not None: + event_json["content"] = content + event_json.update({k: v for k, v in kwargs.items() if v is not None}) + + invocation_id = kwargs.get("invocation_id", "") + ts = kwargs.get("timestamp") or datetime.now(timezone.utc) + sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s - ) + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) + RETURNING session_id, invocation_id, author, timestamp, event_json """ with self._config.provide_connection() as conn, conn.cursor() as cur: cur.execute( sql.encode(), ( - event_record["id"], + session_id, + invocation_id, + author or "", + ts, + Jsonb(event_json), + ), + ) + row = cur.fetchone() + conn.commit() + + if row is None: + msg = f"Failed to create event {event_id}" + raise RuntimeError(msg) + + return EventRecord( + session_id=row["session_id"], + invocation_id=row["invocation_id"], + author=row["author"], + timestamp=row["timestamp"], + event_json=row["event_json"], + ) + + def create_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) + """ + update_sql = f""" + UPDATE {self._session_table} + SET state = %s, update_time = CURRENT_TIMESTAMP + WHERE id = %s + """ + + event_json_value = event_record["event_json"] + jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value + + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute( + insert_sql.encode(), + ( event_record["session_id"], - event_record["app_name"], - event_record["user_id"], event_record["invocation_id"], event_record["author"], - event_record["actions"], - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), event_record["timestamp"], - event_record.get("content"), - event_record.get("grounding_metadata"), - event_record.get("custom_metadata"), - event_record.get("partial"), - event_record.get("turn_complete"), - event_record.get("interrupted"), - event_record.get("error_code"), - event_record.get("error_message"), + jsonb_value, ), ) + cur.execute(update_sql.encode(), (Jsonb(state), session_id)) conn.commit() def list_events(self, session_id: str) -> "list[EventRecord]": sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} WHERE session_id = %s ORDER BY timestamp ASC @@ -549,24 +628,11 @@ def list_events(self, session_id: str) -> "list[EventRecord]": return [ EventRecord( - id=row["id"], session_id=row["session_id"], - app_name=row["app_name"], - user_id=row["user_id"], invocation_id=row["invocation_id"], author=row["author"], - actions=bytes(row["actions"]), - long_running_tool_ids_json=row["long_running_tool_ids_json"], - branch=row["branch"], timestamp=row["timestamp"], - content=row["content"], - grounding_metadata=row["grounding_metadata"], - custom_metadata=row["custom_metadata"], - partial=row["partial"], - turn_complete=row["turn_complete"], - interrupted=row["interrupted"], - error_code=row["error_code"], - error_message=row["error_message"], + event_json=row["event_json"], ) for row in rows ] @@ -587,6 +653,13 @@ async def _get_create_memory_table_sql(self) -> str: if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" + fts_index = "" + if self._use_fts: + fts_index = f""" + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts + ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); + """ + return f""" CREATE TABLE IF NOT EXISTS {self._memory_table} ( id VARCHAR(128) PRIMARY KEY, @@ -607,6 +680,7 @@ async def _get_create_memory_table_sql(self) -> str: CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session ON {self._memory_table}(session_id); + {fts_index} """ def _get_drop_memory_table_sql(self) -> "list[str]": @@ -675,16 +749,25 @@ async def search_entries( return [] effective_limit = limit if limit is not None else self._max_results + if self._use_fts: - logger.debug("CockroachDB full-text search not supported; using simple search") + sql = f""" + SELECT * FROM {self._memory_table} + WHERE app_name = %s AND user_id = %s + AND to_tsvector('english', content_text) @@ plainto_tsquery('english', %s) + ORDER BY timestamp DESC + LIMIT %s + """ + else: + sql = f""" + SELECT * FROM {self._memory_table} + WHERE app_name = %s AND user_id = %s AND content_text ILIKE %s + ORDER BY timestamp DESC + LIMIT %s + """ - sql = f""" - SELECT * FROM {self._memory_table} - WHERE app_name = %s AND user_id = %s AND content_text ILIKE %s - ORDER BY timestamp DESC - LIMIT %s - """ - params = (app_name, user_id, f"%{query}%", effective_limit) + search_param = query if self._use_fts else f"%{query}%" + params = (app_name, user_id, search_param, effective_limit) try: async with self._config.provide_connection() as conn, conn.cursor() as cur: @@ -714,10 +797,10 @@ async def delete_entries_older_than(self, days: int) -> int: sql = f""" DELETE FROM {self._memory_table} - WHERE inserted_at < (CURRENT_TIMESTAMP - INTERVAL %s DAY) + WHERE inserted_at < CURRENT_TIMESTAMP - INTERVAL '{days} days' """ async with self._config.provide_connection() as conn, conn.cursor() as cur: - await cur.execute(sql.encode(), (days,)) + await cur.execute(sql.encode()) await conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 @@ -735,6 +818,13 @@ def _get_create_memory_table_sql(self) -> str: if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" + fts_index = "" + if self._use_fts: + fts_index = f""" + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts + ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); + """ + return f""" CREATE TABLE IF NOT EXISTS {self._memory_table} ( id VARCHAR(128) PRIMARY KEY, @@ -755,6 +845,7 @@ def _get_create_memory_table_sql(self) -> str: CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session ON {self._memory_table}(session_id); + {fts_index} """ def _get_drop_memory_table_sql(self) -> "list[str]": @@ -821,16 +912,25 @@ def search_entries( return [] effective_limit = limit if limit is not None else self._max_results + if self._use_fts: - logger.debug("CockroachDB full-text search not supported; using simple search") + sql = f""" + SELECT * FROM {self._memory_table} + WHERE app_name = %s AND user_id = %s + AND to_tsvector('english', content_text) @@ plainto_tsquery('english', %s) + ORDER BY timestamp DESC + LIMIT %s + """ + else: + sql = f""" + SELECT * FROM {self._memory_table} + WHERE app_name = %s AND user_id = %s AND content_text ILIKE %s + ORDER BY timestamp DESC + LIMIT %s + """ - sql = f""" - SELECT * FROM {self._memory_table} - WHERE app_name = %s AND user_id = %s AND content_text ILIKE %s - ORDER BY timestamp DESC - LIMIT %s - """ - params = (app_name, user_id, f"%{query}%", effective_limit) + search_param = query if self._use_fts else f"%{query}%" + params = (app_name, user_id, search_param, effective_limit) try: with self._config.provide_connection() as conn, conn.cursor() as cur: @@ -860,9 +960,9 @@ def delete_entries_older_than(self, days: int) -> int: sql = f""" DELETE FROM {self._memory_table} - WHERE inserted_at < (CURRENT_TIMESTAMP - INTERVAL %s DAY) + WHERE inserted_at < CURRENT_TIMESTAMP - INTERVAL '{days} days' """ with self._config.provide_connection() as conn, conn.cursor() as cur: - cur.execute(sql.encode(), (days,)) + cur.execute(sql.encode()) conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 diff --git a/sqlspec/adapters/psqlpy/adk/store.py b/sqlspec/adapters/psqlpy/adk/store.py index c1629e05d..9cfdcea27 100644 --- a/sqlspec/adapters/psqlpy/adk/store.py +++ b/sqlspec/adapters/psqlpy/adk/store.py @@ -29,79 +29,28 @@ class PsqlpyADKStore(BaseAsyncADKStore["PsqlpyConfig"]): Implements session and event storage for Google Agent Development Kit using PostgreSQL via the high-performance Rust-based psqlpy driver. + Events are stored as a single JSONB blob (``event_json``) alongside + indexed scalar columns for efficient querying. Provides: - Session state management with JSONB storage - - Event history tracking with BYTEA-serialized actions + - Full-fidelity event storage via ``event_json`` JSONB column + - Atomic ``append_event_and_update_state`` for durable session mutations - Microsecond-precision timestamps with TIMESTAMPTZ - Foreign key constraints with cascade delete - - Efficient upserts using ON CONFLICT - GIN indexes for JSONB queries - HOT updates with FILLFACTOR 80 Args: config: PsqlpyConfig with extension_config["adk"] settings. - - Example: - from sqlspec.adapters.psqlpy import PsqlpyConfig - from sqlspec.adapters.psqlpy.adk import PsqlpyADKStore - - config = PsqlpyConfig( - connection_config={"dsn": "postgresql://..."}, - extension_config={ - "adk": { - "session_table": "my_sessions", - "events_table": "my_events", - "owner_id_column": "tenant_id INTEGER NOT NULL REFERENCES tenants(id) ON DELETE CASCADE" - } - } - ) - store = PsqlpyADKStore(config) - await store.ensure_tables() - - Notes: - - PostgreSQL JSONB type used for state (more efficient than JSON) - - Psqlpy automatically converts Python dicts to/from JSONB - - TIMESTAMPTZ provides timezone-aware microsecond precision - - BYTEA for pre-serialized actions from Google ADK - - GIN index on state for JSONB queries (partial index) - - FILLFACTOR 80 leaves space for HOT updates - - Uses PostgreSQL numeric parameter style ($1, $2, $3) - - Configuration is read from config.extension_config["adk"] """ __slots__ = () def __init__(self, config: "PsqlpyConfig") -> None: - """Initialize Psqlpy ADK store. - - Args: - config: PsqlpyConfig instance. - - Notes: - Configuration is read from config.extension_config["adk"]: - - session_table: Sessions table name (default: "adk_sessions") - - events_table: Events table name (default: "adk_events") - - owner_id_column: Optional owner FK column DDL (default: None) - """ super().__init__(config) async def _get_create_sessions_table_sql(self) -> str: - """Get PostgreSQL CREATE TABLE SQL for sessions. - - Returns: - SQL statement to create adk_sessions table with indexes. - - Notes: - - VARCHAR(128) for IDs and names (sufficient for UUIDs and app names) - - JSONB type for state storage with default empty object - - TIMESTAMPTZ with microsecond precision - - FILLFACTOR 80 for HOT updates (reduces table bloat) - - Composite index on (app_name, user_id) for listing - - Index on update_time DESC for recent session queries - - Partial GIN index on state for JSONB queries (only non-empty) - - Optional owner ID column for multi-tenancy or user references - """ owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" @@ -128,66 +77,24 @@ async def _get_create_sessions_table_sql(self) -> str: """ async def _get_create_events_table_sql(self) -> str: - """Get PostgreSQL CREATE TABLE SQL for events. - - Returns: - SQL statement to create adk_events table with indexes. - - Notes: - - VARCHAR sizes: id(128), session_id(128), invocation_id(256), author(256), - branch(256), error_code(256), error_message(1024) - - BYTEA for pre-serialized actions (no size limit) - - JSONB for content, grounding_metadata, custom_metadata, long_running_tool_ids_json - - BOOLEAN for partial, turn_complete, interrupted - - Foreign key to sessions with CASCADE delete - - Index on (session_id, timestamp ASC) for ordered event retrieval - """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256), - author VARCHAR(256), - actions BYTEA, - long_running_tool_ids_json JSONB, - branch VARCHAR(256), + invocation_id VARCHAR(256) NOT NULL, + author VARCHAR(256) NOT NULL, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - content JSONB, - grounding_metadata JSONB, - custom_metadata JSONB, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE - ); + ) WITH (fillfactor = 80); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); """ def _get_drop_tables_sql(self) -> "list[str]": - """Get PostgreSQL DROP TABLE SQL statements. - - Returns: - List of SQL statements to drop tables and indexes. - - Notes: - Order matters: drop events table (child) before sessions (parent). - PostgreSQL automatically drops indexes when dropping tables. - """ return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] async def create_tables(self) -> None: - """Create both sessions and events tables if they don't exist. - - Notes: - Uses driver.execute_script() which handles multiple statements. - Creates sessions table first, then events table (FK dependency). - """ async with self._config.provide_session() as driver: await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) @@ -195,23 +102,6 @@ async def create_tables(self) -> None: async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: - """Create a new session. - - Args: - session_id: Unique session identifier. - app_name: Application name. - user_id: User identifier. - state: Initial session state. - owner_id: Optional owner ID value for owner_id_column (if configured). - - Returns: - Created session record. - - Notes: - Uses CURRENT_TIMESTAMP for create_time and update_time. - State is passed as dict and psqlpy converts to JSONB automatically. - If owner_id_column is configured, owner_id value must be provided. - """ async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] if self._owner_id_column_name: sql = f""" @@ -230,19 +120,6 @@ async def create_session( return await self.get_session(session_id) # type: ignore[return-value] async def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID. - - Args: - session_id: Session identifier. - - Returns: - Session record or None if not found. - - Notes: - PostgreSQL returns datetime objects for TIMESTAMPTZ columns. - JSONB is automatically parsed by psqlpy to Python dicts. - Returns None if table doesn't exist (catches database errors). - """ sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} @@ -273,17 +150,6 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": raise async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. - - Args: - session_id: Session identifier. - state: New state dictionary (replaces existing state). - - Notes: - This replaces the entire state dictionary. - Uses CURRENT_TIMESTAMP for update_time. - Psqlpy automatically converts dict to JSONB. - """ sql = f""" UPDATE {self._session_table} SET state = $1, update_time = CURRENT_TIMESTAMP @@ -294,33 +160,12 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") - await conn.execute(sql, [state, session_id]) async def delete_session(self, session_id: str) -> None: - """Delete session and all associated events (cascade). - - Args: - session_id: Session identifier. - - Notes: - Foreign key constraint ensures events are cascade-deleted. - """ sql = f"DELETE FROM {self._session_table} WHERE id = $1" async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] await conn.execute(sql, [session_id]) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app, optionally filtered by user. - - Args: - app_name: Application name. - user_id: User identifier. If None, lists all sessions for the app. - - Returns: - List of session records ordered by update_time DESC. - - Notes: - Uses composite index on (app_name, user_id) when user_id is provided. - Returns empty list if table doesn't exist. - """ if user_id is None: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -361,74 +206,54 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis raise async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session. - - Args: - event_record: Event record to store. - - Notes: - Uses CURRENT_TIMESTAMP for timestamp if not provided. - JSONB fields are passed as dicts and psqlpy converts automatically. - BYTEA actions field stores pre-serialized data from Google ADK. - """ - content_json = event_record.get("content") - grounding_metadata_json = event_record.get("grounding_metadata") - custom_metadata_json = event_record.get("custom_metadata") - sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18 - ) + session_id, invocation_id, author, timestamp, event_json + ) VALUES ($1, $2, $3, $4, $5) """ async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] await conn.execute( sql, [ - event_record["id"], event_record["session_id"], - event_record["app_name"], - event_record["user_id"], - event_record.get("invocation_id"), - event_record.get("author"), - event_record.get("actions"), - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_record["event_json"], + ], + ) + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES ($1, $2, $3, $4, $5) + """ + update_sql = f""" + UPDATE {self._session_table} + SET state = $1, update_time = CURRENT_TIMESTAMP + WHERE id = $2 + """ + + async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] + await conn.execute( + insert_sql, + [ + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], event_record["timestamp"], - content_json, - grounding_metadata_json, - custom_metadata_json, - event_record.get("partial"), - event_record.get("turn_complete"), - event_record.get("interrupted"), - event_record.get("error_code"), - event_record.get("error_message"), + event_record["event_json"], ], ) + await conn.execute(update_sql, [state, session_id]) async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": - """Get events for a session. - - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. - - Returns: - List of event records ordered by timestamp ASC. - - Notes: - Uses index on (session_id, timestamp ASC). - Parses JSONB fields and converts BYTEA actions to bytes. - Returns empty list if table doesn't exist. - """ where_clauses = ["session_id = $1"] params: list[Any] = [session_id] @@ -442,10 +267,7 @@ async def get_events( params.append(limit) sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -458,24 +280,11 @@ async def get_events( return [ EventRecord( - id=row["id"], session_id=row["session_id"], - app_name=row["app_name"], - user_id=row["user_id"], invocation_id=row["invocation_id"], author=row["author"], - actions=bytes(row["actions"]) if row["actions"] else b"", - long_running_tool_ids_json=row["long_running_tool_ids_json"], - branch=row["branch"], timestamp=row["timestamp"], - content=row["content"], - grounding_metadata=row["grounding_metadata"], - custom_metadata=row["custom_metadata"], - partial=row["partial"], - turn_complete=row["turn_complete"], - interrupted=row["interrupted"], - error_code=row["error_code"], - error_message=row["error_message"], + event_json=row["event_json"], ) for row in rows ] diff --git a/sqlspec/adapters/psycopg/adk/store.py b/sqlspec/adapters/psycopg/adk/store.py index 6011e6dc2..bee3991ee 100644 --- a/sqlspec/adapters/psycopg/adk/store.py +++ b/sqlspec/adapters/psycopg/adk/store.py @@ -56,84 +56,32 @@ def _build_insert_params_with_owner(entry: "MemoryRecord", owner_id: "object | N class PsycopgAsyncADKStore(BaseAsyncADKStore["PsycopgAsyncConfig"]): - """PostgreSQL ADK store using Psycopg3 driver. + """PostgreSQL ADK store using Psycopg3 async driver. Implements session and event storage for Google Agent Development Kit using PostgreSQL via psycopg3 with native async/await support. + Events are stored as a single JSONB blob (``event_json``) alongside + indexed scalar columns for efficient querying. Provides: - - Session state management with JSONB storage and merge operations - - Event history tracking with BYTEA-serialized actions + - Session state management with JSONB storage + - Full-fidelity event storage via ``event_json`` JSONB column + - Atomic ``append_event_and_update_state`` for durable session mutations - Microsecond-precision timestamps with TIMESTAMPTZ - Foreign key constraints with cascade delete - - Efficient upserts using ON CONFLICT - GIN indexes for JSONB queries - HOT updates with FILLFACTOR 80 Args: config: PsycopgAsyncConfig with extension_config["adk"] settings. - - Example: - from sqlspec.adapters.psycopg import PsycopgAsyncConfig - from sqlspec.adapters.psycopg.adk import PsycopgAsyncADKStore - - config = PsycopgAsyncConfig( - connection_config={"conninfo": "postgresql://..."}, - extension_config={ - "adk": { - "session_table": "my_sessions", - "events_table": "my_events", - "owner_id_column": "tenant_id INTEGER NOT NULL REFERENCES tenants(id) ON DELETE CASCADE" - } - } - ) - store = PsycopgAsyncADKStore(config) - await store.ensure_tables() - - Notes: - - PostgreSQL JSONB type used for state (more efficient than JSON) - - Psycopg requires wrapping dicts with Jsonb() for type safety - - TIMESTAMPTZ provides timezone-aware microsecond precision - - State merging uses `state || $1::jsonb` operator for efficiency - - BYTEA for pre-serialized actions from Google ADK - - GIN index on state for JSONB queries (partial index) - - FILLFACTOR 80 leaves space for HOT updates - - Parameter style: $1, $2, $3 (PostgreSQL numeric placeholders) - - Configuration is read from config.extension_config["adk"] """ __slots__ = () def __init__(self, config: "PsycopgAsyncConfig") -> None: - """Initialize Psycopg ADK store. - - Args: - config: PsycopgAsyncConfig instance. - - Notes: - Configuration is read from config.extension_config["adk"]: - - session_table: Sessions table name (default: "adk_sessions") - - events_table: Events table name (default: "adk_events") - - owner_id_column: Optional owner FK column DDL (default: None) - """ super().__init__(config) async def _get_create_sessions_table_sql(self) -> str: - """Get PostgreSQL CREATE TABLE SQL for sessions. - - Returns: - SQL statement to create adk_sessions table with indexes. - - Notes: - - VARCHAR(128) for IDs and names (sufficient for UUIDs and app names) - - JSONB type for state storage with default empty object - - TIMESTAMPTZ with microsecond precision - - FILLFACTOR 80 for HOT updates (reduces table bloat) - - Composite index on (app_name, user_id) for listing - - Index on update_time DESC for recent session queries - - Partial GIN index on state for JSONB queries (only non-empty) - - Optional owner ID column for multi-tenancy or user references - """ owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" @@ -160,61 +108,24 @@ async def _get_create_sessions_table_sql(self) -> str: """ async def _get_create_events_table_sql(self) -> str: - """Get PostgreSQL CREATE TABLE SQL for events. - - Returns: - SQL statement to create adk_events table with indexes. - - Notes: - - VARCHAR sizes: id(128), session_id(128), invocation_id(256), author(256), - branch(256), error_code(256), error_message(1024) - - BYTEA for pickled actions (no size limit) - - JSONB for content, grounding_metadata, custom_metadata, long_running_tool_ids_json - - BOOLEAN for partial, turn_complete, interrupted - - Foreign key to sessions with CASCADE delete - - Index on (session_id, timestamp ASC) for ordered event retrieval - """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256), - author VARCHAR(256), - actions BYTEA, - long_running_tool_ids_json JSONB, - branch VARCHAR(256), + invocation_id VARCHAR(256) NOT NULL, + author VARCHAR(256) NOT NULL, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - content JSONB, - grounding_metadata JSONB, - custom_metadata JSONB, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE - ); + ) WITH (fillfactor = 80); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); """ def _get_drop_tables_sql(self) -> "list[str]": - """Get PostgreSQL DROP TABLE SQL statements. - - Returns: - List of SQL statements to drop tables and indexes. - - Notes: - Order matters: drop events table (child) before sessions (parent). - PostgreSQL automatically drops indexes when dropping tables. - """ return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] async def create_tables(self) -> None: - """Create both sessions and events tables if they don't exist.""" async with self._config.provide_session() as driver: await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) @@ -222,23 +133,6 @@ async def create_tables(self) -> None: async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: - """Create a new session. - - Args: - session_id: Unique session identifier. - app_name: Application name. - user_id: User identifier. - state: Initial session state. - owner_id: Optional owner ID value for owner_id_column (if configured). - - Returns: - Created session record. - - Notes: - Uses CURRENT_TIMESTAMP for create_time and update_time. - State is wrapped with Jsonb() for PostgreSQL type safety. - If owner_id_column is configured, owner_id value must be provided. - """ params: tuple[Any, ...] if self._owner_id_column_name: query = pg_sql.SQL(""" @@ -261,18 +155,6 @@ async def create_session( return await self.get_session(session_id) # type: ignore[return-value] async def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID. - - Args: - session_id: Session identifier. - - Returns: - Session record or None if not found. - - Notes: - PostgreSQL returns datetime objects for TIMESTAMPTZ columns. - JSONB is automatically deserialized by psycopg to Python dict. - """ query = pg_sql.SQL(""" SELECT id, app_name, user_id, state, create_time, update_time FROM {table} @@ -299,17 +181,6 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": return None async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. - - Args: - session_id: Session identifier. - state: New state dictionary (replaces existing state). - - Notes: - This replaces the entire state dictionary. - Uses CURRENT_TIMESTAMP for update_time. - State is wrapped with Jsonb() for PostgreSQL type safety. - """ query = pg_sql.SQL(""" UPDATE {table} SET state = %s, update_time = CURRENT_TIMESTAMP @@ -320,32 +191,12 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") - await cur.execute(query, (Jsonb(state), session_id)) async def delete_session(self, session_id: str) -> None: - """Delete session and all associated events (cascade). - - Args: - session_id: Session identifier. - - Notes: - Foreign key constraint ensures events are cascade-deleted. - """ query = pg_sql.SQL("DELETE FROM {table} WHERE id = %s").format(table=pg_sql.Identifier(self._session_table)) async with self._config.provide_connection() as conn, conn.cursor() as cur: await cur.execute(query, (session_id,)) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app, optionally filtered by user. - - Args: - app_name: Application name. - user_id: User identifier. If None, lists all sessions for the app. - - Returns: - List of session records ordered by update_time DESC. - - Notes: - Uses composite index on (app_name, user_id) when user_id is provided. - """ if user_id is None: query = pg_sql.SQL(""" SELECT id, app_name, user_id, state, create_time, update_time @@ -383,73 +234,62 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis return [] async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session. + query = pg_sql.SQL(""" + INSERT INTO {table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) + """).format(table=pg_sql.Identifier(self._events_table)) - Args: - event_record: Event record to store. + event_json_value = event_record["event_json"] + jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value - Notes: - Uses CURRENT_TIMESTAMP for timestamp if not provided. - JSONB fields are wrapped with Jsonb() for PostgreSQL type safety. - """ - content_json = event_record.get("content") - grounding_metadata_json = event_record.get("grounding_metadata") - custom_metadata_json = event_record.get("custom_metadata") + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute( + query, + ( + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + jsonb_value, + ), + ) - query = pg_sql.SQL(""" + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + insert_query = pg_sql.SQL(""" INSERT INTO {table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s - ) + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) """).format(table=pg_sql.Identifier(self._events_table)) + update_query = pg_sql.SQL(""" + UPDATE {table} + SET state = %s, update_time = CURRENT_TIMESTAMP + WHERE id = %s + """).format(table=pg_sql.Identifier(self._session_table)) + + event_json_value = event_record["event_json"] + jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value + async with self._config.provide_connection() as conn, conn.cursor() as cur: await cur.execute( - query, + insert_query, ( - event_record["id"], event_record["session_id"], - event_record["app_name"], - event_record["user_id"], - event_record.get("invocation_id"), - event_record.get("author"), - event_record.get("actions"), - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), + event_record["invocation_id"], + event_record["author"], event_record["timestamp"], - Jsonb(content_json) if content_json is not None else None, - Jsonb(grounding_metadata_json) if grounding_metadata_json is not None else None, - Jsonb(custom_metadata_json) if custom_metadata_json is not None else None, - event_record.get("partial"), - event_record.get("turn_complete"), - event_record.get("interrupted"), - event_record.get("error_code"), - event_record.get("error_message"), + jsonb_value, ), ) + await cur.execute(update_query, (Jsonb(state), session_id)) + await conn.commit() async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": - """Get events for a session. - - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. - - Returns: - List of event records ordered by timestamp ASC. - - Notes: - Uses index on (session_id, timestamp ASC). - JSONB fields are automatically deserialized by psycopg. - BYTEA actions are converted to bytes. - """ where_clauses = ["session_id = %s"] params: list[Any] = [session_id] @@ -463,10 +303,7 @@ async def get_events( query = pg_sql.SQL( """ - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -484,24 +321,11 @@ async def get_events( return [ EventRecord( - id=row["id"], session_id=row["session_id"], - app_name=row["app_name"], - user_id=row["user_id"], invocation_id=row["invocation_id"], author=row["author"], - actions=bytes(row["actions"]) if row["actions"] else b"", - long_running_tool_ids_json=row["long_running_tool_ids_json"], - branch=row["branch"], timestamp=row["timestamp"], - content=row["content"], - grounding_metadata=row["grounding_metadata"], - custom_metadata=row["custom_metadata"], - partial=row["partial"], - turn_complete=row["turn_complete"], - interrupted=row["interrupted"], - error_code=row["error_code"], - error_message=row["error_message"], + event_json=row["event_json"], ) for row in rows ] @@ -514,80 +338,28 @@ class PsycopgSyncADKStore(BaseSyncADKStore["PsycopgSyncConfig"]): Implements session and event storage for Google Agent Development Kit using PostgreSQL via psycopg3 with synchronous execution. + Events are stored as a single JSONB blob (``event_json``) alongside + indexed scalar columns for efficient querying. Provides: - - Session state management with JSONB storage and merge operations - - Event history tracking with BYTEA-serialized actions + - Session state management with JSONB storage + - Full-fidelity event storage via ``event_json`` JSONB column + - Atomic ``create_event_and_update_state`` for durable session mutations - Microsecond-precision timestamps with TIMESTAMPTZ - Foreign key constraints with cascade delete - - Efficient upserts using ON CONFLICT - GIN indexes for JSONB queries - HOT updates with FILLFACTOR 80 Args: config: PsycopgSyncConfig with extension_config["adk"] settings. - - Example: - from sqlspec.adapters.psycopg import PsycopgSyncConfig - from sqlspec.adapters.psycopg.adk import PsycopgSyncADKStore - - config = PsycopgSyncConfig( - connection_config={"conninfo": "postgresql://..."}, - extension_config={ - "adk": { - "session_table": "my_sessions", - "events_table": "my_events", - "owner_id_column": "tenant_id INTEGER NOT NULL REFERENCES tenants(id) ON DELETE CASCADE" - } - } - ) - store = PsycopgSyncADKStore(config) - store.ensure_tables() - - Notes: - - PostgreSQL JSONB type used for state (more efficient than JSON) - - Psycopg requires wrapping dicts with Jsonb() for type safety - - TIMESTAMPTZ provides timezone-aware microsecond precision - - State merging uses `state || $1::jsonb` operator for efficiency - - BYTEA for pre-serialized actions from Google ADK - - GIN index on state for JSONB queries (partial index) - - FILLFACTOR 80 leaves space for HOT updates - - Parameter style: $1, $2, $3 (PostgreSQL numeric placeholders) - - Configuration is read from config.extension_config["adk"] """ __slots__ = () def __init__(self, config: "PsycopgSyncConfig") -> None: - """Initialize Psycopg synchronous ADK store. - - Args: - config: PsycopgSyncConfig instance. - - Notes: - Configuration is read from config.extension_config["adk"]: - - session_table: Sessions table name (default: "adk_sessions") - - events_table: Events table name (default: "adk_events") - - owner_id_column: Optional owner FK column DDL (default: None) - """ super().__init__(config) def _get_create_sessions_table_sql(self) -> str: - """Get PostgreSQL CREATE TABLE SQL for sessions. - - Returns: - SQL statement to create adk_sessions table with indexes. - - Notes: - - VARCHAR(128) for IDs and names (sufficient for UUIDs and app names) - - JSONB type for state storage with default empty object - - TIMESTAMPTZ with microsecond precision - - FILLFACTOR 80 for HOT updates (reduces table bloat) - - Composite index on (app_name, user_id) for listing - - Index on update_time DESC for recent session queries - - Partial GIN index on state for JSONB queries (only non-empty) - - Optional owner ID column for multi-tenancy or user references - """ owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" @@ -614,61 +386,24 @@ def _get_create_sessions_table_sql(self) -> str: """ def _get_create_events_table_sql(self) -> str: - """Get PostgreSQL CREATE TABLE SQL for events. - - Returns: - SQL statement to create adk_events table with indexes. - - Notes: - - VARCHAR sizes: id(128), session_id(128), invocation_id(256), author(256), - branch(256), error_code(256), error_message(1024) - - BYTEA for pickled actions (no size limit) - - JSONB for content, grounding_metadata, custom_metadata, long_running_tool_ids_json - - BOOLEAN for partial, turn_complete, interrupted - - Foreign key to sessions with CASCADE delete - - Index on (session_id, timestamp ASC) for ordered event retrieval - """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256), - author VARCHAR(256), - actions BYTEA, - long_running_tool_ids_json JSONB, - branch VARCHAR(256), + invocation_id VARCHAR(256) NOT NULL, + author VARCHAR(256) NOT NULL, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - content JSONB, - grounding_metadata JSONB, - custom_metadata JSONB, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE - ); + ) WITH (fillfactor = 80); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); """ def _get_drop_tables_sql(self) -> "list[str]": - """Get PostgreSQL DROP TABLE SQL statements. - - Returns: - List of SQL statements to drop tables and indexes. - - Notes: - Order matters: drop events table (child) before sessions (parent). - PostgreSQL automatically drops indexes when dropping tables. - """ return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] def create_tables(self) -> None: - """Create both sessions and events tables if they don't exist.""" with self._config.provide_session() as driver: driver.execute_script(self._get_create_sessions_table_sql()) driver.execute_script(self._get_create_events_table_sql()) @@ -676,23 +411,6 @@ def create_tables(self) -> None: def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: - """Create a new session. - - Args: - session_id: Unique session identifier. - app_name: Application name. - user_id: User identifier. - state: Initial session state. - owner_id: Optional owner ID value for owner_id_column (if configured). - - Returns: - Created session record. - - Notes: - Uses CURRENT_TIMESTAMP for create_time and update_time. - State is wrapped with Jsonb() for PostgreSQL type safety. - If owner_id_column is configured, owner_id value must be provided. - """ params: tuple[Any, ...] if self._owner_id_column_name: query = pg_sql.SQL(""" @@ -715,18 +433,6 @@ def create_session( return self.get_session(session_id) # type: ignore[return-value] def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID. - - Args: - session_id: Session identifier. - - Returns: - Session record or None if not found. - - Notes: - PostgreSQL returns datetime objects for TIMESTAMPTZ columns. - JSONB is automatically deserialized by psycopg to Python dict. - """ query = pg_sql.SQL(""" SELECT id, app_name, user_id, state, create_time, update_time FROM {table} @@ -753,17 +459,6 @@ def get_session(self, session_id: str) -> "SessionRecord | None": return None def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. - - Args: - session_id: Session identifier. - state: New state dictionary (replaces existing state). - - Notes: - This replaces the entire state dictionary. - Uses CURRENT_TIMESTAMP for update_time. - State is wrapped with Jsonb() for PostgreSQL type safety. - """ query = pg_sql.SQL(""" UPDATE {table} SET state = %s, update_time = CURRENT_TIMESTAMP @@ -774,32 +469,12 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None cur.execute(query, (Jsonb(state), session_id)) def delete_session(self, session_id: str) -> None: - """Delete session and all associated events (cascade). - - Args: - session_id: Session identifier. - - Notes: - Foreign key constraint ensures events are cascade-deleted. - """ query = pg_sql.SQL("DELETE FROM {table} WHERE id = %s").format(table=pg_sql.Identifier(self._session_table)) with self._config.provide_connection() as conn, conn.cursor() as cur: cur.execute(query, (session_id,)) def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app, optionally filtered by user. - - Args: - app_name: Application name. - user_id: User identifier. If None, lists all sessions for the app. - - Returns: - List of session records ordered by update_time DESC. - - Notes: - Uses composite index on (app_name, user_id) when user_id is provided. - """ if user_id is None: query = pg_sql.SQL(""" SELECT id, app_name, user_id, state, create_time, update_time @@ -847,70 +522,42 @@ def create_event( content: "dict[str, Any] | None" = None, **kwargs: Any, ) -> EventRecord: - """Create a new event. - - Args: - event_id: Unique event identifier. - session_id: Session identifier. - app_name: Application name. - user_id: User identifier. - author: Event author (user/assistant/system). - actions: Pickled actions object. - content: Event content (JSONB). - **kwargs: Additional optional fields (invocation_id, branch, timestamp, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message, long_running_tool_ids_json). - - Returns: - Created event record. - - Notes: - Uses CURRENT_TIMESTAMP for timestamp if not provided in kwargs. - JSONB fields are wrapped with Jsonb() for PostgreSQL type safety. + """Create a new event using the legacy positional API. + + This method is required by the BaseSyncADKStore contract. For new code, + prefer ``create_event_and_update_state`` which atomically persists the + event and updates session state. """ - content_json = Jsonb(content) if content is not None else None - grounding_metadata = kwargs.get("grounding_metadata") - grounding_metadata_json = Jsonb(grounding_metadata) if grounding_metadata is not None else None - custom_metadata = kwargs.get("custom_metadata") - custom_metadata_json = Jsonb(custom_metadata) if custom_metadata is not None else None + from datetime import datetime, timezone + + event_json: dict[str, Any] = {} + if author is not None: + event_json["author"] = author + if actions is not None: + event_json["actions"] = actions.hex() + if content is not None: + event_json["content"] = content + event_json.update({k: v for k, v in kwargs.items() if v is not None}) + + invocation_id = kwargs.get("invocation_id", "") + ts = kwargs.get("timestamp") or datetime.now(timezone.utc) query = pg_sql.SQL(""" INSERT INTO {table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - %s, %s, %s, %s, %s, %s, %s, %s, %s, COALESCE(%s, CURRENT_TIMESTAMP), %s, %s, %s, %s, %s, %s, %s, %s - ) - RETURNING id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) + RETURNING session_id, invocation_id, author, timestamp, event_json """).format(table=pg_sql.Identifier(self._events_table)) with self._config.provide_connection() as conn, conn.cursor() as cur: cur.execute( query, ( - event_id, session_id, - app_name, - user_id, - kwargs.get("invocation_id"), - author, - actions, - kwargs.get("long_running_tool_ids_json"), - kwargs.get("branch"), - kwargs.get("timestamp"), - content_json, - grounding_metadata_json, - custom_metadata_json, - kwargs.get("partial"), - kwargs.get("turn_complete"), - kwargs.get("interrupted"), - kwargs.get("error_code"), - kwargs.get("error_message"), + invocation_id, + author or "", + ts, + Jsonb(event_json), ), ) row = cur.fetchone() @@ -920,45 +567,48 @@ def create_event( raise RuntimeError(msg) return EventRecord( - id=row["id"], session_id=row["session_id"], - app_name=row["app_name"], - user_id=row["user_id"], invocation_id=row["invocation_id"], author=row["author"], - actions=bytes(row["actions"]) if row["actions"] else b"", - long_running_tool_ids_json=row["long_running_tool_ids_json"], - branch=row["branch"], timestamp=row["timestamp"], - content=row["content"], - grounding_metadata=row["grounding_metadata"], - custom_metadata=row["custom_metadata"], - partial=row["partial"], - turn_complete=row["turn_complete"], - interrupted=row["interrupted"], - error_code=row["error_code"], - error_message=row["error_message"], + event_json=row["event_json"], ) - def list_events(self, session_id: str) -> "list[EventRecord]": - """List events for a session ordered by timestamp. + def create_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + insert_query = pg_sql.SQL(""" + INSERT INTO {table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) + """).format(table=pg_sql.Identifier(self._events_table)) + + update_query = pg_sql.SQL(""" + UPDATE {table} + SET state = %s, update_time = CURRENT_TIMESTAMP + WHERE id = %s + """).format(table=pg_sql.Identifier(self._session_table)) - Args: - session_id: Session identifier. + event_json_value = event_record["event_json"] + jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value - Returns: - List of event records ordered by timestamp ASC. + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute( + insert_query, + ( + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + jsonb_value, + ), + ) + cur.execute(update_query, (Jsonb(state), session_id)) + conn.commit() - Notes: - Uses index on (session_id, timestamp ASC). - JSONB fields are automatically deserialized by psycopg. - BYTEA actions are converted to bytes. - """ + def list_events(self, session_id: str) -> "list[EventRecord]": query = pg_sql.SQL(""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {table} WHERE session_id = %s ORDER BY timestamp ASC @@ -971,24 +621,11 @@ def list_events(self, session_id: str) -> "list[EventRecord]": return [ EventRecord( - id=row["id"], session_id=row["session_id"], - app_name=row["app_name"], - user_id=row["user_id"], invocation_id=row["invocation_id"], author=row["author"], - actions=bytes(row["actions"]) if row["actions"] else b"", - long_running_tool_ids_json=row["long_running_tool_ids_json"], - branch=row["branch"], timestamp=row["timestamp"], - content=row["content"], - grounding_metadata=row["grounding_metadata"], - custom_metadata=row["custom_metadata"], - partial=row["partial"], - turn_complete=row["turn_complete"], - interrupted=row["interrupted"], - error_code=row["error_code"], - error_message=row["error_message"], + event_json=row["event_json"], ) for row in rows ] From e9ba77a077f5cda5eb6fcff78aad59252216e2b3 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 21 Mar 2026 17:21:05 +0000 Subject: [PATCH 13/23] feat(adk): expand ADKConfig with capability-based configuration --- sqlspec/config.py | 268 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 268 insertions(+) diff --git a/sqlspec/config.py b/sqlspec/config.py index e3a96c35e..f5723554a 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -35,7 +35,11 @@ __all__ = ( + "ADKCompressionConfig", "ADKConfig", + "ADKPartitionConfig", + "ADKRetentionConfig", + "ADKSqliteOptimizationConfig", "AsyncConfigT", "AsyncDatabaseConfig", "ConfigT", @@ -374,6 +378,178 @@ class FastAPIConfig(StarletteConfig): """ +class ADKPartitionConfig(TypedDict): + """Configuration for table partitioning and sharding strategies. + + Controls how ADK tables are partitioned across backends that support it. + Backends without native partitioning support ignore these settings. + + Example: + extension_config={ + "adk": { + "partitioning": { + "strategy": "range", + "partition_key": "created_at", + "interval": "month", + } + } + } + """ + + strategy: NotRequired[Literal["range", "list", "hash"]] + """Partitioning strategy. Default: None (no partitioning). + + - range: Partition by range of values (e.g., time-based) + - list: Partition by discrete value lists + - hash: Partition by hash of the partition key + + Supported by: PostgreSQL, MySQL 8+, Oracle, BigQuery, Spanner. + Ignored by: SQLite, DuckDB. + """ + + partition_key: NotRequired[str] + """Column name used as the partition key. + + For range partitioning with time-based data, this is typically a timestamp column + like 'created_at'. For hash partitioning, this is typically the primary key. + """ + + interval: NotRequired[str] + """Partition interval for range partitioning. + + Examples: 'day', 'week', 'month', 'year'. + Only meaningful when strategy is 'range'. + """ + + +class ADKRetentionConfig(TypedDict): + """Configuration for data retention and TTL policies. + + Controls automatic cleanup of expired data. Backends with native TTL support + (CockroachDB Row-Level TTL, Spanner Row Deletion Policy) use database-level + enforcement. Others fall back to application-level sweep queries. + + Example: + extension_config={ + "adk": { + "retention": { + "session_ttl_seconds": 86400, + "event_ttl_seconds": 604800, + "memory_ttl_seconds": 0, + } + } + } + """ + + session_ttl_seconds: NotRequired[int] + """TTL for session records in seconds. Default: 0 (no expiry). + + When set, sessions older than this threshold are eligible for cleanup. + Backends with native TTL (CockroachDB, Spanner) enforce this at the database level. + Others require application-level cleanup via periodic sweep. + """ + + event_ttl_seconds: NotRequired[int] + """TTL for event records in seconds. Default: 0 (no expiry). + + When set, events older than this threshold are eligible for cleanup. + """ + + memory_ttl_seconds: NotRequired[int] + """TTL for memory entries in seconds. Default: 0 (no expiry). + + When set, memory entries older than this threshold are eligible for cleanup. + """ + + sweep_interval_seconds: NotRequired[int] + """Interval between application-level cleanup sweeps in seconds. Default: 3600 (1 hour). + + Only used when the backend does not support native TTL enforcement. + Set to 0 to disable automatic sweeps (manual cleanup only). + """ + + +class ADKCompressionConfig(TypedDict): + """Configuration for table-level compression. + + Controls compression of ADK table storage. Support and algorithms vary by backend. + + Example: + extension_config={ + "adk": { + "compression": { + "enabled": True, + "algorithm": "zstd", + } + } + } + """ + + enabled: NotRequired[bool] + """Enable table compression. Default: False. + + When True, adapters that support table-level compression will apply it + during table creation. + """ + + algorithm: NotRequired[str] + """Compression algorithm name. Backend-specific. + + Examples: + - PostgreSQL (with TOAST): 'pglz', 'lz4' (PG14+) + - MySQL/InnoDB: 'zlib' + - Oracle: 'basic', 'oltp', 'query_high', 'archive_high' + - DuckDB: 'zstd', 'snappy' + + When omitted, the backend default is used. + """ + + level: NotRequired[int] + """Compression level (where supported). Higher levels trade CPU for space savings. + + Valid ranges depend on the algorithm and backend. + """ + + +class ADKSqliteOptimizationConfig(TypedDict): + """SQLite-specific PRAGMA optimization settings. + + Controls SQLite performance tuning parameters applied at connection time. + These settings are ignored by non-SQLite adapters. + + Example: + extension_config={ + "adk": { + "sqlite_optimization": { + "cache_size": -64000, + "mmap_size": 31457280, + "journal_size_limit": 67108864, + } + } + } + """ + + cache_size: NotRequired[int] + """SQLite page cache size. Default: -64000 (64 MB, negative means KiB). + + Larger caches reduce disk I/O for read-heavy workloads. + Negative values specify size in KiB; positive values specify page count. + """ + + mmap_size: NotRequired[int] + """SQLite memory-mapped I/O size in bytes. Default: 31457280 (30 MB). + + Enables memory-mapped I/O for faster reads. Set to 0 to disable. + """ + + journal_size_limit: NotRequired[int] + """SQLite journal file size limit in bytes. Default: 67108864 (64 MB). + + Limits the size of the WAL or rollback journal file. + Prevents unbounded journal growth in write-heavy workloads. + """ + + class ADKConfig(TypedDict): """Configuration options for ADK session and memory store extension. @@ -585,6 +761,98 @@ class ADKConfig(TypedDict): expires_index_options: NotRequired[str] """Adapter-specific options for the expires/index used in ADK stores.""" + # --- Capability-based configuration (Chapter 2: schema-capability-config) --- + + fts_language: NotRequired[str] + """Language configuration for full-text search indexing. Default: 'english'. + + Controls the language dictionary/stemmer used by FTS implementations: + - PostgreSQL: to_tsvector/to_tsquery language parameter + - SQLite FTS5: tokenizer language for unicode61/porter + - MySQL: FULLTEXT parser language (with ngram for CJK on 5.7.6+) + - Oracle: CTXSYS.CONTEXT lexer language + - Spanner: TOKENIZE_FULLTEXT language parameter + - DuckDB: FTS stemmer language + + Only takes effect when ``memory_use_fts`` is True. + + Common values: 'english', 'simple', 'german', 'french', 'spanish', + 'portuguese', 'italian', 'dutch', 'russian', 'chinese', 'japanese', 'korean'. + + Notes: + Available languages vary by backend. Backends that do not support the + specified language will fall back to 'simple' or 'english'. + """ + + artifact_storage_uri: NotRequired[str] + """Base URI for artifact content storage. Default: None (store inline in database). + + When set, large artifact payloads are stored externally and only metadata + is kept in the database. The URI scheme determines the storage backend: + - ``file:///path/to/artifacts`` — local filesystem + - ``s3://bucket/prefix`` — AWS S3 or S3-compatible storage + - ``gs://bucket/prefix`` — Google Cloud Storage + - ``az://container/prefix`` — Azure Blob Storage + + When None, artifact content is stored inline in the database tables, + which is suitable for small payloads but may cause performance issues + with large binary artifacts. + + Integrates with the ``StorageRegistry`` for pluggable storage backends. + """ + + schema_version: NotRequired[int] + """Explicit schema version for ADK tables. Default: None (auto-detect). + + When set, locks the ADK schema to a specific version. This is useful for: + - Preventing automatic schema upgrades in production + - Pinning to a known-good schema during testing + - Coordinating schema changes across multiple application instances + + When None, the ADK extension auto-detects the current schema version + and applies any pending upgrades during initialization. + + Notes: + Schema versions are monotonically increasing integers managed by + the ADK extension migration system. Setting this to a version + lower than the current database schema will raise a configuration + error at startup. + """ + + partitioning: NotRequired[ADKPartitionConfig] + """Table partitioning configuration. Default: None (no partitioning). + + Controls how ADK tables are partitioned for improved query performance + and data management at scale. See ``ADKPartitionConfig`` for options. + + Supported by: PostgreSQL, MySQL 8+, Oracle, BigQuery, Spanner. + Ignored by: SQLite, DuckDB. + """ + + retention: NotRequired[ADKRetentionConfig] + """Data retention and TTL configuration. Default: None (no automatic cleanup). + + Controls automatic expiry and cleanup of old session, event, and memory data. + See ``ADKRetentionConfig`` for options. + + Backends with native TTL (CockroachDB, Spanner) use database-level enforcement. + Others fall back to application-level sweep queries. + """ + + compression: NotRequired[ADKCompressionConfig] + """Table compression configuration. Default: None (no compression). + + Controls table-level compression for ADK tables. + See ``ADKCompressionConfig`` for options. + """ + + sqlite_optimization: NotRequired[ADKSqliteOptimizationConfig] + """SQLite-specific PRAGMA optimization settings. Default: None (SQLite defaults). + + Controls SQLite performance tuning parameters. Ignored by non-SQLite adapters. + See ``ADKSqliteOptimizationConfig`` for options. + """ + class EventsConfig(TypedDict): """Configuration options for the events extension. From bc856afad2cadaf7a6775d19abaa66342501985f Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 21 Mar 2026 17:23:47 +0000 Subject: [PATCH 14/23] feat(adk): add artifact service with storage-backed content Implement Chapter 11 (artifact-service) of the ADK Clean-Break Overhaul. Adds SQLSpecArtifactService implementing all 7 BaseArtifactService methods (save_artifact, load_artifact, list_artifact_keys, delete_artifact, list_versions, list_artifact_versions, get_artifact_version) by composing SQL-backed metadata storage with sqlspec/storage/ content backends. Architecture: - Metadata (version, filename, mime_type, custom_metadata, canonical_uri) stored in a SQL table via BaseAsync/SyncADKArtifactStore - Content bytes stored in object storage (S3, GCS, Azure, local filesystem) via StorageRegistry, addressed by canonical URI - URI pattern: {base}/apps/{app}/users/{user}/[sessions/{sid}/]artifacts/{file}/v{ver} - Versioning is append-only (0-based, monotonically increasing) - Content written first, metadata second (fail-fast ordering) - Delete removes metadata first, content cleanup is best-effort New files: - sqlspec/extensions/adk/artifact/__init__.py (package exports) - sqlspec/extensions/adk/artifact/_types.py (ArtifactRecord TypedDict) - sqlspec/extensions/adk/artifact/store.py (BaseAsync/SyncADKArtifactStore ABCs) - sqlspec/extensions/adk/artifact/service.py (SQLSpecArtifactService) Also adds artifact_table and artifact_storage_uri fields to ADKConfig. --- sqlspec/config.py | 21 + sqlspec/extensions/adk/__init__.py | 20 +- sqlspec/extensions/adk/artifact/__init__.py | 57 +++ sqlspec/extensions/adk/artifact/_types.py | 32 ++ sqlspec/extensions/adk/artifact/service.py | 509 ++++++++++++++++++++ sqlspec/extensions/adk/artifact/store.py | 363 ++++++++++++++ 6 files changed, 999 insertions(+), 3 deletions(-) create mode 100644 sqlspec/extensions/adk/artifact/__init__.py create mode 100644 sqlspec/extensions/adk/artifact/_types.py create mode 100644 sqlspec/extensions/adk/artifact/service.py create mode 100644 sqlspec/extensions/adk/artifact/store.py diff --git a/sqlspec/config.py b/sqlspec/config.py index f5723554a..29a873110 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -636,6 +636,27 @@ class ADKConfig(TypedDict): "tenant_acme_memories" """ + artifact_table: NotRequired[str] + """Name of the artifact versions table. Default: 'adk_artifact_versions' + + Examples: + "agent_artifacts" + "my_app_artifact_versions" + """ + + artifact_storage_uri: NotRequired[str] + """Base URI for artifact content storage. + + Points to a ``sqlspec/storage/`` backend where artifact binary content + is stored. Can be a direct URI (``s3://bucket/path``, ``file:///path``) + or a registered alias in the storage registry. + + Examples: + "s3://my-bucket/adk-artifacts/" + "file:///var/data/artifacts/" + "gcs://my-gcs-bucket/artifacts/" + """ + memory_use_fts: NotRequired[bool] """Enable full-text search when supported. Default: False. diff --git a/sqlspec/extensions/adk/__init__.py b/sqlspec/extensions/adk/__init__.py index c7877b1a5..7f6cfc586 100644 --- a/sqlspec/extensions/adk/__init__.py +++ b/sqlspec/extensions/adk/__init__.py @@ -1,20 +1,24 @@ -"""Google ADK session backend extension for SQLSpec. +"""Google ADK session, memory, and artifact backend extension for SQLSpec. -Provides session, event, and memory storage for Google Agent Development Kit using -SQLSpec database adapters. +Provides session, event, memory, and artifact storage for Google Agent Development Kit +using SQLSpec database adapters. Public API exports: - ADKConfig: TypedDict for extension config (type-safe configuration) - SQLSpecSessionService: Main service class implementing BaseSessionService - SQLSpecMemoryService: Main async service class implementing BaseMemoryService - SQLSpecSyncMemoryService: Sync memory service for sync adapters + - SQLSpecArtifactService: Artifact service implementing BaseArtifactService - BaseAsyncADKStore: Base class for async database store implementations - BaseSyncADKStore: Base class for sync database store implementations - BaseAsyncADKMemoryStore: Base class for async memory store implementations - BaseSyncADKMemoryStore: Base class for sync memory store implementations + - BaseAsyncADKArtifactStore: Base class for async artifact metadata stores + - BaseSyncADKArtifactStore: Base class for sync artifact metadata stores - SessionRecord: TypedDict for session database records - EventRecord: TypedDict for event database records - MemoryRecord: TypedDict for memory database records + - ArtifactRecord: TypedDict for artifact metadata database records Example (with extension_config): from sqlspec.adapters.asyncpg import AsyncpgConfig @@ -45,6 +49,12 @@ from sqlspec.config import ADKConfig from sqlspec.extensions.adk._types import EventRecord, SessionRecord +from sqlspec.extensions.adk.artifact import ( + ArtifactRecord, + BaseAsyncADKArtifactStore, + BaseSyncADKArtifactStore, + SQLSpecArtifactService, +) from sqlspec.extensions.adk.memory import ( BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore, @@ -57,12 +67,16 @@ __all__ = ( "ADKConfig", + "ArtifactRecord", + "BaseAsyncADKArtifactStore", "BaseAsyncADKMemoryStore", "BaseAsyncADKStore", + "BaseSyncADKArtifactStore", "BaseSyncADKMemoryStore", "BaseSyncADKStore", "EventRecord", "MemoryRecord", + "SQLSpecArtifactService", "SQLSpecMemoryService", "SQLSpecSessionService", "SQLSpecSyncMemoryService", diff --git a/sqlspec/extensions/adk/artifact/__init__.py b/sqlspec/extensions/adk/artifact/__init__.py new file mode 100644 index 000000000..36c3f8478 --- /dev/null +++ b/sqlspec/extensions/adk/artifact/__init__.py @@ -0,0 +1,57 @@ +"""Google ADK artifact service extension for SQLSpec. + +Provides artifact versioning and storage for Google Agent Development Kit +using SQLSpec database adapters for metadata and ``sqlspec/storage/`` backends +for content. + +Public API exports: + - SQLSpecArtifactService: Main service implementing BaseArtifactService + - BaseAsyncADKArtifactStore: Base class for async artifact metadata stores + - BaseSyncADKArtifactStore: Base class for sync artifact metadata stores + - ArtifactRecord: TypedDict for artifact metadata database records + +Example: + from sqlspec.adapters.asyncpg import AsyncpgConfig + from sqlspec.extensions.adk.artifact import SQLSpecArtifactService + + config = AsyncpgConfig( + connection_config={"dsn": "postgresql://..."}, + extension_config={ + "adk": { + "artifact_table": "adk_artifact_versions", + } + } + ) + + # Create an adapter-specific artifact store (e.g., AsyncpgADKArtifactStore) + # and ensure tables exist: + artifact_store = AsyncpgADKArtifactStore(config) + await artifact_store.ensure_table() + + # Create the service with a storage backend URI: + service = SQLSpecArtifactService( + store=artifact_store, + artifact_storage_uri="s3://my-bucket/adk-artifacts/", + ) + + # Save an artifact (returns version number starting from 0): + version = await service.save_artifact( + app_name="my_app", + user_id="user123", + filename="report.pdf", + artifact=part, + ) + + # Load artifact content: + loaded = await service.load_artifact( + app_name="my_app", + user_id="user123", + filename="report.pdf", + ) +""" + +from sqlspec.extensions.adk.artifact._types import ArtifactRecord +from sqlspec.extensions.adk.artifact.service import SQLSpecArtifactService +from sqlspec.extensions.adk.artifact.store import BaseAsyncADKArtifactStore, BaseSyncADKArtifactStore + +__all__ = ("ArtifactRecord", "BaseAsyncADKArtifactStore", "BaseSyncADKArtifactStore", "SQLSpecArtifactService") diff --git a/sqlspec/extensions/adk/artifact/_types.py b/sqlspec/extensions/adk/artifact/_types.py new file mode 100644 index 000000000..dcedffcf6 --- /dev/null +++ b/sqlspec/extensions/adk/artifact/_types.py @@ -0,0 +1,32 @@ +"""Type definitions for ADK artifact extension. + +These types define the database record structures for storing artifact metadata. +They are separate from the Pydantic models to keep mypyc compilation working. +""" + +from datetime import datetime +from typing import Any, TypedDict + +__all__ = ("ArtifactRecord",) + + +class ArtifactRecord(TypedDict): + """Database record for an artifact version. + + Represents the schema for artifact metadata stored in the database. + Content is stored separately in object storage; this record tracks + versioning, ownership, and the canonical URI pointing to the content. + + The composite key is (app_name, user_id, session_id, filename, version), + where session_id may be NULL for user-scoped artifacts. + """ + + app_name: str + user_id: str + session_id: "str | None" + filename: str + version: int + mime_type: "str | None" + canonical_uri: str + custom_metadata: "dict[str, Any] | None" + created_at: datetime diff --git a/sqlspec/extensions/adk/artifact/service.py b/sqlspec/extensions/adk/artifact/service.py new file mode 100644 index 000000000..cecb8f076 --- /dev/null +++ b/sqlspec/extensions/adk/artifact/service.py @@ -0,0 +1,509 @@ +"""SQLSpec-backed artifact service for Google ADK. + +Implements ``BaseArtifactService`` by composing SQL-backed metadata storage +(via :class:`BaseAsyncADKArtifactStore`) with ``sqlspec/storage/`` content +backends (via :class:`StorageRegistry`). + +Metadata (version, filename, MIME type, custom metadata, canonical URI) lives +in a SQL table. Content bytes live in object storage addressed by canonical +URI. Versioning is append-only with monotonically increasing version numbers +starting from 0. +""" + +import json +import logging +import re +from typing import TYPE_CHECKING, Any + +from google.adk.artifacts.base_artifact_service import BaseArtifactService + +from sqlspec.extensions.adk.artifact._types import ArtifactRecord +from sqlspec.storage.registry import StorageRegistry, storage_registry +from sqlspec.utils.logging import get_logger, log_with_context + +if TYPE_CHECKING: + from google.adk.artifacts.base_artifact_service import ArtifactVersion + from google.genai import types + + from sqlspec.extensions.adk.artifact.store import BaseAsyncADKArtifactStore + +logger = get_logger("sqlspec.extensions.adk.artifact.service") + +__all__ = ("SQLSpecArtifactService",) + +# Matches path traversal and absolute path components +_UNSAFE_PATH_CHARS = re.compile(r"(?:^|/)\.\.(?:/|$)|[\x00]") + + +def _sanitize_path_component(value: str) -> str: + """Sanitize a path component to prevent directory traversal. + + Removes leading/trailing slashes, rejects ``..`` traversals, and + replaces NUL bytes. + + Args: + value: Raw path component. + + Returns: + Sanitized path component. + + Raises: + ValueError: If the value contains path traversal sequences. + """ + value = value.strip("/") + if _UNSAFE_PATH_CHARS.search(value): + msg = f"Unsafe path component: {value!r}" + raise ValueError(msg) + return value + + +def _build_content_path( + app_name: str, user_id: str, filename: str, version: int, session_id: "str | None" = None +) -> str: + """Build the storage path for artifact content. + + Pattern: + ``apps/{app_name}/users/{user_id}/[sessions/{session_id}/]artifacts/{filename}/v{version}`` + + All path components are sanitized to prevent directory traversal. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + version: Version number. + session_id: Optional session identifier. + + Returns: + Sanitized storage path. + """ + parts = ["apps", _sanitize_path_component(app_name), "users", _sanitize_path_component(user_id)] + if session_id is not None: + parts.extend(["sessions", _sanitize_path_component(session_id)]) + parts.extend(["artifacts", _sanitize_path_component(filename), f"v{version}"]) + return "/".join(parts) + + +def _extract_mime_type(artifact: "types.Part | dict[str, Any]") -> "str | None": + """Extract MIME type from an artifact Part. + + Checks ``inline_data.mime_type`` and ``file_data.mime_type`` on the Part. + + Args: + artifact: ADK Part or dict representation. + + Returns: + MIME type string, or None if not determinable. + """ + if isinstance(artifact, dict): + # Handle camelCase and snake_case keys + inline = artifact.get("inline_data") or artifact.get("inlineData") + if isinstance(inline, dict): + return inline.get("mime_type") or inline.get("mimeType") + file_data = artifact.get("file_data") or artifact.get("fileData") + if isinstance(file_data, dict): + return file_data.get("mime_type") or file_data.get("mimeType") + return None + + # types.Part object + if hasattr(artifact, "inline_data") and artifact.inline_data is not None: + return getattr(artifact.inline_data, "mime_type", None) + if hasattr(artifact, "file_data") and artifact.file_data is not None: + return getattr(artifact.file_data, "mime_type", None) + return None + + +def _serialize_artifact(artifact: "types.Part | dict[str, Any]") -> bytes: + """Serialize an artifact Part to bytes for content storage. + + The artifact is serialized as JSON via ``model_dump(exclude_none=True)``. + This preserves the full Part structure including text, inline_data, + file_data, and any future Part fields. + + Args: + artifact: ADK Part or dict representation. + + Returns: + JSON-encoded bytes. + """ + if isinstance(artifact, dict): + return json.dumps(artifact, default=str).encode("utf-8") + + # Use Pydantic model serialization + if hasattr(artifact, "model_dump"): + data = artifact.model_dump(exclude_none=True) + return json.dumps(data, default=str).encode("utf-8") + + # Fallback for unexpected types + return json.dumps({"text": str(artifact)}).encode("utf-8") + + +def _deserialize_artifact(data: bytes) -> "types.Part": + """Deserialize bytes back into an ADK Part. + + Args: + data: JSON-encoded bytes from content storage. + + Returns: + Reconstructed Part object. + """ + from google.genai import types + + parsed = json.loads(data.decode("utf-8")) + return types.Part.model_validate(parsed) + + +def _record_to_artifact_version(record: "ArtifactRecord") -> "ArtifactVersion": + """Convert a database artifact record to an ADK ArtifactVersion. + + Args: + record: Database artifact record. + + Returns: + ArtifactVersion model instance. + """ + from google.adk.artifacts.base_artifact_service import ArtifactVersion + + return ArtifactVersion( + version=record["version"], + canonical_uri=record["canonical_uri"], + custom_metadata=record["custom_metadata"] or {}, + create_time=record["created_at"].timestamp(), + mime_type=record["mime_type"], + ) + + +class SQLSpecArtifactService(BaseArtifactService): + """SQLSpec-backed implementation of BaseArtifactService. + + Composes SQL metadata storage with ``sqlspec/storage/`` content backends + to provide versioned artifact management for Google ADK. + + Metadata (version number, filename, MIME type, custom metadata, canonical + URI) is stored in a SQL table managed by the artifact store. Content + bytes are stored in object storage (S3, GCS, Azure, local filesystem) + via the storage registry. + + Args: + store: Artifact metadata store implementation. + artifact_storage_uri: Base URI for content storage (e.g., + ``"s3://my-bucket/adk-artifacts/"``, ``"file:///var/data/artifacts/"``). + Can also be a registered alias in the storage registry. + registry: Storage registry to use. Defaults to the global singleton. + + Example: + from sqlspec.adapters.asyncpg.adk.artifact_store import AsyncpgADKArtifactStore + from sqlspec.extensions.adk.artifact import SQLSpecArtifactService + + artifact_store = AsyncpgADKArtifactStore(config) + await artifact_store.ensure_table() + + service = SQLSpecArtifactService( + store=artifact_store, + artifact_storage_uri="s3://my-bucket/adk-artifacts/", + ) + + version = await service.save_artifact( + app_name="my_app", + user_id="user123", + filename="output.png", + artifact=part, + ) + """ + + def __init__( + self, store: "BaseAsyncADKArtifactStore", artifact_storage_uri: str, registry: "StorageRegistry | None" = None + ) -> None: + self._store = store + self._artifact_storage_uri = artifact_storage_uri.rstrip("/") + self._registry = registry or storage_registry + + @property + def store(self) -> "BaseAsyncADKArtifactStore": + """Return the artifact metadata store.""" + return self._store + + @property + def artifact_storage_uri(self) -> str: + """Return the base URI for content storage.""" + return self._artifact_storage_uri + + async def save_artifact( + self, + *, + app_name: str, + user_id: str, + filename: str, + artifact: "types.Part | dict[str, Any]", + session_id: "str | None" = None, + custom_metadata: "dict[str, Any] | None" = None, + ) -> int: + """Save an artifact, returning the new version number. + + Writes content to object storage first, then inserts the metadata + row. If content write succeeds but metadata insert fails, the + orphaned content blob is logged but not automatically cleaned up + (eventual consistency is acceptable; orphan sweep can be added later). + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + artifact: ADK Part or dict to save. + session_id: Session identifier (None for user-scoped). + custom_metadata: Optional per-version metadata dict. + + Returns: + The version number (0-based, incrementing). + """ + from google.adk.artifacts.base_artifact_service import ensure_part + + # Normalize artifact to Part + artifact_part: types.Part = ensure_part(artifact) + + # Determine the next version + version = await self._store.get_next_version( + app_name=app_name, user_id=user_id, filename=filename, session_id=session_id + ) + + # Build the content path and canonical URI + content_path = _build_content_path( + app_name=app_name, user_id=user_id, filename=filename, version=version, session_id=session_id + ) + canonical_uri = f"{self._artifact_storage_uri}/{content_path}" + + # Serialize content + content_bytes = _serialize_artifact(artifact_part) + + # Extract MIME type + mime_type = _extract_mime_type(artifact_part) + + # Write content first (fail-fast before metadata) + backend = self._registry.get(self._artifact_storage_uri) + if hasattr(backend, "write_bytes_async"): + await backend.write_bytes_async(content_path, content_bytes) + else: + backend.write_bytes_sync(content_path, content_bytes) + + # Insert metadata row + from datetime import datetime, timezone + + record = ArtifactRecord( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + version=version, + mime_type=mime_type, + canonical_uri=canonical_uri, + custom_metadata=custom_metadata, + created_at=datetime.now(tz=timezone.utc), + ) + await self._store.insert_artifact(record) + + log_with_context( + logger, + logging.DEBUG, + "adk.artifact.save", + app_name=app_name, + user_id=user_id, + filename=filename, + version=version, + session_id=session_id, + mime_type=mime_type, + ) + return version + + async def load_artifact( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: "str | None" = None, + version: "int | None" = None, + ) -> "types.Part | None": + """Load an artifact by reading metadata then fetching content. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + version: Specific version, or None for latest. + + Returns: + Deserialized Part, or None if not found. + """ + record = await self._store.get_artifact( + app_name=app_name, user_id=user_id, filename=filename, session_id=session_id, version=version + ) + if record is None: + log_with_context( + logger, + logging.DEBUG, + "adk.artifact.load", + app_name=app_name, + filename=filename, + version=version, + found=False, + ) + return None + + # Derive content path from canonical URI + content_path = record["canonical_uri"].removeprefix(self._artifact_storage_uri + "/") + + backend = self._registry.get(self._artifact_storage_uri) + if hasattr(backend, "read_bytes_async"): + content_bytes = await backend.read_bytes_async(content_path) + else: + content_bytes = backend.read_bytes_sync(content_path) + + log_with_context( + logger, + logging.DEBUG, + "adk.artifact.load", + app_name=app_name, + filename=filename, + version=record["version"], + found=True, + ) + return _deserialize_artifact(content_bytes) + + async def list_artifact_keys(self, *, app_name: str, user_id: str, session_id: "str | None" = None) -> "list[str]": + """List distinct artifact filenames. + + When ``session_id`` is provided, returns both session-scoped and + user-scoped filenames. When None, returns only user-scoped filenames. + + Args: + app_name: Application name. + user_id: User identifier. + session_id: Session identifier. + + Returns: + List of artifact filenames. + """ + keys = await self._store.list_artifact_keys(app_name=app_name, user_id=user_id, session_id=session_id) + log_with_context( + logger, + logging.DEBUG, + "adk.artifact.list_keys", + app_name=app_name, + user_id=user_id, + session_id=session_id, + count=len(keys), + ) + return keys + + async def delete_artifact( + self, *, app_name: str, user_id: str, filename: str, session_id: "str | None" = None + ) -> None: + """Delete an artifact and all its versions. + + Deletes metadata rows first (fail-fast), then removes content + objects from storage (best-effort). + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + """ + deleted_records = await self._store.delete_artifact( + app_name=app_name, user_id=user_id, filename=filename, session_id=session_id + ) + + # Best-effort content cleanup + backend = self._registry.get(self._artifact_storage_uri) + for record in deleted_records: + content_path = record["canonical_uri"].removeprefix(self._artifact_storage_uri + "/") + try: + if hasattr(backend, "delete_async"): + await backend.delete_async(content_path) + else: + backend.delete_sync(content_path) + except Exception: + log_with_context( + logger, + logging.WARNING, + "adk.artifact.delete.content_cleanup_failed", + canonical_uri=record["canonical_uri"], + version=record["version"], + ) + + log_with_context( + logger, + logging.DEBUG, + "adk.artifact.delete", + app_name=app_name, + filename=filename, + session_id=session_id, + versions_deleted=len(deleted_records), + ) + + async def list_versions( + self, *, app_name: str, user_id: str, filename: str, session_id: "str | None" = None + ) -> "list[int]": + """List all version numbers for an artifact. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + + Returns: + Sorted list of version numbers. + """ + records = await self._store.list_artifact_versions( + app_name=app_name, user_id=user_id, filename=filename, session_id=session_id + ) + return [r["version"] for r in records] + + async def list_artifact_versions( + self, *, app_name: str, user_id: str, filename: str, session_id: "str | None" = None + ) -> "list[ArtifactVersion]": + """List all versions with full metadata for an artifact. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + + Returns: + List of ArtifactVersion objects ordered by version ascending. + """ + records = await self._store.list_artifact_versions( + app_name=app_name, user_id=user_id, filename=filename, session_id=session_id + ) + return [_record_to_artifact_version(r) for r in records] + + async def get_artifact_version( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: "str | None" = None, + version: "int | None" = None, + ) -> "ArtifactVersion | None": + """Get metadata for a specific artifact version. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + version: Version number, or None for latest. + + Returns: + ArtifactVersion if found, None otherwise. + """ + record = await self._store.get_artifact( + app_name=app_name, user_id=user_id, filename=filename, session_id=session_id, version=version + ) + if record is None: + return None + return _record_to_artifact_version(record) diff --git a/sqlspec/extensions/adk/artifact/store.py b/sqlspec/extensions/adk/artifact/store.py new file mode 100644 index 000000000..ec9c08a33 --- /dev/null +++ b/sqlspec/extensions/adk/artifact/store.py @@ -0,0 +1,363 @@ +"""Base store classes for ADK artifact metadata backend (sync and async). + +These abstract base classes define the database operations needed to manage +artifact version metadata. Content storage is handled separately by +``sqlspec/storage/`` backends; these stores only manage the relational +metadata rows. + +Adapter-specific subclasses (e.g., ``AsyncpgADKArtifactStore``) implement +the abstract methods with dialect-specific SQL. +""" + +import logging +import re +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Final, Generic, TypeVar, cast + +from sqlspec.observability import resolve_db_system +from sqlspec.utils.logging import get_logger, log_with_context + +if TYPE_CHECKING: + from sqlspec.config import ADKConfig, DatabaseConfigProtocol + from sqlspec.extensions.adk.artifact._types import ArtifactRecord + +ConfigT = TypeVar("ConfigT", bound="DatabaseConfigProtocol[Any, Any, Any]") + +logger = get_logger("sqlspec.extensions.adk.artifact.store") + +__all__ = ("BaseAsyncADKArtifactStore", "BaseSyncADKArtifactStore") + +VALID_TABLE_NAME_PATTERN: Final = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") +MAX_TABLE_NAME_LENGTH: Final = 63 + + +def _validate_table_name(table_name: str) -> None: + """Validate table name for SQL safety. + + Args: + table_name: Table name to validate. + + Raises: + ValueError: If table name is invalid. + """ + if not table_name: + msg = "Table name cannot be empty" + raise ValueError(msg) + + if len(table_name) > MAX_TABLE_NAME_LENGTH: + msg = f"Table name too long: {len(table_name)} chars (max {MAX_TABLE_NAME_LENGTH})" + raise ValueError(msg) + + if not VALID_TABLE_NAME_PATTERN.match(table_name): + msg = ( + f"Invalid table name: {table_name!r}. " + "Must start with letter/underscore and contain only alphanumeric characters and underscores" + ) + raise ValueError(msg) + + +class BaseAsyncADKArtifactStore(ABC, Generic[ConfigT]): + """Base class for async SQLSpec-backed ADK artifact metadata stores. + + Manages artifact version metadata in a SQL table. Content bytes are + stored externally via ``sqlspec/storage/`` backends and referenced + by canonical URI in each metadata row. + + Subclasses must implement dialect-specific SQL queries. + + Args: + config: SQLSpec database configuration with extension_config["adk"] settings. + + Notes: + Configuration is read from config.extension_config["adk"]: + - artifact_table: Artifact versions table name (default: "adk_artifact_versions") + """ + + __slots__ = ("_artifact_table", "_config") + + def __init__(self, config: ConfigT) -> None: + """Initialize the async ADK artifact store. + + Args: + config: SQLSpec database configuration. + """ + self._config = config + adk_config = self._get_adk_config() + self._artifact_table: str = str(adk_config.get("artifact_table", "adk_artifact_versions")) + _validate_table_name(self._artifact_table) + + def _get_adk_config(self) -> "dict[str, Any]": + """Extract ADK configuration from extension_config. + + Returns: + Dict with ADK configuration values. + """ + extension_config = self._config.extension_config + return dict(cast("ADKConfig", extension_config.get("adk", {}))) + + @property + def config(self) -> ConfigT: + """Return the database configuration.""" + return self._config + + @property + def artifact_table(self) -> str: + """Return the artifact versions table name.""" + return self._artifact_table + + @abstractmethod + async def insert_artifact(self, record: "ArtifactRecord") -> None: + """Insert an artifact version metadata row. + + Args: + record: Artifact metadata record to insert. + """ + + @abstractmethod + async def get_artifact( + self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None, version: "int | None" = None + ) -> "ArtifactRecord | None": + """Get a specific artifact version's metadata. + + When ``version`` is None, returns the latest version. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + version: Specific version number, or None for latest. + + Returns: + Artifact record if found, None otherwise. + """ + + @abstractmethod + async def list_artifact_keys(self, app_name: str, user_id: str, session_id: "str | None" = None) -> "list[str]": + """List distinct artifact filenames. + + When ``session_id`` is provided, returns filenames from both + session-scoped and user-scoped artifacts. When None, returns + only user-scoped artifact filenames. + + Args: + app_name: Application name. + user_id: User identifier. + session_id: Session identifier (None for user-scoped only). + + Returns: + List of distinct artifact filenames. + """ + + @abstractmethod + async def list_artifact_versions( + self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None + ) -> "list[ArtifactRecord]": + """List all version records for an artifact, ordered by version ascending. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + + Returns: + List of artifact records ordered by version ascending. + """ + + @abstractmethod + async def delete_artifact( + self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None + ) -> "list[ArtifactRecord]": + """Delete all version records for an artifact and return them. + + The caller uses the returned records to clean up content from + object storage. Metadata is deleted first (fail-fast); content + cleanup is best-effort. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + + Returns: + List of deleted artifact records (needed for content cleanup). + """ + + @abstractmethod + async def get_next_version( + self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None + ) -> int: + """Get the next version number for an artifact. + + Returns 0 if no versions exist (first version), otherwise + ``max(version) + 1``. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + + Returns: + Next version number (0-based). + """ + + @abstractmethod + async def create_table(self) -> None: + """Create the artifact versions table if it does not exist.""" + + async def ensure_table(self) -> None: + """Create the artifact table and emit a standardized log entry.""" + await self.create_table() + log_with_context( + logger, + logging.DEBUG, + "adk.artifact.table.ready", + db_system=resolve_db_system(type(self).__name__), + artifact_table=self._artifact_table, + ) + + +class BaseSyncADKArtifactStore(ABC, Generic[ConfigT]): + """Base class for sync SQLSpec-backed ADK artifact metadata stores. + + Synchronous counterpart of :class:`BaseAsyncADKArtifactStore`. + + Args: + config: SQLSpec database configuration with extension_config["adk"] settings. + """ + + __slots__ = ("_artifact_table", "_config") + + def __init__(self, config: ConfigT) -> None: + """Initialize the sync ADK artifact store. + + Args: + config: SQLSpec database configuration. + """ + self._config = config + adk_config = self._get_adk_config() + self._artifact_table: str = str(adk_config.get("artifact_table", "adk_artifact_versions")) + _validate_table_name(self._artifact_table) + + def _get_adk_config(self) -> "dict[str, Any]": + """Extract ADK configuration from extension_config. + + Returns: + Dict with ADK configuration values. + """ + extension_config = self._config.extension_config + return dict(cast("ADKConfig", extension_config.get("adk", {}))) + + @property + def config(self) -> ConfigT: + """Return the database configuration.""" + return self._config + + @property + def artifact_table(self) -> str: + """Return the artifact versions table name.""" + return self._artifact_table + + @abstractmethod + def insert_artifact(self, record: "ArtifactRecord") -> None: + """Insert an artifact version metadata row. + + Args: + record: Artifact metadata record to insert. + """ + + @abstractmethod + def get_artifact( + self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None, version: "int | None" = None + ) -> "ArtifactRecord | None": + """Get a specific artifact version's metadata. + + When ``version`` is None, returns the latest version. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + version: Specific version number, or None for latest. + + Returns: + Artifact record if found, None otherwise. + """ + + @abstractmethod + def list_artifact_keys(self, app_name: str, user_id: str, session_id: "str | None" = None) -> "list[str]": + """List distinct artifact filenames. + + Args: + app_name: Application name. + user_id: User identifier. + session_id: Session identifier (None for user-scoped only). + + Returns: + List of distinct artifact filenames. + """ + + @abstractmethod + def list_artifact_versions( + self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None + ) -> "list[ArtifactRecord]": + """List all version records for an artifact, ordered by version ascending. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + + Returns: + List of artifact records ordered by version ascending. + """ + + @abstractmethod + def delete_artifact( + self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None + ) -> "list[ArtifactRecord]": + """Delete all version records for an artifact and return them. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + + Returns: + List of deleted artifact records (needed for content cleanup). + """ + + @abstractmethod + def get_next_version(self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None) -> int: + """Get the next version number for an artifact. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + + Returns: + Next version number (0-based). + """ + + @abstractmethod + def create_table(self) -> None: + """Create the artifact versions table if it does not exist.""" + + def ensure_table(self) -> None: + """Create the artifact table and emit a standardized log entry.""" + self.create_table() + log_with_context( + logger, + logging.DEBUG, + "adk.artifact.table.ready", + db_system=resolve_db_system(type(self).__name__), + artifact_table=self._artifact_table, + ) From c2e151830a34e8a476fef1d07219d0a64866d7db Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 21 Mar 2026 17:28:46 +0000 Subject: [PATCH 15/23] docs(adk): update documentation for clean-break architecture --- docs/extensions/adk/adapters.rst | 35 ++- docs/extensions/adk/api.rst | 73 ++++++- docs/extensions/adk/backends.rst | 256 +++++++++++++++++++--- docs/extensions/adk/index.rst | 53 +++-- docs/extensions/adk/installation.rst | 44 +++- docs/extensions/adk/migrations.rst | 28 +++ docs/extensions/adk/quickstart.rst | 148 ++++++++++--- docs/extensions/adk/schema.rst | 305 ++++++++++++++++++++++++++- docs/reference/extensions/adk.rst | 31 ++- 9 files changed, 878 insertions(+), 95 deletions(-) diff --git a/docs/extensions/adk/adapters.rst b/docs/extensions/adk/adapters.rst index 5ff658122..72b0d06b7 100644 --- a/docs/extensions/adk/adapters.rst +++ b/docs/extensions/adk/adapters.rst @@ -10,11 +10,36 @@ Choosing an Adapter Use async adapters for best performance with ADK runners: -- **PostgreSQL**: ``asyncpg`` (recommended), ``psycopg`` (async mode) -- **SQLite**: ``aiosqlite`` -- **MySQL**: ``asyncmy`` +- **PostgreSQL** (recommended): ``asyncpg``, ``psycopg`` (async mode), ``psqlpy`` +- **CockroachDB**: ``cockroach_asyncpg``, ``cockroach_psycopg`` (full FTS support) +- **MySQL/MariaDB**: ``asyncmy`` +- **SQLite**: ``aiosqlite`` (development and single-process) +- **Oracle**: ``oracledb`` +- **DuckDB**: ``duckdb`` (analytics; reduced-scope for ADK) +- **ADBC**: ``adbc`` (Arrow-native, driver-agnostic) +- **Spanner**: ``spanner`` (Google Cloud, globally distributed) -Sync adapters work but require wrapping with ``anyio`` for async ADK runners. +Sync adapters (``psycopg`` sync mode, ``sqlite``, ``mysqlconnector``, ``pymysql``) +work but require wrapping with ``anyio`` for async ADK runners. + +Each Adapter Provides +===================== + +Every adapter with ADK support ships three store classes: + +- **Session store** (e.g., ``AsyncpgADKStore``) -- sessions and events. +- **Memory store** (e.g., ``AsyncpgADKMemoryStore``) -- long-term memory with FTS. +- **Artifact store** (e.g., ``AsyncpgADKArtifactStore``) -- artifact metadata. + +Import from the adapter's ``adk`` subpackage: + +.. code-block:: python + + from sqlspec.adapters.asyncpg.adk import ( + AsyncpgADKStore, + AsyncpgADKMemoryStore, + AsyncpgADKArtifactStore, + ) Example ======= @@ -30,6 +55,6 @@ Example See Also ======== -- :doc:`backends` for the full adapter support matrix. +- :doc:`backends` for the full support matrix and backend-specific notes. - :doc:`/usage/drivers_and_querying` for adapter configuration patterns. - :doc:`/reference/adapters` for the complete adapter API. diff --git a/docs/extensions/adk/api.rst b/docs/extensions/adk/api.rst index 95d2431d4..e95a28a31 100644 --- a/docs/extensions/adk/api.rst +++ b/docs/extensions/adk/api.rst @@ -19,8 +19,20 @@ Services :show-inheritance: :no-index: -Base Stores -=========== +.. autoclass:: sqlspec.extensions.adk.memory.SQLSpecSyncMemoryService + :members: + :undoc-members: + :show-inheritance: + :no-index: + +.. autoclass:: SQLSpecArtifactService + :members: + :undoc-members: + :show-inheritance: + :no-index: + +Session Stores +============== .. autoclass:: BaseAsyncADKStore :members: @@ -34,6 +46,9 @@ Base Stores :show-inheritance: :no-index: +Memory Stores +============= + .. autoclass:: BaseAsyncADKMemoryStore :members: :undoc-members: @@ -45,3 +60,57 @@ Base Stores :undoc-members: :show-inheritance: :no-index: + +Artifact Stores +=============== + +.. autoclass:: BaseAsyncADKArtifactStore + :members: + :undoc-members: + :show-inheritance: + :no-index: + +.. autoclass:: BaseSyncADKArtifactStore + :members: + :undoc-members: + :show-inheritance: + :no-index: + +Record Types +============ + +.. autoclass:: SessionRecord + :members: + :show-inheritance: + :no-index: + +.. autoclass:: EventRecord + :members: + :show-inheritance: + :no-index: + +.. autoclass:: MemoryRecord + :members: + :show-inheritance: + :no-index: + +.. autoclass:: ArtifactRecord + :members: + :show-inheritance: + :no-index: + +Configuration +============= + +.. autoclass:: ADKConfig + :members: + :show-inheritance: + :no-index: + +Converters +========== + +.. automodule:: sqlspec.extensions.adk.converters + :members: + :undoc-members: + :no-index: diff --git a/docs/extensions/adk/backends.rst b/docs/extensions/adk/backends.rst index e990286c9..1f770f68d 100644 --- a/docs/extensions/adk/backends.rst +++ b/docs/extensions/adk/backends.rst @@ -2,49 +2,249 @@ Backends ======== -ADK stores are implemented per adapter. Use the backend config helpers when -connecting to multiple databases or configuring advanced options. +ADK stores are implemented per adapter. Each backend has different capabilities +for session, event, memory, and artifact storage. Use the support matrix below +to select the right backend for your deployment. -Example -======= +.. _adk-support-matrix: -.. literalinclude:: /examples/extensions/adk/backend_config.py - :language: python - :caption: ``adk backend config`` - :start-after: # start-example - :end-before: # end-example - :dedent: 4 - :no-upgrade: +Support Matrix +============== -Supported Backends -================== +The table below classifies every backend by its ADK support level. .. list-table:: :header-rows: 1 + :widths: 20 15 15 15 15 20 * - Adapter - Status + - Session/Event + - Memory (FTS) + - Artifacts + - Notes * - asyncpg - - Production - * - psycopg - - Production + - Recommended + - Full + - Full + - Full + - Best async PostgreSQL driver. + * - psycopg (async) + - Recommended + - Full + - Full + - Full + - Supports both sync and async modes. * - psqlpy - - Production + - Supported + - Full + - Full + - Full + - Rust-backed PostgreSQL driver. + * - cockroach_asyncpg + - Supported + - Full + - Full + - Full + - CockroachDB with full FTS support. + * - cockroach_psycopg + - Supported + - Full + - Full + - Full + - CockroachDB with full FTS support. * - asyncmy - - Production - * - sqlite - - Production + - Supported + - Full + - Full + - Full + - MySQL/MariaDB async driver. + * - mysqlconnector + - Supported + - Full + - Full + - Full + - MySQL/MariaDB sync driver. + * - pymysql + - Supported + - Full + - Full + - Full + - MySQL/MariaDB sync driver. * - aiosqlite - - Production + - Supported + - Full + - Full + - Full + - SQLite async, ideal for development. + * - sqlite + - Supported + - Full + - Full + - Full + - SQLite sync with thread-local pools. * - oracledb - - Production + - Supported + - Full + - Full + - Full + - Oracle Database driver. * - duckdb - - Production (analytics) + - Reduced-scope + - Full + - Limited + - Full + - Analytics-oriented; no concurrent writes. * - adbc - - Production + - Supported + - Full + - Full + - Full + - Arrow-native database connectivity. + * - spanner + - Supported + - Full + - Full + - Full + - Google Cloud Spanner (cloud-managed). + +Status Definitions +------------------ + +**Recommended** + Production-grade, fully tested, actively optimized. Start here unless you + have a specific reason not to. + +**Supported** + Fully implemented and tested. Works correctly for all ADK operations. + +**Reduced-scope** + Implemented with known limitations. Specific features may be absent or + behave differently. See backend-specific notes. + +**Removed** + Previously available but no longer supported. See the removal notice for + migration guidance. + +Removed Backends +---------------- + +**BigQuery** was removed from the ADK backend surface. BigQuery's batch-oriented +architecture is incompatible with the low-latency, transactional write patterns +that ADK session and event storage require. If you were using BigQuery for ADK +storage, migrate to PostgreSQL (asyncpg or psycopg) or any other supported +backend. + +Backend Details +=============== + +PostgreSQL Family +----------------- + +PostgreSQL backends (asyncpg, psycopg, psqlpy) provide the fullest feature set: + +- Native ``JSONB`` storage for session state and event JSON. +- Full-text search via ``tsvector`` for memory entries. +- ``UPSERT`` and ``RETURNING`` clauses for atomic operations. +- ``append_event_and_update_state()`` executes as a single transaction. + +**Recommended for production deployments.** + +CockroachDB +------------ + +CockroachDB backends (cockroach_asyncpg, cockroach_psycopg) provide full ADK +support including full-text search. CockroachDB is a distributed SQL database +compatible with the PostgreSQL wire protocol. + +- Full FTS support for memory search. +- Distributed transactions for session and event atomicity. +- Horizontal scalability for high-throughput agent deployments. + +MySQL Family +------------ + +MySQL backends (asyncmy, mysqlconnector, pymysql) provide full ADK support: + +- JSON column storage for session state and event records. +- Full-text search on ``InnoDB`` tables for memory entries. +- Transactional writes for ``append_event_and_update_state()``. + +SQLite +------ + +SQLite backends (aiosqlite, sqlite) are ideal for local development, testing, +and single-process deployments: + +- JSON1 extension for state and event storage. +- FTS5 virtual tables for memory full-text search. +- File-based or in-memory operation. + +.. note:: + + SQLite does not support concurrent writers. Use a server-backed database + for production multi-process deployments. + +Oracle +------ + +Oracle Database (oracledb) provides full ADK support: + +- Native JSON column support (Oracle 21c+). +- Oracle Text for full-text search on memory entries. +- Full transactional support for atomic operations. + +DuckDB +------ + +DuckDB provides session and event storage but has limitations: + +- Optimized for analytics, not OLTP workloads. +- Single-writer constraint limits concurrent access. +- Memory search capabilities are limited compared to server databases. + +**Best suited for analytics pipelines and offline agent evaluation.** + +ADBC +---- + +ADBC (Arrow Database Connectivity) provides a driver-agnostic interface: + +- Works with any ADBC-compatible driver (PostgreSQL, SQLite, DuckDB, etc.). +- Arrow-native data transfer for high-throughput event ingestion. +- Backend capabilities depend on the underlying database driver. + +Spanner +------- + +Google Cloud Spanner provides globally distributed ADK storage: + +- Cloud-managed, horizontally scalable. +- Full-text search support for memory entries. +- Strong consistency across regions. +- Suitable for multi-region agent deployments. + +Configuration +============= + +All backends are configured through ``extension_config["adk"]``: + +.. code-block:: python + + from sqlspec.adapters.asyncpg import AsyncpgConfig -Notes -===== + config = AsyncpgConfig( + connection_config={"dsn": "postgresql://localhost/mydb"}, + extension_config={ + "adk": { + "session_table": "adk_sessions", + "events_table": "adk_events", + "memory_table": "adk_memory_entries", + "memory_use_fts": True, + "artifact_table": "adk_artifact_versions", + "owner_id_column": "tenant_id INTEGER NOT NULL", + } + }, + ) -- Use async backends for ADK runners; sync backends can be wrapped with anyio. -- Backend stores expose ``create_tables`` to bootstrap schema. +See :doc:`adapters` for adapter-specific configuration patterns. diff --git a/docs/extensions/adk/index.rst b/docs/extensions/adk/index.rst index d83c786c1..de771ccc8 100644 --- a/docs/extensions/adk/index.rst +++ b/docs/extensions/adk/index.rst @@ -2,8 +2,23 @@ Google ADK Extension ==================== -SQLSpec provides an ADK extension for session, event, and memory storage with -SQL-backed persistence. +SQLSpec provides a full-featured backend for +`Google Agent Development Kit `_, +covering session, event, memory, and artifact storage with SQL-backed +persistence across 14 database adapters. + +Key capabilities: + +- **Session and event storage** with atomic ``append_event_and_update_state()`` + ensuring events and state are always consistent. +- **Full-event JSON storage** (EventRecord) that captures the entire ADK Event + in a single column, eliminating schema drift with upstream ADK releases. +- **Scoped state semantics** (``app:``, ``user:``, ``temp:``) for controlling + state visibility and persistence across sessions. +- **Memory service** with database-native full-text search (tsvector, FTS5, + InnoDB FT) for long-term agent context. +- **Artifact service** with append-only versioning, SQL metadata, and pluggable + object storage backends. Choose a guide ============== @@ -22,45 +37,45 @@ Choose a guide :link: quickstart :link-type: doc - Persist memory and sessions with minimal setup. + Persist sessions, memory, and artifacts with minimal setup. - .. grid-item-card:: API Reference - :link: api + .. grid-item-card:: Support Matrix + :link: backends :link-type: doc - Interfaces, stores, and configuration helpers. + See which backends are recommended, supported, or reduced-scope. .. grid-item-card:: Adapters :link: adapters :link-type: doc - Configure supported SQLSpec adapters. + Configure supported SQLSpec adapters for ADK. - .. grid-item-card:: Backends - :link: backends + .. grid-item-card:: Schema + :link: schema :link-type: doc - Storage backends and connection profiles. + Table layouts, EventRecord, scoped state, and artifact metadata. - .. grid-item-card:: Migrations - :link: migrations + .. grid-item-card:: API Reference + :link: api :link-type: doc - Apply schema changes safely over time. + Services, stores, and record types. - .. grid-item-card:: Schema - :link: schema + .. grid-item-card:: Migrations + :link: migrations :link-type: doc - Table layouts for sessions and memory records. + Apply schema changes safely over time. .. toctree:: :hidden: installation quickstart - api - adapters backends - migrations + adapters schema + api + migrations diff --git a/docs/extensions/adk/installation.rst b/docs/extensions/adk/installation.rst index 8304d72f7..347f8aa13 100644 --- a/docs/extensions/adk/installation.rst +++ b/docs/extensions/adk/installation.rst @@ -6,7 +6,7 @@ Install SQLSpec with a database adapter and the Google ADK SDK. .. tab-set:: - .. tab-item:: PostgreSQL + .. tab-item:: PostgreSQL (recommended) .. tab-set:: @@ -90,6 +90,34 @@ Install SQLSpec with a database adapter and the Google ADK SDK. pdm add "sqlspec[asyncmy,adk]" + .. tab-item:: CockroachDB + + .. tab-set:: + + .. tab-item:: uv + + .. code-block:: bash + + uv add "sqlspec[cockroach-asyncpg,adk]" + + .. tab-item:: pip + + .. code-block:: bash + + pip install "sqlspec[cockroach-asyncpg,adk]" + + .. tab-item:: Poetry + + .. code-block:: bash + + poetry add "sqlspec[cockroach-asyncpg,adk]" + + .. tab-item:: PDM + + .. code-block:: bash + + pdm add "sqlspec[cockroach-asyncpg,adk]" + .. tab-item:: DuckDB .. tab-set:: @@ -123,11 +151,17 @@ What This Provides The ``adk`` extra includes the Google ADK SDK (``google-genai``). SQLSpec provides: -- **Session Store** - Persist ADK agent sessions to your database. -- **Memory Store** - Store agent memory for context across conversations. -- **Event Store** - Log agent events for observability. +- **Session Service** -- Persist ADK agent sessions and events to your database + with atomic ``append_event_and_update_state()`` writes. +- **Memory Service** -- Store agent memory with database-native full-text search + for context retrieval across conversations. +- **Artifact Service** -- Version and store binary artifacts with SQL metadata + and pluggable object storage backends. +- **Event Storage** -- Full-event JSON storage (EventRecord) that captures the + entire ADK Event without schema drift. Next Steps ---------- -Proceed to :doc:`quickstart` to set up stores for your ADK agent. +Proceed to :doc:`quickstart` to set up stores for your ADK agent, or see +:doc:`backends` for the full support matrix. diff --git a/docs/extensions/adk/migrations.rst b/docs/extensions/adk/migrations.rst index 948969ae8..0bbb9c85c 100644 --- a/docs/extensions/adk/migrations.rst +++ b/docs/extensions/adk/migrations.rst @@ -5,4 +5,32 @@ Migrations ADK stores use standard SQLSpec migrations. Generate migrations for the database used by your ADK backend, then run them with the SQLSpec migration CLI. +Schema Bootstrapping +==================== + +For development, use ``ensure_tables()`` to create tables on first use: + +.. code-block:: python + + await session_store.ensure_tables() + await memory_store.ensure_tables() + await artifact_store.ensure_table() + +For production, run migrations ahead of deployment to avoid runtime DDL. + +Clean-Break Migration Notes +============================ + +If you are upgrading from a pre-clean-break version of the ADK extension, +note the following schema changes: + +- **Events table**: The column layout changed to full-event JSON storage. + The ``event_json`` column now stores the entire ADK Event as a JSON blob. + Individual event columns (``content``, ``actions``, ``branch``, etc.) have + been replaced by indexed scalar columns (``invocation_id``, ``author``, + ``timestamp``) plus ``event_json``. +- **Artifact table**: New table (``adk_artifact_versions``) for artifact + metadata. Create this table when enabling the artifact service. +- **BigQuery**: Removed. Migrate to PostgreSQL or any other supported backend. + See :doc:`/usage/migrations` for the full workflow and commands. diff --git a/docs/extensions/adk/quickstart.rst b/docs/extensions/adk/quickstart.rst index 93c829100..af14b00ca 100644 --- a/docs/extensions/adk/quickstart.rst +++ b/docs/extensions/adk/quickstart.rst @@ -2,56 +2,140 @@ Quickstart ========== -Wire SQLSpec stores into your ADK agent to persist sessions and memory across restarts. +Wire SQLSpec stores into your ADK agent to persist sessions, events, memory, +and artifacts across restarts. How It Works ============ -1. Create a SQLSpec database config. -2. Initialize ADK stores (session, memory, event) backed by that config. -3. Pass the stores to your ADK agent. +1. Create a SQLSpec database config with ADK extension settings. +2. Initialize the appropriate stores (session, memory, artifact). +3. Pass the service wrappers to your ADK agent. -Session Store -============= +Session Service +=============== -The session store persists agent state between conversations. When a user returns, -the agent can resume from where it left off. +The session service persists agent state and events between conversations. +When a user returns, the agent can resume from where it left off. -.. literalinclude:: /examples/extensions/adk/memory_store.py - :language: python - :caption: ``adk session store`` - :start-after: # start-example - :end-before: # end-example - :dedent: 4 - :no-upgrade: +.. code-block:: python + + from sqlspec.adapters.asyncpg import AsyncpgConfig + from sqlspec.adapters.asyncpg.adk import AsyncpgADKStore + from sqlspec.extensions.adk import SQLSpecSessionService + + config = AsyncpgConfig( + connection_config={"dsn": "postgresql://localhost/mydb"}, + extension_config={ + "adk": { + "session_table": "adk_sessions", + "events_table": "adk_events", + } + }, + ) + + store = AsyncpgADKStore(config) + await store.ensure_tables() + + session_service = SQLSpecSessionService(store) + + # Create a session with scoped state + session = await session_service.create_session( + app_name="my_agent", + user_id="user_123", + state={ + "app:model": "gemini-2.0", # shared across all sessions + "user:name": "Alice", # shared across user's sessions + "conversation_turn": 0, # session-local + "temp:scratch": "...", # runtime-only, never persisted + }, + ) + +Events are persisted automatically when you use the session service with an +ADK runner. Each call to ``append_event()`` atomically stores the event and +updates the session's durable state via ``append_event_and_update_state()``. + +Scoped State +------------ + +State keys use prefixes to control their scope and persistence: + +- ``app:`` -- shared across all sessions for the same application. +- ``user:`` -- shared across all sessions for the same user. +- ``temp:`` -- runtime-only, stripped before every write to storage. +- *(no prefix)* -- private to the current session. + +See :ref:`scoped-state` for full details. + +Memory Service +============== + +The memory service retains context that the agent can reference later. This +enables long-term memory across sessions with full-text search. + +.. code-block:: python + + from sqlspec.adapters.asyncpg.adk import AsyncpgADKMemoryStore + from sqlspec.extensions.adk import SQLSpecMemoryService + + memory_store = AsyncpgADKMemoryStore(config) + await memory_store.ensure_tables() + + memory_service = SQLSpecMemoryService(memory_store) + +Enable full-text search by setting ``memory_use_fts: True`` in the ADK config. +This creates database-native FTS indexes (tsvector, FTS5, InnoDB FT) for +efficient memory retrieval. + +Artifact Service +================ + +The artifact service stores binary artifacts (files, images, reports) with +automatic versioning. Metadata lives in SQL; content lives in object storage. + +.. code-block:: python + + from sqlspec.adapters.asyncpg.adk import AsyncpgADKArtifactStore + from sqlspec.extensions.adk import SQLSpecArtifactService + + artifact_store = AsyncpgADKArtifactStore(config) + await artifact_store.ensure_table() -Memory Store Integration -======================== + artifact_service = SQLSpecArtifactService( + store=artifact_store, + artifact_storage_uri="s3://my-bucket/adk-artifacts/", + ) -The memory store retains context that the agent can reference later. This enables -long-term memory across sessions. + # Save an artifact (returns version number starting from 0) + version = await artifact_service.save_artifact( + app_name="my_agent", + user_id="user_123", + filename="report.pdf", + artifact=part, + ) -.. literalinclude:: /examples/extensions/adk/tool_integration.py - :language: python - :caption: ``adk memory integration`` - :start-after: # start-example - :end-before: # end-example - :dedent: 4 - :no-upgrade: + # Load the latest version + loaded = await artifact_service.load_artifact( + app_name="my_agent", + user_id="user_123", + filename="report.pdf", + ) Schema Setup ============ -Stores create their tables automatically on first use. For production, run migrations -ahead of time: +Stores create their tables on first use via ``ensure_tables()``. For +production, run table creation ahead of deployment: .. code-block:: python - await session_store.create_tables() - await memory_store.create_tables() + await session_store.ensure_tables() + await memory_store.ensure_tables() + await artifact_store.ensure_table() Next Steps ========== -- :doc:`backends` for adapter-specific configuration. -- :doc:`schema` for table layouts and indexes. +- :doc:`backends` for the full support matrix and backend-specific details. +- :doc:`schema` for table layouts, EventRecord format, and scoped state semantics. +- :doc:`api` for the complete API reference. diff --git a/docs/extensions/adk/schema.rst b/docs/extensions/adk/schema.rst index 18a1482bd..b49a4996d 100644 --- a/docs/extensions/adk/schema.rst +++ b/docs/extensions/adk/schema.rst @@ -2,7 +2,306 @@ Schema ====== -ADK stores create tables for sessions, events, and memory entries. Table names -and schemas are configurable via the store config. +ADK stores create tables for sessions, events, memory entries, and artifact +metadata. Table names are configurable via ``extension_config["adk"]``. -Use ``create_tables()`` on a store to apply the schema. +Use ``create_tables()`` or ``ensure_tables()`` on a store to apply the schema. + +.. contents:: On this page + :local: + :depth: 2 + +Sessions Table +============== + +The sessions table stores agent session metadata and durable state. + +Default name: ``adk_sessions`` + +.. list-table:: + :header-rows: 1 + + * - Column + - Type + - Description + * - ``id`` + - ``VARCHAR`` / ``TEXT`` + - Primary key. UUID assigned by the service layer. + * - ``app_name`` + - ``VARCHAR`` / ``TEXT`` + - Application identifier. + * - ``user_id`` + - ``VARCHAR`` / ``TEXT`` + - User identifier. + * - ``state`` + - ``JSONB`` / ``JSON`` / ``TEXT`` + - Durable session state (see :ref:`scoped-state`). + * - ``create_time`` + - ``TIMESTAMP`` + - When the session was created (UTC). + * - ``update_time`` + - ``TIMESTAMP`` + - Last state update time (UTC). + +An optional ``owner_id`` column can be added via ``owner_id_column`` in the ADK +config for multi-tenant deployments. + +.. _event-record: + +Events Table (EventRecord) +========================== + +The events table uses **full-event JSON storage**: the entire ADK ``Event`` is +serialized into a single ``event_json`` column alongside a small set of indexed +scalar columns used for query filtering. + +This design eliminates column drift with upstream ADK releases. New ``Event`` +fields are automatically captured in ``event_json`` without schema changes. + +Default name: ``adk_events`` + +.. list-table:: + :header-rows: 1 + + * - Column + - Type + - Description + * - ``session_id`` + - ``VARCHAR`` / ``TEXT`` + - Foreign key to the sessions table. + * - ``invocation_id`` + - ``VARCHAR`` / ``TEXT`` + - ADK invocation identifier (indexed for filtering). + * - ``author`` + - ``VARCHAR`` / ``TEXT`` + - Event author: ``"user"``, ``"agent"``, or ``"system"``. + * - ``timestamp`` + - ``TIMESTAMP`` + - Event timestamp (UTC, indexed for range queries). + * - ``event_json`` + - ``JSONB`` / ``JSON`` / ``TEXT`` + - Full ADK Event serialized via ``Event.model_dump()``. + +**Serialization and reconstruction:** + +Events are converted to records via ``event_to_record()``, which calls +``event.model_dump(exclude_none=True, mode="json")`` to produce the JSON blob. +Reconstruction is lossless: ``record_to_event()`` restores the full ``Event`` +via ``Event.model_validate()``. + +.. code-block:: python + + from sqlspec.extensions.adk.converters import event_to_record, record_to_event + + # Serialize: Event -> EventRecord + record = event_to_record(event=adk_event, session_id="sess_123") + + # Reconstruct: EventRecord -> Event + restored_event = record_to_event(record) + +.. _scoped-state: + +Scoped State Semantics +====================== + +ADK uses key prefixes to scope state visibility across sessions. SQLSpec +respects these prefixes when persisting and loading state. + +.. list-table:: + :header-rows: 1 + + * - Prefix + - Scope + - Persisted + - Description + * - ``app:`` + - Application + - Yes + - Shared across all sessions for the same ``app_name``. + * - ``user:`` + - User + - Yes + - Shared across all sessions for the same ``app_name`` + ``user_id``. + * - ``temp:`` + - Runtime + - **No** + - Process-local state. Stripped before every write to storage. + * - *(no prefix)* + - Session + - Yes + - Private to a single session. + +**How scoped state is handled:** + +1. On ``create_session()``, the service strips ``temp:`` keys before the + initial ``INSERT``. + +2. On ``append_event()``, the service calls ``filter_temp_state()`` to produce + a durable state snapshot, then calls ``append_event_and_update_state()`` to + atomically persist the event and the state update. + +3. On ``get_session()``, state is loaded from the database. Since ``temp:`` + keys were never written, they are absent from the loaded state. + +.. code-block:: python + + from sqlspec.extensions.adk.converters import filter_temp_state, split_scoped_state + + state = { + "app:model_version": "v2", + "user:preferences": {"theme": "dark"}, + "temp:scratch_pad": "...", + "conversation_turn": 5, + } + + # Strip temp keys before persisting + durable = filter_temp_state(state) + # {"app:model_version": "v2", "user:preferences": {...}, "conversation_turn": 5} + + # Split into scoped buckets + app_state, user_state, session_state = split_scoped_state(durable) + # app_state: {"app:model_version": "v2"} + # user_state: {"user:preferences": {"theme": "dark"}} + # session_state: {"conversation_turn": 5} + +.. _append-event-contract: + +The ``append_event_and_update_state()`` Contract +================================================= + +This method is the **authoritative durable write boundary** for post-creation +session mutations. It atomically: + +1. Inserts the event record into the events table. +2. Updates the session's durable state in the sessions table. + +Both operations succeed together or fail together within a single database +transaction. + +.. code-block:: python + + # Called by SQLSpecSessionService.append_event() internally: + await store.append_event_and_update_state( + event_record=event_record, + session_id=session.id, + state=durable_state, # temp: keys already stripped + ) + +**Why this matters:** + +- Prevents state from advancing without the corresponding event being recorded. +- Prevents orphaned events that reference a stale session state. +- Ensures that on session reload, the state always reflects all persisted events. + +Every backend store implements this as a single transaction (or equivalent +atomic operation for the backend's concurrency model). + +Memory Table +============ + +The memory table stores long-term context entries that agents can search and +reference across sessions. + +Default name: ``adk_memory_entries`` + +.. list-table:: + :header-rows: 1 + + * - Column + - Type + - Description + * - ``id`` + - ``VARCHAR`` / ``TEXT`` + - Primary key. + * - ``session_id`` + - ``VARCHAR`` / ``TEXT`` + - Session that produced this memory. + * - ``app_name`` + - ``VARCHAR`` / ``TEXT`` + - Application identifier. + * - ``user_id`` + - ``VARCHAR`` / ``TEXT`` + - User identifier. + * - ``content_text`` + - ``TEXT`` + - Searchable text content (used by FTS). + * - ``content_json`` + - ``JSONB`` / ``JSON`` / ``TEXT`` + - Structured content. + * - ``inserted_at`` + - ``TIMESTAMP`` + - When the entry was created. + +When ``memory_use_fts`` is enabled in the ADK config, backends create +full-text search indexes on ``content_text`` using the database's native +FTS engine (tsvector, FTS5, InnoDB FT, etc.). + +.. _artifact-schema: + +Artifact Metadata Table +======================= + +The artifact table stores versioning metadata for binary artifacts. Content +bytes are stored separately in object storage; this table tracks ownership, +versioning, and canonical URIs. + +Default name: ``adk_artifact_versions`` + +.. list-table:: + :header-rows: 1 + + * - Column + - Type + - Description + * - ``app_name`` + - ``VARCHAR`` / ``TEXT`` + - Application identifier. + * - ``user_id`` + - ``VARCHAR`` / ``TEXT`` + - User identifier. + * - ``session_id`` + - ``VARCHAR`` / ``TEXT`` (nullable) + - Session identifier. NULL for user-scoped artifacts. + * - ``filename`` + - ``VARCHAR`` / ``TEXT`` + - Artifact filename. + * - ``version`` + - ``INTEGER`` + - Monotonically increasing version (starts at 0). + * - ``mime_type`` + - ``VARCHAR`` / ``TEXT`` (nullable) + - MIME type of the artifact content. + * - ``canonical_uri`` + - ``VARCHAR`` / ``TEXT`` + - URI pointing to content in object storage. + * - ``custom_metadata`` + - ``JSONB`` / ``JSON`` / ``TEXT`` (nullable) + - User-defined metadata. + * - ``created_at`` + - ``TIMESTAMP`` + - When this version was created. + +The composite key is ``(app_name, user_id, session_id, filename, version)``. + +Table Name Configuration +======================== + +All table names are configurable: + +.. code-block:: python + + config = AsyncpgConfig( + connection_config={"dsn": "postgresql://..."}, + extension_config={ + "adk": { + "session_table": "my_sessions", # default: "adk_sessions" + "events_table": "my_events", # default: "adk_events" + "memory_table": "my_memory", # default: "adk_memory_entries" + "artifact_table": "my_artifacts", # default: "adk_artifact_versions" + } + }, + ) + +Table names are validated on store initialization: they must start with a +letter or underscore, contain only alphanumeric characters and underscores, +and be at most 63 characters long. diff --git a/docs/reference/extensions/adk.rst b/docs/reference/extensions/adk.rst index aa0ca33a9..3c08b6ed7 100644 --- a/docs/reference/extensions/adk.rst +++ b/docs/reference/extensions/adk.rst @@ -2,7 +2,7 @@ Google ADK ========== -Session, event, and memory storage backends for +Session, event, memory, and artifact storage backends for `Google Agent Development Kit `_. Session Service @@ -23,6 +23,13 @@ Memory Services :members: :show-inheritance: +Artifact Service +================ + +.. autoclass:: sqlspec.extensions.adk.SQLSpecArtifactService + :members: + :show-inheritance: + Store Base Classes ================== @@ -45,6 +52,17 @@ Memory Store Base Classes :members: :show-inheritance: +Artifact Store Base Classes +=========================== + +.. autoclass:: sqlspec.extensions.adk.BaseAsyncADKArtifactStore + :members: + :show-inheritance: + +.. autoclass:: sqlspec.extensions.adk.BaseSyncADKArtifactStore + :members: + :show-inheritance: + Record Types ============ @@ -59,3 +77,14 @@ Record Types .. autoclass:: sqlspec.extensions.adk.MemoryRecord :members: :show-inheritance: + +.. autoclass:: sqlspec.extensions.adk.ArtifactRecord + :members: + :show-inheritance: + +Configuration +============= + +.. autoclass:: sqlspec.extensions.adk.ADKConfig + :members: + :show-inheritance: From 5bbe63b95bc26301cdffd194c973d117b977bc70 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 21 Mar 2026 17:35:10 +0000 Subject: [PATCH 16/23] test(adk): update integration tests for clean-break contract Update all ADK integration tests to work with the new 5-column EventRecord (session_id, invocation_id, author, timestamp, event_json) instead of the old 17-column decomposed schema. Changes per adapter: - asyncmy: Replace old event dicts with 5-key EventRecord for append_event; update schema verification to check new columns - mysqlconnector: Same as asyncmy - duckdb: Remove get_event calls (method removed), update event assertions to check event_json instead of top-level columns - oracledb: Replace 17-column EventRecord literals with 5-key records; remove old boolean/LOB column-specific tests - spanner: Update event content assertions to read from event_json - adbc: Update all event operation tests to verify 5-key shape; update DDL tests to check for event_json column types instead of old BLOB/BOOLEAN/INTEGER columns No BigQuery ADK tests exist (already removed in prior chapter). --- .../adk/test_dialect_integration.py | 13 +- .../extensions/adk/test_dialect_support.py | 33 ++-- .../adbc/extensions/adk/test_edge_cases.py | 23 ++- .../extensions/adk/test_event_operations.py | 147 +++++------------- .../asyncmy/extensions/adk/test_store.py | 74 ++++----- .../duckdb/extensions/adk/test_store.py | 85 +++++----- .../extensions/adk/test_store.py | 74 ++++----- .../extensions/adk/test_oracle_specific.py | 140 ++++++----------- .../spanner/extensions/adk/test_adk_store.py | 15 +- 9 files changed, 231 insertions(+), 373 deletions(-) diff --git a/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py b/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py index 7c0451ba8..653682fdc 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py @@ -10,6 +10,7 @@ if the driver is not installed. """ +import json from pathlib import Path from typing import Any @@ -74,20 +75,20 @@ def test_sqlite_dialect_event_operations(sqlite_store: Any) -> None: sqlite_store.create_session(session_id, app_name, user_id, {}) - event_id = "event-1" - actions = b"pickled_actions_data" content = {"message": "Hello"} event = sqlite_store.create_event( - event_id=event_id, session_id=session_id, app_name=app_name, user_id=user_id, actions=actions, content=content + event_id="event-1", session_id=session_id, app_name=app_name, user_id=user_id, content=content ) - assert event["id"] == event_id - assert event["content"] == content + assert event["session_id"] == session_id + event_data = json.loads(event["event_json"]) if isinstance(event["event_json"], str) else event["event_json"] + assert event_data["content"] == content events = sqlite_store.list_events(session_id) assert len(events) == 1 - assert events[0]["content"] == content + retrieved_data = json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + assert retrieved_data["content"] == content @pytest.mark.postgres diff --git a/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py b/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py index c87302f23..fa5130dbd 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py @@ -89,42 +89,46 @@ def test_generic_sessions_ddl_contains_text() -> None: assert "TIMESTAMP" in ddl -def test_postgresql_events_ddl_contains_jsonb() -> None: - """Test PostgreSQL events DDL uses JSONB for content fields.""" +def test_postgresql_events_ddl_uses_jsonb() -> None: + """Test PostgreSQL events DDL uses JSONB for event_json.""" config = AdbcConfig(connection_config={"driver_name": "postgresql", "uri": ":memory:"}) store = AdbcADKStore(config) ddl = store._get_events_ddl_postgresql() # pyright: ignore[reportPrivateUsage] assert "JSONB" in ddl - assert "BYTEA" in ddl - assert "BOOLEAN" in ddl + assert "event_json" in ddl + assert "session_id" in ddl + assert "invocation_id" in ddl + assert "author" in ddl + assert "timestamp" in ddl.lower() -def test_sqlite_events_ddl_contains_text_and_integer() -> None: - """Test SQLite events DDL uses TEXT for JSON and INTEGER for booleans.""" +def test_sqlite_events_ddl_uses_text() -> None: + """Test SQLite events DDL uses TEXT for event_json.""" config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": ":memory:"}) store = AdbcADKStore(config) ddl = store._get_events_ddl_sqlite() # pyright: ignore[reportPrivateUsage] assert "TEXT" in ddl - assert "BLOB" in ddl - assert "INTEGER" in ddl + assert "event_json" in ddl + assert "session_id" in ddl + assert "REAL" in ddl # SQLite uses REAL for timestamps -def test_duckdb_events_ddl_contains_json_and_boolean() -> None: - """Test DuckDB events DDL uses JSON and BOOLEAN types.""" +def test_duckdb_events_ddl_uses_json() -> None: + """Test DuckDB events DDL uses JSON type for event_json.""" config = AdbcConfig(connection_config={"driver_name": "duckdb", "uri": ":memory:"}) store = AdbcADKStore(config) ddl = store._get_events_ddl_duckdb() # pyright: ignore[reportPrivateUsage] assert "JSON" in ddl - assert "BOOLEAN" in ddl + assert "event_json" in ddl -def test_snowflake_events_ddl_contains_variant() -> None: - """Test Snowflake events DDL uses VARIANT for content.""" +def test_snowflake_events_ddl_uses_variant() -> None: + """Test Snowflake events DDL uses VARIANT for event_json.""" config = AdbcConfig(connection_config={"driver_name": "snowflake", "uri": "snowflake://test"}) store = AdbcADKStore(config) ddl = store._get_events_ddl_snowflake() # pyright: ignore[reportPrivateUsage] assert "VARIANT" in ddl - assert "BINARY" in ddl + assert "event_json" in ddl def test_ddl_dispatch_uses_correct_dialect() -> None: @@ -137,6 +141,7 @@ def test_ddl_dispatch_uses_correct_dialect() -> None: events_ddl = store._get_create_events_table_sql() # pyright: ignore[reportPrivateUsage] assert "JSONB" in events_ddl + assert "event_json" in events_ddl def test_owner_id_column_included_in_sessions_ddl() -> None: diff --git a/tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py b/tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py index fc39cebb2..c028f4511 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py @@ -1,5 +1,6 @@ """Tests for ADBC ADK store edge cases and error handling.""" +import json from pathlib import Path from typing import Any @@ -100,13 +101,12 @@ def test_unicode_in_fields(adbc_store: Any) -> None: session_id = "unicode-session" app_name = "测试应用" user_id = "ユーザー123" - state = {"message": "Hello 世界", "emoji": "🎉"} + state = {"message": "Hello 世界"} created_session = adbc_store.create_session(session_id, app_name, user_id, state) assert created_session["app_name"] == app_name assert created_session["user_id"] == user_id assert created_session["state"]["message"] == "Hello 世界" - assert created_session["state"]["emoji"] == "🎉" event = adbc_store.create_event( event_id="unicode-event", @@ -114,11 +114,12 @@ def test_unicode_in_fields(adbc_store: Any) -> None: app_name=app_name, user_id=user_id, author="アシスタント", - content={"text": "こんにちは 🌍"}, + content={"text": "こんにちは"}, ) assert event["author"] == "アシスタント" - assert event["content"]["text"] == "こんにちは 🌍" + event_data = json.loads(event["event_json"]) if isinstance(event["event_json"], str) else event["event_json"] + assert event_data["content"]["text"] == "こんにちは" def test_special_characters_in_json(adbc_store: Any) -> None: @@ -176,7 +177,7 @@ def test_concurrent_session_updates(adbc_store: Any) -> None: def test_event_with_none_values(adbc_store: Any) -> None: - """Test creating event with explicit None values.""" + """Test creating event with explicit None values for optional fields.""" session_id = "none-test" adbc_store.create_session(session_id, "app", "user", {}) @@ -198,15 +199,9 @@ def test_event_with_none_values(adbc_store: Any) -> None: error_message=None, ) - assert event["invocation_id"] is None - assert event["author"] is None - assert event["actions"] == b"" - assert event["content"] is None - assert event["grounding_metadata"] is None - assert event["custom_metadata"] is None - assert event["partial"] is None - assert event["turn_complete"] is None - assert event["interrupted"] is None + # The event should still have the 5-key shape + assert event["session_id"] == session_id + assert "event_json" in event def test_list_sessions_with_same_user_different_apps(adbc_store: Any) -> None: diff --git a/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py b/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py index e18cd1496..bb1784302 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py @@ -1,6 +1,6 @@ """Tests for ADBC ADK store event operations.""" -from datetime import datetime, timezone +import json from pathlib import Path from typing import Any @@ -34,24 +34,24 @@ def session_fixture(adbc_store: Any) -> dict[str, str]: def test_create_event(adbc_store: Any, session_fixture: Any) -> None: - """Test creating a new event.""" - event_id = "event-1" + """Test creating a new event returns 5-key EventRecord.""" event = adbc_store.create_event( - event_id=event_id, + event_id="event-1", session_id=session_fixture["session_id"], app_name=session_fixture["app_name"], user_id=session_fixture["user_id"], author="user", - actions=b"serialized_actions", content={"message": "Hello"}, ) - assert event["id"] == event_id assert event["session_id"] == session_fixture["session_id"] assert event["author"] == "user" - assert event["actions"] == b"serialized_actions" - assert event["content"] == {"message": "Hello"} assert event["timestamp"] is not None + assert "event_json" in event + + # Content is stored inside event_json + event_data = json.loads(event["event_json"]) if isinstance(event["event_json"], str) else event["event_json"] + assert event_data["content"] == {"message": "Hello"} def test_list_events(adbc_store: Any, session_fixture: Any) -> None: @@ -76,8 +76,8 @@ def test_list_events(adbc_store: Any, session_fixture: Any) -> None: events = adbc_store.list_events(session_fixture["session_id"]) assert len(events) == 2 - assert events[0]["id"] == "event-1" - assert events[1]["id"] == "event-2" + assert events[0]["author"] == "user" + assert events[1]["author"] == "assistant" def test_list_events_empty(adbc_store: Any, session_fixture: Any) -> None: @@ -87,8 +87,7 @@ def test_list_events_empty(adbc_store: Any, session_fixture: Any) -> None: def test_event_with_all_fields(adbc_store: Any, session_fixture: Any) -> None: - """Test creating event with all optional fields.""" - timestamp = datetime.now(timezone.utc) + """Test creating event with all optional fields stored in event_json.""" event = adbc_store.create_event( event_id="full-event", session_id=session_fixture["session_id"], @@ -97,9 +96,7 @@ def test_event_with_all_fields(adbc_store: Any, session_fixture: Any) -> None: invocation_id="invocation-123", author="assistant", actions=b"complex_action_data", - long_running_tool_ids_json='["tool1", "tool2"]', branch="main", - timestamp=timestamp, content={"text": "Response"}, grounding_metadata={"sources": ["doc1", "doc2"]}, custom_metadata={"custom": "data"}, @@ -110,19 +107,21 @@ def test_event_with_all_fields(adbc_store: Any, session_fixture: Any) -> None: error_message="No errors", ) + # Top-level indexed columns assert event["invocation_id"] == "invocation-123" assert event["author"] == "assistant" - assert event["actions"] == b"complex_action_data" - assert event["long_running_tool_ids_json"] == '["tool1", "tool2"]' - assert event["branch"] == "main" - assert event["content"] == {"text": "Response"} - assert event["grounding_metadata"] == {"sources": ["doc1", "doc2"]} - assert event["custom_metadata"] == {"custom": "data"} - assert event["partial"] is True - assert event["turn_complete"] is False - assert event["interrupted"] is False - assert event["error_code"] == "NONE" - assert event["error_message"] == "No errors" + + # Everything else is in event_json + event_data = json.loads(event["event_json"]) if isinstance(event["event_json"], str) else event["event_json"] + assert event_data["content"] == {"text": "Response"} + assert event_data["branch"] == "main" + assert event_data["grounding_metadata"] == {"sources": ["doc1", "doc2"]} + assert event_data["custom_metadata"] == {"custom": "data"} + assert event_data["partial"] is True + assert event_data["turn_complete"] is False + assert event_data["interrupted"] is False + assert event_data["error_code"] == "NONE" + assert event_data["error_message"] == "No errors" def test_event_with_minimal_fields(adbc_store: Any, session_fixture: Any) -> None: @@ -134,59 +133,12 @@ def test_event_with_minimal_fields(adbc_store: Any, session_fixture: Any) -> Non user_id=session_fixture["user_id"], ) - assert event["id"] == "minimal-event" assert event["session_id"] == session_fixture["session_id"] - assert event["app_name"] == session_fixture["app_name"] - assert event["user_id"] == session_fixture["user_id"] - assert event["author"] is None - assert event["actions"] == b"" - assert event["content"] is None - - -def test_event_boolean_fields(adbc_store: Any, session_fixture: Any) -> None: - """Test event boolean field conversion.""" - event_true = adbc_store.create_event( - event_id="event-true", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - partial=True, - turn_complete=True, - interrupted=True, - ) - - assert event_true["partial"] is True - assert event_true["turn_complete"] is True - assert event_true["interrupted"] is True - - event_false = adbc_store.create_event( - event_id="event-false", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - partial=False, - turn_complete=False, - interrupted=False, - ) - - assert event_false["partial"] is False - assert event_false["turn_complete"] is False - assert event_false["interrupted"] is False - - event_none = adbc_store.create_event( - event_id="event-none", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - ) - - assert event_none["partial"] is None - assert event_none["turn_complete"] is None - assert event_none["interrupted"] is None + assert "event_json" in event def test_event_json_fields(adbc_store: Any, session_fixture: Any) -> None: - """Test event JSON field serialization and deserialization.""" + """Test event JSON field serialization and deserialization via event_json.""" complex_content = {"nested": {"data": "value"}, "list": [1, 2, 3], "null": None} complex_grounding = {"sources": [{"title": "Doc", "url": "http://example.com"}]} complex_custom = {"metadata": {"version": 1, "tags": ["tag1", "tag2"]}} @@ -201,16 +153,16 @@ def test_event_json_fields(adbc_store: Any, session_fixture: Any) -> None: custom_metadata=complex_custom, ) - assert event["content"] == complex_content - assert event["grounding_metadata"] == complex_grounding - assert event["custom_metadata"] == complex_custom + event_data = json.loads(event["event_json"]) if isinstance(event["event_json"], str) else event["event_json"] + assert event_data["content"] == complex_content + assert event_data["grounding_metadata"] == complex_grounding + assert event_data["custom_metadata"] == complex_custom events = adbc_store.list_events(session_fixture["session_id"]) - retrieved = events[0] - - assert retrieved["content"] == complex_content - assert retrieved["grounding_metadata"] == complex_grounding - assert retrieved["custom_metadata"] == complex_custom + retrieved_data = json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + assert retrieved_data["content"] == complex_content + assert retrieved_data["grounding_metadata"] == complex_grounding + assert retrieved_data["custom_metadata"] == complex_custom def test_event_ordering(adbc_store: Any, session_fixture: Any) -> None: @@ -245,9 +197,6 @@ def test_event_ordering(adbc_store: Any, session_fixture: Any) -> None: events = adbc_store.list_events(session_fixture["session_id"]) assert len(events) == 3 - assert events[0]["id"] == "event-1" - assert events[1]["id"] == "event-2" - assert events[2]["id"] == "event-3" assert events[0]["timestamp"] < events[1]["timestamp"] assert events[1]["timestamp"] < events[2]["timestamp"] @@ -274,19 +223,11 @@ def test_delete_session_cascades_events(adbc_store: Any, session_fixture: Any, t events_before = adbc_store.list_events(session_fixture["session_id"]) assert len(events_before) == 2 - # For SQLite with separate connections per operation, we need to manually delete events - # or note that cascade deletes require persistent connections - # For this test, just verify the session deletion works adbc_store.delete_session(session_fixture["session_id"]) - # Session should be gone session_after = adbc_store.get_session(session_fixture["session_id"]) assert session_after is None - # Events may still exist with ADBC SQLite due to FK enforcement across connections - # This is a known limitation when using ADBC with SQLite in-memory or file-based - # with separate connections per operation - def test_event_with_empty_actions(adbc_store: Any, session_fixture: Any) -> None: """Test creating event with empty actions bytes.""" @@ -298,23 +239,21 @@ def test_event_with_empty_actions(adbc_store: Any, session_fixture: Any) -> None actions=b"", ) - assert event["actions"] == b"" - - events = adbc_store.list_events(session_fixture["session_id"]) - assert events[0]["actions"] == b"" + # actions=b"" is either ignored or stored as hex in event_json + assert "event_json" in event -def test_event_with_large_actions(adbc_store: Any, session_fixture: Any) -> None: - """Test creating event with large actions BLOB.""" - large_actions = b"x" * 10000 +def test_event_with_large_content(adbc_store: Any, session_fixture: Any) -> None: + """Test creating event with large content in event_json.""" + large_content = {"data": "x" * 10000} event = adbc_store.create_event( - event_id="large-actions", + event_id="large-content", session_id=session_fixture["session_id"], app_name=session_fixture["app_name"], user_id=session_fixture["user_id"], - actions=large_actions, + content=large_content, ) - assert event["actions"] == large_actions - assert len(event["actions"]) == 10000 + event_data = json.loads(event["event_json"]) if isinstance(event["event_json"], str) else event["event_json"] + assert event_data["content"] == large_content diff --git a/tests/integration/adapters/asyncmy/extensions/adk/test_store.py b/tests/integration/adapters/asyncmy/extensions/adk/test_store.py index d53323565..b2f25a9c6 100644 --- a/tests/integration/adapters/asyncmy/extensions/adk/test_store.py +++ b/tests/integration/adapters/asyncmy/extensions/adk/test_store.py @@ -1,11 +1,12 @@ """Integration tests for AsyncMY ADK session store.""" -import pickle +import json from datetime import datetime, timezone import pytest from sqlspec.adapters.asyncmy.adk.store import AsyncmyADKStore +from sqlspec.extensions.adk import EventRecord pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.asyncmy, pytest.mark.integration] @@ -51,13 +52,14 @@ async def test_storage_types_verification(asyncmy_adk_store: AsyncmyADKStore) -> ORDER BY ORDINAL_POSITION """) event_columns = await cursor.fetchall() + event_col_names = [col[0] for col in event_columns] - actions_col = next(col for col in event_columns if col[0] == "actions") - assert actions_col[1] == "blob", "actions column must use BLOB type for pickled data" - - content_col = next((col for col in event_columns if col[0] == "content"), None) - if content_col: - assert content_col[1] == "json", "content column must use native JSON type" + # New 5-column schema: session_id, invocation_id, author, timestamp, event_json + assert "session_id" in event_col_names + assert "invocation_id" in event_col_names + assert "author" in event_col_names + assert "timestamp" in event_col_names + assert "event_json" in event_col_names timestamp_col = next(col for col in event_columns if col[0] == "timestamp") assert "timestamp(6)" in timestamp_col[2].lower(), "timestamp must be TIMESTAMP(6) for microseconds" @@ -141,18 +143,14 @@ async def test_delete_session_cascade(asyncmy_adk_store: AsyncmyADKStore) -> Non await asyncmy_adk_store.create_session(session_id, app_name, user_id, {"status": "active"}) - event_record = { - "id": "event-001", + event_record: EventRecord = { "session_id": session_id, - "app_name": app_name, - "user_id": user_id, "invocation_id": "inv-001", "author": "user", - "actions": pickle.dumps([{"type": "test_action"}]), "timestamp": datetime.now(timezone.utc), - "content": {"text": "Hello"}, + "event_json": json.dumps({"content": {"text": "Hello"}, "app_name": app_name, "user_id": user_id}), } - await asyncmy_adk_store.append_event(event_record) # type: ignore[arg-type] + await asyncmy_adk_store.append_event(event_record) events_before = await asyncmy_adk_store.get_events(session_id) assert len(events_before) == 1 @@ -174,48 +172,35 @@ async def test_append_and_get_events(asyncmy_adk_store: AsyncmyADKStore) -> None await asyncmy_adk_store.create_session(session_id, app_name, user_id, {"status": "active"}) - event1 = { - "id": "event-001", + event1: EventRecord = { "session_id": session_id, - "app_name": app_name, - "user_id": user_id, "invocation_id": "inv-001", "author": "user", - "actions": pickle.dumps([{"type": "message", "content": "Hello"}]), "timestamp": datetime.now(timezone.utc), - "content": {"text": "Hello", "role": "user"}, - "partial": False, - "turn_complete": True, + "event_json": json.dumps({"content": {"text": "Hello", "role": "user"}, "app_name": app_name}), } - event2 = { - "id": "event-002", + event2: EventRecord = { "session_id": session_id, - "app_name": app_name, - "user_id": user_id, "invocation_id": "inv-002", "author": "assistant", - "actions": pickle.dumps([{"type": "response", "content": "Hi there"}]), "timestamp": datetime.now(timezone.utc), - "content": {"text": "Hi there", "role": "assistant"}, - "partial": False, - "turn_complete": True, + "event_json": json.dumps({"content": {"text": "Hi there", "role": "assistant"}, "app_name": app_name}), } - await asyncmy_adk_store.append_event(event1) # type: ignore[arg-type] - await asyncmy_adk_store.append_event(event2) # type: ignore[arg-type] + await asyncmy_adk_store.append_event(event1) + await asyncmy_adk_store.append_event(event2) events = await asyncmy_adk_store.get_events(session_id) assert len(events) == 2 - assert events[0]["id"] == "event-001" - assert events[1]["id"] == "event-002" - assert events[0]["content"] is not None - assert events[1]["content"] is not None - assert events[0]["content"]["text"] == "Hello" - assert events[1]["content"]["text"] == "Hi there" - assert isinstance(events[0]["actions"], bytes) - assert pickle.loads(events[0]["actions"])[0]["type"] == "message" + assert events[0]["author"] == "user" + assert events[1]["author"] == "assistant" + # Content is inside event_json + event0_data = json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + event1_data = json.loads(events[1]["event_json"]) if isinstance(events[1]["event_json"], str) else events[1]["event_json"] + assert event0_data["content"]["text"] == "Hello" + assert event1_data["content"]["text"] == "Hi there" async def test_timestamp_precision(asyncmy_adk_store: AsyncmyADKStore) -> None: @@ -230,17 +215,14 @@ async def test_timestamp_precision(asyncmy_adk_store: AsyncmyADKStore) -> None: assert hasattr(created["create_time"], "microsecond") event_time = datetime.now(timezone.utc) - event = { - "id": "event-micro", + event: EventRecord = { "session_id": session_id, - "app_name": app_name, - "user_id": user_id, "invocation_id": "inv-micro", "author": "system", - "actions": b"", "timestamp": event_time, + "event_json": json.dumps({"app_name": app_name}), } - await asyncmy_adk_store.append_event(event) # type: ignore[arg-type] + await asyncmy_adk_store.append_event(event) events = await asyncmy_adk_store.get_events(session_id) assert len(events) == 1 diff --git a/tests/integration/adapters/duckdb/extensions/adk/test_store.py b/tests/integration/adapters/duckdb/extensions/adk/test_store.py index a685e08b8..241ff3de5 100644 --- a/tests/integration/adapters/duckdb/extensions/adk/test_store.py +++ b/tests/integration/adapters/duckdb/extensions/adk/test_store.py @@ -1,5 +1,6 @@ """Integration tests for DuckDB ADK session store.""" +import json from collections.abc import Generator from datetime import datetime, timezone from pathlib import Path @@ -145,11 +146,10 @@ def test_delete_session_cascade_events(duckdb_adk_store: DuckdbADKStore) -> None app_name="test-app", user_id="user-005", author="user", - actions=b"test-actions", content={"message": "Hello"}, ) - assert event["id"] == "event-001" + assert event["session_id"] == session_id events = duckdb_adk_store.list_events(session_id) assert len(events) == 1 @@ -160,44 +160,33 @@ def test_delete_session_cascade_events(duckdb_adk_store: DuckdbADKStore) -> None assert len(events_after) == 0 -def test_create_and_get_event(duckdb_adk_store: DuckdbADKStore) -> None: - """Test creating and retrieving an event.""" +def test_create_event(duckdb_adk_store: DuckdbADKStore) -> None: + """Test creating an event and verifying the returned 5-key EventRecord.""" session_id = "session-006" duckdb_adk_store.create_session(session_id, "test-app", "user-006", {}) - event_id = "event-002" timestamp = datetime.now(timezone.utc) content = {"text": "Test message", "role": "user"} - custom_metadata = {"source": "test"} created_event = duckdb_adk_store.create_event( - event_id=event_id, + event_id="event-002", session_id=session_id, app_name="test-app", user_id="user-006", author="user", - actions=b"pickled-actions", content=content, timestamp=timestamp, - custom_metadata=custom_metadata, ) - assert created_event["id"] == event_id + # Returned record has the 5-key shape assert created_event["session_id"] == session_id assert created_event["author"] == "user" - assert created_event["content"] == content - assert created_event["custom_metadata"] == custom_metadata + assert created_event["timestamp"] == timestamp + assert "event_json" in created_event - retrieved_event = duckdb_adk_store.get_event(event_id) - assert retrieved_event is not None - assert retrieved_event["id"] == event_id - assert retrieved_event["content"] == content - - -def test_get_nonexistent_event(duckdb_adk_store: DuckdbADKStore) -> None: - """Test getting a non-existent event returns None.""" - result = duckdb_adk_store.get_event("nonexistent-event") - assert result is None + # Content is stored inside event_json + event_data = json.loads(created_event["event_json"]) if isinstance(created_event["event_json"], str) else created_event["event_json"] + assert event_data["content"] == content def test_list_events(duckdb_adk_store: DuckdbADKStore) -> None: @@ -225,8 +214,8 @@ def test_list_events(duckdb_adk_store: DuckdbADKStore) -> None: events = duckdb_adk_store.list_events(session_id) assert len(events) == 2 - assert events[0]["id"] == "event-1" - assert events[1]["id"] == "event-2" + assert events[0]["author"] == "user" + assert events[1]["author"] == "assistant" assert events[0]["timestamp"] <= events[1]["timestamp"] @@ -240,7 +229,7 @@ def test_list_events_empty(duckdb_adk_store: DuckdbADKStore) -> None: def test_event_with_optional_fields(duckdb_adk_store: DuckdbADKStore) -> None: - """Test creating events with all optional fields.""" + """Test creating events with optional fields stored in event_json.""" session_id = "session-008" duckdb_adk_store.create_session(session_id, "test-app", "user-008", {}) @@ -250,7 +239,6 @@ def test_event_with_optional_fields(duckdb_adk_store: DuckdbADKStore) -> None: app_name="test-app", user_id="user-008", author="assistant", - actions=b"actions-data", content={"text": "Response"}, invocation_id="inv-123", branch="main", @@ -263,15 +251,15 @@ def test_event_with_optional_fields(duckdb_adk_store: DuckdbADKStore) -> None: error_message=None, ) + # The 5-key record has invocation_id as a top-level indexed column assert event["invocation_id"] == "inv-123" - assert event["branch"] == "main" - assert event["grounding_metadata"] == {"sources": ["doc1", "doc2"]} - assert event["partial"] is True - assert event["turn_complete"] is False - retrieved = duckdb_adk_store.get_event("event-full") - assert retrieved is not None - assert retrieved["grounding_metadata"] == {"sources": ["doc1", "doc2"]} + # Other fields are inside event_json + event_data = json.loads(event["event_json"]) if isinstance(event["event_json"], str) else event["event_json"] + assert event_data["branch"] == "main" + assert event_data["grounding_metadata"] == {"sources": ["doc1", "doc2"]} + assert event_data["partial"] is True + assert event_data["turn_complete"] is False def test_event_ordering_by_timestamp(duckdb_adk_store: DuckdbADKStore) -> None: @@ -296,9 +284,12 @@ def test_event_ordering_by_timestamp(duckdb_adk_store: DuckdbADKStore) -> None: events = duckdb_adk_store.list_events(session_id) assert len(events) == 3 - assert events[0]["id"] == "event-first" - assert events[1]["id"] == "event-middle" - assert events[2]["id"] == "event-last" + # Events should be ordered by timestamp ASC + event_ids = [] + for e in events: + data = json.loads(e["event_json"]) if isinstance(e["event_json"], str) else e["event_json"] + event_ids.append(data["id"]) + assert event_ids == ["event-first", "event-middle", "event-last"] def test_session_state_with_complex_data(duckdb_adk_store: DuckdbADKStore) -> None: @@ -353,28 +344,26 @@ def test_table_not_found_handling(tmp_path: Path, worker_id: str) -> None: db_path.unlink() -def test_binary_actions_data(duckdb_adk_store: DuckdbADKStore) -> None: - """Test storing and retrieving binary actions data.""" - session_id = "session-binary" +def test_event_json_round_trip(duckdb_adk_store: DuckdbADKStore) -> None: + """Test storing and retrieving event data via event_json.""" + session_id = "session-json-rt" duckdb_adk_store.create_session(session_id, "test-app", "user-012", {}) - binary_data = bytes(range(256)) - event = duckdb_adk_store.create_event( - event_id="event-binary", + event_id="event-json", session_id=session_id, app_name="test-app", user_id="user-012", author="system", - actions=binary_data, + content={"data": "value"}, ) - assert event["actions"] == binary_data + assert "event_json" in event - retrieved = duckdb_adk_store.get_event("event-binary") - assert retrieved is not None - assert retrieved["actions"] == binary_data - assert len(retrieved["actions"]) == 256 + events = duckdb_adk_store.list_events(session_id) + assert len(events) == 1 + event_data = json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + assert event_data["content"] == {"data": "value"} def test_concurrent_session_updates(duckdb_adk_store: DuckdbADKStore) -> None: diff --git a/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py b/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py index 5e6253a47..f85a0ee79 100644 --- a/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py +++ b/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py @@ -1,12 +1,13 @@ """Integration tests for MysqlConnector ADK session store.""" -import pickle +import json from datetime import datetime, timezone from typing import cast import pytest from sqlspec.adapters.mysqlconnector.adk.store import MysqlConnectorAsyncADKStore +from sqlspec.extensions.adk import EventRecord pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.mysql_connector, pytest.mark.integration] @@ -50,13 +51,14 @@ async def test_storage_types_verification(mysqlconnector_adk_store: MysqlConnect ORDER BY ORDINAL_POSITION """) event_columns = await cursor.fetchall() + event_col_names = [col[0] for col in event_columns] - actions_col = next(col for col in event_columns if col[0] == "actions") - assert actions_col[1] == "blob" - - content_col = next((col for col in event_columns if col[0] == "content"), None) - if content_col: - assert content_col[1] == "json" + # New 5-column schema: session_id, invocation_id, author, timestamp, event_json + assert "session_id" in event_col_names + assert "invocation_id" in event_col_names + assert "author" in event_col_names + assert "timestamp" in event_col_names + assert "event_json" in event_col_names timestamp_col = next(col for col in event_columns if col[0] == "timestamp") assert "timestamp(6)" in cast("str", timestamp_col[2]).lower() @@ -142,18 +144,14 @@ async def test_delete_session_cascade(mysqlconnector_adk_store: MysqlConnectorAs await mysqlconnector_adk_store.create_session(session_id, app_name, user_id, {"status": "active"}) - event_record = { - "id": "event-001", + event_record: EventRecord = { "session_id": session_id, - "app_name": app_name, - "user_id": user_id, "invocation_id": "inv-001", "author": "user", - "actions": pickle.dumps([{"type": "test_action"}]), "timestamp": datetime.now(timezone.utc), - "content": {"text": "Hello"}, + "event_json": json.dumps({"content": {"text": "Hello"}, "app_name": app_name, "user_id": user_id}), } - await mysqlconnector_adk_store.append_event(event_record) # type: ignore[arg-type] + await mysqlconnector_adk_store.append_event(event_record) events_before = await mysqlconnector_adk_store.get_events(session_id) assert len(events_before) == 1 @@ -175,48 +173,35 @@ async def test_append_and_get_events(mysqlconnector_adk_store: MysqlConnectorAsy await mysqlconnector_adk_store.create_session(session_id, app_name, user_id, {"status": "active"}) - event1 = { - "id": "event-001", + event1: EventRecord = { "session_id": session_id, - "app_name": app_name, - "user_id": user_id, "invocation_id": "inv-001", "author": "user", - "actions": pickle.dumps([{"type": "message", "content": "Hello"}]), "timestamp": datetime.now(timezone.utc), - "content": {"text": "Hello", "role": "user"}, - "partial": False, - "turn_complete": True, + "event_json": json.dumps({"content": {"text": "Hello", "role": "user"}, "app_name": app_name}), } - event2 = { - "id": "event-002", + event2: EventRecord = { "session_id": session_id, - "app_name": app_name, - "user_id": user_id, "invocation_id": "inv-002", "author": "assistant", - "actions": pickle.dumps([{"type": "response", "content": "Hi there"}]), "timestamp": datetime.now(timezone.utc), - "content": {"text": "Hi there", "role": "assistant"}, - "partial": False, - "turn_complete": True, + "event_json": json.dumps({"content": {"text": "Hi there", "role": "assistant"}, "app_name": app_name}), } - await mysqlconnector_adk_store.append_event(event1) # type: ignore[arg-type] - await mysqlconnector_adk_store.append_event(event2) # type: ignore[arg-type] + await mysqlconnector_adk_store.append_event(event1) + await mysqlconnector_adk_store.append_event(event2) events = await mysqlconnector_adk_store.get_events(session_id) assert len(events) == 2 - assert events[0]["id"] == "event-001" - assert events[1]["id"] == "event-002" - content0 = events[0]["content"] - content1 = events[1]["content"] - assert content0 is not None and content0["text"] == "Hello" - assert content1 is not None and content1["text"] == "Hi there" - assert isinstance(events[0]["actions"], bytes) - assert pickle.loads(events[0]["actions"])[0]["type"] == "message" + assert events[0]["author"] == "user" + assert events[1]["author"] == "assistant" + # Content is inside event_json + event0_data = json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + event1_data = json.loads(events[1]["event_json"]) if isinstance(events[1]["event_json"], str) else events[1]["event_json"] + assert event0_data["content"]["text"] == "Hello" + assert event1_data["content"]["text"] == "Hi there" async def test_timestamp_precision(mysqlconnector_adk_store: MysqlConnectorAsyncADKStore) -> None: @@ -230,17 +215,14 @@ async def test_timestamp_precision(mysqlconnector_adk_store: MysqlConnectorAsync assert hasattr(created["create_time"], "microsecond") event_time = datetime.now(timezone.utc) - event = { - "id": "event-micro", + event: EventRecord = { "session_id": session_id, - "app_name": app_name, - "user_id": user_id, "invocation_id": "inv-micro", "author": "system", - "actions": b"", "timestamp": event_time, + "event_json": json.dumps({"app_name": app_name}), } - await mysqlconnector_adk_store.append_event(event) # type: ignore[arg-type] + await mysqlconnector_adk_store.append_event(event) events = await mysqlconnector_adk_store.get_events(session_id) assert len(events) == 1 diff --git a/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py b/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py index 52cd230f4..26ed4730a 100644 --- a/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py +++ b/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py @@ -1,6 +1,6 @@ """Oracle-specific ADK store tests for LOB handling, JSON types, and FK columns.""" -import pickle +import json from collections.abc import AsyncGenerator, Generator from datetime import datetime, timezone from typing import Any, cast @@ -220,8 +220,8 @@ async def test_state_lob_deserialization(oracle_async_store: "OracleAsyncADKStor assert retrieved["state"]["large_field"] == "x" * 10000 -async def test_event_content_lob_deserialization(oracle_async_store: "OracleAsyncADKStore") -> None: - """Test event content CLOB is correctly deserialized.""" +async def test_event_json_lob_deserialization(oracle_async_store: "OracleAsyncADKStore") -> None: + """Test event_json CLOB is correctly deserialized.""" session_id = _unique_session_id("event-lob") app_name = "test-app" user_id = "user-123" @@ -229,78 +229,57 @@ async def test_event_content_lob_deserialization(oracle_async_store: "OracleAsyn await oracle_async_store.create_session(session_id, app_name, user_id, {}) content = {"message": "x" * 5000, "data": {"nested": True}} - grounding_metadata = {"sources": ["a" * 1000, "b" * 1000]} - custom_metadata = {"tags": ["tag1", "tag2"], "priority": "high"} + event_data = { + "content": content, + "app_name": app_name, + "user_id": user_id, + "grounding_metadata": {"sources": ["a" * 1000, "b" * 1000]}, + "custom_metadata": {"tags": ["tag1", "tag2"], "priority": "high"}, + } event_record: EventRecord = { - "id": "event-1", "session_id": session_id, - "app_name": app_name, - "user_id": user_id, + "invocation_id": "", "author": "assistant", - "actions": pickle.dumps([{"name": "test", "args": {}}]), - "content": content, - "grounding_metadata": grounding_metadata, - "custom_metadata": custom_metadata, "timestamp": datetime.now(timezone.utc), - "partial": False, - "turn_complete": True, - "interrupted": False, - "error_code": None, - "error_message": None, - "invocation_id": "", - "branch": None, - "long_running_tool_ids_json": None, + "event_json": json.dumps(event_data), } await oracle_async_store.append_event(event_record) events = await oracle_async_store.get_events(session_id) assert len(events) == 1 - assert events[0]["content"] == content - assert events[0]["grounding_metadata"] == grounding_metadata - assert events[0]["custom_metadata"] == custom_metadata + # event_json contains all the data + retrieved_data = json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + assert retrieved_data["content"] == content + assert retrieved_data["grounding_metadata"] == {"sources": ["a" * 1000, "b" * 1000]} + assert retrieved_data["custom_metadata"] == {"tags": ["tag1", "tag2"], "priority": "high"} -async def test_actions_blob_handling(oracle_async_store: "OracleAsyncADKStore") -> None: - """Test actions BLOB is correctly read and unpickled.""" - session_id = _unique_session_id("actions-blob") +async def test_event_json_storage(oracle_async_store: "OracleAsyncADKStore") -> None: + """Test event_json blob is correctly stored and retrieved.""" + session_id = _unique_session_id("event-json") app_name = "test-app" user_id = "user-123" await oracle_async_store.create_session(session_id, app_name, user_id, {}) - test_actions = [{"function": "test_func", "args": {"param": "value"}, "result": 42}] - actions_bytes = pickle.dumps(test_actions) + event_data = {"function": "test_func", "args": {"param": "value"}, "result": 42} event_record: EventRecord = { - "id": "event-actions", "session_id": session_id, - "app_name": app_name, - "user_id": user_id, + "invocation_id": "", "author": "user", - "actions": actions_bytes, - "content": None, - "grounding_metadata": None, - "custom_metadata": None, "timestamp": datetime.now(timezone.utc), - "partial": None, - "turn_complete": None, - "interrupted": None, - "error_code": None, - "error_message": None, - "invocation_id": "", - "branch": None, - "long_running_tool_ids_json": None, + "event_json": json.dumps(event_data), } await oracle_async_store.append_event(event_record) events = await oracle_async_store.get_events(session_id) assert len(events) == 1 - assert events[0]["actions"] == actions_bytes - unpickled = pickle.loads(events[0]["actions"]) - assert unpickled == test_actions + retrieved_data = json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + assert retrieved_data == event_data def test_state_lob_deserialization_sync(oracle_sync_store: "OracleSyncADKStore") -> None: @@ -318,80 +297,61 @@ def test_state_lob_deserialization_sync(oracle_sync_store: "OracleSyncADKStore") assert retrieved["state"] == state -async def test_boolean_fields_conversion(oracle_async_store: "OracleAsyncADKStore") -> None: - """Test partial, turn_complete, interrupted converted to NUMBER(1).""" - session_id = _unique_session_id("bool-session") +async def test_event_record_5_column_contract(oracle_async_store: "OracleAsyncADKStore") -> None: + """Test the new 5-column EventRecord contract with append_event.""" + session_id = _unique_session_id("5col-session") app_name = "test-app" user_id = "user-123" await oracle_async_store.create_session(session_id, app_name, user_id, {}) event_record: EventRecord = { - "id": "bool-event-1", "session_id": session_id, - "app_name": app_name, - "user_id": user_id, + "invocation_id": "inv-001", "author": "assistant", - "actions": b"", - "content": None, - "grounding_metadata": None, - "custom_metadata": None, "timestamp": datetime.now(timezone.utc), - "partial": True, - "turn_complete": False, - "interrupted": True, - "error_code": None, - "error_message": None, - "invocation_id": "", - "branch": None, - "long_running_tool_ids_json": None, + "event_json": json.dumps({ + "content": {"text": "Hello"}, + "partial": True, + "turn_complete": False, + "interrupted": True, + }), } await oracle_async_store.append_event(event_record) events = await oracle_async_store.get_events(session_id) assert len(events) == 1 - assert events[0]["partial"] is True - assert events[0]["turn_complete"] is False - assert events[0]["interrupted"] is True + assert events[0]["session_id"] == session_id + assert events[0]["invocation_id"] == "inv-001" + assert events[0]["author"] == "assistant" + retrieved_data = json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + assert retrieved_data["partial"] is True + assert retrieved_data["turn_complete"] is False + assert retrieved_data["interrupted"] is True -async def test_boolean_fields_none_values(oracle_async_store: "OracleAsyncADKStore") -> None: - """Test None values for boolean fields.""" - session_id = _unique_session_id("bool-none-session") + +async def test_event_with_none_values(oracle_async_store: "OracleAsyncADKStore") -> None: + """Test event with minimal event_json content.""" + session_id = _unique_session_id("none-session") app_name = "test-app" user_id = "user-123" await oracle_async_store.create_session(session_id, app_name, user_id, {}) event_record: EventRecord = { - "id": "bool-event-none", "session_id": session_id, - "app_name": app_name, - "user_id": user_id, + "invocation_id": "", "author": "user", - "actions": b"", - "content": None, - "grounding_metadata": None, - "custom_metadata": None, "timestamp": datetime.now(timezone.utc), - "partial": None, - "turn_complete": None, - "interrupted": None, - "error_code": None, - "error_message": None, - "invocation_id": "", - "branch": None, - "long_running_tool_ids_json": None, + "event_json": json.dumps({"app_name": app_name}), } await oracle_async_store.append_event(event_record) events = await oracle_async_store.get_events(session_id) assert len(events) == 1 - assert events[0]["partial"] is None - assert events[0]["turn_complete"] is None - assert events[0]["interrupted"] is None async def test_create_session_with_owner_id(oracle_store_with_fk: "OracleAsyncADKStore") -> None: @@ -465,7 +425,7 @@ async def test_json_fields_stored_and_retrieved(oracle_async_store: "OracleAsync "complex": { "nested": {"deep": {"structure": "value"}}, "array": [1, 2, 3, {"key": "value"}], - "unicode": "こんにちは世界", + "unicode": "日本語テスト", "special_chars": "test@example.com | value > 100", } } @@ -476,7 +436,7 @@ async def test_json_fields_stored_and_retrieved(oracle_async_store: "OracleAsync retrieved = await oracle_async_store.get_session(session_id) assert retrieved is not None assert retrieved["state"] == state - assert retrieved["state"]["complex"]["unicode"] == "こんにちは世界" + assert retrieved["state"]["complex"]["unicode"] == "日本語テスト" def test_create_session_with_owner_id_sync(oracle_store_sync_with_fk: "OracleSyncADKStore") -> None: diff --git a/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py b/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py index 98a313d15..61bc624cd 100644 --- a/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py +++ b/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py @@ -1,5 +1,6 @@ """Integration tests for Spanner ADK store (sync).""" +import json from typing import Any import pytest @@ -65,11 +66,15 @@ def test_create_and_list_events(spanner_adk_store: Any) -> None: "user", author="assistant", content={"msg": "ok"}, - partial=False, - turn_complete=True, ) events = spanner_adk_store.list_events(session_id) - ids = [e["id"] for e in events] - assert ids == ["event-1", "event-2"] - assert events[0]["content"] == {"msg": "hi"} + assert len(events) == 2 + assert events[0]["author"] == "user" + assert events[1]["author"] == "assistant" + + # Content is inside event_json in the new 5-column schema + event0_data = json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + event1_data = json.loads(events[1]["event_json"]) if isinstance(events[1]["event_json"], str) else events[1]["event_json"] + assert event0_data["content"] == {"msg": "hi"} + assert event1_data["content"] == {"msg": "ok"} From 1794e8ab539db36480c8ff776881f5c3d45b7dae Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 21 Mar 2026 17:53:28 +0000 Subject: [PATCH 17/23] refactor(adk): unify all stores behind async interface via sync_tools.async_() Convert all sync ADK session and memory stores from BaseSyncADKStore/ BaseSyncADKMemoryStore to BaseAsyncADKStore/BaseAsyncADKMemoryStore. Each sync store now: - Extends the async base class instead of the sync base class - Has private sync methods (e.g., _create_session, _get_session) - Has public async wrappers using async_() from sync_tools - Maps create_event_and_update_state -> append_event_and_update_state - Maps list_events -> get_events (with after_timestamp/limit params) - Adds append_event method - Drops legacy create_event method Stores converted (session + memory): - DuckDB, ADBC, Psycopg (sync), CockroachDB Psycopg (sync), MySQL Connector (sync), PyMySQL, Oracle (sync), Spanner (sync), SQLite (memory store only) The SQLite session store was already using this pattern and served as the reference implementation. --- sqlspec/adapters/adbc/adk/store.py | 184 ++++---- .../adapters/cockroach_psycopg/adk/store.py | 184 ++++---- sqlspec/adapters/duckdb/adk/store.py | 398 +++++++++++------- sqlspec/adapters/mysqlconnector/adk/store.py | 178 ++++---- sqlspec/adapters/oracledb/adk/store.py | 181 ++++---- sqlspec/adapters/psycopg/adk/store.py | 183 ++++---- sqlspec/adapters/pymysql/adk/store.py | 180 ++++---- sqlspec/adapters/spanner/adk/store.py | 163 ++++--- sqlspec/adapters/sqlite/adk/store.py | 43 +- 9 files changed, 992 insertions(+), 702 deletions(-) diff --git a/sqlspec/adapters/adbc/adk/store.py b/sqlspec/adapters/adbc/adk/store.py index 3d2963d8e..803d0955b 100644 --- a/sqlspec/adapters/adbc/adk/store.py +++ b/sqlspec/adapters/adbc/adk/store.py @@ -4,10 +4,11 @@ from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Final -from sqlspec.extensions.adk import BaseSyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore +from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json +from sqlspec.utils.sync_tools import async_ if TYPE_CHECKING: from sqlspec.adapters.adbc.config import AdbcConfig @@ -26,7 +27,7 @@ ADBC_TABLE_NOT_FOUND_PATTERNS: Final = ("no such table", "table or view does not exist", "relation does not exist") -class AdbcADKStore(BaseSyncADKStore["AdbcConfig"]): +class AdbcADKStore(BaseAsyncADKStore["AdbcConfig"]): """ADBC synchronous ADK store for Arrow Database Connectivity. Implements session and event storage for Google Agent Development Kit @@ -176,7 +177,7 @@ def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": return None return from_json(str(data)) # type: ignore[no-any-return] - def _get_create_sessions_table_sql(self) -> str: + async def _get_create_sessions_table_sql(self) -> str: """Get CREATE TABLE SQL for sessions with dialect dispatch. Returns: @@ -282,7 +283,7 @@ def _get_sessions_ddl_generic(self) -> str: ) """ - def _get_create_events_table_sql(self) -> str: + async def _get_create_events_table_sql(self) -> str: """Get CREATE TABLE SQL for events with dialect dispatch. Returns: @@ -410,7 +411,7 @@ def _get_drop_tables_sql(self) -> "list[str]": """ return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - def create_tables(self) -> None: + def _create_tables(self) -> None: """Create both sessions and events tables if they don't exist.""" with self._config.provide_connection() as conn: cursor = conn.cursor() @@ -446,6 +447,11 @@ def create_tables(self) -> None: finally: cursor.close() # type: ignore[no-untyped-call] + + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + def _enable_foreign_keys(self, cursor: Any, conn: Any) -> None: """Enable foreign key constraints for SQLite. @@ -463,7 +469,7 @@ def _enable_foreign_keys(self, cursor: Any, conn: Any) -> None: except Exception: logger.debug("Foreign key enforcement not supported or already enabled") - def create_session( + def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: """Create a new session. @@ -505,7 +511,14 @@ def create_session( return self.get_session(session_id) # type: ignore[return-value] - def get_session(self, session_id: str) -> "SessionRecord | None": + + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + def _get_session(self, session_id: str) -> "SessionRecord | None": """Get session by ID. Args: @@ -549,7 +562,12 @@ def get_session(self, session_id: str) -> "SessionRecord | None": return None raise - def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + + async def get_session(self, session_id: str) -> "SessionRecord | None": + """Get session by ID.""" + return await async_(self._get_session)(session_id) + + def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: """Update session state. Args: @@ -575,7 +593,12 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None finally: cursor.close() # type: ignore[no-untyped-call] - def delete_session(self, session_id: str) -> None: + + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + await async_(self._update_session_state)(session_id, state) + + def _delete_session(self, session_id: str) -> None: """Delete session and all associated events (cascade). Args: @@ -595,7 +618,12 @@ def delete_session(self, session_id: str) -> None: finally: cursor.close() # type: ignore[no-untyped-call] - def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + + async def delete_session(self, session_id: str) -> None: + """Delete session and associated events.""" + await async_(self._delete_session)(session_id) + + def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app, optionally filtered by user. Args: @@ -651,69 +679,10 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess return [] raise - def create_event( - self, - event_id: str, - session_id: str, - app_name: str, - user_id: str, - author: "str | None" = None, - actions: "bytes | None" = None, - content: "dict[str, Any] | None" = None, - **kwargs: Any, - ) -> "EventRecord": - """Create a new event using the new 5-column EventRecord contract. - - Args: - event_id: Unique event identifier (unused in new schema, kept for API compat). - session_id: Session identifier. - app_name: Application name (stored inside event_json). - user_id: User identifier (stored inside event_json). - author: Event author (user/assistant/system). - actions: Pickled actions object (stored inside event_json if provided). - content: Event content (stored inside event_json). - **kwargs: Additional optional fields (stored inside event_json). - - Returns: - Created event record. - Notes: - Builds an event_json blob from all provided fields and stores it - alongside the indexed scalar columns (session_id, invocation_id, - author, timestamp). - """ - timestamp = kwargs.pop("timestamp", None) - if timestamp is None: - timestamp = datetime.now(timezone.utc) - - invocation_id = kwargs.pop("invocation_id", "") or "" - - # Build event_json from all provided data - event_data: dict[str, Any] = { - "id": event_id, - "app_name": app_name, - "user_id": user_id, - } - if content is not None: - event_data["content"] = content - if actions is not None: - event_data["actions"] = actions.hex() - if author is not None: - event_data["author"] = author - # Include remaining kwargs in event_json - event_data.update({k: v for k, v in kwargs.items() if v is not None}) - - event_json_str = to_json(event_data) - - event_record = EventRecord( - session_id=session_id, - invocation_id=invocation_id, - author=author or "", - timestamp=timestamp, - event_json=event_json_str, - ) - self._insert_event(event_record) - return event_record + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return await async_(self._list_sessions)(app_name, user_id) def _insert_event(self, event_record: "EventRecord") -> None: """Insert an event record into the events table. @@ -744,7 +713,7 @@ def _insert_event(self, event_record: "EventRecord") -> None: finally: cursor.close() # type: ignore[no-untyped-call] - def create_event_and_update_state( + def _append_event_and_update_state( self, event_record: "EventRecord", session_id: str, state: "dict[str, Any]" ) -> None: """Atomically insert an event and update the session's durable state. @@ -793,7 +762,16 @@ def create_event_and_update_state( finally: cursor.close() # type: ignore[no-untyped-call] - def list_events(self, session_id: str) -> "list[EventRecord]": + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state.""" + await async_(self._append_event_and_update_state)(event_record, session_id, state) + + def _get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": """List events for a session ordered by timestamp. Args: @@ -840,7 +818,22 @@ def list_events(self, session_id: str) -> "list[EventRecord]": raise -class AdbcADKMemoryStore(BaseSyncADKMemoryStore["AdbcConfig"]): + + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session.""" + return await async_(self._get_events)(session_id, after_timestamp, limit) + + def _append_event(self, event_record: EventRecord) -> None: + """Synchronous implementation of append_event.""" + self._append_event_and_update_state(event_record, event_record["session_id"], {}) + + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + await async_(self._append_event)(event_record) + +class AdbcADKMemoryStore(BaseAsyncADKMemoryStore["AdbcConfig"]): """ADBC synchronous ADK memory store for Arrow Database Connectivity.""" __slots__ = ("_dialect",) @@ -885,7 +878,7 @@ def _decode_timestamp(self, value: Any) -> datetime: return datetime.fromisoformat(value) return datetime.fromisoformat(str(value)) - def _get_create_memory_table_sql(self) -> str: + async def _get_create_memory_table_sql(self) -> str: if self._dialect == DIALECT_POSTGRESQL: return self._get_memory_ddl_postgresql() if self._dialect == DIALECT_SQLITE: @@ -989,7 +982,7 @@ def _get_memory_ddl_generic(self) -> str: def _get_drop_memory_table_sql(self) -> "list[str]": return [f"DROP TABLE IF EXISTS {self._memory_table}"] - def create_tables(self) -> None: + def _create_tables(self) -> None: if not self._enabled: return @@ -1014,7 +1007,12 @@ def create_tables(self) -> None: finally: cursor.close() # type: ignore[no-untyped-call] - def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -1120,7 +1118,12 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object return inserted_count - def search_entries( + + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + + def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": if not self._enabled: @@ -1160,7 +1163,14 @@ def search_entries( return self._rows_to_records(rows) - def delete_entries_by_session(self, session_id: str) -> int: + + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + + def _delete_entries_by_session(self, session_id: str) -> int: use_returning = self._dialect in {DIALECT_SQLITE, DIALECT_POSTGRESQL, DIALECT_DUCKDB} if use_returning: sql = f"DELETE FROM {self._memory_table} WHERE session_id = ? RETURNING 1" @@ -1179,7 +1189,12 @@ def delete_entries_by_session(self, session_id: str) -> int: finally: cursor.close() # type: ignore[no-untyped-call] - def delete_entries_older_than(self, days: int) -> int: + + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + def _delete_entries_older_than(self, days: int) -> int: cutoff = self._encode_timestamp(datetime.now(timezone.utc) - timedelta(days=days)) use_returning = self._dialect in {DIALECT_SQLITE, DIALECT_POSTGRESQL, DIALECT_DUCKDB} if use_returning: @@ -1199,6 +1214,11 @@ def delete_entries_older_than(self, days: int) -> int: finally: cursor.close() # type: ignore[no-untyped-call] + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) + def _rows_to_records(self, rows: "list[Any]") -> "list[MemoryRecord]": records: list[MemoryRecord] = [] for row in rows: diff --git a/sqlspec/adapters/cockroach_psycopg/adk/store.py b/sqlspec/adapters/cockroach_psycopg/adk/store.py index fa589168b..959c450d9 100644 --- a/sqlspec/adapters/cockroach_psycopg/adk/store.py +++ b/sqlspec/adapters/cockroach_psycopg/adk/store.py @@ -6,9 +6,10 @@ from psycopg import sql as pg_sql from psycopg.types.json import Jsonb -from sqlspec.extensions.adk import BaseAsyncADKStore, BaseSyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore +from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.utils.logging import get_logger +from sqlspec.utils.sync_tools import async_ if TYPE_CHECKING: from datetime import datetime @@ -337,7 +338,7 @@ async def get_events( return [] -class CockroachPsycopgSyncADKStore(BaseSyncADKStore["CockroachPsycopgSyncConfig"]): +class CockroachPsycopgSyncADKStore(BaseAsyncADKStore["CockroachPsycopgSyncConfig"]): """CockroachDB ADK store using psycopg sync driver. Implements session and event storage for Google Agent Development Kit @@ -356,7 +357,7 @@ class CockroachPsycopgSyncADKStore(BaseSyncADKStore["CockroachPsycopgSyncConfig" def __init__(self, config: "CockroachPsycopgSyncConfig") -> None: super().__init__(config) - def _get_create_sessions_table_sql(self) -> str: + async def _get_create_sessions_table_sql(self) -> str: owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" @@ -382,7 +383,7 @@ def _get_create_sessions_table_sql(self) -> str: WHERE state != '{{}}'::jsonb; """ - def _get_create_events_table_sql(self) -> str: + async def _get_create_events_table_sql(self) -> str: return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( session_id VARCHAR(128) NOT NULL, @@ -403,12 +404,17 @@ def _get_create_events_table_sql(self) -> str: def _get_drop_tables_sql(self) -> "list[str]": return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - def create_tables(self) -> None: + def _create_tables(self) -> None: with self._config.provide_session() as driver: driver.execute_script(self._get_create_sessions_table_sql()) driver.execute_script(self._get_create_events_table_sql()) - def create_session( + + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: state_json = Jsonb(state) @@ -437,7 +443,14 @@ def create_session( raise RuntimeError(msg) return result - def get_session(self, session_id: str) -> "SessionRecord | None": + + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + def _get_session(self, session_id: str) -> "SessionRecord | None": sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} @@ -463,7 +476,12 @@ def get_session(self, session_id: str) -> "SessionRecord | None": except errors.UndefinedTable: return None - def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + + async def get_session(self, session_id: str) -> "SessionRecord | None": + """Get session by ID.""" + return await async_(self._get_session)(session_id) + + def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: sql = f""" UPDATE {self._session_table} SET state = %s, update_time = CURRENT_TIMESTAMP @@ -474,14 +492,24 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None cur.execute(sql.encode(), (Jsonb(state), session_id)) conn.commit() - def delete_session(self, session_id: str) -> None: + + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + await async_(self._update_session_state)(session_id, state) + + def _delete_session(self, session_id: str) -> None: sql = f"DELETE FROM {self._session_table} WHERE id = %s" with self._config.provide_connection() as conn, conn.cursor() as cur: cur.execute(sql.encode(), (session_id,)) conn.commit() - def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + + async def delete_session(self, session_id: str) -> None: + """Delete session and associated events.""" + await async_(self._delete_session)(session_id) + + def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -518,71 +546,12 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess except errors.UndefinedTable: return [] - def create_event( - self, - event_id: str, - session_id: str, - app_name: str, - user_id: str, - author: "str | None" = None, - actions: "bytes | None" = None, - content: "dict[str, Any] | None" = None, - **kwargs: Any, - ) -> EventRecord: - """Create a new event using the legacy positional API. - - This method is required by the BaseSyncADKStore contract. For new code, - prefer ``create_event_and_update_state`` which atomically persists the - event and updates session state. - """ - from datetime import datetime, timezone - - event_json: dict[str, Any] = {} - if author is not None: - event_json["author"] = author - if actions is not None: - event_json["actions"] = actions.hex() - if content is not None: - event_json["content"] = content - event_json.update({k: v for k, v in kwargs.items() if v is not None}) - - invocation_id = kwargs.get("invocation_id", "") - ts = kwargs.get("timestamp") or datetime.now(timezone.utc) - - sql = f""" - INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json - ) VALUES (%s, %s, %s, %s, %s) - RETURNING session_id, invocation_id, author, timestamp, event_json - """ - - with self._config.provide_connection() as conn, conn.cursor() as cur: - cur.execute( - sql.encode(), - ( - session_id, - invocation_id, - author or "", - ts, - Jsonb(event_json), - ), - ) - row = cur.fetchone() - conn.commit() - if row is None: - msg = f"Failed to create event {event_id}" - raise RuntimeError(msg) - - return EventRecord( - session_id=row["session_id"], - invocation_id=row["invocation_id"], - author=row["author"], - timestamp=row["timestamp"], - event_json=row["event_json"], - ) + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return await async_(self._list_sessions)(app_name, user_id) - def create_event_and_update_state( + def _append_event_and_update_state( self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" ) -> None: insert_sql = f""" @@ -613,7 +582,16 @@ def create_event_and_update_state( cur.execute(update_sql.encode(), (Jsonb(state), session_id)) conn.commit() - def list_events(self, session_id: str) -> "list[EventRecord]": + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state.""" + await async_(self._append_event_and_update_state)(event_record, session_id, state) + + def _get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": sql = f""" SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} @@ -640,6 +618,21 @@ def list_events(self, session_id: str) -> "list[EventRecord]": return [] + + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session.""" + return await async_(self._get_events)(session_id, after_timestamp, limit) + + def _append_event(self, event_record: EventRecord) -> None: + """Synchronous implementation of append_event.""" + self._append_event_and_update_state(event_record, event_record["session_id"], {}) + + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + await async_(self._append_event)(event_record) + class CockroachPsycopgAsyncADKMemoryStore(BaseAsyncADKMemoryStore["CockroachPsycopgAsyncConfig"]): """CockroachDB ADK memory store using psycopg async driver.""" @@ -805,7 +798,7 @@ async def delete_entries_older_than(self, days: int) -> int: return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 -class CockroachPsycopgSyncADKMemoryStore(BaseSyncADKMemoryStore["CockroachPsycopgSyncConfig"]): +class CockroachPsycopgSyncADKMemoryStore(BaseAsyncADKMemoryStore["CockroachPsycopgSyncConfig"]): """CockroachDB ADK memory store using psycopg sync driver.""" __slots__ = () @@ -813,7 +806,7 @@ class CockroachPsycopgSyncADKMemoryStore(BaseSyncADKMemoryStore["CockroachPsycop def __init__(self, config: "CockroachPsycopgSyncConfig") -> None: super().__init__(config) - def _get_create_memory_table_sql(self) -> str: + async def _get_create_memory_table_sql(self) -> str: owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" @@ -851,14 +844,19 @@ def _get_create_memory_table_sql(self) -> str: def _get_drop_memory_table_sql(self) -> "list[str]": return [f"DROP TABLE IF EXISTS {self._memory_table}"] - def create_tables(self) -> None: + def _create_tables(self) -> None: if not self._enabled: return with self._config.provide_session() as driver: driver.execute_script(self._get_create_memory_table_sql()) - def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -901,7 +899,12 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object inserted_count += cur.rowcount return inserted_count - def search_entries( + + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + + def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": if not self._enabled: @@ -942,7 +945,14 @@ def search_entries( return [cast("MemoryRecord", dict(zip(columns, row, strict=False))) for row in rows] - def delete_entries_by_session(self, session_id: str) -> int: + + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + + def _delete_entries_by_session(self, session_id: str) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -953,7 +963,12 @@ def delete_entries_by_session(self, session_id: str) -> int: conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 - def delete_entries_older_than(self, days: int) -> int: + + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + def _delete_entries_older_than(self, days: int) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -966,3 +981,8 @@ def delete_entries_older_than(self, days: int) -> int: cur.execute(sql.encode()) conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) diff --git a/sqlspec/adapters/duckdb/adk/store.py b/sqlspec/adapters/duckdb/adk/store.py index 0255ca758..ddb3059a9 100644 --- a/sqlspec/adapters/duckdb/adk/store.py +++ b/sqlspec/adapters/duckdb/adk/store.py @@ -16,10 +16,11 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Final, cast -from sqlspec.extensions.adk import BaseSyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore +from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json +from sqlspec.utils.sync_tools import async_ if TYPE_CHECKING: from sqlspec.adapters.duckdb.config import DuckDBConfig @@ -33,11 +34,12 @@ DUCKDB_TABLE_NOT_FOUND_ERROR: Final = "does not exist" -class DuckdbADKStore(BaseSyncADKStore["DuckDBConfig"]): +class DuckdbADKStore(BaseAsyncADKStore["DuckDBConfig"]): """DuckDB ADK store for Google Agent Development Kit. Implements session and event storage for Google Agent Development Kit - using DuckDB's synchronous driver. Provides: + using DuckDB's synchronous driver with async wrappers via ``async_()``. + Provides: - Session state management with native JSON type - Event history with single JSON blob (event_json) plus indexed scalars - Native TIMESTAMPTZ type support @@ -62,7 +64,7 @@ class DuckdbADKStore(BaseSyncADKStore["DuckDBConfig"]): } ) store = DuckdbADKStore(config) - store.ensure_tables() + await store.ensure_tables() Notes: - Uses DuckDB native JSON type for event_json and state @@ -90,7 +92,7 @@ def __init__(self, config: "DuckDBConfig") -> None: """ super().__init__(config) - def _get_create_sessions_table_sql(self) -> str: + async def _get_create_sessions_table_sql(self) -> str: """Get DuckDB CREATE TABLE SQL for sessions. Returns: @@ -122,7 +124,7 @@ def _get_create_sessions_table_sql(self) -> str: CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time ON {self._session_table}(update_time DESC); """ - def _get_create_events_table_sql(self) -> str: + async def _get_create_events_table_sql(self) -> str: """Get DuckDB CREATE TABLE SQL for events. Returns: @@ -160,31 +162,53 @@ def _get_drop_tables_sql(self) -> "list[str]": """ return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - def create_tables(self) -> None: - """Create both sessions and events tables if they don't exist.""" + def _create_tables(self) -> None: + """Synchronous implementation of create_tables.""" with self._config.provide_connection() as conn: - conn.execute(self._get_create_sessions_table_sql()) - conn.execute(self._get_create_events_table_sql()) + conn.execute(self.__get_create_sessions_table_sql_sync()) + conn.execute(self.__get_create_events_table_sql_sync()) - def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Create a new session. - - Args: - session_id: Unique session identifier. - app_name: Application name. - user_id: User identifier. - state: Initial session state. - owner_id: Optional owner ID value for owner_id_column (if configured). + def __get_create_sessions_table_sql_sync(self) -> str: + """Synchronous version of DDL generation for use in _create_tables.""" + owner_id_line = "" + if self._owner_id_column_ddl: + owner_id_line = f",\n {self._owner_id_column_ddl}" - Returns: - Created session record. + return f""" + CREATE TABLE IF NOT EXISTS {self._session_table} ( + id VARCHAR PRIMARY KEY, + app_name VARCHAR NOT NULL, + user_id VARCHAR NOT NULL{owner_id_line}, + state JSON NOT NULL, + create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user ON {self._session_table}(app_name, user_id); + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time ON {self._session_table}(update_time DESC); + """ - Notes: - Uses current UTC timestamp for create_time and update_time. - State is JSON-serialized using SQLSpec serializers. + def __get_create_events_table_sql_sync(self) -> str: + """Synchronous version of DDL generation for use in _create_tables.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._events_table} ( + session_id VARCHAR NOT NULL, + invocation_id VARCHAR NOT NULL, + author VARCHAR NOT NULL, + timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + event_json JSON NOT NULL, + FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) + ); + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); """ + + async def create_tables(self) -> None: + """Create both sessions and events tables if they don't exist.""" + await async_(self._create_tables)() + + def _create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Synchronous implementation of create_session.""" now = datetime.now(timezone.utc) state_json = to_json(state) @@ -211,19 +235,29 @@ def create_session( id=session_id, app_name=app_name, user_id=user_id, state=state, create_time=now, update_time=now ) - def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID. + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session. Args: - session_id: Session identifier. + session_id: Unique session identifier. + app_name: Application name. + user_id: User identifier. + state: Initial session state. + owner_id: Optional owner ID value for owner_id_column (if configured). Returns: - Session record or None if not found. + Created session record. Notes: - DuckDB returns datetime objects for TIMESTAMPTZ columns. - JSON is parsed from database storage. + Uses current UTC timestamp for create_time and update_time. + State is JSON-serialized using SQLSpec serializers. """ + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + def _get_session(self, session_id: str) -> "SessionRecord | None": + """Synchronous implementation of get_session.""" sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} @@ -255,17 +289,23 @@ def get_session(self, session_id: str) -> "SessionRecord | None": return None raise - def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. + async def get_session(self, session_id: str) -> "SessionRecord | None": + """Get session by ID. Args: session_id: Session identifier. - state: New state dictionary (replaces existing state). + + Returns: + Session record or None if not found. Notes: - This replaces the entire state dictionary. - Update time is automatically set to current UTC timestamp. + DuckDB returns datetime objects for TIMESTAMPTZ columns. + JSON is parsed from database storage. """ + return await async_(self._get_session)(session_id) + + def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Synchronous implementation of update_session_state.""" now = datetime.now(timezone.utc) state_json = to_json(state) @@ -279,15 +319,21 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None conn.execute(sql, (state_json, now, session_id)) conn.commit() - def delete_session(self, session_id: str) -> None: - """Delete session and all associated events. + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state. Args: session_id: Session identifier. + state: New state dictionary (replaces existing state). Notes: - DuckDB doesn't support CASCADE in foreign keys, so we manually delete events first. + This replaces the entire state dictionary. + Update time is automatically set to current UTC timestamp. """ + await async_(self._update_session_state)(session_id, state) + + def _delete_session(self, session_id: str) -> None: + """Synchronous implementation of delete_session.""" delete_events_sql = f"DELETE FROM {self._events_table} WHERE session_id = ?" delete_session_sql = f"DELETE FROM {self._session_table} WHERE id = ?" @@ -296,19 +342,19 @@ def delete_session(self, session_id: str) -> None: conn.execute(delete_session_sql, (session_id,)) conn.commit() - def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app, optionally filtered by user. + async def delete_session(self, session_id: str) -> None: + """Delete session and all associated events. Args: - app_name: Application name. - user_id: User identifier. If None, lists all sessions for the app. - - Returns: - List of session records ordered by update_time DESC. + session_id: Session identifier. Notes: - Uses composite index on (app_name, user_id) when user_id is provided. + DuckDB doesn't support CASCADE in foreign keys, so we manually delete events first. """ + await async_(self._delete_session)(session_id) + + def _list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]": + """Synchronous implementation of list_sessions.""" if user_id is None: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -347,63 +393,26 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess return [] raise - def create_event( - self, - event_id: str, - session_id: str, - app_name: str, - user_id: str, - author: "str | None" = None, - actions: "bytes | None" = None, - content: "dict[str, Any] | None" = None, - **kwargs: Any, - ) -> EventRecord: - """Create a new event using the legacy decomposed-parameter signature. - - This method satisfies the abstract base class contract. It builds an - ``EventRecord`` from the provided arguments and delegates to the new - 5-column schema. + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app, optionally filtered by user. Args: - event_id: Unique event identifier (unused in new schema, kept for API compat). - session_id: Session identifier. - app_name: Application name (stored inside event_json). - user_id: User identifier (stored inside event_json). - author: Event author (user/assistant/system). - actions: Legacy actions bytes (ignored in new schema). - content: Event content dict (stored inside event_json). - **kwargs: Additional optional fields folded into event_json. + app_name: Application name. + user_id: User identifier. If None, lists all sessions for the app. Returns: - Created event record with the new 5-key shape. + List of session records ordered by update_time DESC. + + Notes: + Uses composite index on (app_name, user_id) when user_id is provided. """ - timestamp = kwargs.get("timestamp", datetime.now(timezone.utc)) - - # Build the event_json blob from all provided fields - event_data: dict[str, Any] = { - "id": event_id, - "app_name": app_name, - "user_id": user_id, - } - if content is not None: - event_data["content"] = content - for key in ( - "invocation_id", - "branch", - "grounding_metadata", - "custom_metadata", - "long_running_tool_ids_json", - "partial", - "turn_complete", - "interrupted", - "error_code", - "error_message", - ): - val = kwargs.get(key) - if val is not None: - event_data[key] = val - - event_json_str = to_json(event_data) + return await async_(self._list_sessions)(app_name, user_id) + + def _append_event(self, event_record: EventRecord) -> None: + """Synchronous implementation of append_event.""" + event_json_value = event_record["event_json"] + if not isinstance(event_json_value, str): + event_json_value = to_json(event_json_value) sql = f""" INSERT INTO {self._events_table} @@ -415,37 +424,28 @@ def create_event( conn.execute( sql, ( - session_id, - kwargs.get("invocation_id", ""), - author or "", - timestamp, - event_json_str, + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_json_value, ), ) conn.commit() - return EventRecord( - session_id=session_id, - invocation_id=kwargs.get("invocation_id", ""), - author=author or "", - timestamp=timestamp, - event_json=event_json_str, - ) - - def create_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> None: - """Atomically create an event and update the session's durable state. - - The event insert and state update succeed together or fail together - within a single DuckDB transaction. + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session. Args: - event_record: Event record to store (5-key shape). - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot (``temp:`` keys already - stripped by the service layer). + event_record: Event record with 5 keys (session_id, invocation_id, + author, timestamp, event_json). """ + await async_(self._append_event)(event_record) + + def _append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Synchronous implementation of append_event_and_update_state.""" now = datetime.now(timezone.utc) state_json = to_json(state) event_json_value = event_record["event_json"] @@ -478,25 +478,46 @@ def create_event_and_update_state( conn.execute(update_sql, (state_json, now, session_id)) conn.commit() - def list_events(self, session_id: str) -> "list[EventRecord]": - """List events for a session ordered by timestamp. + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state. - Args: - session_id: Session identifier. + The event insert and state update succeed together or fail together + within a single DuckDB transaction. - Returns: - List of event records ordered by timestamp ASC. + Args: + event_record: Event record to store (5-key shape). + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot (``temp:`` keys already + stripped by the service layer). """ + await async_(self._append_event_and_update_state)(event_record, session_id, state) + + def _get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Synchronous implementation of get_events.""" + where_clauses = ["session_id = ?"] + params: list[Any] = [session_id] + + if after_timestamp is not None: + where_clauses.append("timestamp > ?") + params.append(after_timestamp) + + where_clause = " AND ".join(where_clauses) + limit_clause = f" LIMIT {limit}" if limit else "" + sql = f""" SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} - WHERE session_id = ? - ORDER BY timestamp ASC + WHERE {where_clause} + ORDER BY timestamp ASC{limit_clause} """ try: with self._config.provide_connection() as conn: - cursor = conn.execute(sql, (session_id,)) + cursor = conn.execute(sql, params) rows = cursor.fetchall() return [ @@ -514,12 +535,28 @@ def list_events(self, session_id: str) -> "list[EventRecord]": return [] raise + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session. + + Args: + session_id: Session identifier. + after_timestamp: Only return events after this time. + limit: Maximum number of events to return. + + Returns: + List of event records ordered by timestamp ASC. + """ + return await async_(self._get_events)(session_id, after_timestamp, limit) + -class DuckdbADKMemoryStore(BaseSyncADKMemoryStore["DuckDBConfig"]): - """DuckDB ADK memory store using synchronous DuckDB driver. +class DuckdbADKMemoryStore(BaseAsyncADKMemoryStore["DuckDBConfig"]): + """DuckDB ADK memory store using synchronous DuckDB driver with async wrappers. Implements memory entry storage for Google Agent Development Kit - using DuckDB's synchronous driver. Provides: + using DuckDB's synchronous driver with async wrappers via ``async_()``. + Provides: - Session memory storage with native JSON type - Simple ILIKE search or BM25 full-text search via FTS extension - Native TIMESTAMP type support @@ -544,7 +581,7 @@ class DuckdbADKMemoryStore(BaseSyncADKMemoryStore["DuckDBConfig"]): } ) store = DuckdbADKMemoryStore(config) - store.ensure_tables() + await store.ensure_tables() Notes: - Uses DuckDB native JSON type (not JSONB) @@ -622,7 +659,7 @@ def _refresh_fts_index(self, conn: Any) -> None: except Exception as exc: logger.debug("Failed to refresh DuckDB FTS index: %s", exc) - def _get_create_memory_table_sql(self) -> str: + async def _get_create_memory_table_sql(self) -> str: """Get DuckDB CREATE TABLE SQL for memory entries. Returns: @@ -658,21 +695,51 @@ def _get_drop_memory_table_sql(self) -> "list[str]": """Get DuckDB DROP TABLE SQL statements.""" return [f"DROP TABLE IF EXISTS {self._memory_table}"] - def create_tables(self) -> None: - """Create the memory table and indexes if they don't exist.""" + def _create_tables(self) -> None: + """Synchronous implementation of create_tables.""" if not self._enabled: return + ddl = self.__get_create_memory_table_sql_sync() with self._config.provide_connection() as conn: - conn.execute(self._get_create_memory_table_sql()) + conn.execute(ddl) if self._use_fts: self._create_fts_index(conn) - def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication. + def __get_create_memory_table_sql_sync(self) -> str: + """Synchronous version of DDL generation for use in _create_tables.""" + owner_id_line = "" + if self._owner_id_column_ddl: + owner_id_line = f",\n {self._owner_id_column_ddl}" - After successful inserts, refreshes the FTS index if FTS is enabled. + return f""" + CREATE TABLE IF NOT EXISTS {self._memory_table} ( + id VARCHAR(128) PRIMARY KEY, + session_id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + event_id VARCHAR(128) NOT NULL UNIQUE, + author VARCHAR(256){owner_id_line}, + timestamp TIMESTAMP NOT NULL, + content_json JSON NOT NULL, + content_text TEXT NOT NULL, + metadata_json JSON, + inserted_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time + ON {self._memory_table}(app_name, user_id, timestamp DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session + ON {self._memory_table}(session_id); """ + + async def create_tables(self) -> None: + """Create the memory table and indexes if they don't exist.""" + await async_(self._create_tables)() + + def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Synchronous implementation of insert_memory_entries.""" if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -741,14 +808,17 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object return inserted_count - def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - """Search memory entries by text query. + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication. - When FTS is enabled, uses ``match_bm25()`` for BM25-ranked results. - Falls back to ILIKE for simple substring matching. + After successful inserts, refreshes the FTS index if FTS is enabled. """ + return await async_(self._insert_memory_entries)(entries, owner_id) + + def _search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Synchronous implementation of search_entries.""" if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -795,8 +865,18 @@ def search_entries( records.append(record) return records - def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query. + + When FTS is enabled, uses ``match_bm25()`` for BM25-ranked results. + Falls back to ILIKE for simple substring matching. + """ + return await async_(self._search_entries)(query, app_name, user_id, limit) + + def _delete_entries_by_session(self, session_id: str) -> int: + """Synchronous implementation of delete_entries_by_session.""" if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -810,8 +890,12 @@ def delete_entries_by_session(self, session_id: str) -> int: self._refresh_fts_index(conn) return deleted_count - def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + def _delete_entries_older_than(self, days: int) -> int: + """Synchronous implementation of delete_entries_older_than.""" if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -828,3 +912,7 @@ def delete_entries_older_than(self, days: int) -> int: if self._use_fts and deleted_count > 0: self._refresh_fts_index(conn) return deleted_count + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) diff --git a/sqlspec/adapters/mysqlconnector/adk/store.py b/sqlspec/adapters/mysqlconnector/adk/store.py index 2c3079688..aaa014a1f 100644 --- a/sqlspec/adapters/mysqlconnector/adk/store.py +++ b/sqlspec/adapters/mysqlconnector/adk/store.py @@ -5,9 +5,10 @@ import mysql.connector -from sqlspec.extensions.adk import BaseAsyncADKStore, BaseSyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore +from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.utils.serializers import from_json, to_json +from sqlspec.utils.sync_tools import async_ if TYPE_CHECKING: from datetime import datetime @@ -385,7 +386,7 @@ async def get_events( raise -class MysqlConnectorSyncADKStore(BaseSyncADKStore["MysqlConnectorSyncConfig"]): +class MysqlConnectorSyncADKStore(BaseAsyncADKStore["MysqlConnectorSyncConfig"]): """MySQL/MariaDB ADK store using mysql-connector sync driver. Provides: @@ -408,21 +409,26 @@ def __init__(self, config: "MysqlConnectorSyncConfig") -> None: def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]": return _parse_owner_id_column_for_mysql(column_ddl) - def _get_create_sessions_table_sql(self) -> str: + async def _get_create_sessions_table_sql(self) -> str: return _mysql_sessions_ddl(self._session_table, self._owner_id_column_ddl) - def _get_create_events_table_sql(self) -> str: + async def _get_create_events_table_sql(self) -> str: return _mysql_events_ddl(self._events_table, self._session_table) def _get_drop_tables_sql(self) -> "list[str]": return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - def create_tables(self) -> None: + def _create_tables(self) -> None: with self._config.provide_session() as driver: driver.execute_script(self._get_create_sessions_table_sql()) driver.execute_script(self._get_create_events_table_sql()) - def create_session( + + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: state_json = to_json(state) @@ -455,7 +461,14 @@ def create_session( raise RuntimeError(msg) return result - def get_session(self, session_id: str) -> "SessionRecord | None": + + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + def _get_session(self, session_id: str) -> "SessionRecord | None": sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} @@ -489,7 +502,12 @@ def get_session(self, session_id: str) -> "SessionRecord | None": return None raise - def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + + async def get_session(self, session_id: str) -> "SessionRecord | None": + """Get session by ID.""" + return await async_(self._get_session)(session_id) + + def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: state_json = to_json(state) sql = f""" @@ -506,7 +524,12 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None cursor.close() conn.commit() - def delete_session(self, session_id: str) -> None: + + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + await async_(self._update_session_state)(session_id, state) + + def _delete_session(self, session_id: str) -> None: sql = f"DELETE FROM {self._session_table} WHERE id = %s" with self._config.provide_connection() as conn: @@ -517,7 +540,12 @@ def delete_session(self, session_id: str) -> None: cursor.close() conn.commit() - def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + + async def delete_session(self, session_id: str) -> None: + """Delete session and associated events.""" + await async_(self._delete_session)(session_id) + + def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -560,65 +588,12 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess return [] raise - def create_event( - self, - event_id: str, - session_id: str, - app_name: str, - user_id: str, - author: "str | None" = None, - actions: "bytes | None" = None, - content: "dict[str, Any] | None" = None, - **kwargs: Any, - ) -> EventRecord: - """Create a new event. - Args: - event_id: Unique event identifier (unused in new schema, kept for contract). - session_id: Session identifier. - app_name: Application name (unused in new schema, kept for contract). - user_id: User identifier (unused in new schema, kept for contract). - author: Event author. - actions: Unused in new contract (kept for interface compatibility). - content: Event content dictionary. - **kwargs: Additional fields including invocation_id, timestamp, event_json. - - Returns: - Created event record. - """ - from datetime import datetime, timezone - - invocation_id = kwargs.get("invocation_id", "") - timestamp = kwargs.get("timestamp", datetime.now(tz=timezone.utc)) - event_json = kwargs.get("event_json", content or {}) - event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json - - sql = f""" - INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json - ) VALUES (%s, %s, %s, %s, %s) - """ - - with self._config.provide_connection() as conn: - cursor = conn.cursor() - try: - cursor.execute( - sql, - (session_id, invocation_id, author or "", timestamp, event_json_str), - ) - finally: - cursor.close() - conn.commit() - - return EventRecord( - session_id=session_id, - invocation_id=invocation_id, - author=author or "", - timestamp=timestamp, - event_json=event_json, - ) + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return await async_(self._list_sessions)(app_name, user_id) - def create_event_and_update_state( + def _append_event_and_update_state( self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" ) -> None: """Atomically create an event and update the session's durable state. @@ -665,7 +640,16 @@ def create_event_and_update_state( cursor.close() conn.commit() - def list_events(self, session_id: str) -> "list[EventRecord]": + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state.""" + await async_(self._append_event_and_update_state)(event_record, session_id, state) + + def _get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": """List events for a session ordered by timestamp. Args: @@ -706,6 +690,21 @@ def list_events(self, session_id: str) -> "list[EventRecord]": raise + + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session.""" + return await async_(self._get_events)(session_id, after_timestamp, limit) + + def _append_event(self, event_record: EventRecord) -> None: + """Synchronous implementation of append_event.""" + self._append_event_and_update_state(event_record, event_record["session_id"], {}) + + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + await async_(self._append_event)(event_record) + class MysqlConnectorAsyncADKMemoryStore(BaseAsyncADKMemoryStore["MysqlConnectorAsyncConfig"]): """MySQL/MariaDB ADK memory store using mysql-connector async driver.""" @@ -895,7 +894,7 @@ async def delete_entries_older_than(self, days: int) -> int: await cursor.close() -class MysqlConnectorSyncADKMemoryStore(BaseSyncADKMemoryStore["MysqlConnectorSyncConfig"]): +class MysqlConnectorSyncADKMemoryStore(BaseAsyncADKMemoryStore["MysqlConnectorSyncConfig"]): """MySQL/MariaDB ADK memory store using mysql-connector sync driver.""" __slots__ = () @@ -903,7 +902,7 @@ class MysqlConnectorSyncADKMemoryStore(BaseSyncADKMemoryStore["MysqlConnectorSyn def __init__(self, config: "MysqlConnectorSyncConfig") -> None: super().__init__(config) - def _get_create_memory_table_sql(self) -> str: + async def _get_create_memory_table_sql(self) -> str: owner_id_line = "" fk_constraint = "" if self._owner_id_column_ddl: @@ -937,14 +936,19 @@ def _get_create_memory_table_sql(self) -> str: def _get_drop_memory_table_sql(self) -> "list[str]": return [f"DROP TABLE IF EXISTS {self._memory_table}"] - def create_tables(self) -> None: + def _create_tables(self) -> None: if not self._enabled: return with self._config.provide_session() as driver: driver.execute_script(self._get_create_memory_table_sql()) - def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -1010,7 +1014,12 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object conn.commit() return inserted_count - def search_entries( + + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + + def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": if not self._enabled: @@ -1050,7 +1059,14 @@ def search_entries( return [cast("MemoryRecord", dict(zip(columns, row, strict=False))) for row in rows] - def delete_entries_by_session(self, session_id: str) -> int: + + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + + def _delete_entries_by_session(self, session_id: str) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -1065,7 +1081,12 @@ def delete_entries_by_session(self, session_id: str) -> int: finally: cursor.close() - def delete_entries_older_than(self, days: int) -> int: + + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + def _delete_entries_older_than(self, days: int) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -1082,3 +1103,8 @@ def delete_entries_older_than(self, days: int) -> int: return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 finally: cursor.close() + + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) diff --git a/sqlspec/adapters/oracledb/adk/store.py b/sqlspec/adapters/oracledb/adk/store.py index 2e7d994b7..69d2b3879 100644 --- a/sqlspec/adapters/oracledb/adk/store.py +++ b/sqlspec/adapters/oracledb/adk/store.py @@ -12,10 +12,11 @@ OracledbSyncDataDictionary, OracleVersionInfo, ) -from sqlspec.extensions.adk import BaseAsyncADKStore, BaseSyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore +from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json +from sqlspec.utils.sync_tools import async_ from sqlspec.utils.type_guards import is_async_readable, is_readable if TYPE_CHECKING: @@ -785,7 +786,7 @@ async def get_events( raise -class OracleSyncADKStore(BaseSyncADKStore["OracleSyncConfig"]): +class OracleSyncADKStore(BaseAsyncADKStore["OracleSyncConfig"]): """Oracle synchronous ADK store using oracledb sync driver. Implements session and event storage for Google Agent Development Kit @@ -832,7 +833,7 @@ def __init__(self, config: "OracleSyncConfig") -> None: adk_config = config.extension_config.get("adk", {}) self._in_memory: bool = bool(adk_config.get("in_memory", False)) - def _get_create_sessions_table_sql(self) -> str: + async def _get_create_sessions_table_sql(self) -> str: """Get Oracle CREATE TABLE SQL for sessions table. Auto-detects optimal JSON storage type based on Oracle version. @@ -841,7 +842,7 @@ def _get_create_sessions_table_sql(self) -> str: storage_type = self._detect_json_storage_type() return self._get_create_sessions_table_sql_for_type(storage_type) - def _get_create_events_table_sql(self) -> str: + async def _get_create_events_table_sql(self) -> str: """Get Oracle CREATE TABLE SQL for events table. Auto-detects optimal JSON storage type based on Oracle version. @@ -1114,7 +1115,7 @@ def _get_drop_tables_sql(self) -> "list[str]": """, ] - def create_tables(self) -> None: + def _create_tables(self) -> None: """Create both sessions and events tables if they don't exist. Notes: @@ -1131,7 +1132,12 @@ def create_tables(self) -> None: events_sql = SQL(self._get_create_events_table_sql_for_type(storage_type)) driver.execute_script(events_sql) - def create_session( + + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: """Create a new session. @@ -1179,7 +1185,14 @@ def create_session( return self.get_session(session_id) # type: ignore[return-value] - def get_session(self, session_id: str) -> "SessionRecord | None": + + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + def _get_session(self, session_id: str) -> "SessionRecord | None": """Get session by ID. Args: @@ -1226,7 +1239,12 @@ def get_session(self, session_id: str) -> "SessionRecord | None": return None raise - def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + + async def get_session(self, session_id: str) -> "SessionRecord | None": + """Get session by ID.""" + return await async_(self._get_session)(session_id) + + def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: """Update session state. Args: @@ -1251,7 +1269,12 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None cursor.execute(sql, {"state": state_data, "id": session_id}) conn.commit() - def delete_session(self, session_id: str) -> None: + + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + await async_(self._update_session_state)(session_id, state) + + def _delete_session(self, session_id: str) -> None: """Delete session and all associated events (cascade). Args: @@ -1267,7 +1290,12 @@ def delete_session(self, session_id: str) -> None: cursor.execute(sql, {"id": session_id}) conn.commit() - def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + + async def delete_session(self, session_id: str) -> None: + """Delete session and associated events.""" + await async_(self._delete_session)(session_id) + + def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app, optionally filtered by user. Args: @@ -1326,68 +1354,12 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess return [] raise - def create_event( - self, - event_id: str, - session_id: str, - app_name: str, - user_id: str, - author: "str | None" = None, - actions: "bytes | None" = None, - content: "dict[str, Any] | None" = None, - **kwargs: Any, - ) -> "EventRecord": - """Create a new event. - Args: - event_id: Unused (kept for base class compatibility). - session_id: Session identifier. - app_name: Unused (kept for base class compatibility). - user_id: Unused (kept for base class compatibility). - author: Event author. - actions: Unused (no longer stored). - content: Unused (no longer stored separately). - **kwargs: Must include ``invocation_id``, ``timestamp``, and - ``event_json``. - - Returns: - Created event record. - """ - event_json: str = kwargs["event_json"] - invocation_id: str = kwargs.get("invocation_id", "") - timestamp = kwargs.get("timestamp") - - sql = f""" - INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json - ) VALUES ( - :session_id, :invocation_id, :author, :timestamp, :event_json - ) - """ - - with self._config.provide_connection() as conn: - cursor = conn.cursor() - cursor.execute( - sql, - { - "session_id": session_id, - "invocation_id": invocation_id, - "author": author or "", - "timestamp": timestamp, - "event_json": event_json, - }, - ) - conn.commit() - - return EventRecord( - session_id=session_id, - invocation_id=invocation_id, - author=author or "", - timestamp=timestamp, - event_json=event_json, - ) + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return await async_(self._list_sessions)(app_name, user_id) - def create_event_and_update_state( + def _append_event_and_update_state( self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" ) -> None: """Atomically create an event and update the session's durable state. @@ -1432,7 +1404,16 @@ def create_event_and_update_state( cursor.execute(update_sql, {"state": state_data, "id": session_id}) conn.commit() - def list_events(self, session_id: str) -> "list[EventRecord]": + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state.""" + await async_(self._append_event_and_update_state)(event_record, session_id, state) + + def _get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": """List events for a session ordered by timestamp. Args: @@ -1476,6 +1457,21 @@ def list_events(self, session_id: str) -> "list[EventRecord]": raise + + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session.""" + return await async_(self._get_events)(session_id, after_timestamp, limit) + + def _append_event(self, event_record: EventRecord) -> None: + """Synchronous implementation of append_event.""" + self._append_event_and_update_state(event_record, event_record["session_id"], {}) + + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + await async_(self._append_event)(event_record) + ORACLE_DUPLICATE_KEY_ERROR: Final = 1 @@ -1842,7 +1838,7 @@ async def _rows_to_records(self, rows: "list[Any]") -> "list[MemoryRecord]": return records -class OracleSyncADKMemoryStore(BaseSyncADKMemoryStore["OracleSyncConfig"]): +class OracleSyncADKMemoryStore(BaseAsyncADKMemoryStore["OracleSyncConfig"]): """Oracle ADK memory store using sync oracledb driver.""" __slots__ = ("_in_memory", "_json_storage_type", "_oracle_version_info") @@ -1893,7 +1889,7 @@ def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": return _extract_json_value(data) - def _get_create_memory_table_sql(self) -> str: + async def _get_create_memory_table_sql(self) -> str: storage_type = self._detect_json_storage_type() return self._get_create_memory_table_sql_for_type(storage_type) @@ -2008,13 +2004,18 @@ def _get_drop_memory_table_sql(self) -> "list[str]": """, ] - def create_tables(self) -> None: + def _create_tables(self) -> None: if not self._enabled: return with self._config.provide_session() as driver: driver.execute_script(self._get_create_memory_table_sql()) + + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + def _execute_insert_entry(self, cursor: Any, sql: str, params: "dict[str, Any]") -> bool: """Execute an insert and skip duplicate key errors.""" try: @@ -2026,7 +2027,7 @@ def _execute_insert_entry(self, cursor: Any, sql: str, params: "dict[str, Any]") raise return True - def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -2073,7 +2074,12 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object return inserted_count - def search_entries( + + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + + def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": if not self._enabled: @@ -2092,6 +2098,13 @@ def search_entries( return [] raise + + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": sql = f""" SELECT id, session_id, app_name, user_id, event_id, author, @@ -2138,7 +2151,7 @@ def _search_entries_simple(self, query: str, app_name: str, user_id: str, limit: rows = cursor.fetchall() return self._rows_to_records(rows) - def delete_entries_by_session(self, session_id: str) -> int: + def _delete_entries_by_session(self, session_id: str) -> int: sql = f"DELETE FROM {self._memory_table} WHERE session_id = :session_id" with self._config.provide_connection() as conn: cursor = conn.cursor() @@ -2146,7 +2159,12 @@ def delete_entries_by_session(self, session_id: str) -> int: conn.commit() return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 - def delete_entries_older_than(self, days: int) -> int: + + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + def _delete_entries_older_than(self, days: int) -> int: sql = f""" DELETE FROM {self._memory_table} WHERE inserted_at < SYSTIMESTAMP - NUMTODSINTERVAL(:days, 'DAY') @@ -2157,6 +2175,11 @@ def delete_entries_older_than(self, days: int) -> int: conn.commit() return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) + def _rows_to_records(self, rows: "list[Any]") -> "list[MemoryRecord]": records: list[MemoryRecord] = [] for row in rows: diff --git a/sqlspec/adapters/psycopg/adk/store.py b/sqlspec/adapters/psycopg/adk/store.py index bee3991ee..73dceae2a 100644 --- a/sqlspec/adapters/psycopg/adk/store.py +++ b/sqlspec/adapters/psycopg/adk/store.py @@ -6,9 +6,10 @@ from psycopg import sql as pg_sql from psycopg.types.json import Jsonb -from sqlspec.extensions.adk import BaseAsyncADKStore, BaseSyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore +from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.utils.logging import get_logger +from sqlspec.utils.sync_tools import async_ if TYPE_CHECKING: from datetime import datetime @@ -333,7 +334,7 @@ async def get_events( return [] -class PsycopgSyncADKStore(BaseSyncADKStore["PsycopgSyncConfig"]): +class PsycopgSyncADKStore(BaseAsyncADKStore["PsycopgSyncConfig"]): """PostgreSQL synchronous ADK store using Psycopg3 driver. Implements session and event storage for Google Agent Development Kit @@ -359,7 +360,7 @@ class PsycopgSyncADKStore(BaseSyncADKStore["PsycopgSyncConfig"]): def __init__(self, config: "PsycopgSyncConfig") -> None: super().__init__(config) - def _get_create_sessions_table_sql(self) -> str: + async def _get_create_sessions_table_sql(self) -> str: owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" @@ -385,7 +386,7 @@ def _get_create_sessions_table_sql(self) -> str: WHERE state != '{{}}'::jsonb; """ - def _get_create_events_table_sql(self) -> str: + async def _get_create_events_table_sql(self) -> str: return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( session_id VARCHAR(128) NOT NULL, @@ -403,12 +404,17 @@ def _get_create_events_table_sql(self) -> str: def _get_drop_tables_sql(self) -> "list[str]": return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - def create_tables(self) -> None: + def _create_tables(self) -> None: with self._config.provide_session() as driver: driver.execute_script(self._get_create_sessions_table_sql()) driver.execute_script(self._get_create_events_table_sql()) - def create_session( + + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: params: tuple[Any, ...] @@ -432,7 +438,14 @@ def create_session( return self.get_session(session_id) # type: ignore[return-value] - def get_session(self, session_id: str) -> "SessionRecord | None": + + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + def _get_session(self, session_id: str) -> "SessionRecord | None": query = pg_sql.SQL(""" SELECT id, app_name, user_id, state, create_time, update_time FROM {table} @@ -458,7 +471,12 @@ def get_session(self, session_id: str) -> "SessionRecord | None": except errors.UndefinedTable: return None - def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + + async def get_session(self, session_id: str) -> "SessionRecord | None": + """Get session by ID.""" + return await async_(self._get_session)(session_id) + + def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: query = pg_sql.SQL(""" UPDATE {table} SET state = %s, update_time = CURRENT_TIMESTAMP @@ -468,13 +486,23 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None with self._config.provide_connection() as conn, conn.cursor() as cur: cur.execute(query, (Jsonb(state), session_id)) - def delete_session(self, session_id: str) -> None: + + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + await async_(self._update_session_state)(session_id, state) + + def _delete_session(self, session_id: str) -> None: query = pg_sql.SQL("DELETE FROM {table} WHERE id = %s").format(table=pg_sql.Identifier(self._session_table)) with self._config.provide_connection() as conn, conn.cursor() as cur: cur.execute(query, (session_id,)) - def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + + async def delete_session(self, session_id: str) -> None: + """Delete session and associated events.""" + await async_(self._delete_session)(session_id) + + def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: query = pg_sql.SQL(""" SELECT id, app_name, user_id, state, create_time, update_time @@ -511,70 +539,12 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess except errors.UndefinedTable: return [] - def create_event( - self, - event_id: str, - session_id: str, - app_name: str, - user_id: str, - author: "str | None" = None, - actions: "bytes | None" = None, - content: "dict[str, Any] | None" = None, - **kwargs: Any, - ) -> EventRecord: - """Create a new event using the legacy positional API. - - This method is required by the BaseSyncADKStore contract. For new code, - prefer ``create_event_and_update_state`` which atomically persists the - event and updates session state. - """ - from datetime import datetime, timezone - - event_json: dict[str, Any] = {} - if author is not None: - event_json["author"] = author - if actions is not None: - event_json["actions"] = actions.hex() - if content is not None: - event_json["content"] = content - event_json.update({k: v for k, v in kwargs.items() if v is not None}) - - invocation_id = kwargs.get("invocation_id", "") - ts = kwargs.get("timestamp") or datetime.now(timezone.utc) - - query = pg_sql.SQL(""" - INSERT INTO {table} ( - session_id, invocation_id, author, timestamp, event_json - ) VALUES (%s, %s, %s, %s, %s) - RETURNING session_id, invocation_id, author, timestamp, event_json - """).format(table=pg_sql.Identifier(self._events_table)) - with self._config.provide_connection() as conn, conn.cursor() as cur: - cur.execute( - query, - ( - session_id, - invocation_id, - author or "", - ts, - Jsonb(event_json), - ), - ) - row = cur.fetchone() - - if row is None: - msg = f"Failed to create event {event_id}" - raise RuntimeError(msg) - - return EventRecord( - session_id=row["session_id"], - invocation_id=row["invocation_id"], - author=row["author"], - timestamp=row["timestamp"], - event_json=row["event_json"], - ) + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return await async_(self._list_sessions)(app_name, user_id) - def create_event_and_update_state( + def _append_event_and_update_state( self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" ) -> None: insert_query = pg_sql.SQL(""" @@ -606,7 +576,16 @@ def create_event_and_update_state( cur.execute(update_query, (Jsonb(state), session_id)) conn.commit() - def list_events(self, session_id: str) -> "list[EventRecord]": + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state.""" + await async_(self._append_event_and_update_state)(event_record, session_id, state) + + def _get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": query = pg_sql.SQL(""" SELECT session_id, invocation_id, author, timestamp, event_json FROM {table} @@ -633,6 +612,21 @@ def list_events(self, session_id: str) -> "list[EventRecord]": return [] + + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session.""" + return await async_(self._get_events)(session_id, after_timestamp, limit) + + def _append_event(self, event_record: EventRecord) -> None: + """Synchronous implementation of append_event.""" + self._append_event_and_update_state(event_record, event_record["session_id"], {}) + + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + await async_(self._append_event)(event_record) + class PsycopgAsyncADKMemoryStore(BaseAsyncADKMemoryStore["PsycopgAsyncConfig"]): """PostgreSQL ADK memory store using Psycopg3 async driver.""" @@ -819,7 +813,7 @@ async def delete_entries_older_than(self, days: int) -> int: return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 -class PsycopgSyncADKMemoryStore(BaseSyncADKMemoryStore["PsycopgSyncConfig"]): +class PsycopgSyncADKMemoryStore(BaseAsyncADKMemoryStore["PsycopgSyncConfig"]): """PostgreSQL ADK memory store using Psycopg3 sync driver.""" __slots__ = () @@ -828,7 +822,7 @@ def __init__(self, config: "PsycopgSyncConfig") -> None: """Initialize Psycopg sync memory store.""" super().__init__(config) - def _get_create_memory_table_sql(self) -> str: + async def _get_create_memory_table_sql(self) -> str: """Get PostgreSQL CREATE TABLE SQL for memory entries.""" owner_id_line = "" if self._owner_id_column_ddl: @@ -868,7 +862,7 @@ def _get_drop_memory_table_sql(self) -> "list[str]": """Get PostgreSQL DROP TABLE SQL statements.""" return [f"DROP TABLE IF EXISTS {self._memory_table}"] - def create_tables(self) -> None: + def _create_tables(self) -> None: """Create the memory table and indexes if they don't exist.""" if not self._enabled: return @@ -876,7 +870,12 @@ def create_tables(self) -> None: with self._config.provide_session() as driver: driver.execute_script(self._get_create_memory_table_sql()) - def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: """Bulk insert memory entries with deduplication.""" if not self._enabled: msg = "Memory store is disabled" @@ -921,7 +920,12 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object return inserted_count - def search_entries( + + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + + def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": """Search memory entries by text query.""" @@ -941,6 +945,13 @@ def search_entries( except errors.UndefinedTable: return [] + + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": sql = pg_sql.SQL( """ @@ -981,7 +992,7 @@ def _search_entries_simple(self, query: str, app_name: str, user_id: str, limit: rows = cur.fetchall() return _rows_to_records(rows) - def delete_entries_by_session(self, session_id: str) -> int: + def _delete_entries_by_session(self, session_id: str) -> int: """Delete all memory entries for a specific session.""" sql = pg_sql.SQL("DELETE FROM {table} WHERE session_id = %s").format( table=pg_sql.Identifier(self._memory_table) @@ -991,7 +1002,12 @@ def delete_entries_by_session(self, session_id: str) -> int: cur.execute(sql, (session_id,)) return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 - def delete_entries_older_than(self, days: int) -> int: + + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + def _delete_entries_older_than(self, days: int) -> int: """Delete memory entries older than specified days.""" sql = pg_sql.SQL( """ @@ -1005,6 +1021,11 @@ def delete_entries_older_than(self, days: int) -> int: return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) + def _rows_to_records(rows: "list[Any]") -> "list[MemoryRecord]": return [ { diff --git a/sqlspec/adapters/pymysql/adk/store.py b/sqlspec/adapters/pymysql/adk/store.py index a73d3f638..05a5d0c69 100644 --- a/sqlspec/adapters/pymysql/adk/store.py +++ b/sqlspec/adapters/pymysql/adk/store.py @@ -5,9 +5,10 @@ import pymysql -from sqlspec.extensions.adk import BaseSyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore +from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.utils.serializers import from_json, to_json +from sqlspec.utils.sync_tools import async_ if TYPE_CHECKING: from sqlspec.adapters.pymysql.config import PyMysqlConfig @@ -31,7 +32,7 @@ def _parse_owner_id_column_for_mysql(column_ddl: str) -> "tuple[str, str]": return (col_def, fk_constraint) -class PyMysqlADKStore(BaseSyncADKStore["PyMysqlConfig"]): +class PyMysqlADKStore(BaseAsyncADKStore["PyMysqlConfig"]): """MySQL/MariaDB ADK store using PyMySQL. Implements session and event storage for Google Agent Development Kit @@ -57,7 +58,7 @@ def __init__(self, config: "PyMysqlConfig") -> None: def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]": return _parse_owner_id_column_for_mysql(column_ddl) - def _get_create_sessions_table_sql(self) -> str: + async def _get_create_sessions_table_sql(self) -> str: owner_id_col = "" fk_constraint = "" @@ -81,7 +82,7 @@ def _get_create_sessions_table_sql(self) -> str: ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci """ - def _get_create_events_table_sql(self) -> str: + async def _get_create_events_table_sql(self) -> str: """Get MySQL CREATE TABLE SQL for events. Post clean-break schema: 5 columns only. @@ -101,12 +102,17 @@ def _get_create_events_table_sql(self) -> str: def _get_drop_tables_sql(self) -> "list[str]": return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - def create_tables(self) -> None: + def _create_tables(self) -> None: with self._config.provide_session() as driver: driver.execute_script(self._get_create_sessions_table_sql()) driver.execute_script(self._get_create_events_table_sql()) - def create_session( + + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: state_json = to_json(state) @@ -139,7 +145,14 @@ def create_session( raise RuntimeError(msg) return result - def get_session(self, session_id: str) -> "SessionRecord | None": + + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + def _get_session(self, session_id: str) -> "SessionRecord | None": sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} @@ -173,7 +186,12 @@ def get_session(self, session_id: str) -> "SessionRecord | None": return None raise - def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + + async def get_session(self, session_id: str) -> "SessionRecord | None": + """Get session by ID.""" + return await async_(self._get_session)(session_id) + + def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: state_json = to_json(state) sql = f""" @@ -190,7 +208,12 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None cursor.close() conn.commit() - def delete_session(self, session_id: str) -> None: + + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + await async_(self._update_session_state)(session_id, state) + + def _delete_session(self, session_id: str) -> None: sql = f"DELETE FROM {self._session_table} WHERE id = %s" with self._config.provide_connection() as conn: @@ -201,7 +224,12 @@ def delete_session(self, session_id: str) -> None: cursor.close() conn.commit() - def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + + async def delete_session(self, session_id: str) -> None: + """Delete session and associated events.""" + await async_(self._delete_session)(session_id) + + def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -244,67 +272,12 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess return [] raise - def create_event( - self, - event_id: str, - session_id: str, - app_name: str, - user_id: str, - author: "str | None" = None, - actions: "bytes | None" = None, - content: "dict[str, Any] | None" = None, - **kwargs: Any, - ) -> EventRecord: - """Create a new event. - - Constructs an EventRecord from the provided fields and inserts it. - - Args: - event_id: Unique event identifier (unused in new schema, kept for contract). - session_id: Session identifier. - app_name: Application name (unused in new schema, kept for contract). - user_id: User identifier (unused in new schema, kept for contract). - author: Event author. - actions: Unused in new contract (kept for interface compatibility). - content: Event content dictionary. - **kwargs: Additional fields including invocation_id, timestamp, event_json. - - Returns: - Created event record. - """ - from datetime import datetime, timezone - invocation_id = kwargs.get("invocation_id", "") - timestamp = kwargs.get("timestamp", datetime.now(tz=timezone.utc)) - event_json = kwargs.get("event_json", content or {}) - event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json - - sql = f""" - INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json - ) VALUES (%s, %s, %s, %s, %s) - """ + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return await async_(self._list_sessions)(app_name, user_id) - with self._config.provide_connection() as conn: - cursor = conn.cursor() - try: - cursor.execute( - sql, - (session_id, invocation_id, author or "", timestamp, event_json_str), - ) - finally: - cursor.close() - conn.commit() - - return EventRecord( - session_id=session_id, - invocation_id=invocation_id, - author=author or "", - timestamp=timestamp, - event_json=event_json, - ) - - def create_event_and_update_state( + def _append_event_and_update_state( self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" ) -> None: """Atomically create an event and update the session's durable state. @@ -351,7 +324,16 @@ def create_event_and_update_state( cursor.close() conn.commit() - def list_events(self, session_id: str) -> "list[EventRecord]": + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state.""" + await async_(self._append_event_and_update_state)(event_record, session_id, state) + + def _get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": """List events for a session ordered by timestamp. Args: @@ -392,7 +374,22 @@ def list_events(self, session_id: str) -> "list[EventRecord]": raise -class PyMysqlADKMemoryStore(BaseSyncADKMemoryStore["PyMysqlConfig"]): + + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session.""" + return await async_(self._get_events)(session_id, after_timestamp, limit) + + def _append_event(self, event_record: EventRecord) -> None: + """Synchronous implementation of append_event.""" + self._append_event_and_update_state(event_record, event_record["session_id"], {}) + + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + await async_(self._append_event)(event_record) + +class PyMysqlADKMemoryStore(BaseAsyncADKMemoryStore["PyMysqlConfig"]): """MySQL/MariaDB ADK memory store using PyMySQL.""" __slots__ = () @@ -400,7 +397,7 @@ class PyMysqlADKMemoryStore(BaseSyncADKMemoryStore["PyMysqlConfig"]): def __init__(self, config: "PyMysqlConfig") -> None: super().__init__(config) - def _get_create_memory_table_sql(self) -> str: + async def _get_create_memory_table_sql(self) -> str: owner_id_line = "" fk_constraint = "" if self._owner_id_column_ddl: @@ -434,14 +431,19 @@ def _get_create_memory_table_sql(self) -> str: def _get_drop_memory_table_sql(self) -> "list[str]": return [f"DROP TABLE IF EXISTS {self._memory_table}"] - def create_tables(self) -> None: + def _create_tables(self) -> None: if not self._enabled: return with self._config.provide_session() as driver: driver.execute_script(self._get_create_memory_table_sql()) - def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -507,7 +509,12 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object conn.commit() return inserted_count - def search_entries( + + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + + def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": if not self._enabled: @@ -547,7 +554,14 @@ def search_entries( return [cast("MemoryRecord", dict(zip(columns, row, strict=False))) for row in rows] - def delete_entries_by_session(self, session_id: str) -> int: + + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + + def _delete_entries_by_session(self, session_id: str) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -562,7 +576,12 @@ def delete_entries_by_session(self, session_id: str) -> int: finally: cursor.close() - def delete_entries_older_than(self, days: int) -> int: + + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + def _delete_entries_older_than(self, days: int) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -579,3 +598,8 @@ def delete_entries_older_than(self, days: int) -> int: return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 finally: cursor.close() + + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) diff --git a/sqlspec/adapters/spanner/adk/store.py b/sqlspec/adapters/spanner/adk/store.py index dc9090a86..a83aab48d 100644 --- a/sqlspec/adapters/spanner/adk/store.py +++ b/sqlspec/adapters/spanner/adk/store.py @@ -7,10 +7,11 @@ from google.cloud.spanner_v1 import param_types from sqlspec.adapters.spanner.config import SpannerSyncConfig -from sqlspec.extensions.adk import BaseSyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore +from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.protocols import SpannerParamTypesProtocol from sqlspec.utils.serializers import from_json, to_json +from sqlspec.utils.sync_tools import async_ if TYPE_CHECKING: from google.cloud.spanner_v1.database import Database @@ -41,7 +42,7 @@ def __call__(self, transaction: "Transaction") -> None: transaction.execute_update(sql, params=params, param_types=types) # type: ignore[no-untyped-call] -class SpannerSyncADKStore(BaseSyncADKStore[SpannerSyncConfig]): +class SpannerSyncADKStore(BaseAsyncADKStore[SpannerSyncConfig]): """Spanner ADK store backed by synchronous Spanner client.""" connector_name: ClassVar[str] = "spanner" @@ -101,7 +102,7 @@ def _decode_json(self, raw: Any) -> Any: return from_json(raw) return raw - def create_session( + def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: state_json = to_json(state) @@ -130,7 +131,14 @@ def create_session( "update_time": datetime.now(timezone.utc), } - def get_session(self, session_id: str) -> "SessionRecord | None": + + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + def _get_session(self, session_id: str) -> "SessionRecord | None": sql = f""" SELECT id, app_name, user_id, state, create_time, update_time{", " + self._owner_id_column_name if self._owner_id_column_name else ""} FROM {self._session_table} @@ -156,7 +164,12 @@ def get_session(self, session_id: str) -> "SessionRecord | None": } return record - def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + + async def get_session(self, session_id: str) -> "SessionRecord | None": + """Get session by ID.""" + return await async_(self._get_session)(session_id) + + def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: params = {"id": session_id, "state": to_json(state)} json_type = _json_param_type() sql = f""" @@ -168,7 +181,12 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None sql = f"{sql} AND shard_id = MOD(FARM_FINGERPRINT(@id), {self._shard_count})" self._run_write([(sql, params, {"id": SPANNER_PARAM_TYPES.STRING, "state": json_type})]) - def list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]": + + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + await async_(self._update_session_state)(session_id, state) + + def _list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]": sql = f""" SELECT id, app_name, user_id, state, create_time, update_time{", " + self._owner_id_column_name if self._owner_id_column_name else ""} FROM {self._session_table} @@ -199,7 +217,12 @@ def list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[Se records.append(record) return records - def delete_session(self, session_id: str) -> None: + + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return await async_(self._list_sessions)(app_name, user_id) + + def _delete_session(self, session_id: str) -> None: shard_clause = ( f" AND shard_id = MOD(FARM_FINGERPRINT(@session_id), {self._shard_count})" if self._shard_count > 1 else "" ) @@ -209,50 +232,12 @@ def delete_session(self, session_id: str) -> None: types = {"session_id": SPANNER_PARAM_TYPES.STRING} self._run_write([(delete_events_sql, params, types), (delete_session_sql, params, types)]) - def create_event( - self, - event_id: str, - session_id: str, - app_name: str, - user_id: str, - author: "str | None" = None, - actions: "bytes | None" = None, - content: "dict[str, Any] | None" = None, - **kwargs: Any, - ) -> EventRecord: - invocation_id = kwargs.get("invocation_id", "") - event_json = to_json({ - "id": event_id, - "app_name": app_name, - "user_id": user_id, - "author": author, - "content": content, - **{k: v for k, v in kwargs.items() if v is not None}, - }) - now = datetime.now(timezone.utc) - params: dict[str, Any] = { - "session_id": session_id, - "invocation_id": invocation_id, - "author": author or "", - "timestamp": now, - "event_json": event_json, - } - - sql = f""" - INSERT INTO {self._events_table} (session_id, invocation_id, author, timestamp, event_json) - VALUES (@session_id, @invocation_id, @author, PENDING_COMMIT_TIMESTAMP(), @event_json) - """ - self._run_write([(sql, params, self._event_param_types())]) - return { - "session_id": session_id, - "invocation_id": invocation_id, - "author": author or "", - "timestamp": now, - "event_json": event_json, - } + async def delete_session(self, session_id: str) -> None: + """Delete session and associated events.""" + await async_(self._delete_session)(session_id) - def create_event_and_update_state( + def _append_event_and_update_state( self, event_record: "EventRecord", session_id: str, state: "dict[str, Any]" ) -> None: """Atomically insert an event and update session state in one transaction. @@ -292,7 +277,16 @@ def create_event_and_update_state( (update_sql, state_params, {"id": SPANNER_PARAM_TYPES.STRING, "state": json_type}), ]) - def list_events(self, session_id: str) -> "list[EventRecord]": + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state.""" + await async_(self._append_event_and_update_state)(event_record, session_id, state) + + def _get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": sql = f""" SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} @@ -315,7 +309,22 @@ def list_events(self, session_id: str) -> "list[EventRecord]": for row in rows ] - def create_tables(self) -> None: + + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session.""" + return await async_(self._get_events)(session_id, after_timestamp, limit) + + def _append_event(self, event_record: EventRecord) -> None: + """Synchronous implementation of append_event.""" + self._append_event_and_update_state(event_record, event_record["session_id"], {}) + + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + await async_(self._append_event)(event_record) + + def _create_tables(self) -> None: database = self._database() existing_tables = {t.table_id for t in database.list_tables()} # type: ignore[no-untyped-call] @@ -328,7 +337,12 @@ def create_tables(self) -> None: if ddl_statements: database.update_ddl(ddl_statements).result(300) # type: ignore[no-untyped-call] - def _get_create_sessions_table_sql(self) -> str: + + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + async def _get_create_sessions_table_sql(self) -> str: owner_line = "" if self._owner_id_column_ddl: owner_line = f",\n {self._owner_id_column_ddl}" @@ -351,7 +365,7 @@ def _get_create_sessions_table_sql(self) -> str: ) {pk}{options} """ - def _get_create_events_table_sql(self) -> str: + async def _get_create_events_table_sql(self) -> str: shard_column = "" pk = "PRIMARY KEY (session_id, timestamp)" if self._shard_count > 1: @@ -403,7 +417,7 @@ def execute_sql( ) -> Iterable[Any]: ... -class SpannerSyncADKMemoryStore(BaseSyncADKMemoryStore[SpannerSyncConfig]): +class SpannerSyncADKMemoryStore(BaseAsyncADKMemoryStore[SpannerSyncConfig]): """Spanner ADK memory store backed by synchronous Spanner client.""" connector_name: ClassVar[str] = "spanner" @@ -456,7 +470,7 @@ def _decode_json(self, raw: Any) -> Any: return from_json(raw) return raw - def create_tables(self) -> None: + def _create_tables(self) -> None: if not self._enabled: return @@ -470,7 +484,12 @@ def create_tables(self) -> None: if ddl_statements: database.update_ddl(ddl_statements).result(300) # type: ignore[no-untyped-call] - def _get_create_memory_table_sql(self) -> "list[str]": + + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + async def _get_create_memory_table_sql(self) -> "list[str]": owner_line = "" if self._owner_id_column_ddl: owner_line = f",\n {self._owner_id_column_ddl}" @@ -530,7 +549,7 @@ def _get_drop_memory_table_sql(self) -> "list[str]": ]) return statements - def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -579,12 +598,17 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object self._run_write(statements) return inserted_count + + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + def _event_exists(self, event_id: str) -> bool: sql = f"SELECT event_id FROM {self._memory_table} WHERE event_id = @event_id LIMIT 1" rows = self._run_read(sql, {"event_id": event_id}, {"event_id": SPANNER_PARAM_TYPES.STRING}) return bool(rows) - def search_entries( + def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": if not self._enabled: @@ -597,6 +621,13 @@ def search_entries( return self._search_entries_fts(query, app_name, user_id, effective_limit) return self._search_entries_simple(query, app_name, user_id, effective_limit) + + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": sql = f""" SELECT id, session_id, app_name, user_id, event_id, author, @@ -640,19 +671,29 @@ def _search_entries_simple(self, query: str, app_name: str, user_id: str, limit: rows = self._run_read(sql, params, types) return self._rows_to_records(rows) - def delete_entries_by_session(self, session_id: str) -> int: + def _delete_entries_by_session(self, session_id: str) -> int: sql = f"DELETE FROM {self._memory_table} WHERE session_id = @session_id" params = {"session_id": session_id} types = {"session_id": SPANNER_PARAM_TYPES.STRING} return self._execute_update(sql, params, types) - def delete_entries_older_than(self, days: int) -> int: + + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + def _delete_entries_older_than(self, days: int) -> int: cutoff = datetime.now(timezone.utc) - timedelta(days=days) sql = f"DELETE FROM {self._memory_table} WHERE inserted_at < @cutoff" params = {"cutoff": cutoff} types = {"cutoff": SPANNER_PARAM_TYPES.TIMESTAMP} return self._execute_update(sql, params, types) + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) + def _rows_to_records(self, rows: "list[Any]") -> "list[MemoryRecord]": return [ { diff --git a/sqlspec/adapters/sqlite/adk/store.py b/sqlspec/adapters/sqlite/adk/store.py index 6cfc14399..edbf7b74b 100644 --- a/sqlspec/adapters/sqlite/adk/store.py +++ b/sqlspec/adapters/sqlite/adk/store.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json from sqlspec.utils.sync_tools import async_, run_ @@ -566,7 +566,7 @@ async def get_events( return await async_(self._get_events)(session_id, after_timestamp, limit) -class SqliteADKMemoryStore(BaseSyncADKMemoryStore["SqliteConfig"]): +class SqliteADKMemoryStore(BaseAsyncADKMemoryStore["SqliteConfig"]): """SQLite ADK memory store using synchronous SQLite driver. Implements memory entry storage for Google Agent Development Kit @@ -625,7 +625,7 @@ def __init__(self, config: "SqliteConfig") -> None: """ super().__init__(config) - def _get_create_memory_table_sql(self) -> str: + async def _get_create_memory_table_sql(self) -> str: """Get SQLite CREATE TABLE SQL for memory entries. Returns: @@ -717,7 +717,7 @@ def _enable_foreign_keys(self, connection: Any) -> None: """ connection.execute("PRAGMA foreign_keys = ON") - def create_tables(self) -> None: + def _create_tables(self) -> None: """Create the memory table and indexes if they don't exist. Skips table creation if memory store is disabled. @@ -729,7 +729,12 @@ def create_tables(self) -> None: self._enable_foreign_keys(driver.connection) driver.execute_script(self._get_create_memory_table_sql()) - def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: """Bulk insert memory entries with deduplication. Uses INSERT OR IGNORE to skip duplicates based on event_id @@ -813,7 +818,12 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object return inserted_count - def search_entries( + + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + + def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": """Search memory entries by text query. @@ -843,6 +853,13 @@ def search_entries( logger.warning("FTS search failed; falling back to simple search: %s", exc) return self._search_entries_simple(query, app_name, user_id, effective_limit) + + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": sql = f""" SELECT m.id, m.session_id, m.app_name, m.user_id, m.event_id, m.author, @@ -895,7 +912,7 @@ def _fetch_records(self, sql: str, params: "tuple[Any, ...]") -> "list[MemoryRec for row in rows ] - def delete_entries_by_session(self, session_id: str) -> int: + def _delete_entries_by_session(self, session_id: str) -> int: """Delete all memory entries for a specific session. Args: @@ -914,7 +931,12 @@ def delete_entries_by_session(self, session_id: str) -> int: return deleted_count - def delete_entries_older_than(self, days: int) -> int: + + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + def _delete_entries_older_than(self, days: int) -> int: """Delete memory entries older than specified days. Used for TTL cleanup operations. @@ -936,3 +958,8 @@ def delete_entries_older_than(self, days: int) -> int: conn.commit() return deleted_count + + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) From 194aeff89898e02366833e0dfffa609d83efe3bb Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 21 Mar 2026 18:49:36 +0000 Subject: [PATCH 18/23] fix(adk): resolve all production code type errors - Change EventRecord.event_json type from str to dict[str, Any] to match actual usage (model_dump returns dict, stores deserialize from DB to dict) - Fix sync _create_tables methods calling async DDL methods without await by using run_() wrapper consistently across all sync stores - Fix sync _create_session methods calling async get_session by using sync _get_session instead - Add missing datetime import in pymysql store - Fix Oracle/ADBC/DuckDB stores to deserialize event_json from str to dict when reading from database --- .gitignore | 2 +- sqlspec/adapters/adbc/adk/store.py | 25 +++--------- sqlspec/adapters/asyncpg/adk/store.py | 18 ++++----- .../adapters/cockroach_asyncpg/adk/store.py | 18 ++++----- .../adapters/cockroach_psycopg/adk/store.py | 25 +++--------- sqlspec/adapters/duckdb/adk/store.py | 14 +++---- sqlspec/adapters/mysqlconnector/adk/store.py | 25 +++--------- sqlspec/adapters/oracledb/adk/store.py | 23 +++-------- sqlspec/adapters/psycopg/adk/store.py | 31 ++++++--------- sqlspec/adapters/pymysql/adk/store.py | 27 ++++--------- sqlspec/adapters/spanner/adk/store.py | 21 ++-------- sqlspec/adapters/sqlite/adk/store.py | 7 +--- sqlspec/config.py | 1 - sqlspec/extensions/adk/_types.py | 2 +- sqlspec/extensions/adk/memory/converters.py | 5 +-- sqlspec/extensions/adk/memory/service.py | 24 +++--------- sqlspec/extensions/adk/service.py | 4 +- .../adk/test_dialect_integration.py | 4 +- .../extensions/adk/test_event_operations.py | 4 +- .../asyncmy/extensions/adk/test_store.py | 8 +++- .../duckdb/extensions/adk/test_store.py | 10 ++++- .../extensions/adk/test_store.py | 8 +++- .../extensions/adk/test_oracle_specific.py | 12 ++++-- .../spanner/extensions/adk/test_adk_store.py | 17 ++++---- .../extensions/test_adk/test_converters.py | 37 +++--------------- .../unit/extensions/test_adk/test_service.py | 39 +++++++------------ .../test_adk/test_store_instantiation.py | 11 ++---- 27 files changed, 148 insertions(+), 274 deletions(-) diff --git a/.gitignore b/.gitignore index afd84ce28..442abbaa4 100644 --- a/.gitignore +++ b/.gitignore @@ -68,4 +68,4 @@ uv.toml .geminiignore .beads/ tools/scripts/profiles/*.prof -.agents/ \ No newline at end of file +.agents/ diff --git a/sqlspec/adapters/adbc/adk/store.py b/sqlspec/adapters/adbc/adk/store.py index 803d0955b..3c7155b67 100644 --- a/sqlspec/adapters/adbc/adk/store.py +++ b/sqlspec/adapters/adbc/adk/store.py @@ -8,7 +8,7 @@ from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json -from sqlspec.utils.sync_tools import async_ +from sqlspec.utils.sync_tools import async_, run_ if TYPE_CHECKING: from sqlspec.adapters.adbc.config import AdbcConfig @@ -418,7 +418,7 @@ def _create_tables(self) -> None: try: self._enable_foreign_keys(cursor, conn) - cursor.execute(self._get_create_sessions_table_sql()) + cursor.execute(run_(self._get_create_sessions_table_sql)()) conn.commit() sessions_idx_app_user = ( @@ -435,7 +435,7 @@ def _create_tables(self) -> None: cursor.execute(sessions_idx_update) conn.commit() - cursor.execute(self._get_create_events_table_sql()) + cursor.execute(run_(self._get_create_events_table_sql)()) conn.commit() events_idx = ( @@ -447,7 +447,6 @@ def _create_tables(self) -> None: finally: cursor.close() # type: ignore[no-untyped-call] - async def create_tables(self) -> None: """Create tables if they don't exist.""" await async_(self._create_tables)() @@ -511,7 +510,6 @@ def _create_session( return self.get_session(session_id) # type: ignore[return-value] - async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: @@ -562,7 +560,6 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": return None raise - async def get_session(self, session_id: str) -> "SessionRecord | None": """Get session by ID.""" return await async_(self._get_session)(session_id) @@ -593,7 +590,6 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non finally: cursor.close() # type: ignore[no-untyped-call] - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: """Update session state.""" await async_(self._update_session_state)(session_id, state) @@ -618,7 +614,6 @@ def _delete_session(self, session_id: str) -> None: finally: cursor.close() # type: ignore[no-untyped-call] - async def delete_session(self, session_id: str) -> None: """Delete session and associated events.""" await async_(self._delete_session)(session_id) @@ -679,7 +674,6 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses return [] raise - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app.""" return await async_(self._list_sessions)(app_name, user_id) @@ -762,7 +756,6 @@ def _append_event_and_update_state( finally: cursor.close() # type: ignore[no-untyped-call] - async def append_event_and_update_state( self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" ) -> None: @@ -805,7 +798,7 @@ def _get_events( invocation_id=row[1], author=row[2], timestamp=row[3], - event_json=str(row[4]) if row[4] is not None else "{}", + event_json=self._deserialize_json_field(row[4]) or {}, ) for row in rows ] @@ -817,8 +810,6 @@ def _get_events( return [] raise - - async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": @@ -833,6 +824,7 @@ async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session.""" await async_(self._append_event)(event_record) + class AdbcADKMemoryStore(BaseAsyncADKMemoryStore["AdbcConfig"]): """ADBC synchronous ADK memory store for Arrow Database Connectivity.""" @@ -989,7 +981,7 @@ def _create_tables(self) -> None: with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute(self._get_create_memory_table_sql()) + cursor.execute(run_(self._get_create_memory_table_sql)()) conn.commit() idx_app_user = ( @@ -1007,7 +999,6 @@ def _create_tables(self) -> None: finally: cursor.close() # type: ignore[no-untyped-call] - async def create_tables(self) -> None: """Create tables if they don't exist.""" await async_(self._create_tables)() @@ -1118,7 +1109,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: """Bulk insert memory entries with deduplication.""" return await async_(self._insert_memory_entries)(entries, owner_id) @@ -1163,7 +1153,6 @@ def _search_entries( return self._rows_to_records(rows) - async def search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -1189,7 +1178,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: finally: cursor.close() # type: ignore[no-untyped-call] - async def delete_entries_by_session(self, session_id: str) -> int: """Delete all memory entries for a specific session.""" return await async_(self._delete_entries_by_session)(session_id) @@ -1214,7 +1202,6 @@ def _delete_entries_older_than(self, days: int) -> int: finally: cursor.close() # type: ignore[no-untyped-call] - async def delete_entries_older_than(self, days: int) -> int: """Delete memory entries older than specified days.""" return await async_(self._delete_entries_older_than)(days) diff --git a/sqlspec/adapters/asyncpg/adk/store.py b/sqlspec/adapters/asyncpg/adk/store.py index 999395d51..76cffccab 100644 --- a/sqlspec/adapters/asyncpg/adk/store.py +++ b/sqlspec/adapters/asyncpg/adk/store.py @@ -224,15 +224,15 @@ async def append_event_and_update_state( """ async with self.config.provide_connection() as conn, conn.transaction(): - await conn.execute( - insert_sql, - event_record["session_id"], - event_record["invocation_id"], - event_record["author"], - event_record["timestamp"], - event_record["event_json"], - ) - await conn.execute(update_sql, state, session_id) + await conn.execute( + insert_sql, + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_record["event_json"], + ) + await conn.execute(update_sql, state, session_id) async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None diff --git a/sqlspec/adapters/cockroach_asyncpg/adk/store.py b/sqlspec/adapters/cockroach_asyncpg/adk/store.py index 36979547e..bedaaa788 100644 --- a/sqlspec/adapters/cockroach_asyncpg/adk/store.py +++ b/sqlspec/adapters/cockroach_asyncpg/adk/store.py @@ -218,15 +218,15 @@ async def append_event_and_update_state( """ async with self._config.provide_connection() as conn, conn.transaction(): - await conn.execute( - insert_sql, - event_record["session_id"], - event_record["invocation_id"], - event_record["author"], - event_record["timestamp"], - event_record["event_json"], - ) - await conn.execute(update_sql, state, session_id) + await conn.execute( + insert_sql, + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_record["event_json"], + ) + await conn.execute(update_sql, state, session_id) async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None diff --git a/sqlspec/adapters/cockroach_psycopg/adk/store.py b/sqlspec/adapters/cockroach_psycopg/adk/store.py index 959c450d9..535facd6c 100644 --- a/sqlspec/adapters/cockroach_psycopg/adk/store.py +++ b/sqlspec/adapters/cockroach_psycopg/adk/store.py @@ -9,7 +9,7 @@ from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.utils.logging import get_logger -from sqlspec.utils.sync_tools import async_ +from sqlspec.utils.sync_tools import async_, run_ if TYPE_CHECKING: from datetime import datetime @@ -406,9 +406,8 @@ def _get_drop_tables_sql(self) -> "list[str]": def _create_tables(self) -> None: with self._config.provide_session() as driver: - driver.execute_script(self._get_create_sessions_table_sql()) - driver.execute_script(self._get_create_events_table_sql()) - + driver.execute_script(run_(self._get_create_sessions_table_sql)()) + driver.execute_script(run_(self._get_create_events_table_sql)()) async def create_tables(self) -> None: """Create tables if they don't exist.""" @@ -437,13 +436,12 @@ def _create_session( cur.execute(sql.encode(), params) conn.commit() - result = self.get_session(session_id) + result = self._get_session(session_id) if result is None: msg = "Session creation failed" raise RuntimeError(msg) return result - async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: @@ -476,7 +474,6 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": except errors.UndefinedTable: return None - async def get_session(self, session_id: str) -> "SessionRecord | None": """Get session by ID.""" return await async_(self._get_session)(session_id) @@ -492,7 +489,6 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non cur.execute(sql.encode(), (Jsonb(state), session_id)) conn.commit() - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: """Update session state.""" await async_(self._update_session_state)(session_id, state) @@ -504,7 +500,6 @@ def _delete_session(self, session_id: str) -> None: cur.execute(sql.encode(), (session_id,)) conn.commit() - async def delete_session(self, session_id: str) -> None: """Delete session and associated events.""" await async_(self._delete_session)(session_id) @@ -546,7 +541,6 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses except errors.UndefinedTable: return [] - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app.""" return await async_(self._list_sessions)(app_name, user_id) @@ -582,7 +576,6 @@ def _append_event_and_update_state( cur.execute(update_sql.encode(), (Jsonb(state), session_id)) conn.commit() - async def append_event_and_update_state( self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" ) -> None: @@ -617,8 +610,6 @@ def _get_events( except errors.UndefinedTable: return [] - - async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": @@ -633,6 +624,7 @@ async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session.""" await async_(self._append_event)(event_record) + class CockroachPsycopgAsyncADKMemoryStore(BaseAsyncADKMemoryStore["CockroachPsycopgAsyncConfig"]): """CockroachDB ADK memory store using psycopg async driver.""" @@ -849,8 +841,7 @@ def _create_tables(self) -> None: return with self._config.provide_session() as driver: - driver.execute_script(self._get_create_memory_table_sql()) - + driver.execute_script(run_(self._get_create_memory_table_sql)()) async def create_tables(self) -> None: """Create tables if they don't exist.""" @@ -899,7 +890,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec inserted_count += cur.rowcount return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: """Bulk insert memory entries with deduplication.""" return await async_(self._insert_memory_entries)(entries, owner_id) @@ -945,7 +935,6 @@ def _search_entries( return [cast("MemoryRecord", dict(zip(columns, row, strict=False))) for row in rows] - async def search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -963,7 +952,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 - async def delete_entries_by_session(self, session_id: str) -> int: """Delete all memory entries for a specific session.""" return await async_(self._delete_entries_by_session)(session_id) @@ -982,7 +970,6 @@ def _delete_entries_older_than(self, days: int) -> int: conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 - async def delete_entries_older_than(self, days: int) -> int: """Delete memory entries older than specified days.""" return await async_(self._delete_entries_older_than)(days) diff --git a/sqlspec/adapters/duckdb/adk/store.py b/sqlspec/adapters/duckdb/adk/store.py index ddb3059a9..9c8d00715 100644 --- a/sqlspec/adapters/duckdb/adk/store.py +++ b/sqlspec/adapters/duckdb/adk/store.py @@ -410,9 +410,7 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" - event_json_value = event_record["event_json"] - if not isinstance(event_json_value, str): - event_json_value = to_json(event_json_value) + event_json_str = to_json(event_record["event_json"]) sql = f""" INSERT INTO {self._events_table} @@ -428,7 +426,7 @@ def _append_event(self, event_record: EventRecord) -> None: event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_json_value, + event_json_str, ), ) conn.commit() @@ -448,9 +446,7 @@ def _append_event_and_update_state( """Synchronous implementation of append_event_and_update_state.""" now = datetime.now(timezone.utc) state_json = to_json(state) - event_json_value = event_record["event_json"] - if not isinstance(event_json_value, str): - event_json_value = to_json(event_json_value) + event_json_str = to_json(event_record["event_json"]) insert_sql = f""" INSERT INTO {self._events_table} @@ -472,7 +468,7 @@ def _append_event_and_update_state( event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_json_value, + event_json_str, ), ) conn.execute(update_sql, (state_json, now, session_id)) @@ -526,7 +522,7 @@ def _get_events( invocation_id=row[1], author=row[2], timestamp=row[3], - event_json=row[4] if isinstance(row[4], str) else to_json(row[4]), + event_json=from_json(row[4]) if isinstance(row[4], str) else row[4], ) for row in rows ] diff --git a/sqlspec/adapters/mysqlconnector/adk/store.py b/sqlspec/adapters/mysqlconnector/adk/store.py index aaa014a1f..23c451040 100644 --- a/sqlspec/adapters/mysqlconnector/adk/store.py +++ b/sqlspec/adapters/mysqlconnector/adk/store.py @@ -8,7 +8,7 @@ from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.utils.serializers import from_json, to_json -from sqlspec.utils.sync_tools import async_ +from sqlspec.utils.sync_tools import async_, run_ if TYPE_CHECKING: from datetime import datetime @@ -420,9 +420,8 @@ def _get_drop_tables_sql(self) -> "list[str]": def _create_tables(self) -> None: with self._config.provide_session() as driver: - driver.execute_script(self._get_create_sessions_table_sql()) - driver.execute_script(self._get_create_events_table_sql()) - + driver.execute_script(run_(self._get_create_sessions_table_sql)()) + driver.execute_script(run_(self._get_create_events_table_sql)()) async def create_tables(self) -> None: """Create tables if they don't exist.""" @@ -455,13 +454,12 @@ def _create_session( cursor.close() conn.commit() - result = self.get_session(session_id) + result = self._get_session(session_id) if result is None: msg = "Failed to fetch created session" raise RuntimeError(msg) return result - async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: @@ -502,7 +500,6 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": return None raise - async def get_session(self, session_id: str) -> "SessionRecord | None": """Get session by ID.""" return await async_(self._get_session)(session_id) @@ -524,7 +521,6 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non cursor.close() conn.commit() - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: """Update session state.""" await async_(self._update_session_state)(session_id, state) @@ -540,7 +536,6 @@ def _delete_session(self, session_id: str) -> None: cursor.close() conn.commit() - async def delete_session(self, session_id: str) -> None: """Delete session and associated events.""" await async_(self._delete_session)(session_id) @@ -588,7 +583,6 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses return [] raise - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app.""" return await async_(self._list_sessions)(app_name, user_id) @@ -640,7 +634,6 @@ def _append_event_and_update_state( cursor.close() conn.commit() - async def append_event_and_update_state( self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" ) -> None: @@ -689,8 +682,6 @@ def _get_events( return [] raise - - async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": @@ -705,6 +696,7 @@ async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session.""" await async_(self._append_event)(event_record) + class MysqlConnectorAsyncADKMemoryStore(BaseAsyncADKMemoryStore["MysqlConnectorAsyncConfig"]): """MySQL/MariaDB ADK memory store using mysql-connector async driver.""" @@ -941,8 +933,7 @@ def _create_tables(self) -> None: return with self._config.provide_session() as driver: - driver.execute_script(self._get_create_memory_table_sql()) - + driver.execute_script(run_(self._get_create_memory_table_sql)()) async def create_tables(self) -> None: """Create tables if they don't exist.""" @@ -1014,7 +1005,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec conn.commit() return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: """Bulk insert memory entries with deduplication.""" return await async_(self._insert_memory_entries)(entries, owner_id) @@ -1059,7 +1049,6 @@ def _search_entries( return [cast("MemoryRecord", dict(zip(columns, row, strict=False))) for row in rows] - async def search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -1081,7 +1070,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: finally: cursor.close() - async def delete_entries_by_session(self, session_id: str) -> int: """Delete all memory entries for a specific session.""" return await async_(self._delete_entries_by_session)(session_id) @@ -1104,7 +1092,6 @@ def _delete_entries_older_than(self, days: int) -> int: finally: cursor.close() - async def delete_entries_older_than(self, days: int) -> int: """Delete memory entries older than specified days.""" return await async_(self._delete_entries_older_than)(days) diff --git a/sqlspec/adapters/oracledb/adk/store.py b/sqlspec/adapters/oracledb/adk/store.py index 69d2b3879..1c4647470 100644 --- a/sqlspec/adapters/oracledb/adk/store.py +++ b/sqlspec/adapters/oracledb/adk/store.py @@ -16,7 +16,7 @@ from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json -from sqlspec.utils.sync_tools import async_ +from sqlspec.utils.sync_tools import async_, run_ from sqlspec.utils.type_guards import is_async_readable, is_readable if TYPE_CHECKING: @@ -775,7 +775,7 @@ async def get_events( invocation_id=row[1], author=row[2], timestamp=row[3], - event_json=event_json_str, + event_json=from_json(event_json_str) if isinstance(event_json_str, str) else event_json_str, ) ) return results @@ -1132,7 +1132,6 @@ def _create_tables(self) -> None: events_sql = SQL(self._get_create_events_table_sql_for_type(storage_type)) driver.execute_script(events_sql) - async def create_tables(self) -> None: """Create tables if they don't exist.""" await async_(self._create_tables)() @@ -1185,7 +1184,6 @@ def _create_session( return self.get_session(session_id) # type: ignore[return-value] - async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: @@ -1239,7 +1237,6 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": return None raise - async def get_session(self, session_id: str) -> "SessionRecord | None": """Get session by ID.""" return await async_(self._get_session)(session_id) @@ -1269,7 +1266,6 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non cursor.execute(sql, {"state": state_data, "id": session_id}) conn.commit() - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: """Update session state.""" await async_(self._update_session_state)(session_id, state) @@ -1290,7 +1286,6 @@ def _delete_session(self, session_id: str) -> None: cursor.execute(sql, {"id": session_id}) conn.commit() - async def delete_session(self, session_id: str) -> None: """Delete session and associated events.""" await async_(self._delete_session)(session_id) @@ -1354,7 +1349,6 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses return [] raise - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app.""" return await async_(self._list_sessions)(app_name, user_id) @@ -1404,7 +1398,6 @@ def _append_event_and_update_state( cursor.execute(update_sql, {"state": state_data, "id": session_id}) conn.commit() - async def append_event_and_update_state( self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" ) -> None: @@ -1446,7 +1439,7 @@ def _get_events( invocation_id=row[1], author=row[2], timestamp=row[3], - event_json=event_json_str, + event_json=from_json(event_json_str) if isinstance(event_json_str, str) else event_json_str, ) ) return results @@ -1456,8 +1449,6 @@ def _get_events( return [] raise - - async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": @@ -1472,6 +1463,7 @@ async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session.""" await async_(self._append_event)(event_record) + ORACLE_DUPLICATE_KEY_ERROR: Final = 1 @@ -2009,8 +2001,7 @@ def _create_tables(self) -> None: return with self._config.provide_session() as driver: - driver.execute_script(self._get_create_memory_table_sql()) - + driver.execute_script(run_(self._get_create_memory_table_sql)()) async def create_tables(self) -> None: """Create tables if they don't exist.""" @@ -2074,7 +2065,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: """Bulk insert memory entries with deduplication.""" return await async_(self._insert_memory_entries)(entries, owner_id) @@ -2098,7 +2088,6 @@ def _search_entries( return [] raise - async def search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -2159,7 +2148,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: conn.commit() return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 - async def delete_entries_by_session(self, session_id: str) -> int: """Delete all memory entries for a specific session.""" return await async_(self._delete_entries_by_session)(session_id) @@ -2175,7 +2163,6 @@ def _delete_entries_older_than(self, days: int) -> int: conn.commit() return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 - async def delete_entries_older_than(self, days: int) -> int: """Delete memory entries older than specified days.""" return await async_(self._delete_entries_older_than)(days) diff --git a/sqlspec/adapters/psycopg/adk/store.py b/sqlspec/adapters/psycopg/adk/store.py index 73dceae2a..5ce2c9c28 100644 --- a/sqlspec/adapters/psycopg/adk/store.py +++ b/sqlspec/adapters/psycopg/adk/store.py @@ -9,7 +9,7 @@ from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.utils.logging import get_logger -from sqlspec.utils.sync_tools import async_ +from sqlspec.utils.sync_tools import async_, run_ if TYPE_CHECKING: from datetime import datetime @@ -406,9 +406,8 @@ def _get_drop_tables_sql(self) -> "list[str]": def _create_tables(self) -> None: with self._config.provide_session() as driver: - driver.execute_script(self._get_create_sessions_table_sql()) - driver.execute_script(self._get_create_events_table_sql()) - + driver.execute_script(run_(self._get_create_sessions_table_sql)()) + driver.execute_script(run_(self._get_create_events_table_sql)()) async def create_tables(self) -> None: """Create tables if they don't exist.""" @@ -436,8 +435,11 @@ def _create_session( with self._config.provide_connection() as conn, conn.cursor() as cur: cur.execute(query, params) - return self.get_session(session_id) # type: ignore[return-value] - + result = self._get_session(session_id) + if result is None: + msg = "Failed to fetch created session" + raise RuntimeError(msg) + return result async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -471,7 +473,6 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": except errors.UndefinedTable: return None - async def get_session(self, session_id: str) -> "SessionRecord | None": """Get session by ID.""" return await async_(self._get_session)(session_id) @@ -486,7 +487,6 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non with self._config.provide_connection() as conn, conn.cursor() as cur: cur.execute(query, (Jsonb(state), session_id)) - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: """Update session state.""" await async_(self._update_session_state)(session_id, state) @@ -497,7 +497,6 @@ def _delete_session(self, session_id: str) -> None: with self._config.provide_connection() as conn, conn.cursor() as cur: cur.execute(query, (session_id,)) - async def delete_session(self, session_id: str) -> None: """Delete session and associated events.""" await async_(self._delete_session)(session_id) @@ -539,7 +538,6 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses except errors.UndefinedTable: return [] - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app.""" return await async_(self._list_sessions)(app_name, user_id) @@ -576,7 +574,6 @@ def _append_event_and_update_state( cur.execute(update_query, (Jsonb(state), session_id)) conn.commit() - async def append_event_and_update_state( self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" ) -> None: @@ -611,8 +608,6 @@ def _get_events( except errors.UndefinedTable: return [] - - async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": @@ -627,6 +622,7 @@ async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session.""" await async_(self._append_event)(event_record) + class PsycopgAsyncADKMemoryStore(BaseAsyncADKMemoryStore["PsycopgAsyncConfig"]): """PostgreSQL ADK memory store using Psycopg3 async driver.""" @@ -868,8 +864,7 @@ def _create_tables(self) -> None: return with self._config.provide_session() as driver: - driver.execute_script(self._get_create_memory_table_sql()) - + driver.execute_script(run_(self._get_create_memory_table_sql)()) async def create_tables(self) -> None: """Create tables if they don't exist.""" @@ -920,7 +915,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: """Bulk insert memory entries with deduplication.""" return await async_(self._insert_memory_entries)(entries, owner_id) @@ -945,7 +939,6 @@ def _search_entries( except errors.UndefinedTable: return [] - async def search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -1002,7 +995,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: cur.execute(sql, (session_id,)) return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 - async def delete_entries_by_session(self, session_id: str) -> int: """Delete all memory entries for a specific session.""" return await async_(self._delete_entries_by_session)(session_id) @@ -1020,12 +1012,11 @@ def _delete_entries_older_than(self, days: int) -> int: cur.execute(sql) return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 - - async def delete_entries_older_than(self, days: int) -> int: """Delete memory entries older than specified days.""" return await async_(self._delete_entries_older_than)(days) + def _rows_to_records(rows: "list[Any]") -> "list[MemoryRecord]": return [ { diff --git a/sqlspec/adapters/pymysql/adk/store.py b/sqlspec/adapters/pymysql/adk/store.py index 05a5d0c69..ec32e4765 100644 --- a/sqlspec/adapters/pymysql/adk/store.py +++ b/sqlspec/adapters/pymysql/adk/store.py @@ -8,9 +8,11 @@ from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.utils.serializers import from_json, to_json -from sqlspec.utils.sync_tools import async_ +from sqlspec.utils.sync_tools import async_, run_ if TYPE_CHECKING: + from datetime import datetime + from sqlspec.adapters.pymysql.config import PyMysqlConfig from sqlspec.extensions.adk import MemoryRecord @@ -104,9 +106,8 @@ def _get_drop_tables_sql(self) -> "list[str]": def _create_tables(self) -> None: with self._config.provide_session() as driver: - driver.execute_script(self._get_create_sessions_table_sql()) - driver.execute_script(self._get_create_events_table_sql()) - + driver.execute_script(run_(self._get_create_sessions_table_sql)()) + driver.execute_script(run_(self._get_create_events_table_sql)()) async def create_tables(self) -> None: """Create tables if they don't exist.""" @@ -139,13 +140,12 @@ def _create_session( cursor.close() conn.commit() - result = self.get_session(session_id) + result = self._get_session(session_id) if result is None: msg = "Failed to fetch created session" raise RuntimeError(msg) return result - async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: @@ -186,7 +186,6 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": return None raise - async def get_session(self, session_id: str) -> "SessionRecord | None": """Get session by ID.""" return await async_(self._get_session)(session_id) @@ -208,7 +207,6 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non cursor.close() conn.commit() - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: """Update session state.""" await async_(self._update_session_state)(session_id, state) @@ -224,7 +222,6 @@ def _delete_session(self, session_id: str) -> None: cursor.close() conn.commit() - async def delete_session(self, session_id: str) -> None: """Delete session and associated events.""" await async_(self._delete_session)(session_id) @@ -272,7 +269,6 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses return [] raise - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app.""" return await async_(self._list_sessions)(app_name, user_id) @@ -324,7 +320,6 @@ def _append_event_and_update_state( cursor.close() conn.commit() - async def append_event_and_update_state( self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" ) -> None: @@ -373,8 +368,6 @@ def _get_events( return [] raise - - async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": @@ -389,6 +382,7 @@ async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session.""" await async_(self._append_event)(event_record) + class PyMysqlADKMemoryStore(BaseAsyncADKMemoryStore["PyMysqlConfig"]): """MySQL/MariaDB ADK memory store using PyMySQL.""" @@ -436,8 +430,7 @@ def _create_tables(self) -> None: return with self._config.provide_session() as driver: - driver.execute_script(self._get_create_memory_table_sql()) - + driver.execute_script(run_(self._get_create_memory_table_sql)()) async def create_tables(self) -> None: """Create tables if they don't exist.""" @@ -509,7 +502,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec conn.commit() return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: """Bulk insert memory entries with deduplication.""" return await async_(self._insert_memory_entries)(entries, owner_id) @@ -554,7 +546,6 @@ def _search_entries( return [cast("MemoryRecord", dict(zip(columns, row, strict=False))) for row in rows] - async def search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -576,7 +567,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: finally: cursor.close() - async def delete_entries_by_session(self, session_id: str) -> int: """Delete all memory entries for a specific session.""" return await async_(self._delete_entries_by_session)(session_id) @@ -599,7 +589,6 @@ def _delete_entries_older_than(self, days: int) -> int: finally: cursor.close() - async def delete_entries_older_than(self, days: int) -> int: """Delete memory entries older than specified days.""" return await async_(self._delete_entries_older_than)(days) diff --git a/sqlspec/adapters/spanner/adk/store.py b/sqlspec/adapters/spanner/adk/store.py index a83aab48d..95462f90a 100644 --- a/sqlspec/adapters/spanner/adk/store.py +++ b/sqlspec/adapters/spanner/adk/store.py @@ -11,7 +11,7 @@ from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.protocols import SpannerParamTypesProtocol from sqlspec.utils.serializers import from_json, to_json -from sqlspec.utils.sync_tools import async_ +from sqlspec.utils.sync_tools import async_, run_ if TYPE_CHECKING: from google.cloud.spanner_v1.database import Database @@ -131,7 +131,6 @@ def _create_session( "update_time": datetime.now(timezone.utc), } - async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: @@ -164,7 +163,6 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": } return record - async def get_session(self, session_id: str) -> "SessionRecord | None": """Get session by ID.""" return await async_(self._get_session)(session_id) @@ -181,7 +179,6 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non sql = f"{sql} AND shard_id = MOD(FARM_FINGERPRINT(@id), {self._shard_count})" self._run_write([(sql, params, {"id": SPANNER_PARAM_TYPES.STRING, "state": json_type})]) - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: """Update session state.""" await async_(self._update_session_state)(session_id, state) @@ -217,7 +214,6 @@ def _list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[S records.append(record) return records - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app.""" return await async_(self._list_sessions)(app_name, user_id) @@ -232,7 +228,6 @@ def _delete_session(self, session_id: str) -> None: types = {"session_id": SPANNER_PARAM_TYPES.STRING} self._run_write([(delete_events_sql, params, types), (delete_session_sql, params, types)]) - async def delete_session(self, session_id: str) -> None: """Delete session and associated events.""" await async_(self._delete_session)(session_id) @@ -277,7 +272,6 @@ def _append_event_and_update_state( (update_sql, state_params, {"id": SPANNER_PARAM_TYPES.STRING, "state": json_type}), ]) - async def append_event_and_update_state( self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" ) -> None: @@ -309,7 +303,6 @@ def _get_events( for row in rows ] - async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": @@ -330,14 +323,13 @@ def _create_tables(self) -> None: ddl_statements: list[str] = [] if self._session_table not in existing_tables: - ddl_statements.append(self._get_create_sessions_table_sql()) + ddl_statements.append(run_(self._get_create_sessions_table_sql)()) if self._events_table not in existing_tables: - ddl_statements.append(self._get_create_events_table_sql()) + ddl_statements.append(run_(self._get_create_events_table_sql)()) if ddl_statements: database.update_ddl(ddl_statements).result(300) # type: ignore[no-untyped-call] - async def create_tables(self) -> None: """Create tables if they don't exist.""" await async_(self._create_tables)() @@ -479,12 +471,11 @@ def _create_tables(self) -> None: ddl_statements: list[str] = [] if self._memory_table not in existing_tables: - ddl_statements.extend(self._get_create_memory_table_sql()) + ddl_statements.extend(run_(self._get_create_memory_table_sql)()) if ddl_statements: database.update_ddl(ddl_statements).result(300) # type: ignore[no-untyped-call] - async def create_tables(self) -> None: """Create tables if they don't exist.""" await async_(self._create_tables)() @@ -598,7 +589,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec self._run_write(statements) return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: """Bulk insert memory entries with deduplication.""" return await async_(self._insert_memory_entries)(entries, owner_id) @@ -621,7 +611,6 @@ def _search_entries( return self._search_entries_fts(query, app_name, user_id, effective_limit) return self._search_entries_simple(query, app_name, user_id, effective_limit) - async def search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -677,7 +666,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: types = {"session_id": SPANNER_PARAM_TYPES.STRING} return self._execute_update(sql, params, types) - async def delete_entries_by_session(self, session_id: str) -> int: """Delete all memory entries for a specific session.""" return await async_(self._delete_entries_by_session)(session_id) @@ -689,7 +677,6 @@ def _delete_entries_older_than(self, days: int) -> int: types = {"cutoff": SPANNER_PARAM_TYPES.TIMESTAMP} return self._execute_update(sql, params, types) - async def delete_entries_older_than(self, days: int) -> int: """Delete memory entries older than specified days.""" return await async_(self._delete_entries_older_than)(days) diff --git a/sqlspec/adapters/sqlite/adk/store.py b/sqlspec/adapters/sqlite/adk/store.py index edbf7b74b..bf4e51e52 100644 --- a/sqlspec/adapters/sqlite/adk/store.py +++ b/sqlspec/adapters/sqlite/adk/store.py @@ -727,8 +727,7 @@ def _create_tables(self) -> None: with self._config.provide_session() as driver: self._enable_foreign_keys(driver.connection) - driver.execute_script(self._get_create_memory_table_sql()) - + driver.execute_script(run_(self._get_create_memory_table_sql)()) async def create_tables(self) -> None: """Create tables if they don't exist.""" @@ -818,7 +817,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: """Bulk insert memory entries with deduplication.""" return await async_(self._insert_memory_entries)(entries, owner_id) @@ -853,7 +851,6 @@ def _search_entries( logger.warning("FTS search failed; falling back to simple search: %s", exc) return self._search_entries_simple(query, app_name, user_id, effective_limit) - async def search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -931,7 +928,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: return deleted_count - async def delete_entries_by_session(self, session_id: str) -> int: """Delete all memory entries for a specific session.""" return await async_(self._delete_entries_by_session)(session_id) @@ -959,7 +955,6 @@ def _delete_entries_older_than(self, days: int) -> int: return deleted_count - async def delete_entries_older_than(self, days: int) -> int: """Delete memory entries older than specified days.""" return await async_(self._delete_entries_older_than)(days) diff --git a/sqlspec/config.py b/sqlspec/config.py index 29a873110..d59f51869 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -805,7 +805,6 @@ class ADKConfig(TypedDict): specified language will fall back to 'simple' or 'english'. """ - artifact_storage_uri: NotRequired[str] """Base URI for artifact content storage. Default: None (store inline in database). When set, large artifact payloads are stored externally and only metadata diff --git a/sqlspec/extensions/adk/_types.py b/sqlspec/extensions/adk/_types.py index 838cc14a3..3f11b62f0 100644 --- a/sqlspec/extensions/adk/_types.py +++ b/sqlspec/extensions/adk/_types.py @@ -38,4 +38,4 @@ class EventRecord(TypedDict): invocation_id: str author: str timestamp: datetime - event_json: str + event_json: "dict[str, Any]" diff --git a/sqlspec/extensions/adk/memory/converters.py b/sqlspec/extensions/adk/memory/converters.py index 2742d9b8d..fafea6d12 100644 --- a/sqlspec/extensions/adk/memory/converters.py +++ b/sqlspec/extensions/adk/memory/converters.py @@ -99,10 +99,7 @@ def event_to_memory_record(event: "Event", session_id: str, app_name: str, user_ def memory_entry_to_record( - entry: "MemoryEntry", - app_name: str, - user_id: str, - extra_metadata: "dict[str, Any] | None" = None, + entry: "MemoryEntry", app_name: str, user_id: str, extra_metadata: "dict[str, Any] | None" = None ) -> "MemoryRecord | None": """Convert an ADK MemoryEntry to a database record. diff --git a/sqlspec/extensions/adk/memory/service.py b/sqlspec/extensions/adk/memory/service.py index dc56a9b32..a94c7e9cd 100644 --- a/sqlspec/extensions/adk/memory/service.py +++ b/sqlspec/extensions/adk/memory/service.py @@ -1,6 +1,5 @@ """SQLSpec-backed memory service for Google ADK.""" -from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING from google.adk.memory.base_memory_service import BaseMemoryService, SearchMemoryResponse @@ -13,6 +12,8 @@ from sqlspec.utils.logging import get_logger if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from google.adk.events.event import Event from google.adk.memory.memory_entry import MemoryEntry from google.adk.sessions import Session @@ -138,10 +139,7 @@ async def add_events_to_memory( records = [] for event in events: record = event_to_memory_record( - event=event, - session_id=session_id or "", - app_name=app_name, - user_id=user_id, + event=event, session_id=session_id or "", app_name=app_name, user_id=user_id ) if record is not None: if metadata_dict: @@ -150,20 +148,13 @@ async def add_events_to_memory( if not records: logger.debug( - "No content to store for events (app=%s, user=%s, count=%d)", - app_name, - user_id, - len(list(events)), + "No content to store for events (app=%s, user=%s, count=%d)", app_name, user_id, len(list(events)) ) return inserted_count = await self._store.insert_memory_entries(records) logger.debug( - "Stored %d memory entries from %d events (app=%s, user=%s)", - inserted_count, - len(records), - app_name, - user_id, + "Stored %d memory entries from %d events (app=%s, user=%s)", inserted_count, len(records), app_name, user_id ) async def add_memory( @@ -192,10 +183,7 @@ async def add_memory( records = [] for entry in memories: record = memory_entry_to_record( - entry=entry, - app_name=app_name, - user_id=user_id, - extra_metadata=call_metadata, + entry=entry, app_name=app_name, user_id=user_id, extra_metadata=call_metadata ) if record is not None: records.append(record) diff --git a/sqlspec/extensions/adk/service.py b/sqlspec/extensions/adk/service.py index 200b3c948..132d9ad69 100644 --- a/sqlspec/extensions/adk/service.py +++ b/sqlspec/extensions/adk/service.py @@ -216,7 +216,9 @@ async def append_event(self, session: "Session", event: "Event") -> "Event": # Strip temp: keys before persisting state durable_state = filter_temp_state(session.state) - await self._store.append_event_and_update_state(event_record=event_record, session_id=session.id, state=durable_state) + await self._store.append_event_and_update_state( + event_record=event_record, session_id=session.id, state=durable_state + ) log_with_context( logger, logging.DEBUG, diff --git a/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py b/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py index 653682fdc..c481f2f3c 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py @@ -87,7 +87,9 @@ def test_sqlite_dialect_event_operations(sqlite_store: Any) -> None: events = sqlite_store.list_events(session_id) assert len(events) == 1 - retrieved_data = json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + retrieved_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + ) assert retrieved_data["content"] == content diff --git a/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py b/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py index bb1784302..01b991d99 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py @@ -159,7 +159,9 @@ def test_event_json_fields(adbc_store: Any, session_fixture: Any) -> None: assert event_data["custom_metadata"] == complex_custom events = adbc_store.list_events(session_fixture["session_id"]) - retrieved_data = json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + retrieved_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + ) assert retrieved_data["content"] == complex_content assert retrieved_data["grounding_metadata"] == complex_grounding assert retrieved_data["custom_metadata"] == complex_custom diff --git a/tests/integration/adapters/asyncmy/extensions/adk/test_store.py b/tests/integration/adapters/asyncmy/extensions/adk/test_store.py index b2f25a9c6..44a346d66 100644 --- a/tests/integration/adapters/asyncmy/extensions/adk/test_store.py +++ b/tests/integration/adapters/asyncmy/extensions/adk/test_store.py @@ -197,8 +197,12 @@ async def test_append_and_get_events(asyncmy_adk_store: AsyncmyADKStore) -> None assert events[0]["author"] == "user" assert events[1]["author"] == "assistant" # Content is inside event_json - event0_data = json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] - event1_data = json.loads(events[1]["event_json"]) if isinstance(events[1]["event_json"], str) else events[1]["event_json"] + event0_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + ) + event1_data = ( + json.loads(events[1]["event_json"]) if isinstance(events[1]["event_json"], str) else events[1]["event_json"] + ) assert event0_data["content"]["text"] == "Hello" assert event1_data["content"]["text"] == "Hi there" diff --git a/tests/integration/adapters/duckdb/extensions/adk/test_store.py b/tests/integration/adapters/duckdb/extensions/adk/test_store.py index 241ff3de5..aadfefa8a 100644 --- a/tests/integration/adapters/duckdb/extensions/adk/test_store.py +++ b/tests/integration/adapters/duckdb/extensions/adk/test_store.py @@ -185,7 +185,11 @@ def test_create_event(duckdb_adk_store: DuckdbADKStore) -> None: assert "event_json" in created_event # Content is stored inside event_json - event_data = json.loads(created_event["event_json"]) if isinstance(created_event["event_json"], str) else created_event["event_json"] + event_data = ( + json.loads(created_event["event_json"]) + if isinstance(created_event["event_json"], str) + else created_event["event_json"] + ) assert event_data["content"] == content @@ -362,7 +366,9 @@ def test_event_json_round_trip(duckdb_adk_store: DuckdbADKStore) -> None: events = duckdb_adk_store.list_events(session_id) assert len(events) == 1 - event_data = json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + event_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + ) assert event_data["content"] == {"data": "value"} diff --git a/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py b/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py index f85a0ee79..67b38f0a2 100644 --- a/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py +++ b/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py @@ -198,8 +198,12 @@ async def test_append_and_get_events(mysqlconnector_adk_store: MysqlConnectorAsy assert events[0]["author"] == "user" assert events[1]["author"] == "assistant" # Content is inside event_json - event0_data = json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] - event1_data = json.loads(events[1]["event_json"]) if isinstance(events[1]["event_json"], str) else events[1]["event_json"] + event0_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + ) + event1_data = ( + json.loads(events[1]["event_json"]) if isinstance(events[1]["event_json"], str) else events[1]["event_json"] + ) assert event0_data["content"]["text"] == "Hello" assert event1_data["content"]["text"] == "Hi there" diff --git a/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py b/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py index 26ed4730a..540cf123f 100644 --- a/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py +++ b/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py @@ -250,7 +250,9 @@ async def test_event_json_lob_deserialization(oracle_async_store: "OracleAsyncAD events = await oracle_async_store.get_events(session_id) assert len(events) == 1 # event_json contains all the data - retrieved_data = json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + retrieved_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + ) assert retrieved_data["content"] == content assert retrieved_data["grounding_metadata"] == {"sources": ["a" * 1000, "b" * 1000]} assert retrieved_data["custom_metadata"] == {"tags": ["tag1", "tag2"], "priority": "high"} @@ -278,7 +280,9 @@ async def test_event_json_storage(oracle_async_store: "OracleAsyncADKStore") -> events = await oracle_async_store.get_events(session_id) assert len(events) == 1 - retrieved_data = json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + retrieved_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + ) assert retrieved_data == event_data @@ -326,7 +330,9 @@ async def test_event_record_5_column_contract(oracle_async_store: "OracleAsyncAD assert events[0]["invocation_id"] == "inv-001" assert events[0]["author"] == "assistant" - retrieved_data = json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + retrieved_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + ) assert retrieved_data["partial"] is True assert retrieved_data["turn_complete"] is False assert retrieved_data["interrupted"] is True diff --git a/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py b/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py index 61bc624cd..75b51dc08 100644 --- a/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py +++ b/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py @@ -59,14 +59,7 @@ def test_create_and_list_events(spanner_adk_store: Any) -> None: spanner_adk_store.create_session(session_id, "app", "user", {"x": 1}) spanner_adk_store.create_event("event-1", session_id, "app", "user", author="user", content={"msg": "hi"}) - spanner_adk_store.create_event( - "event-2", - session_id, - "app", - "user", - author="assistant", - content={"msg": "ok"}, - ) + spanner_adk_store.create_event("event-2", session_id, "app", "user", author="assistant", content={"msg": "ok"}) events = spanner_adk_store.list_events(session_id) assert len(events) == 2 @@ -74,7 +67,11 @@ def test_create_and_list_events(spanner_adk_store: Any) -> None: assert events[1]["author"] == "assistant" # Content is inside event_json in the new 5-column schema - event0_data = json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] - event1_data = json.loads(events[1]["event_json"]) if isinstance(events[1]["event_json"], str) else events[1]["event_json"] + event0_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + ) + event1_data = ( + json.loads(events[1]["event_json"]) if isinstance(events[1]["event_json"], str) else events[1]["event_json"] + ) assert event0_data["content"] == {"msg": "hi"} assert event1_data["content"] == {"msg": "ok"} diff --git a/tests/unit/extensions/test_adk/test_converters.py b/tests/unit/extensions/test_adk/test_converters.py index d634e4dd6..903358ae9 100644 --- a/tests/unit/extensions/test_adk/test_converters.py +++ b/tests/unit/extensions/test_adk/test_converters.py @@ -31,7 +31,6 @@ split_scoped_state, ) - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -66,11 +65,7 @@ def _make_event( def _make_session( - *, - session_id: str = "session-1", - app_name: str = "test-app", - user_id: str = "user-1", - state: "dict | None" = None, + *, session_id: str = "session-1", app_name: str = "test-app", user_id: str = "user-1", state: "dict | None" = None ) -> Session: return Session( id=session_id, @@ -183,29 +178,19 @@ def test_split_scoped_state_preserves_full_key_names() -> None: def test_merge_scoped_state_combines_all_buckets() -> None: """All three buckets appear in the merged result.""" - merged = merge_scoped_state( - session_state={"key": "s"}, - app_state={"app:x": "a"}, - user_state={"user:y": "u"}, - ) + merged = merge_scoped_state(session_state={"key": "s"}, app_state={"app:x": "a"}, user_state={"user:y": "u"}) assert merged == {"key": "s", "app:x": "a", "user:y": "u"} def test_merge_scoped_state_overlay_priority_app_over_session() -> None: """app_state overlays session_state for the same key.""" - merged = merge_scoped_state( - session_state={"app:x": "old"}, - app_state={"app:x": "new"}, - ) + merged = merge_scoped_state(session_state={"app:x": "old"}, app_state={"app:x": "new"}) assert merged["app:x"] == "new" def test_merge_scoped_state_overlay_priority_user_over_session() -> None: """user_state overlays session_state for the same key.""" - merged = merge_scoped_state( - session_state={"user:y": "session_val"}, - user_state={"user:y": "user_val"}, - ) + merged = merge_scoped_state(session_state={"user:y": "session_val"}, user_state={"user:y": "user_val"}) assert merged["user:y"] == "user_val" @@ -218,11 +203,7 @@ def test_merge_scoped_state_no_app_no_user() -> None: def test_merge_scoped_state_empty_session_state() -> None: """Empty session_state with app/user state returns combined app+user keys.""" - merged = merge_scoped_state( - session_state={}, - app_state={"app:a": 1}, - user_state={"user:b": 2}, - ) + merged = merge_scoped_state(session_state={}, app_state={"app:a": 1}, user_state={"user:b": 2}) assert merged == {"app:a": 1, "user:b": 2} @@ -378,13 +359,7 @@ def test_record_to_event_roundtrip_preserves_turn_complete() -> None: def test_record_to_event_roundtrip_preserves_timestamp() -> None: """timestamp survives the round-trip within float precision.""" fixed_ts = datetime(2024, 6, 1, 10, 30, 0, tzinfo=timezone.utc).timestamp() - event = Event( - id="ts-evt", - invocation_id="inv-1", - author="user", - actions=EventActions(), - timestamp=fixed_ts, - ) + event = Event(id="ts-evt", invocation_id="inv-1", author="user", actions=EventActions(), timestamp=fixed_ts) record = event_to_record(event, "s1") restored = record_to_event(record) diff --git a/tests/unit/extensions/test_adk/test_service.py b/tests/unit/extensions/test_adk/test_service.py index 407166172..a34d78862 100644 --- a/tests/unit/extensions/test_adk/test_service.py +++ b/tests/unit/extensions/test_adk/test_service.py @@ -12,7 +12,6 @@ import importlib.util from datetime import datetime, timezone from typing import Any -from unittest.mock import AsyncMock, MagicMock import pytest @@ -25,7 +24,6 @@ from sqlspec.extensions.adk.service import SQLSpecSessionService - # --------------------------------------------------------------------------- # Mock store # --------------------------------------------------------------------------- @@ -56,13 +54,13 @@ def __init__(self) -> None: "update_time": datetime.now(timezone.utc), } - async def append_event_and_update_state( - self, event_record: Any, session_id: str, state: "dict[str, Any]" - ) -> None: + async def append_event_and_update_state(self, event_record: Any, session_id: str, state: "dict[str, Any]") -> None: self.append_event_and_update_state_called = True - self.append_event_and_update_state_calls.append( - {"event_record": event_record, "session_id": session_id, "state": state} - ) + self.append_event_and_update_state_calls.append({ + "event_record": event_record, + "session_id": session_id, + "state": state, + }) async def get_session(self, session_id: str) -> "dict[str, Any] | None": return self._session_record @@ -70,9 +68,12 @@ async def get_session(self, session_id: str) -> "dict[str, Any] | None": async def create_session( self, *, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]" ) -> "dict[str, Any]": - self.create_session_calls.append( - {"session_id": session_id, "app_name": app_name, "user_id": user_id, "state": state} - ) + self.create_session_calls.append({ + "session_id": session_id, + "app_name": app_name, + "user_id": user_id, + "state": state, + }) return { "id": session_id, "app_name": app_name, @@ -102,11 +103,7 @@ async def delete_session(self, session_id: str) -> None: def _make_session( - *, - session_id: str = "s1", - app_name: str = "app", - user_id: str = "u1", - state: "dict | None" = None, + *, session_id: str = "s1", app_name: str = "app", user_id: str = "u1", state: "dict | None" = None ) -> Session: return Session( id=session_id, @@ -118,11 +115,7 @@ def _make_session( def _make_event( - *, - invocation_id: str = "inv-1", - author: str = "model", - state_delta: "dict | None" = None, - partial: bool = False, + *, invocation_id: str = "inv-1", author: str = "model", state_delta: "dict | None" = None, partial: bool = False ) -> Event: actions = EventActions(state_delta=state_delta or {}) return Event( @@ -274,9 +267,7 @@ async def test_create_session_strips_temp_keys_from_initial_state() -> None: store = MockStore() service = SQLSpecSessionService(store) # type: ignore[arg-type] - await service.create_session( - app_name="app", user_id="u1", state={"x": 1, "temp:y": 2, "app:z": 3} - ) + await service.create_session(app_name="app", user_id="u1", state={"x": 1, "temp:y": 2, "app:z": 3}) assert len(store.create_session_calls) == 1 persisted_state = store.create_session_calls[0]["state"] diff --git a/tests/unit/extensions/test_adk/test_store_instantiation.py b/tests/unit/extensions/test_adk/test_store_instantiation.py index 1bd442a15..56ebc59ef 100644 --- a/tests/unit/extensions/test_adk/test_store_instantiation.py +++ b/tests/unit/extensions/test_adk/test_store_instantiation.py @@ -18,7 +18,7 @@ "sqlspec.adapters.asyncpg.adk.store.AsyncpgADKStore", "sqlspec.adapters.aiosqlite.adk.store.AiosqliteADKStore", "sqlspec.adapters.asyncmy.adk.store.AsyncmyADKStore", -"sqlspec.adapters.cockroach_asyncpg.adk.store.CockroachAsyncpgADKStore", + "sqlspec.adapters.cockroach_asyncpg.adk.store.CockroachAsyncpgADKStore", "sqlspec.adapters.cockroach_psycopg.adk.store.CockroachPsycopgAsyncADKStore", "sqlspec.adapters.mysqlconnector.adk.store.MysqlConnectorAsyncADKStore", "sqlspec.adapters.oracledb.adk.store.OracleAsyncADKStore", @@ -45,7 +45,7 @@ "sqlspec.adapters.asyncpg.adk.store.AsyncpgADKMemoryStore", "sqlspec.adapters.aiosqlite.adk.store.AiosqliteADKMemoryStore", "sqlspec.adapters.asyncmy.adk.store.AsyncmyADKMemoryStore", -"sqlspec.adapters.cockroach_asyncpg.adk.store.CockroachAsyncpgADKMemoryStore", + "sqlspec.adapters.cockroach_asyncpg.adk.store.CockroachAsyncpgADKMemoryStore", "sqlspec.adapters.cockroach_psycopg.adk.store.CockroachPsycopgAsyncADKMemoryStore", "sqlspec.adapters.mysqlconnector.adk.store.MysqlConnectorAsyncADKMemoryStore", "sqlspec.adapters.oracledb.adk.store.OracleAsyncADKMemoryStore", @@ -66,12 +66,7 @@ "sqlspec.adapters.sqlite.adk.store.SqliteADKMemoryStore", ] -ALL_STORE_CLASSES = ( - ASYNC_SESSION_STORES - + SYNC_SESSION_STORES - + ASYNC_MEMORY_STORES - + SYNC_MEMORY_STORES -) +ALL_STORE_CLASSES = ASYNC_SESSION_STORES + SYNC_SESSION_STORES + ASYNC_MEMORY_STORES + SYNC_MEMORY_STORES @pytest.mark.parametrize("class_path", ALL_STORE_CLASSES) From a234cc79f85c17cf231cea4836221a5e06882d3e Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 21 Mar 2026 18:50:45 +0000 Subject: [PATCH 19/23] fix(adk): add await to all tests after sync-to-async conversion --- .../adk/test_dialect_integration.py | 67 ++-- .../extensions/adk/test_dialect_support.py | 6 +- .../adbc/extensions/adk/test_edge_cases.py | 175 ++++---- .../extensions/adk/test_event_operations.py | 369 ++++++++++------- .../adbc/extensions/adk/test_memory_store.py | 34 +- .../extensions/adk/test_owner_id_column.py | 30 +- .../extensions/adk/test_session_operations.py | 86 ++-- .../asyncmy/extensions/adk/test_store.py | 8 +- .../extensions/adk/test_memory_store.py | 34 +- .../duckdb/extensions/adk/test_store.py | 375 ++++++++++-------- .../extensions/adk/test_store.py | 8 +- .../oracledb/extensions/adk/test_inmemory.py | 16 +- .../extensions/adk/test_oracle_specific.py | 45 +-- .../extensions/adk/test_owner_id_column.py | 16 +- .../spanner/extensions/adk/conftest.py | 6 +- .../extensions/adk/test_memory_store.py | 30 +- .../test_adk/test_store_instantiation.py | 2 +- 17 files changed, 706 insertions(+), 601 deletions(-) diff --git a/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py b/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py index c481f2f3c..e20536d83 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py @@ -23,16 +23,16 @@ @pytest.fixture() -def sqlite_store(tmp_path: Path) -> Any: +async def sqlite_store(tmp_path: Path) -> Any: """SQLite ADBC store fixture.""" db_path = tmp_path / "sqlite_test.db" config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) store = AdbcADKStore(config) - store.create_tables() + await store.create_tables() return store -def test_sqlite_dialect_creates_text_columns(sqlite_store: Any) -> None: +async def test_sqlite_dialect_creates_text_columns(sqlite_store: Any) -> None: """Test SQLite dialect creates TEXT columns for JSON.""" with sqlite_store.config.provide_connection() as conn: cursor = conn.cursor() @@ -46,46 +46,51 @@ def test_sqlite_dialect_creates_text_columns(sqlite_store: Any) -> None: cursor.close() # type: ignore[no-untyped-call] -def test_sqlite_dialect_session_operations(sqlite_store: Any) -> None: +async def test_sqlite_dialect_session_operations(sqlite_store: Any) -> None: """Test SQLite dialect with full session CRUD.""" session_id = "sqlite-session-1" app_name = "test-app" user_id = "user-123" state = {"nested": {"key": "value"}, "count": 42} - created = sqlite_store.create_session(session_id, app_name, user_id, state) + created = await sqlite_store.create_session(session_id, app_name, user_id, state) assert created["id"] == session_id assert created["state"] == state - retrieved = sqlite_store.get_session(session_id) + retrieved = await sqlite_store.get_session(session_id) assert retrieved["state"] == state new_state = {"updated": True} - sqlite_store.update_session_state(session_id, new_state) + await sqlite_store.update_session_state(session_id, new_state) - updated = sqlite_store.get_session(session_id) + updated = await sqlite_store.get_session(session_id) assert updated["state"] == new_state -def test_sqlite_dialect_event_operations(sqlite_store: Any) -> None: +async def test_sqlite_dialect_event_operations(sqlite_store: Any) -> None: """Test SQLite dialect with event operations.""" session_id = "sqlite-session-events" app_name = "test-app" user_id = "user-123" - sqlite_store.create_session(session_id, app_name, user_id, {}) + await sqlite_store.create_session(session_id, app_name, user_id, {}) content = {"message": "Hello"} - event = sqlite_store.create_event( - event_id="event-1", session_id=session_id, app_name=app_name, user_id=user_id, content=content - ) + from datetime import datetime, timezone + + from sqlspec.extensions.adk import EventRecord - assert event["session_id"] == session_id - event_data = json.loads(event["event_json"]) if isinstance(event["event_json"], str) else event["event_json"] - assert event_data["content"] == content + event_record: EventRecord = { + "session_id": session_id, + "invocation_id": "", + "author": "", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "event-1", "content": content, "app_name": app_name, "user_id": user_id}, + } + await sqlite_store.append_event(event_record) - events = sqlite_store.list_events(session_id) + events = await sqlite_store.get_events(session_id) assert len(events) == 1 retrieved_data = ( json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] @@ -95,7 +100,7 @@ def test_sqlite_dialect_event_operations(sqlite_store: Any) -> None: @pytest.mark.postgres @pytest.mark.skipif(True, reason="Requires adbc-driver-postgresql and PostgreSQL server") -def test_postgresql_dialect_creates_jsonb_columns() -> None: +async def test_postgresql_dialect_creates_jsonb_columns() -> None: """Test PostgreSQL dialect creates JSONB columns. This test is skipped by default. To run: @@ -108,7 +113,7 @@ def test_postgresql_dialect_creates_jsonb_columns() -> None: connection_config={"driver_name": "postgresql", "uri": "postgresql://user:pass@localhost/testdb"} ) store = AdbcADKStore(config) - store.create_tables() + await store.create_tables() with store.config.provide_connection() as conn: cursor = conn.cursor() @@ -130,7 +135,7 @@ def test_postgresql_dialect_creates_jsonb_columns() -> None: @pytest.mark.duckdb @pytest.mark.skipif(True, reason="Requires adbc-driver-duckdb") -def test_duckdb_dialect_creates_json_columns(tmp_path: Path) -> None: +async def test_duckdb_dialect_creates_json_columns(tmp_path: Path) -> None: """Test DuckDB dialect creates JSON columns. This test is skipped by default. To run: @@ -140,18 +145,18 @@ def test_duckdb_dialect_creates_json_columns(tmp_path: Path) -> None: db_path = tmp_path / "duckdb_test.db" config = AdbcConfig(connection_config={"driver_name": "duckdb", "uri": f"file:{db_path}"}) store = AdbcADKStore(config) - store.create_tables() + await store.create_tables() session_id = "duckdb-session-1" state = {"analytics": {"count": 1000, "revenue": 50000.00}} - created = store.create_session(session_id, "app", "user", state) + created = await store.create_session(session_id, "app", "user", state) assert created["state"] == state @pytest.mark.snowflake @pytest.mark.skipif(True, reason="Requires adbc-driver-snowflake and Snowflake account") -def test_snowflake_dialect_creates_variant_columns() -> None: +async def test_snowflake_dialect_creates_variant_columns() -> None: """Test Snowflake dialect creates VARIANT columns. This test is skipped by default. To run: @@ -168,7 +173,7 @@ def test_snowflake_dialect_creates_variant_columns() -> None: } ) store = AdbcADKStore(config) - store.create_tables() + await store.create_tables() with store.config.provide_connection() as conn: cursor = conn.cursor() @@ -188,7 +193,7 @@ def test_snowflake_dialect_creates_variant_columns() -> None: cursor.close() # type: ignore[no-untyped-call] -def test_sqlite_with_owner_id_column(tmp_path: Path) -> None: +async def test_sqlite_with_owner_id_column(tmp_path: Path) -> None: """Test SQLite with owner ID column creates proper constraints.""" db_path = tmp_path / "sqlite_fk_test.db" base_config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) @@ -208,16 +213,16 @@ def test_sqlite_with_owner_id_column(tmp_path: Path) -> None: extension_config={"adk": {"owner_id_column": "tenant_id INTEGER NOT NULL REFERENCES tenants(id)"}}, ) store = AdbcADKStore(config) - store.create_tables() + await store.create_tables() - session = store.create_session("s1", "app", "user", {"data": "test"}, owner_id=1) + session = await store.create_session("s1", "app", "user", {"data": "test"}, owner_id=1) assert session["id"] == "s1" - retrieved = store.get_session("s1") + retrieved = await store.get_session("s1") assert retrieved is not None -def test_generic_dialect_fallback(tmp_path: Path) -> None: +async def test_generic_dialect_fallback(tmp_path: Path) -> None: """Test generic dialect is used for unknown drivers.""" db_path = tmp_path / "generic_test.db" @@ -226,7 +231,7 @@ def test_generic_dialect_fallback(tmp_path: Path) -> None: store = AdbcADKStore(config) assert store.dialect in ["sqlite", "generic"] - store.create_tables() + await store.create_tables() - session = store.create_session("generic-1", "app", "user", {"test": True}) + session = await store.create_session("generic-1", "app", "user", {"test": True}) assert session["state"]["test"] is True diff --git a/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py b/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py index fa5130dbd..703d40437 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py @@ -131,15 +131,15 @@ def test_snowflake_events_ddl_uses_variant() -> None: assert "event_json" in ddl -def test_ddl_dispatch_uses_correct_dialect() -> None: +async def test_ddl_dispatch_uses_correct_dialect() -> None: """Test that DDL dispatch selects correct dialect method.""" config = AdbcConfig(connection_config={"driver_name": "postgresql", "uri": ":memory:"}) store = AdbcADKStore(config) - sessions_ddl = store._get_create_sessions_table_sql() # pyright: ignore[reportPrivateUsage] + sessions_ddl = await store._get_create_sessions_table_sql() # pyright: ignore[reportPrivateUsage] assert "JSONB" in sessions_ddl - events_ddl = store._get_create_events_table_sql() # pyright: ignore[reportPrivateUsage] + events_ddl = await store._get_create_events_table_sql() # pyright: ignore[reportPrivateUsage] assert "JSONB" in events_ddl assert "event_json" in events_ddl diff --git a/tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py b/tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py index c028f4511..0e11dd0bb 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py @@ -13,19 +13,19 @@ @pytest.fixture() -def adbc_store(tmp_path: Path) -> AdbcADKStore: +async def adbc_store(tmp_path: Path) -> AdbcADKStore: """Create ADBC ADK store with SQLite backend.""" db_path = tmp_path / "test_adk.db" config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) store = AdbcADKStore(config) - store.create_tables() + await store.create_tables() return store -def test_create_tables_idempotent(adbc_store: Any) -> None: +async def test_create_tables_idempotent(adbc_store: Any) -> None: """Test that create_tables can be called multiple times safely.""" - adbc_store.create_tables() - adbc_store.create_tables() + await adbc_store.create_tables() + await adbc_store.create_tables() def test_table_names_validation(tmp_path: Path) -> None: @@ -62,23 +62,23 @@ def test_table_names_validation(tmp_path: Path) -> None: AdbcADKStore(config) -def test_operations_before_create_tables(tmp_path: Path) -> None: +async def test_operations_before_create_tables(tmp_path: Path) -> None: """Test operations gracefully handle missing tables.""" db_path = tmp_path / "test_no_tables.db" config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) store = AdbcADKStore(config) - session = store.get_session("nonexistent") + session = await store.get_session("nonexistent") assert session is None - sessions = store.list_sessions("app", "user") + sessions = await store.list_sessions("app", "user") assert sessions == [] - events = store.list_events("session") + events = await store.get_events("session") assert events == [] -def test_custom_table_names(tmp_path: Path) -> None: +async def test_custom_table_names(tmp_path: Path) -> None: """Test using custom table names.""" db_path = tmp_path / "test_custom.db" config = AdbcConfig( @@ -86,43 +86,56 @@ def test_custom_table_names(tmp_path: Path) -> None: extension_config={"adk": {"session_table": "custom_sessions", "events_table": "custom_events"}}, ) store = AdbcADKStore(config) - store.create_tables() + await store.create_tables() session_id = "test" - session = store.create_session(session_id, "app", "user", {"data": "test"}) + session = await store.create_session(session_id, "app", "user", {"data": "test"}) assert session["id"] == session_id - retrieved = store.get_session(session_id) + retrieved = await store.get_session(session_id) assert retrieved is not None -def test_unicode_in_fields(adbc_store: Any) -> None: +async def test_unicode_in_fields(adbc_store: Any) -> None: """Test Unicode characters in various fields.""" session_id = "unicode-session" - app_name = "测试应用" - user_id = "ユーザー123" - state = {"message": "Hello 世界"} + app_name = "\u6d4b\u8bd5\u5e94\u7528" + user_id = "\u30e6\u30fc\u30b6\u30fc123" + state = {"message": "Hello \u4e16\u754c"} - created_session = adbc_store.create_session(session_id, app_name, user_id, state) + created_session = await adbc_store.create_session(session_id, app_name, user_id, state) assert created_session["app_name"] == app_name assert created_session["user_id"] == user_id - assert created_session["state"]["message"] == "Hello 世界" - - event = adbc_store.create_event( - event_id="unicode-event", - session_id=session_id, - app_name=app_name, - user_id=user_id, - author="アシスタント", - content={"text": "こんにちは"}, - ) + assert created_session["state"]["message"] == "Hello \u4e16\u754c" + + from datetime import datetime, timezone + + from sqlspec.extensions.adk import EventRecord + + event_record: EventRecord = { + "session_id": session_id, + "invocation_id": "", + "author": "\u30a2\u30b7\u30b9\u30bf\u30f3\u30c8", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "unicode-event", + "content": {"text": "\u3053\u3093\u306b\u3061\u306f"}, + "app_name": app_name, + "user_id": user_id, + }, + } + await adbc_store.append_event(event_record) - assert event["author"] == "アシスタント" - event_data = json.loads(event["event_json"]) if isinstance(event["event_json"], str) else event["event_json"] - assert event_data["content"]["text"] == "こんにちは" + events = await adbc_store.get_events(session_id) + assert len(events) == 1 + assert events[0]["author"] == "\u30a2\u30b7\u30b9\u30bf\u30f3\u30c8" + event_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + ) + assert event_data["content"]["text"] == "\u3053\u3093\u306b\u3061\u306f" -def test_special_characters_in_json(adbc_store: Any) -> None: +async def test_special_characters_in_json(adbc_store: Any) -> None: """Test special characters in JSON fields.""" session_id = "special-chars" state = { @@ -132,109 +145,105 @@ def test_special_characters_in_json(adbc_store: Any) -> None: "tab": "Col1\tCol2", } - adbc_store.create_session(session_id, "app", "user", state) - retrieved = adbc_store.get_session(session_id) + await adbc_store.create_session(session_id, "app", "user", state) + retrieved = await adbc_store.get_session(session_id) assert retrieved is not None assert retrieved["state"] == state -def test_very_long_strings(adbc_store: Any) -> None: +async def test_very_long_strings(adbc_store: Any) -> None: """Test handling very long strings in VARCHAR fields.""" long_id = "x" * 127 long_app = "a" * 127 long_user = "u" * 127 - session = adbc_store.create_session(long_id, long_app, long_user, {}) + session = await adbc_store.create_session(long_id, long_app, long_user, {}) assert session["id"] == long_id assert session["app_name"] == long_app assert session["user_id"] == long_user -def test_session_state_with_deeply_nested_data(adbc_store: Any) -> None: +async def test_session_state_with_deeply_nested_data(adbc_store: Any) -> None: """Test deeply nested JSON structures.""" session_id = "deep-nest" deeply_nested = {"level1": {"level2": {"level3": {"level4": {"level5": {"value": "deep"}}}}}} - adbc_store.create_session(session_id, "app", "user", deeply_nested) - retrieved = adbc_store.get_session(session_id) + await adbc_store.create_session(session_id, "app", "user", deeply_nested) + retrieved = await adbc_store.get_session(session_id) assert retrieved is not None assert retrieved["state"]["level1"]["level2"]["level3"]["level4"]["level5"]["value"] == "deep" -def test_concurrent_session_updates(adbc_store: Any) -> None: +async def test_concurrent_session_updates(adbc_store: Any) -> None: """Test multiple updates to the same session.""" session_id = "concurrent-test" - adbc_store.create_session(session_id, "app", "user", {"version": 1}) + await adbc_store.create_session(session_id, "app", "user", {"version": 1}) for i in range(10): - adbc_store.update_session_state(session_id, {"version": i + 2}) + await adbc_store.update_session_state(session_id, {"version": i + 2}) - final_session = adbc_store.get_session(session_id) + final_session = await adbc_store.get_session(session_id) assert final_session is not None assert final_session["state"]["version"] == 11 -def test_event_with_none_values(adbc_store: Any) -> None: +async def test_event_with_none_values(adbc_store: Any) -> None: """Test creating event with explicit None values for optional fields.""" session_id = "none-test" - adbc_store.create_session(session_id, "app", "user", {}) - - event = adbc_store.create_event( - event_id="none-event", - session_id=session_id, - app_name="app", - user_id="user", - invocation_id=None, - author=None, - actions=None, - content=None, - grounding_metadata=None, - custom_metadata=None, - partial=None, - turn_complete=None, - interrupted=None, - error_code=None, - error_message=None, - ) + await adbc_store.create_session(session_id, "app", "user", {}) + + from datetime import datetime, timezone + + from sqlspec.extensions.adk import EventRecord + + event_record: EventRecord = { + "session_id": session_id, + "invocation_id": "", + "author": "", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "none-event", "app_name": "app", "user_id": "user"}, + } + await adbc_store.append_event(event_record) - # The event should still have the 5-key shape - assert event["session_id"] == session_id - assert "event_json" in event + events = await adbc_store.get_events(session_id) + assert len(events) == 1 + assert events[0]["session_id"] == session_id + assert "event_json" in events[0] -def test_list_sessions_with_same_user_different_apps(adbc_store: Any) -> None: +async def test_list_sessions_with_same_user_different_apps(adbc_store: Any) -> None: """Test listing sessions doesn't mix data across apps.""" user_id = "user-123" app1 = "app1" app2 = "app2" - adbc_store.create_session("s1", app1, user_id, {}) - adbc_store.create_session("s2", app1, user_id, {}) - adbc_store.create_session("s3", app2, user_id, {}) + await adbc_store.create_session("s1", app1, user_id, {}) + await adbc_store.create_session("s2", app1, user_id, {}) + await adbc_store.create_session("s3", app2, user_id, {}) - app1_sessions = adbc_store.list_sessions(app1, user_id) - app2_sessions = adbc_store.list_sessions(app2, user_id) + app1_sessions = await adbc_store.list_sessions(app1, user_id) + app2_sessions = await adbc_store.list_sessions(app2, user_id) assert len(app1_sessions) == 2 assert len(app2_sessions) == 1 -def test_delete_nonexistent_session(adbc_store: Any) -> None: +async def test_delete_nonexistent_session(adbc_store: Any) -> None: """Test deleting a session that doesn't exist.""" - adbc_store.delete_session("nonexistent-session") + await adbc_store.delete_session("nonexistent-session") -def test_update_nonexistent_session(adbc_store: Any) -> None: +async def test_update_nonexistent_session(adbc_store: Any) -> None: """Test updating a session that doesn't exist.""" - adbc_store.update_session_state("nonexistent-session", {"data": "test"}) + await adbc_store.update_session_state("nonexistent-session", {"data": "test"}) -def test_drop_and_recreate_tables(adbc_store: Any) -> None: +async def test_drop_and_recreate_tables(adbc_store: Any) -> None: """Test dropping and recreating tables.""" session_id = "test-session" - adbc_store.create_session(session_id, "app", "user", {"data": "test"}) + await adbc_store.create_session(session_id, "app", "user", {"data": "test"}) drop_sqls = adbc_store._get_drop_tables_sql() with adbc_store._config.provide_connection() as conn: @@ -246,19 +255,19 @@ def test_drop_and_recreate_tables(adbc_store: Any) -> None: finally: cursor.close() - adbc_store.create_tables() + await adbc_store.create_tables() - session = adbc_store.get_session(session_id) + session = await adbc_store.get_session(session_id) assert session is None -def test_json_with_escaped_characters(adbc_store: Any) -> None: +async def test_json_with_escaped_characters(adbc_store: Any) -> None: """Test JSON serialization of escaped characters.""" session_id = "escaped-json" state = {"escaped": r"test\nvalue\t", "quotes": r'"quoted"'} - adbc_store.create_session(session_id, "app", "user", state) - retrieved = adbc_store.get_session(session_id) + await adbc_store.create_session(session_id, "app", "user", state) + retrieved = await adbc_store.get_session(session_id) assert retrieved is not None assert retrieved["state"] == state diff --git a/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py b/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py index 01b991d99..d7ad7c2d2 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py @@ -1,6 +1,7 @@ """Tests for ADBC ADK store event operations.""" import json +from datetime import datetime, timezone from pathlib import Path from typing import Any @@ -8,111 +9,139 @@ from sqlspec.adapters.adbc import AdbcConfig from sqlspec.adapters.adbc.adk import AdbcADKStore +from sqlspec.extensions.adk import EventRecord pytestmark = [pytest.mark.xdist_group("sqlite"), pytest.mark.adbc, pytest.mark.integration] @pytest.fixture() -def adbc_store(tmp_path: Path) -> AdbcADKStore: +async def adbc_store(tmp_path: Path) -> AdbcADKStore: """Create ADBC ADK store with SQLite backend.""" db_path = tmp_path / "test_adk.db" config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) store = AdbcADKStore(config) - store.create_tables() + await store.create_tables() return store @pytest.fixture() -def session_fixture(adbc_store: Any) -> dict[str, str]: +async def session_fixture(adbc_store: Any) -> dict[str, str]: """Create a test session.""" session_id = "test-session" app_name = "test-app" user_id = "user-123" state = {"test": True} - adbc_store.create_session(session_id, app_name, user_id, state) + await adbc_store.create_session(session_id, app_name, user_id, state) return {"session_id": session_id, "app_name": app_name, "user_id": user_id} -def test_create_event(adbc_store: Any, session_fixture: Any) -> None: +async def test_create_event(adbc_store: Any, session_fixture: Any) -> None: """Test creating a new event returns 5-key EventRecord.""" - event = adbc_store.create_event( - event_id="event-1", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - author="user", - content={"message": "Hello"}, - ) - - assert event["session_id"] == session_fixture["session_id"] - assert event["author"] == "user" - assert event["timestamp"] is not None - assert "event_json" in event + event_record: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "user", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "event-1", + "content": {"message": "Hello"}, + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + }, + } + await adbc_store.append_event(event_record) + + events = await adbc_store.get_events(session_fixture["session_id"]) + assert len(events) == 1 + assert events[0]["session_id"] == session_fixture["session_id"] + assert events[0]["author"] == "user" + assert events[0]["timestamp"] is not None + assert "event_json" in events[0] # Content is stored inside event_json - event_data = json.loads(event["event_json"]) if isinstance(event["event_json"], str) else event["event_json"] + event_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + ) assert event_data["content"] == {"message": "Hello"} -def test_list_events(adbc_store: Any, session_fixture: Any) -> None: +async def test_list_events(adbc_store: Any, session_fixture: Any) -> None: """Test listing events for a session.""" - adbc_store.create_event( - event_id="event-1", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - author="user", - content={"seq": 1}, - ) - adbc_store.create_event( - event_id="event-2", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - author="assistant", - content={"seq": 2}, - ) - - events = adbc_store.list_events(session_fixture["session_id"]) + event1: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "user", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "event-1", + "content": {"seq": 1}, + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + }, + } + event2: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "assistant", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "event-2", + "content": {"seq": 2}, + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + }, + } + await adbc_store.append_event(event1) + await adbc_store.append_event(event2) + + events = await adbc_store.get_events(session_fixture["session_id"]) assert len(events) == 2 assert events[0]["author"] == "user" assert events[1]["author"] == "assistant" -def test_list_events_empty(adbc_store: Any, session_fixture: Any) -> None: +async def test_list_events_empty(adbc_store: Any, session_fixture: Any) -> None: """Test listing events when none exist.""" - events = adbc_store.list_events(session_fixture["session_id"]) + events = await adbc_store.get_events(session_fixture["session_id"]) assert events == [] -def test_event_with_all_fields(adbc_store: Any, session_fixture: Any) -> None: +async def test_event_with_all_fields(adbc_store: Any, session_fixture: Any) -> None: """Test creating event with all optional fields stored in event_json.""" - event = adbc_store.create_event( - event_id="full-event", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - invocation_id="invocation-123", - author="assistant", - actions=b"complex_action_data", - branch="main", - content={"text": "Response"}, - grounding_metadata={"sources": ["doc1", "doc2"]}, - custom_metadata={"custom": "data"}, - partial=True, - turn_complete=False, - interrupted=False, - error_code="NONE", - error_message="No errors", - ) + event_record: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "invocation-123", + "author": "assistant", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "full-event", + "content": {"text": "Response"}, + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + "branch": "main", + "grounding_metadata": {"sources": ["doc1", "doc2"]}, + "custom_metadata": {"custom": "data"}, + "partial": True, + "turn_complete": False, + "interrupted": False, + "error_code": "NONE", + "error_message": "No errors", + }, + } + await adbc_store.append_event(event_record) + + events = await adbc_store.get_events(session_fixture["session_id"]) + assert len(events) == 1 # Top-level indexed columns - assert event["invocation_id"] == "invocation-123" - assert event["author"] == "assistant" + assert events[0]["invocation_id"] == "invocation-123" + assert events[0]["author"] == "assistant" # Everything else is in event_json - event_data = json.loads(event["event_json"]) if isinstance(event["event_json"], str) else event["event_json"] + event_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + ) assert event_data["content"] == {"text": "Response"} assert event_data["branch"] == "main" assert event_data["grounding_metadata"] == {"sources": ["doc1", "doc2"]} @@ -124,138 +153,174 @@ def test_event_with_all_fields(adbc_store: Any, session_fixture: Any) -> None: assert event_data["error_message"] == "No errors" -def test_event_with_minimal_fields(adbc_store: Any, session_fixture: Any) -> None: +async def test_event_with_minimal_fields(adbc_store: Any, session_fixture: Any) -> None: """Test creating event with only required fields.""" - event = adbc_store.create_event( - event_id="minimal-event", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - ) - - assert event["session_id"] == session_fixture["session_id"] - assert "event_json" in event - - -def test_event_json_fields(adbc_store: Any, session_fixture: Any) -> None: + event_record: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "minimal-event", + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + }, + } + await adbc_store.append_event(event_record) + + events = await adbc_store.get_events(session_fixture["session_id"]) + assert len(events) == 1 + assert events[0]["session_id"] == session_fixture["session_id"] + assert "event_json" in events[0] + + +async def test_event_json_fields(adbc_store: Any, session_fixture: Any) -> None: """Test event JSON field serialization and deserialization via event_json.""" complex_content = {"nested": {"data": "value"}, "list": [1, 2, 3], "null": None} complex_grounding = {"sources": [{"title": "Doc", "url": "http://example.com"}]} complex_custom = {"metadata": {"version": 1, "tags": ["tag1", "tag2"]}} - event = adbc_store.create_event( - event_id="json-event", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - content=complex_content, - grounding_metadata=complex_grounding, - custom_metadata=complex_custom, + event_record: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "json-event", + "content": complex_content, + "grounding_metadata": complex_grounding, + "custom_metadata": complex_custom, + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + }, + } + await adbc_store.append_event(event_record) + + events = await adbc_store.get_events(session_fixture["session_id"]) + assert len(events) == 1 + event_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] ) - - event_data = json.loads(event["event_json"]) if isinstance(event["event_json"], str) else event["event_json"] assert event_data["content"] == complex_content assert event_data["grounding_metadata"] == complex_grounding assert event_data["custom_metadata"] == complex_custom - events = adbc_store.list_events(session_fixture["session_id"]) - retrieved_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] - ) - assert retrieved_data["content"] == complex_content - assert retrieved_data["grounding_metadata"] == complex_grounding - assert retrieved_data["custom_metadata"] == complex_custom - -def test_event_ordering(adbc_store: Any, session_fixture: Any) -> None: +async def test_event_ordering(adbc_store: Any, session_fixture: Any) -> None: """Test that events are ordered by timestamp ASC.""" import time - adbc_store.create_event( - event_id="event-1", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - ) + ev1: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "event-1", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, + } + await adbc_store.append_event(ev1) time.sleep(0.01) - adbc_store.create_event( - event_id="event-2", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - ) + ev2: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "event-2", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, + } + await adbc_store.append_event(ev2) time.sleep(0.01) - adbc_store.create_event( - event_id="event-3", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - ) + ev3: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "event-3", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, + } + await adbc_store.append_event(ev3) - events = adbc_store.list_events(session_fixture["session_id"]) + events = await adbc_store.get_events(session_fixture["session_id"]) assert len(events) == 3 assert events[0]["timestamp"] < events[1]["timestamp"] assert events[1]["timestamp"] < events[2]["timestamp"] -def test_delete_session_cascades_events(adbc_store: Any, session_fixture: Any, tmp_path: Path) -> None: +async def test_delete_session_cascades_events(adbc_store: Any, session_fixture: Any, tmp_path: Path) -> None: """Test that deleting a session cascades to delete events. Note: SQLite with ADBC requires foreign key enforcement to be explicitly enabled for cascade deletes to work. This test manually enables it. """ - adbc_store.create_event( - event_id="event-1", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - ) - adbc_store.create_event( - event_id="event-2", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - ) - - events_before = adbc_store.list_events(session_fixture["session_id"]) + ev1: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "event-1", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, + } + ev2: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "event-2", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, + } + await adbc_store.append_event(ev1) + await adbc_store.append_event(ev2) + + events_before = await adbc_store.get_events(session_fixture["session_id"]) assert len(events_before) == 2 - adbc_store.delete_session(session_fixture["session_id"]) + await adbc_store.delete_session(session_fixture["session_id"]) - session_after = adbc_store.get_session(session_fixture["session_id"]) + session_after = await adbc_store.get_session(session_fixture["session_id"]) assert session_after is None -def test_event_with_empty_actions(adbc_store: Any, session_fixture: Any) -> None: +async def test_event_with_empty_actions(adbc_store: Any, session_fixture: Any) -> None: """Test creating event with empty actions bytes.""" - event = adbc_store.create_event( - event_id="empty-actions", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - actions=b"", - ) - - # actions=b"" is either ignored or stored as hex in event_json - assert "event_json" in event - - -def test_event_with_large_content(adbc_store: Any, session_fixture: Any) -> None: + event_record: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "empty-actions", + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + }, + } + await adbc_store.append_event(event_record) + + events = await adbc_store.get_events(session_fixture["session_id"]) + assert len(events) == 1 + assert "event_json" in events[0] + + +async def test_event_with_large_content(adbc_store: Any, session_fixture: Any) -> None: """Test creating event with large content in event_json.""" large_content = {"data": "x" * 10000} - event = adbc_store.create_event( - event_id="large-content", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - content=large_content, + event_record: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "large-content", + "content": large_content, + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + }, + } + await adbc_store.append_event(event_record) + + events = await adbc_store.get_events(session_fixture["session_id"]) + assert len(events) == 1 + event_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] ) - - event_data = json.loads(event["event_json"]) if isinstance(event["event_json"], str) else event["event_json"] assert event_data["content"] == large_content diff --git a/tests/integration/adapters/adbc/extensions/adk/test_memory_store.py b/tests/integration/adapters/adbc/extensions/adk/test_memory_store.py index a417cd84f..0c3afa838 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_memory_store.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_memory_store.py @@ -30,63 +30,63 @@ def _build_record(*, session_id: str, event_id: str, content_text: str, inserted ) -def _build_store(tmp_path: Path) -> AdbcADKMemoryStore: +async def _build_store(tmp_path: Path) -> AdbcADKMemoryStore: db_path = tmp_path / "test_adk_memory.db" config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) store = AdbcADKMemoryStore(config) - store.create_tables() + await store.create_tables() return store -def test_adbc_memory_store_insert_search_dedup(tmp_path: Path) -> None: +async def test_adbc_memory_store_insert_search_dedup(tmp_path: Path) -> None: """Insert memory entries, search by text, and skip duplicates.""" - store = _build_store(tmp_path) + store = await _build_store(tmp_path) now = datetime.now(timezone.utc) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now) record2 = _build_record(session_id="s1", event_id="evt-2", content_text="latte", inserted_at=now) - inserted = store.insert_memory_entries([record1, record2]) + inserted = await store.insert_memory_entries([record1, record2]) assert inserted == 2 - results = store.search_entries(query="espresso", app_name="app", user_id="user") + results = await store.search_entries(query="espresso", app_name="app", user_id="user") assert len(results) == 1 assert results[0]["event_id"] == "evt-1" - deduped = store.insert_memory_entries([record1]) + deduped = await store.insert_memory_entries([record1]) assert deduped == 0 -def test_adbc_memory_store_delete_by_session(tmp_path: Path) -> None: +async def test_adbc_memory_store_delete_by_session(tmp_path: Path) -> None: """Delete memory entries by session id.""" - store = _build_store(tmp_path) + store = await _build_store(tmp_path) now = datetime.now(timezone.utc) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now) record2 = _build_record(session_id="s2", event_id="evt-2", content_text="latte", inserted_at=now) - store.insert_memory_entries([record1, record2]) + await store.insert_memory_entries([record1, record2]) - deleted = store.delete_entries_by_session("s1") + deleted = await store.delete_entries_by_session("s1") assert deleted == 1 - remaining = store.search_entries(query="latte", app_name="app", user_id="user") + remaining = await store.search_entries(query="latte", app_name="app", user_id="user") assert len(remaining) == 1 assert remaining[0]["session_id"] == "s2" -def test_adbc_memory_store_delete_older_than(tmp_path: Path) -> None: +async def test_adbc_memory_store_delete_older_than(tmp_path: Path) -> None: """Delete memory entries older than a cutoff.""" - store = _build_store(tmp_path) + store = await _build_store(tmp_path) now = datetime.now(timezone.utc) old = now - timedelta(days=40) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="old", inserted_at=old) record2 = _build_record(session_id="s1", event_id="evt-2", content_text="new", inserted_at=now) - store.insert_memory_entries([record1, record2]) + await store.insert_memory_entries([record1, record2]) - deleted = store.delete_entries_older_than(30) + deleted = await store.delete_entries_older_than(30) assert deleted == 1 - remaining = store.search_entries(query="new", app_name="app", user_id="user") + remaining = await store.search_entries(query="new", app_name="app", user_id="user") assert len(remaining) == 1 assert remaining[0]["event_id"] == "evt-2" diff --git a/tests/integration/adapters/adbc/extensions/adk/test_owner_id_column.py b/tests/integration/adapters/adbc/extensions/adk/test_owner_id_column.py index ce2a1bbfa..1b1752ae4 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_owner_id_column.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_owner_id_column.py @@ -9,7 +9,7 @@ @pytest.fixture() -def adbc_store_with_fk(tmp_path): # type: ignore[no-untyped-def] +async def adbc_store_with_fk(tmp_path): # type: ignore[no-untyped-def] """Create ADBC ADK store with owner ID column (SQLite).""" db_path = tmp_path / "test_fk.db" config = AdbcConfig( @@ -29,21 +29,21 @@ def adbc_store_with_fk(tmp_path): # type: ignore[no-untyped-def] finally: cursor.close() # type: ignore[no-untyped-call] - store.create_tables() + await store.create_tables() return store @pytest.fixture() -def adbc_store_no_fk(tmp_path): # type: ignore[no-untyped-def] +async def adbc_store_no_fk(tmp_path): # type: ignore[no-untyped-def] """Create ADBC ADK store without owner ID column (SQLite).""" db_path = tmp_path / "test_no_fk.db" config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) store = AdbcADKStore(config) - store.create_tables() + await store.create_tables() return store -def test_create_session_with_owner_id(adbc_store_with_fk): # type: ignore[no-untyped-def] +async def test_create_session_with_owner_id(adbc_store_with_fk): # type: ignore[no-untyped-def] """Test creating session with owner ID value.""" session_id = "test-session-1" app_name = "test-app" @@ -51,32 +51,32 @@ def test_create_session_with_owner_id(adbc_store_with_fk): # type: ignore[no-un state = {"key": "value"} tenant_id = 1 - session = adbc_store_with_fk.create_session(session_id, app_name, user_id, state, owner_id=tenant_id) + session = await adbc_store_with_fk.create_session(session_id, app_name, user_id, state, owner_id=tenant_id) assert session["id"] == session_id assert session["state"] == state -def test_create_session_without_owner_id_value(adbc_store_with_fk): # type: ignore[no-untyped-def] +async def test_create_session_without_owner_id_value(adbc_store_with_fk): # type: ignore[no-untyped-def] """Test creating session without providing owner ID value still works.""" session_id = "test-session-2" app_name = "test-app" user_id = "user-123" state = {"key": "value"} - session = adbc_store_with_fk.create_session(session_id, app_name, user_id, state) + session = await adbc_store_with_fk.create_session(session_id, app_name, user_id, state) assert session["id"] == session_id -def test_create_session_no_fk_column_configured(adbc_store_no_fk): # type: ignore[no-untyped-def] +async def test_create_session_no_fk_column_configured(adbc_store_no_fk): # type: ignore[no-untyped-def] """Test creating session when no FK column configured.""" session_id = "test-session-3" app_name = "test-app" user_id = "user-123" state = {"key": "value"} - session = adbc_store_no_fk.create_session(session_id, app_name, user_id, state) + session = await adbc_store_no_fk.create_session(session_id, app_name, user_id, state) assert session["id"] == session_id assert session["state"] == state @@ -109,16 +109,16 @@ def test_owner_id_column_complex_ddl() -> None: assert store._owner_id_column_ddl == complex_ddl # pyright: ignore[reportPrivateUsage] -def test_multiple_tenants_isolation(adbc_store_with_fk): # type: ignore[no-untyped-def] +async def test_multiple_tenants_isolation(adbc_store_with_fk): # type: ignore[no-untyped-def] """Test sessions are properly isolated by tenant.""" app_name = "test-app" user_id = "user-123" - adbc_store_with_fk.create_session("session-tenant1", app_name, user_id, {"data": "tenant1"}, owner_id=1) - adbc_store_with_fk.create_session("session-tenant2", app_name, user_id, {"data": "tenant2"}, owner_id=2) + await adbc_store_with_fk.create_session("session-tenant1", app_name, user_id, {"data": "tenant1"}, owner_id=1) + await adbc_store_with_fk.create_session("session-tenant2", app_name, user_id, {"data": "tenant2"}, owner_id=2) - retrieved1 = adbc_store_with_fk.get_session("session-tenant1") - retrieved2 = adbc_store_with_fk.get_session("session-tenant2") + retrieved1 = await adbc_store_with_fk.get_session("session-tenant1") + retrieved2 = await adbc_store_with_fk.get_session("session-tenant2") assert retrieved1["state"]["data"] == "tenant1" assert retrieved2["state"]["data"] == "tenant2" diff --git a/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py b/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py index 819002edc..b749461d7 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py @@ -12,23 +12,23 @@ @pytest.fixture() -def adbc_store(tmp_path: Path) -> AdbcADKStore: +async def adbc_store(tmp_path: Path) -> AdbcADKStore: """Create ADBC ADK store with SQLite backend.""" db_path = tmp_path / "test_adk.db" config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) store = AdbcADKStore(config) - store.create_tables() + await store.create_tables() return store -def test_create_session(adbc_store: Any) -> None: +async def test_create_session(adbc_store: Any) -> None: """Test creating a new session.""" session_id = "test-session-1" app_name = "test-app" user_id = "user-123" state = {"key": "value", "count": 42} - session = adbc_store.create_session(session_id, app_name, user_id, state) + session = await adbc_store.create_session(session_id, app_name, user_id, state) assert session["id"] == session_id assert session["app_name"] == app_name @@ -38,82 +38,82 @@ def test_create_session(adbc_store: Any) -> None: assert session["update_time"] is not None -def test_get_session(adbc_store: Any) -> None: +async def test_get_session(adbc_store: Any) -> None: """Test retrieving a session by ID.""" session_id = "test-session-2" app_name = "test-app" user_id = "user-123" state = {"data": "test"} - adbc_store.create_session(session_id, app_name, user_id, state) - retrieved = adbc_store.get_session(session_id) + await adbc_store.create_session(session_id, app_name, user_id, state) + retrieved = await adbc_store.get_session(session_id) assert retrieved is not None assert retrieved["id"] == session_id assert retrieved["state"] == state -def test_get_nonexistent_session(adbc_store: Any) -> None: +async def test_get_nonexistent_session(adbc_store: Any) -> None: """Test retrieving a session that doesn't exist.""" - result = adbc_store.get_session("nonexistent-id") + result = await adbc_store.get_session("nonexistent-id") assert result is None -def test_update_session_state(adbc_store: Any) -> None: +async def test_update_session_state(adbc_store: Any) -> None: """Test updating session state.""" session_id = "test-session-3" app_name = "test-app" user_id = "user-123" initial_state = {"version": 1} - adbc_store.create_session(session_id, app_name, user_id, initial_state) + await adbc_store.create_session(session_id, app_name, user_id, initial_state) new_state = {"version": 2, "updated": True} - adbc_store.update_session_state(session_id, new_state) + await adbc_store.update_session_state(session_id, new_state) - updated = adbc_store.get_session(session_id) + updated = await adbc_store.get_session(session_id) assert updated is not None assert updated["state"] == new_state assert updated["state"] != initial_state -def test_delete_session(adbc_store: Any) -> None: +async def test_delete_session(adbc_store: Any) -> None: """Test deleting a session.""" session_id = "test-session-4" app_name = "test-app" user_id = "user-123" state = {"data": "test"} - adbc_store.create_session(session_id, app_name, user_id, state) - assert adbc_store.get_session(session_id) is not None + await adbc_store.create_session(session_id, app_name, user_id, state) + assert await adbc_store.get_session(session_id) is not None - adbc_store.delete_session(session_id) - assert adbc_store.get_session(session_id) is None + await adbc_store.delete_session(session_id) + assert await adbc_store.get_session(session_id) is None -def test_list_sessions(adbc_store: Any) -> None: +async def test_list_sessions(adbc_store: Any) -> None: """Test listing sessions for an app and user.""" app_name = "test-app" user_id = "user-123" - adbc_store.create_session("session-1", app_name, user_id, {"num": 1}) - adbc_store.create_session("session-2", app_name, user_id, {"num": 2}) - adbc_store.create_session("session-3", "other-app", user_id, {"num": 3}) + await adbc_store.create_session("session-1", app_name, user_id, {"num": 1}) + await adbc_store.create_session("session-2", app_name, user_id, {"num": 2}) + await adbc_store.create_session("session-3", "other-app", user_id, {"num": 3}) - sessions = adbc_store.list_sessions(app_name, user_id) + sessions = await adbc_store.list_sessions(app_name, user_id) assert len(sessions) == 2 session_ids = {s["id"] for s in sessions} assert session_ids == {"session-1", "session-2"} -def test_list_sessions_empty(adbc_store: Any) -> None: +async def test_list_sessions_empty(adbc_store: Any) -> None: """Test listing sessions when none exist.""" - sessions = adbc_store.list_sessions("nonexistent-app", "nonexistent-user") + sessions = await adbc_store.list_sessions("nonexistent-app", "nonexistent-user") assert sessions == [] -def test_session_state_with_complex_data(adbc_store: Any) -> None: +async def test_session_state_with_complex_data(adbc_store: Any) -> None: """Test session state with nested complex data structures.""" session_id = "complex-session" app_name = "test-app" @@ -125,41 +125,41 @@ def test_session_state_with_complex_data(adbc_store: Any) -> None: "null_value": None, } - session = adbc_store.create_session(session_id, app_name, user_id, complex_state) + session = await adbc_store.create_session(session_id, app_name, user_id, complex_state) assert session["state"] == complex_state - retrieved = adbc_store.get_session(session_id) + retrieved = await adbc_store.get_session(session_id) assert retrieved is not None assert retrieved["state"] == complex_state -def test_session_state_empty_dict(adbc_store: Any) -> None: +async def test_session_state_empty_dict(adbc_store: Any) -> None: """Test creating session with empty state dictionary.""" session_id = "empty-state-session" app_name = "test-app" user_id = "user-123" empty_state: dict[str, Any] = {} - session = adbc_store.create_session(session_id, app_name, user_id, empty_state) + session = await adbc_store.create_session(session_id, app_name, user_id, empty_state) assert session["state"] == empty_state - retrieved = adbc_store.get_session(session_id) + retrieved = await adbc_store.get_session(session_id) assert retrieved is not None assert retrieved["state"] == empty_state -def test_multiple_users_same_app(adbc_store: Any) -> None: +async def test_multiple_users_same_app(adbc_store: Any) -> None: """Test sessions for multiple users in the same app.""" app_name = "test-app" user1 = "user-1" user2 = "user-2" - adbc_store.create_session("session-user1-1", app_name, user1, {"user": 1}) - adbc_store.create_session("session-user1-2", app_name, user1, {"user": 1}) - adbc_store.create_session("session-user2-1", app_name, user2, {"user": 2}) + await adbc_store.create_session("session-user1-1", app_name, user1, {"user": 1}) + await adbc_store.create_session("session-user1-2", app_name, user1, {"user": 1}) + await adbc_store.create_session("session-user2-1", app_name, user2, {"user": 2}) - user1_sessions = adbc_store.list_sessions(app_name, user1) - user2_sessions = adbc_store.list_sessions(app_name, user2) + user1_sessions = await adbc_store.list_sessions(app_name, user1) + user2_sessions = await adbc_store.list_sessions(app_name, user2) assert len(user1_sessions) == 2 assert len(user2_sessions) == 1 @@ -167,18 +167,18 @@ def test_multiple_users_same_app(adbc_store: Any) -> None: assert all(s["user_id"] == user2 for s in user2_sessions) -def test_session_ordering(adbc_store: Any) -> None: +async def test_session_ordering(adbc_store: Any) -> None: """Test that sessions are ordered by update_time DESC.""" app_name = "test-app" user_id = "user-123" - adbc_store.create_session("session-1", app_name, user_id, {"order": 1}) - adbc_store.create_session("session-2", app_name, user_id, {"order": 2}) - adbc_store.create_session("session-3", app_name, user_id, {"order": 3}) + await adbc_store.create_session("session-1", app_name, user_id, {"order": 1}) + await adbc_store.create_session("session-2", app_name, user_id, {"order": 2}) + await adbc_store.create_session("session-3", app_name, user_id, {"order": 3}) - adbc_store.update_session_state("session-1", {"order": 1, "updated": True}) + await adbc_store.update_session_state("session-1", {"order": 1, "updated": True}) - sessions = adbc_store.list_sessions(app_name, user_id) + sessions = await adbc_store.list_sessions(app_name, user_id) assert len(sessions) == 3 assert sessions[0]["id"] == "session-1" diff --git a/tests/integration/adapters/asyncmy/extensions/adk/test_store.py b/tests/integration/adapters/asyncmy/extensions/adk/test_store.py index 44a346d66..bd52cade2 100644 --- a/tests/integration/adapters/asyncmy/extensions/adk/test_store.py +++ b/tests/integration/adapters/asyncmy/extensions/adk/test_store.py @@ -148,7 +148,7 @@ async def test_delete_session_cascade(asyncmy_adk_store: AsyncmyADKStore) -> Non "invocation_id": "inv-001", "author": "user", "timestamp": datetime.now(timezone.utc), - "event_json": json.dumps({"content": {"text": "Hello"}, "app_name": app_name, "user_id": user_id}), + "event_json": {"content": {"text": "Hello"}, "app_name": app_name, "user_id": user_id}, } await asyncmy_adk_store.append_event(event_record) @@ -177,7 +177,7 @@ async def test_append_and_get_events(asyncmy_adk_store: AsyncmyADKStore) -> None "invocation_id": "inv-001", "author": "user", "timestamp": datetime.now(timezone.utc), - "event_json": json.dumps({"content": {"text": "Hello", "role": "user"}, "app_name": app_name}), + "event_json": {"content": {"text": "Hello", "role": "user"}, "app_name": app_name}, } event2: EventRecord = { @@ -185,7 +185,7 @@ async def test_append_and_get_events(asyncmy_adk_store: AsyncmyADKStore) -> None "invocation_id": "inv-002", "author": "assistant", "timestamp": datetime.now(timezone.utc), - "event_json": json.dumps({"content": {"text": "Hi there", "role": "assistant"}, "app_name": app_name}), + "event_json": {"content": {"text": "Hi there", "role": "assistant"}, "app_name": app_name}, } await asyncmy_adk_store.append_event(event1) @@ -224,7 +224,7 @@ async def test_timestamp_precision(asyncmy_adk_store: AsyncmyADKStore) -> None: "invocation_id": "inv-micro", "author": "system", "timestamp": event_time, - "event_json": json.dumps({"app_name": app_name}), + "event_json": {"app_name": app_name}, } await asyncmy_adk_store.append_event(event) diff --git a/tests/integration/adapters/duckdb/extensions/adk/test_memory_store.py b/tests/integration/adapters/duckdb/extensions/adk/test_memory_store.py index e3092176b..b86fef3ae 100644 --- a/tests/integration/adapters/duckdb/extensions/adk/test_memory_store.py +++ b/tests/integration/adapters/duckdb/extensions/adk/test_memory_store.py @@ -30,63 +30,63 @@ def _build_record(*, session_id: str, event_id: str, content_text: str, inserted ) -def _build_store(tmp_path: Path, worker_id: str) -> DuckdbADKMemoryStore: +async def _build_store(tmp_path: Path, worker_id: str) -> DuckdbADKMemoryStore: db_path = tmp_path / f"test_adk_memory_{worker_id}.duckdb" config = DuckDBConfig(connection_config={"database": str(db_path)}) store = DuckdbADKMemoryStore(config) - store.create_tables() + await store.create_tables() return store -def test_duckdb_memory_store_insert_search_dedup(tmp_path: Path, worker_id: str) -> None: +async def test_duckdb_memory_store_insert_search_dedup(tmp_path: Path, worker_id: str) -> None: """Insert memory entries, search by text, and skip duplicates.""" - store = _build_store(tmp_path, worker_id) + store = await _build_store(tmp_path, worker_id) now = datetime.now(timezone.utc) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now) record2 = _build_record(session_id="s1", event_id="evt-2", content_text="latte", inserted_at=now) - inserted = store.insert_memory_entries([record1, record2]) + inserted = await store.insert_memory_entries([record1, record2]) assert inserted == 2 - results = store.search_entries(query="espresso", app_name="app", user_id="user") + results = await store.search_entries(query="espresso", app_name="app", user_id="user") assert len(results) == 1 assert results[0]["event_id"] == "evt-1" - deduped = store.insert_memory_entries([record1]) + deduped = await store.insert_memory_entries([record1]) assert deduped == 0 -def test_duckdb_memory_store_delete_by_session(tmp_path: Path, worker_id: str) -> None: +async def test_duckdb_memory_store_delete_by_session(tmp_path: Path, worker_id: str) -> None: """Delete memory entries by session id.""" - store = _build_store(tmp_path, worker_id) + store = await _build_store(tmp_path, worker_id) now = datetime.now(timezone.utc) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now) record2 = _build_record(session_id="s2", event_id="evt-2", content_text="latte", inserted_at=now) - store.insert_memory_entries([record1, record2]) + await store.insert_memory_entries([record1, record2]) - deleted = store.delete_entries_by_session("s1") + deleted = await store.delete_entries_by_session("s1") assert deleted == 1 - remaining = store.search_entries(query="latte", app_name="app", user_id="user") + remaining = await store.search_entries(query="latte", app_name="app", user_id="user") assert len(remaining) == 1 assert remaining[0]["session_id"] == "s2" -def test_duckdb_memory_store_delete_older_than(tmp_path: Path, worker_id: str) -> None: +async def test_duckdb_memory_store_delete_older_than(tmp_path: Path, worker_id: str) -> None: """Delete memory entries older than a cutoff.""" - store = _build_store(tmp_path, worker_id) + store = await _build_store(tmp_path, worker_id) now = datetime.now(timezone.utc) old = now - timedelta(days=40) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="old", inserted_at=old) record2 = _build_record(session_id="s1", event_id="evt-2", content_text="new", inserted_at=now) - store.insert_memory_entries([record1, record2]) + await store.insert_memory_entries([record1, record2]) - deleted = store.delete_entries_older_than(30) + deleted = await store.delete_entries_older_than(30) assert deleted == 1 - remaining = store.search_entries(query="new", app_name="app", user_id="user") + remaining = await store.search_entries(query="new", app_name="app", user_id="user") assert len(remaining) == 1 assert remaining[0]["event_id"] == "evt-2" diff --git a/tests/integration/adapters/duckdb/extensions/adk/test_store.py b/tests/integration/adapters/duckdb/extensions/adk/test_store.py index aadfefa8a..bc67ad007 100644 --- a/tests/integration/adapters/duckdb/extensions/adk/test_store.py +++ b/tests/integration/adapters/duckdb/extensions/adk/test_store.py @@ -1,7 +1,7 @@ """Integration tests for DuckDB ADK session store.""" import json -from collections.abc import Generator +from collections.abc import AsyncGenerator from datetime import datetime, timezone from pathlib import Path @@ -9,12 +9,13 @@ from sqlspec.adapters.duckdb.adk.store import DuckdbADKStore from sqlspec.adapters.duckdb.config import DuckDBConfig +from sqlspec.extensions.adk import EventRecord pytestmark = [pytest.mark.duckdb, pytest.mark.integration] @pytest.fixture -def duckdb_adk_store(tmp_path: Path, worker_id: str) -> "Generator[DuckdbADKStore, None, None]": +async def duckdb_adk_store(tmp_path: Path, worker_id: str) -> "AsyncGenerator[DuckdbADKStore, None]": """Create DuckDB ADK store with temporary file-based database. Args: @@ -35,27 +36,27 @@ def duckdb_adk_store(tmp_path: Path, worker_id: str) -> "Generator[DuckdbADKStor extension_config={"adk": {"session_table": "test_sessions", "events_table": "test_events"}}, ) store = DuckdbADKStore(config) - store.create_tables() + await store.create_tables() yield store finally: if db_path.exists(): db_path.unlink() -def test_create_tables(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_create_tables(duckdb_adk_store: DuckdbADKStore) -> None: """Test table creation succeeds without errors.""" assert duckdb_adk_store.session_table == "test_sessions" assert duckdb_adk_store.events_table == "test_events" -def test_create_and_get_session(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_create_and_get_session(duckdb_adk_store: DuckdbADKStore) -> None: """Test creating and retrieving a session.""" session_id = "session-001" app_name = "test-app" user_id = "user-001" state = {"key": "value", "count": 42} - created_session = duckdb_adk_store.create_session( + created_session = await duckdb_adk_store.create_session( session_id=session_id, app_name=app_name, user_id=user_id, state=state ) @@ -66,49 +67,51 @@ def test_create_and_get_session(duckdb_adk_store: DuckdbADKStore) -> None: assert isinstance(created_session["create_time"], datetime) assert isinstance(created_session["update_time"], datetime) - retrieved_session = duckdb_adk_store.get_session(session_id) + retrieved_session = await duckdb_adk_store.get_session(session_id) assert retrieved_session is not None assert retrieved_session["id"] == session_id assert retrieved_session["state"] == state -def test_get_nonexistent_session(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_get_nonexistent_session(duckdb_adk_store: DuckdbADKStore) -> None: """Test getting a non-existent session returns None.""" - result = duckdb_adk_store.get_session("nonexistent-session") + result = await duckdb_adk_store.get_session("nonexistent-session") assert result is None -def test_update_session_state(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_update_session_state(duckdb_adk_store: DuckdbADKStore) -> None: """Test updating session state.""" session_id = "session-002" initial_state = {"status": "active"} updated_state = {"status": "completed", "result": "success"} - duckdb_adk_store.create_session(session_id=session_id, app_name="test-app", user_id="user-002", state=initial_state) + await duckdb_adk_store.create_session( + session_id=session_id, app_name="test-app", user_id="user-002", state=initial_state + ) - session_before = duckdb_adk_store.get_session(session_id) + session_before = await duckdb_adk_store.get_session(session_id) assert session_before is not None assert session_before["state"] == initial_state - duckdb_adk_store.update_session_state(session_id, updated_state) + await duckdb_adk_store.update_session_state(session_id, updated_state) - session_after = duckdb_adk_store.get_session(session_id) + session_after = await duckdb_adk_store.get_session(session_id) assert session_after is not None assert session_after["state"] == updated_state assert session_after["update_time"] >= session_before["update_time"] -def test_list_sessions(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_list_sessions(duckdb_adk_store: DuckdbADKStore) -> None: """Test listing sessions for an app and user.""" app_name = "test-app" user_id = "user-003" - duckdb_adk_store.create_session("session-1", app_name, user_id, {"num": 1}) - duckdb_adk_store.create_session("session-2", app_name, user_id, {"num": 2}) - duckdb_adk_store.create_session("session-3", app_name, user_id, {"num": 3}) - duckdb_adk_store.create_session("session-other", "other-app", user_id, {"num": 999}) + await duckdb_adk_store.create_session("session-1", app_name, user_id, {"num": 1}) + await duckdb_adk_store.create_session("session-2", app_name, user_id, {"num": 2}) + await duckdb_adk_store.create_session("session-3", app_name, user_id, {"num": 3}) + await duckdb_adk_store.create_session("session-other", "other-app", user_id, {"num": 999}) - sessions = duckdb_adk_store.list_sessions(app_name, user_id) + sessions = await duckdb_adk_store.list_sessions(app_name, user_id) assert len(sessions) == 3 session_ids = {s["id"] for s in sessions} @@ -117,105 +120,110 @@ def test_list_sessions(duckdb_adk_store: DuckdbADKStore) -> None: assert all(s["user_id"] == user_id for s in sessions) -def test_list_sessions_empty(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_list_sessions_empty(duckdb_adk_store: DuckdbADKStore) -> None: """Test listing sessions when none exist.""" - sessions = duckdb_adk_store.list_sessions("nonexistent-app", "nonexistent-user") + sessions = await duckdb_adk_store.list_sessions("nonexistent-app", "nonexistent-user") assert sessions == [] -def test_delete_session(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_delete_session(duckdb_adk_store: DuckdbADKStore) -> None: """Test deleting a session.""" session_id = "session-to-delete" - duckdb_adk_store.create_session(session_id, "test-app", "user-004", {"data": "test"}) + await duckdb_adk_store.create_session(session_id, "test-app", "user-004", {"data": "test"}) - assert duckdb_adk_store.get_session(session_id) is not None + assert await duckdb_adk_store.get_session(session_id) is not None - duckdb_adk_store.delete_session(session_id) + await duckdb_adk_store.delete_session(session_id) - assert duckdb_adk_store.get_session(session_id) is None + assert await duckdb_adk_store.get_session(session_id) is None -def test_delete_session_cascade_events(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_delete_session_cascade_events(duckdb_adk_store: DuckdbADKStore) -> None: """Test deleting a session also deletes associated events.""" session_id = "session-with-events" - duckdb_adk_store.create_session(session_id, "test-app", "user-005", {"data": "test"}) - - event = duckdb_adk_store.create_event( - event_id="event-001", - session_id=session_id, - app_name="test-app", - user_id="user-005", - author="user", - content={"message": "Hello"}, - ) + await duckdb_adk_store.create_session(session_id, "test-app", "user-005", {"data": "test"}) + + event_record: EventRecord = { + "session_id": session_id, + "invocation_id": "", + "author": "user", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "event-001", + "content": {"message": "Hello"}, + "app_name": "test-app", + "user_id": "user-005", + }, + } + await duckdb_adk_store.append_event(event_record) - assert event["session_id"] == session_id - events = duckdb_adk_store.list_events(session_id) + events = await duckdb_adk_store.get_events(session_id) assert len(events) == 1 - duckdb_adk_store.delete_session(session_id) + await duckdb_adk_store.delete_session(session_id) - assert duckdb_adk_store.get_session(session_id) is None - events_after = duckdb_adk_store.list_events(session_id) + assert await duckdb_adk_store.get_session(session_id) is None + events_after = await duckdb_adk_store.get_events(session_id) assert len(events_after) == 0 -def test_create_event(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_create_event(duckdb_adk_store: DuckdbADKStore) -> None: """Test creating an event and verifying the returned 5-key EventRecord.""" session_id = "session-006" - duckdb_adk_store.create_session(session_id, "test-app", "user-006", {}) + await duckdb_adk_store.create_session(session_id, "test-app", "user-006", {}) timestamp = datetime.now(timezone.utc) content = {"text": "Test message", "role": "user"} - created_event = duckdb_adk_store.create_event( - event_id="event-002", - session_id=session_id, - app_name="test-app", - user_id="user-006", - author="user", - content=content, - timestamp=timestamp, - ) + event_record: EventRecord = { + "session_id": session_id, + "invocation_id": "", + "author": "user", + "timestamp": timestamp, + "event_json": {"id": "event-002", "content": content, "app_name": "test-app", "user_id": "user-006"}, + } + await duckdb_adk_store.append_event(event_record) - # Returned record has the 5-key shape - assert created_event["session_id"] == session_id - assert created_event["author"] == "user" - assert created_event["timestamp"] == timestamp - assert "event_json" in created_event + events = await duckdb_adk_store.get_events(session_id) + assert len(events) == 1 + assert events[0]["session_id"] == session_id + assert events[0]["author"] == "user" # Content is stored inside event_json event_data = ( - json.loads(created_event["event_json"]) - if isinstance(created_event["event_json"], str) - else created_event["event_json"] + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] ) assert event_data["content"] == content -def test_list_events(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_list_events(duckdb_adk_store: DuckdbADKStore) -> None: """Test listing events for a session.""" session_id = "session-007" - duckdb_adk_store.create_session(session_id, "test-app", "user-007", {}) - - duckdb_adk_store.create_event( - event_id="event-1", - session_id=session_id, - app_name="test-app", - user_id="user-007", - author="user", - content={"message": "First"}, - ) - duckdb_adk_store.create_event( - event_id="event-2", - session_id=session_id, - app_name="test-app", - user_id="user-007", - author="assistant", - content={"message": "Second"}, - ) + await duckdb_adk_store.create_session(session_id, "test-app", "user-007", {}) + + event1: EventRecord = { + "session_id": session_id, + "invocation_id": "", + "author": "user", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "event-1", "content": {"message": "First"}, "app_name": "test-app", "user_id": "user-007"}, + } + event2: EventRecord = { + "session_id": session_id, + "invocation_id": "", + "author": "assistant", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "event-2", + "content": {"message": "Second"}, + "app_name": "test-app", + "user_id": "user-007", + }, + } + await duckdb_adk_store.append_event(event1) + await duckdb_adk_store.append_event(event2) - events = duckdb_adk_store.list_events(session_id) + events = await duckdb_adk_store.get_events(session_id) assert len(events) == 2 assert events[0]["author"] == "user" @@ -223,69 +231,92 @@ def test_list_events(duckdb_adk_store: DuckdbADKStore) -> None: assert events[0]["timestamp"] <= events[1]["timestamp"] -def test_list_events_empty(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_list_events_empty(duckdb_adk_store: DuckdbADKStore) -> None: """Test listing events when none exist.""" session_id = "session-no-events" - duckdb_adk_store.create_session(session_id, "test-app", "user-008", {}) + await duckdb_adk_store.create_session(session_id, "test-app", "user-008", {}) - events = duckdb_adk_store.list_events(session_id) + events = await duckdb_adk_store.get_events(session_id) assert events == [] -def test_event_with_optional_fields(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_event_with_optional_fields(duckdb_adk_store: DuckdbADKStore) -> None: """Test creating events with optional fields stored in event_json.""" session_id = "session-008" - duckdb_adk_store.create_session(session_id, "test-app", "user-008", {}) - - event = duckdb_adk_store.create_event( - event_id="event-full", - session_id=session_id, - app_name="test-app", - user_id="user-008", - author="assistant", - content={"text": "Response"}, - invocation_id="inv-123", - branch="main", - grounding_metadata={"sources": ["doc1", "doc2"]}, - custom_metadata={"priority": "high"}, - partial=True, - turn_complete=False, - interrupted=False, - error_code=None, - error_message=None, - ) + await duckdb_adk_store.create_session(session_id, "test-app", "user-008", {}) + + event_record: EventRecord = { + "session_id": session_id, + "invocation_id": "inv-123", + "author": "assistant", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "event-full", + "content": {"text": "Response"}, + "app_name": "test-app", + "user_id": "user-008", + "branch": "main", + "grounding_metadata": {"sources": ["doc1", "doc2"]}, + "custom_metadata": {"priority": "high"}, + "partial": True, + "turn_complete": False, + "interrupted": False, + }, + } + await duckdb_adk_store.append_event(event_record) + + events = await duckdb_adk_store.get_events(session_id) + assert len(events) == 1 # The 5-key record has invocation_id as a top-level indexed column - assert event["invocation_id"] == "inv-123" + assert events[0]["invocation_id"] == "inv-123" # Other fields are inside event_json - event_data = json.loads(event["event_json"]) if isinstance(event["event_json"], str) else event["event_json"] + event_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + ) assert event_data["branch"] == "main" assert event_data["grounding_metadata"] == {"sources": ["doc1", "doc2"]} assert event_data["partial"] is True assert event_data["turn_complete"] is False -def test_event_ordering_by_timestamp(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_event_ordering_by_timestamp(duckdb_adk_store: DuckdbADKStore) -> None: """Test events are ordered by timestamp ascending.""" session_id = "session-009" - duckdb_adk_store.create_session(session_id, "test-app", "user-009", {}) + await duckdb_adk_store.create_session(session_id, "test-app", "user-009", {}) t1 = datetime.now(timezone.utc) t2 = datetime.now(timezone.utc) t3 = datetime.now(timezone.utc) - duckdb_adk_store.create_event( - event_id="event-middle", session_id=session_id, app_name="test-app", user_id="user-009", timestamp=t2 - ) - duckdb_adk_store.create_event( - event_id="event-last", session_id=session_id, app_name="test-app", user_id="user-009", timestamp=t3 - ) - duckdb_adk_store.create_event( - event_id="event-first", session_id=session_id, app_name="test-app", user_id="user-009", timestamp=t1 - ) + ev_middle: EventRecord = { + "session_id": session_id, + "invocation_id": "", + "author": "", + "timestamp": t2, + "event_json": {"id": "event-middle", "app_name": "test-app", "user_id": "user-009"}, + } + ev_last: EventRecord = { + "session_id": session_id, + "invocation_id": "", + "author": "", + "timestamp": t3, + "event_json": {"id": "event-last", "app_name": "test-app", "user_id": "user-009"}, + } + ev_first: EventRecord = { + "session_id": session_id, + "invocation_id": "", + "author": "", + "timestamp": t1, + "event_json": {"id": "event-first", "app_name": "test-app", "user_id": "user-009"}, + } - events = duckdb_adk_store.list_events(session_id) + await duckdb_adk_store.append_event(ev_middle) + await duckdb_adk_store.append_event(ev_last) + await duckdb_adk_store.append_event(ev_first) + + events = await duckdb_adk_store.get_events(session_id) assert len(events) == 3 # Events should be ordered by timestamp ASC @@ -296,7 +327,7 @@ def test_event_ordering_by_timestamp(duckdb_adk_store: DuckdbADKStore) -> None: assert event_ids == ["event-first", "event-middle", "event-last"] -def test_session_state_with_complex_data(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_session_state_with_complex_data(duckdb_adk_store: DuckdbADKStore) -> None: """Test session state with nested JSON structures.""" session_id = "session-complex" complex_state = { @@ -309,62 +340,60 @@ def test_session_state_with_complex_data(duckdb_adk_store: DuckdbADKStore) -> No "flags": [True, False, True], } - duckdb_adk_store.create_session(session_id, "test-app", "user-010", complex_state) + await duckdb_adk_store.create_session(session_id, "test-app", "user-010", complex_state) - session = duckdb_adk_store.get_session(session_id) + session = await duckdb_adk_store.get_session(session_id) assert session is not None assert session["state"] == complex_state assert session["state"]["user"]["preferences"]["theme"] == "dark" assert session["state"]["conversation"]["turn_count"] == 5 -def test_empty_state(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_empty_state(duckdb_adk_store: DuckdbADKStore) -> None: """Test creating session with empty state.""" session_id = "session-empty-state" - duckdb_adk_store.create_session(session_id, "test-app", "user-011", {}) + await duckdb_adk_store.create_session(session_id, "test-app", "user-011", {}) - session = duckdb_adk_store.get_session(session_id) + session = await duckdb_adk_store.get_session(session_id) assert session is not None assert session["state"] == {} -def test_table_not_found_handling(tmp_path: Path, worker_id: str) -> None: +async def test_table_not_found_handling(tmp_path: Path, worker_id: str) -> None: """Test graceful handling when tables don't exist.""" db_path = tmp_path / f"test_no_tables_{worker_id}.duckdb" try: config = DuckDBConfig(connection_config={"database": str(db_path)}) store = DuckdbADKStore(config) - result = store.get_session("nonexistent") + result = await store.get_session("nonexistent") assert result is None - sessions = store.list_sessions("app", "user") + sessions = await store.list_sessions("app", "user") assert sessions == [] - events = store.list_events("session") + events = await store.get_events("session") assert events == [] finally: if db_path.exists(): db_path.unlink() -def test_event_json_round_trip(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_event_json_round_trip(duckdb_adk_store: DuckdbADKStore) -> None: """Test storing and retrieving event data via event_json.""" session_id = "session-json-rt" - duckdb_adk_store.create_session(session_id, "test-app", "user-012", {}) - - event = duckdb_adk_store.create_event( - event_id="event-json", - session_id=session_id, - app_name="test-app", - user_id="user-012", - author="system", - content={"data": "value"}, - ) - - assert "event_json" in event + await duckdb_adk_store.create_session(session_id, "test-app", "user-012", {}) + + event_record: EventRecord = { + "session_id": session_id, + "invocation_id": "", + "author": "system", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "event-json", "content": {"data": "value"}, "app_name": "test-app", "user_id": "user-012"}, + } + await duckdb_adk_store.append_event(event_record) - events = duckdb_adk_store.list_events(session_id) + events = await duckdb_adk_store.get_events(session_id) assert len(events) == 1 event_data = ( json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] @@ -372,23 +401,23 @@ def test_event_json_round_trip(duckdb_adk_store: DuckdbADKStore) -> None: assert event_data["content"] == {"data": "value"} -def test_concurrent_session_updates(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_concurrent_session_updates(duckdb_adk_store: DuckdbADKStore) -> None: """Test multiple updates to same session.""" session_id = "session-concurrent" - duckdb_adk_store.create_session(session_id, "test-app", "user-013", {"counter": 0}) + await duckdb_adk_store.create_session(session_id, "test-app", "user-013", {"counter": 0}) for i in range(10): - session = duckdb_adk_store.get_session(session_id) + session = await duckdb_adk_store.get_session(session_id) assert session is not None current_counter = session["state"]["counter"] - duckdb_adk_store.update_session_state(session_id, {"counter": current_counter + 1}) + await duckdb_adk_store.update_session_state(session_id, {"counter": current_counter + 1}) - final_session = duckdb_adk_store.get_session(session_id) + final_session = await duckdb_adk_store.get_session(session_id) assert final_session is not None assert final_session["state"]["counter"] == 10 -def test_owner_id_column_with_integer(tmp_path: Path, worker_id: str) -> None: +async def test_owner_id_column_with_integer(tmp_path: Path, worker_id: str) -> None: """Test owner ID column with INTEGER type.""" db_path = tmp_path / f"test_owner_id_int_{worker_id}.duckdb" try: @@ -410,12 +439,12 @@ def test_owner_id_column_with_integer(tmp_path: Path, worker_id: str) -> None: }, ) store = DuckdbADKStore(config_with_extension) - store.create_tables() + await store.create_tables() assert store.owner_id_column_name == "tenant_id" assert store.owner_id_column_ddl == "tenant_id INTEGER NOT NULL REFERENCES tenants(id)" - session = store.create_session( + session = await store.create_session( session_id="session-tenant-1", app_name="test-app", user_id="user-001", state={"data": "test"}, owner_id=1 ) @@ -431,7 +460,7 @@ def test_owner_id_column_with_integer(tmp_path: Path, worker_id: str) -> None: db_path.unlink() -def test_owner_id_column_with_ubigint(tmp_path: Path, worker_id: str) -> None: +async def test_owner_id_column_with_ubigint(tmp_path: Path, worker_id: str) -> None: """Test owner ID column with DuckDB UBIGINT type.""" db_path = tmp_path / f"test_owner_id_ubigint_{worker_id}.duckdb" try: @@ -453,11 +482,11 @@ def test_owner_id_column_with_ubigint(tmp_path: Path, worker_id: str) -> None: }, ) store = DuckdbADKStore(config_with_extension) - store.create_tables() + await store.create_tables() assert store.owner_id_column_name == "owner_id" - session = store.create_session( + session = await store.create_session( session_id="session-user-1", app_name="test-app", user_id="user-001", @@ -477,7 +506,7 @@ def test_owner_id_column_with_ubigint(tmp_path: Path, worker_id: str) -> None: db_path.unlink() -def test_owner_id_column_foreign_key_constraint(tmp_path: Path, worker_id: str) -> None: +async def test_owner_id_column_foreign_key_constraint(tmp_path: Path, worker_id: str) -> None: """Test that FK constraint is enforced.""" db_path = tmp_path / f"test_owner_id_constraint_{worker_id}.duckdb" try: @@ -499,14 +528,14 @@ def test_owner_id_column_foreign_key_constraint(tmp_path: Path, worker_id: str) }, ) store = DuckdbADKStore(config_with_extension) - store.create_tables() + await store.create_tables() - store.create_session( + await store.create_session( session_id="session-org-1", app_name="test-app", user_id="user-001", state={"data": "test"}, owner_id=100 ) with pytest.raises(Exception) as exc_info: - store.create_session( + await store.create_session( session_id="session-org-invalid", app_name="test-app", user_id="user-002", @@ -520,7 +549,7 @@ def test_owner_id_column_foreign_key_constraint(tmp_path: Path, worker_id: str) db_path.unlink() -def test_owner_id_column_without_value(tmp_path: Path, worker_id: str) -> None: +async def test_owner_id_column_without_value(tmp_path: Path, worker_id: str) -> None: """Test creating session without owner_id when column is configured but nullable.""" db_path = tmp_path / f"test_owner_id_nullable_{worker_id}.duckdb" try: @@ -541,22 +570,22 @@ def test_owner_id_column_without_value(tmp_path: Path, worker_id: str) -> None: }, ) store = DuckdbADKStore(config_with_extension) - store.create_tables() + await store.create_tables() - session = store.create_session( + session = await store.create_session( session_id="session-no-fk", app_name="test-app", user_id="user-001", state={"data": "test"}, owner_id=None ) assert session["id"] == "session-no-fk" - retrieved = store.get_session("session-no-fk") + retrieved = await store.get_session("session-no-fk") assert retrieved is not None finally: if db_path.exists(): db_path.unlink() -def test_owner_id_column_with_varchar(tmp_path: Path, worker_id: str) -> None: +async def test_owner_id_column_with_varchar(tmp_path: Path, worker_id: str) -> None: """Test owner ID column with VARCHAR type.""" db_path = tmp_path / f"test_owner_id_varchar_{worker_id}.duckdb" try: @@ -578,9 +607,9 @@ def test_owner_id_column_with_varchar(tmp_path: Path, worker_id: str) -> None: }, ) store = DuckdbADKStore(config_with_extension) - store.create_tables() + await store.create_tables() - session = store.create_session( + session = await store.create_session( session_id="session-company-1", app_name="test-app", user_id="user-001", @@ -600,7 +629,7 @@ def test_owner_id_column_with_varchar(tmp_path: Path, worker_id: str) -> None: db_path.unlink() -def test_owner_id_column_multiple_sessions(tmp_path: Path, worker_id: str) -> None: +async def test_owner_id_column_multiple_sessions(tmp_path: Path, worker_id: str) -> None: """Test multiple sessions with same FK value.""" db_path = tmp_path / f"test_owner_id_multiple_{worker_id}.duckdb" try: @@ -622,10 +651,10 @@ def test_owner_id_column_multiple_sessions(tmp_path: Path, worker_id: str) -> No }, ) store = DuckdbADKStore(config_with_extension) - store.create_tables() + await store.create_tables() for i in range(5): - store.create_session( + await store.create_session( session_id=f"session-dept-{i}", app_name="test-app", user_id=f"user-{i}", @@ -643,7 +672,7 @@ def test_owner_id_column_multiple_sessions(tmp_path: Path, worker_id: str) -> No db_path.unlink() -def test_owner_id_column_query_by_fk(tmp_path: Path, worker_id: str) -> None: +async def test_owner_id_column_query_by_fk(tmp_path: Path, worker_id: str) -> None: """Test querying sessions by FK column value.""" db_path = tmp_path / f"test_owner_id_query_{worker_id}.duckdb" try: @@ -665,11 +694,11 @@ def test_owner_id_column_query_by_fk(tmp_path: Path, worker_id: str) -> None: }, ) store = DuckdbADKStore(config_with_extension) - store.create_tables() + await store.create_tables() - store.create_session("s1", "app", "u1", {"val": 1}, owner_id=1) - store.create_session("s2", "app", "u2", {"val": 2}, owner_id=1) - store.create_session("s3", "app", "u3", {"val": 3}, owner_id=2) + await store.create_session("s1", "app", "u1", {"val": 1}, owner_id=1) + await store.create_session("s2", "app", "u2", {"val": 2}, owner_id=1) + await store.create_session("s3", "app", "u3", {"val": 3}, owner_id=2) with config.provide_connection() as conn: cursor = conn.execute("SELECT id FROM sessions_with_project WHERE project_id = ? ORDER BY id", (1,)) diff --git a/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py b/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py index 67b38f0a2..eb12c69f8 100644 --- a/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py +++ b/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py @@ -149,7 +149,7 @@ async def test_delete_session_cascade(mysqlconnector_adk_store: MysqlConnectorAs "invocation_id": "inv-001", "author": "user", "timestamp": datetime.now(timezone.utc), - "event_json": json.dumps({"content": {"text": "Hello"}, "app_name": app_name, "user_id": user_id}), + "event_json": {"content": {"text": "Hello"}, "app_name": app_name, "user_id": user_id}, } await mysqlconnector_adk_store.append_event(event_record) @@ -178,7 +178,7 @@ async def test_append_and_get_events(mysqlconnector_adk_store: MysqlConnectorAsy "invocation_id": "inv-001", "author": "user", "timestamp": datetime.now(timezone.utc), - "event_json": json.dumps({"content": {"text": "Hello", "role": "user"}, "app_name": app_name}), + "event_json": {"content": {"text": "Hello", "role": "user"}, "app_name": app_name}, } event2: EventRecord = { @@ -186,7 +186,7 @@ async def test_append_and_get_events(mysqlconnector_adk_store: MysqlConnectorAsy "invocation_id": "inv-002", "author": "assistant", "timestamp": datetime.now(timezone.utc), - "event_json": json.dumps({"content": {"text": "Hi there", "role": "assistant"}, "app_name": app_name}), + "event_json": {"content": {"text": "Hi there", "role": "assistant"}, "app_name": app_name}, } await mysqlconnector_adk_store.append_event(event1) @@ -224,7 +224,7 @@ async def test_timestamp_precision(mysqlconnector_adk_store: MysqlConnectorAsync "invocation_id": "inv-micro", "author": "system", "timestamp": event_time, - "event_json": json.dumps({"app_name": app_name}), + "event_json": {"app_name": app_name}, } await mysqlconnector_adk_store.append_event(event) diff --git a/tests/integration/adapters/oracledb/extensions/adk/test_inmemory.py b/tests/integration/adapters/oracledb/extensions/adk/test_inmemory.py index 26d86165f..5a32a911b 100644 --- a/tests/integration/adapters/oracledb/extensions/adk/test_inmemory.py +++ b/tests/integration/adapters/oracledb/extensions/adk/test_inmemory.py @@ -302,14 +302,14 @@ async def test_inmemory_tables_functional_async(oracle_async_config: OracleAsync @pytest.mark.oracledb -def test_inmemory_enabled_sync(oracle_sync_config: OracleSyncConfig) -> None: +async def test_inmemory_enabled_sync(oracle_sync_config: OracleSyncConfig) -> None: """Test that in_memory=True works with sync store.""" config = OracleSyncConfig( connection_config=oracle_sync_config.connection_config, extension_config={"adk": {"in_memory": True}} ) store = OracleSyncADKStore(config) - store.create_tables() + await store.create_tables() try: with config.provide_connection() as conn: @@ -343,14 +343,14 @@ def test_inmemory_enabled_sync(oracle_sync_config: OracleSyncConfig) -> None: @pytest.mark.oracledb -def test_inmemory_disabled_sync(oracle_sync_config: OracleSyncConfig) -> None: +async def test_inmemory_disabled_sync(oracle_sync_config: OracleSyncConfig) -> None: """Test that in_memory=False works with sync store.""" config = OracleSyncConfig( connection_config=oracle_sync_config.connection_config, extension_config={"adk": {"in_memory": False}} ) store = OracleSyncADKStore(config) - store.create_tables() + await store.create_tables() try: with config.provide_connection() as conn: @@ -382,14 +382,14 @@ def test_inmemory_disabled_sync(oracle_sync_config: OracleSyncConfig) -> None: @pytest.mark.oracledb -def test_inmemory_tables_functional_sync(oracle_sync_config: OracleSyncConfig) -> None: +async def test_inmemory_tables_functional_sync(oracle_sync_config: OracleSyncConfig) -> None: """Test that INMEMORY tables work correctly in sync mode.""" config = OracleSyncConfig( connection_config=oracle_sync_config.connection_config, extension_config={"adk": {"in_memory": True}} ) store = OracleSyncADKStore(config) - store.create_tables() + await store.create_tables() try: session_id = "inmemory-sync-session" @@ -397,11 +397,11 @@ def test_inmemory_tables_functional_sync(oracle_sync_config: OracleSyncConfig) - user_id = "user-456" state = {"sync": True, "value": 99} - session = store.create_session(session_id, app_name, user_id, state) + session = await store.create_session(session_id, app_name, user_id, state) assert session["id"] == session_id assert session["state"] == state - retrieved = store.get_session(session_id) + retrieved = await store.get_session(session_id) assert retrieved is not None assert retrieved["state"] == state diff --git a/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py b/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py index 540cf123f..1a05b040c 100644 --- a/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py +++ b/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py @@ -1,7 +1,7 @@ """Oracle-specific ADK store tests for LOB handling, JSON types, and FK columns.""" import json -from collections.abc import AsyncGenerator, Generator +from collections.abc import AsyncGenerator from datetime import datetime, timezone from typing import Any, cast from uuid import uuid4 @@ -63,10 +63,10 @@ async def oracle_async_store(oracle_async_config: "OracleAsyncConfig") -> "Async @pytest.fixture(scope="module") -def oracle_sync_store(oracle_sync_config: "OracleSyncConfig") -> "Generator[OracleSyncADKStore, None, None]": +async def oracle_sync_store(oracle_sync_config: "OracleSyncConfig") -> "AsyncGenerator[OracleSyncADKStore, None]": """Create a sync Oracle ADK store with tables created once per module.""" store = OracleSyncADKStore(oracle_sync_config) - store.create_tables() + await store.create_tables() try: yield store finally: @@ -140,7 +140,9 @@ async def oracle_store_with_fk( @pytest.fixture -def oracle_config_with_users_table(oracle_sync_config: "OracleSyncConfig") -> "Generator[OracleSyncConfig, None, None]": +async def oracle_config_with_users_table( + oracle_sync_config: "OracleSyncConfig", +) -> "AsyncGenerator[OracleSyncConfig, None]": """Create a users table for FK testing.""" with oracle_sync_config.provide_connection() as conn: cursor = conn.cursor() @@ -187,9 +189,9 @@ def oracle_config_with_users_table(oracle_sync_config: "OracleSyncConfig") -> "G @pytest.fixture -def oracle_store_sync_with_fk( +async def oracle_store_sync_with_fk( oracle_config_with_users_table: "OracleSyncConfig", -) -> "Generator[OracleSyncADKStore, None, None]": +) -> "AsyncGenerator[OracleSyncADKStore, None]": """Create a sync Oracle ADK store with owner_id FK column.""" config_with_extension = OracleSyncConfig( connection_config=oracle_config_with_users_table.connection_config, @@ -197,7 +199,7 @@ def oracle_store_sync_with_fk( ) store = OracleSyncADKStore(config_with_extension) _cleanup_sync_store(store, config_with_extension) - store.create_tables() + await store.create_tables() try: yield store finally: @@ -242,7 +244,7 @@ async def test_event_json_lob_deserialization(oracle_async_store: "OracleAsyncAD "invocation_id": "", "author": "assistant", "timestamp": datetime.now(timezone.utc), - "event_json": json.dumps(event_data), + "event_json": event_data, } await oracle_async_store.append_event(event_record) @@ -273,7 +275,7 @@ async def test_event_json_storage(oracle_async_store: "OracleAsyncADKStore") -> "invocation_id": "", "author": "user", "timestamp": datetime.now(timezone.utc), - "event_json": json.dumps(event_data), + "event_json": event_data, } await oracle_async_store.append_event(event_record) @@ -286,17 +288,17 @@ async def test_event_json_storage(oracle_async_store: "OracleAsyncADKStore") -> assert retrieved_data == event_data -def test_state_lob_deserialization_sync(oracle_sync_store: "OracleSyncADKStore") -> None: +async def test_state_lob_deserialization_sync(oracle_sync_store: "OracleSyncADKStore") -> None: """Test state CLOB/BLOB is correctly deserialized in sync mode.""" session_id = _unique_session_id("lob-session-sync") app_name = "test-app" user_id = "user-123" state = {"large_field": "y" * 10000, "nested": {"data": [4, 5, 6]}} - session = oracle_sync_store.create_session(session_id, app_name, user_id, state) + session = await oracle_sync_store.create_session(session_id, app_name, user_id, state) assert session["state"] == state - retrieved = oracle_sync_store.get_session(session_id) + retrieved = await oracle_sync_store.get_session(session_id) assert retrieved is not None assert retrieved["state"] == state @@ -314,12 +316,7 @@ async def test_event_record_5_column_contract(oracle_async_store: "OracleAsyncAD "invocation_id": "inv-001", "author": "assistant", "timestamp": datetime.now(timezone.utc), - "event_json": json.dumps({ - "content": {"text": "Hello"}, - "partial": True, - "turn_complete": False, - "interrupted": True, - }), + "event_json": {"content": {"text": "Hello"}, "partial": True, "turn_complete": False, "interrupted": True}, } await oracle_async_store.append_event(event_record) @@ -351,7 +348,7 @@ async def test_event_with_none_values(oracle_async_store: "OracleAsyncADKStore") "invocation_id": "", "author": "user", "timestamp": datetime.now(timezone.utc), - "event_json": json.dumps({"app_name": app_name}), + "event_json": {"app_name": app_name}, } await oracle_async_store.append_event(event_record) @@ -431,7 +428,7 @@ async def test_json_fields_stored_and_retrieved(oracle_async_store: "OracleAsync "complex": { "nested": {"deep": {"structure": "value"}}, "array": [1, 2, 3, {"key": "value"}], - "unicode": "日本語テスト", + "unicode": "\u65e5\u672c\u8a9e\u30c6\u30b9\u30c8", "special_chars": "test@example.com | value > 100", } } @@ -442,10 +439,10 @@ async def test_json_fields_stored_and_retrieved(oracle_async_store: "OracleAsync retrieved = await oracle_async_store.get_session(session_id) assert retrieved is not None assert retrieved["state"] == state - assert retrieved["state"]["complex"]["unicode"] == "日本語テスト" + assert retrieved["state"]["complex"]["unicode"] == "\u65e5\u672c\u8a9e\u30c6\u30b9\u30c8" -def test_create_session_with_owner_id_sync(oracle_store_sync_with_fk: "OracleSyncADKStore") -> None: +async def test_create_session_with_owner_id_sync(oracle_store_sync_with_fk: "OracleSyncADKStore") -> None: """Test creating session with owner_id in sync mode.""" session_id = _unique_session_id("sync-fk") app_name = "test-app" @@ -453,10 +450,10 @@ def test_create_session_with_owner_id_sync(oracle_store_sync_with_fk: "OracleSyn state = {"data": "sync test"} owner_id = 100 - session = oracle_store_sync_with_fk.create_session(session_id, app_name, user_id, state, owner_id=owner_id) + session = await oracle_store_sync_with_fk.create_session(session_id, app_name, user_id, state, owner_id=owner_id) assert session["id"] == session_id assert session["state"] == state - retrieved = oracle_store_sync_with_fk.get_session(session_id) + retrieved = await oracle_store_sync_with_fk.get_session(session_id) assert retrieved is not None assert retrieved["id"] == session_id diff --git a/tests/integration/adapters/psycopg/extensions/adk/test_owner_id_column.py b/tests/integration/adapters/psycopg/extensions/adk/test_owner_id_column.py index 789eff133..5e7fb0123 100644 --- a/tests/integration/adapters/psycopg/extensions/adk/test_owner_id_column.py +++ b/tests/integration/adapters/psycopg/extensions/adk/test_owner_id_column.py @@ -1,6 +1,6 @@ """Integration tests for Psycopg ADK store owner_id_column feature.""" -from collections.abc import AsyncGenerator, Generator +from collections.abc import AsyncGenerator from typing import TYPE_CHECKING, Any import pytest @@ -42,7 +42,7 @@ async def psycopg_async_store_with_fk(postgres_service: "PostgresService") -> "A @pytest.fixture -def psycopg_sync_store_with_fk(postgres_service: "PostgresService") -> "Generator[Any, None, None]": +async def psycopg_sync_store_with_fk(postgres_service: "PostgresService") -> "AsyncGenerator[Any, None]": """Create Psycopg sync ADK store with owner_id_column configured.""" config = PsycopgSyncConfig( connection_config={ @@ -57,7 +57,7 @@ def psycopg_sync_store_with_fk(postgres_service: "PostgresService") -> "Generato }, ) store = PsycopgSyncADKStore(config) - store.create_tables() + await store.create_tables() yield store with config.provide_connection() as conn, conn.cursor() as cur: @@ -74,7 +74,7 @@ async def test_async_store_owner_id_column_initialization(psycopg_async_store_wi assert psycopg_async_store_with_fk.owner_id_column_name == "tenant_id" -def test_sync_store_owner_id_column_initialization(psycopg_sync_store_with_fk: PsycopgSyncADKStore) -> None: +async def test_sync_store_owner_id_column_initialization(psycopg_sync_store_with_fk: PsycopgSyncADKStore) -> None: """Test that owner_id_column is properly initialized in sync store.""" assert psycopg_sync_store_with_fk.owner_id_column_ddl == "account_id VARCHAR(64) NOT NULL" assert psycopg_sync_store_with_fk.owner_id_column_name == "account_id" @@ -105,7 +105,7 @@ async def test_async_store_inherits_owner_id_column(postgres_service: "PostgresS await config.close_pool() -def test_sync_store_inherits_owner_id_column(postgres_service: "PostgresService") -> None: +async def test_sync_store_inherits_owner_id_column(postgres_service: "PostgresService") -> None: """Test that sync store correctly inherits owner_id_column from base class.""" config = PsycopgSyncConfig( connection_config={ @@ -147,7 +147,7 @@ async def test_async_store_without_owner_id_column(postgres_service: "PostgresSe await config.close_pool() -def test_sync_store_without_owner_id_column(postgres_service: "PostgresService") -> None: +async def test_sync_store_without_owner_id_column(postgres_service: "PostgresService") -> None: """Test that sync store works without owner_id_column (default behavior).""" config = PsycopgSyncConfig( connection_config={ @@ -172,9 +172,9 @@ async def test_async_ddl_includes_owner_id_column(psycopg_async_store_with_fk: P assert "test_sessions_fk" in ddl -def test_sync_ddl_includes_owner_id_column(psycopg_sync_store_with_fk: PsycopgSyncADKStore) -> None: +async def test_sync_ddl_includes_owner_id_column(psycopg_sync_store_with_fk: PsycopgSyncADKStore) -> None: """Test that the DDL generation includes the owner_id_column.""" - ddl = psycopg_sync_store_with_fk._get_create_sessions_table_sql() # pyright: ignore[reportPrivateUsage] + ddl = await psycopg_sync_store_with_fk._get_create_sessions_table_sql() # pyright: ignore[reportPrivateUsage] assert "account_id VARCHAR(64) NOT NULL" in ddl assert "test_sessions_sync_fk" in ddl diff --git a/tests/integration/adapters/spanner/extensions/adk/conftest.py b/tests/integration/adapters/spanner/extensions/adk/conftest.py index 4ae8782d2..57ad9bace 100644 --- a/tests/integration/adapters/spanner/extensions/adk/conftest.py +++ b/tests/integration/adapters/spanner/extensions/adk/conftest.py @@ -1,4 +1,4 @@ -from collections.abc import Generator +from collections.abc import AsyncGenerator from typing import TYPE_CHECKING import pytest @@ -30,7 +30,7 @@ def spanner_adk_config(spanner_service: SpannerService, spanner_database: "Datab @pytest.fixture -def spanner_adk_store(spanner_adk_config: SpannerSyncConfig) -> Generator[SpannerSyncADKStore, None, None]: +async def spanner_adk_store(spanner_adk_config: SpannerSyncConfig) -> AsyncGenerator[SpannerSyncADKStore, None]: store = SpannerSyncADKStore(spanner_adk_config) - store.create_tables() + await store.create_tables() yield store diff --git a/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py b/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py index e3397dd8e..71645b388 100644 --- a/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py +++ b/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py @@ -30,64 +30,64 @@ def _build_record(*, session_id: str, event_id: str, content_text: str, inserted ) -def test_sqlite_memory_store_insert_search_dedup() -> None: +async def test_sqlite_memory_store_insert_search_dedup() -> None: """Insert memory entries, search by text, and skip duplicates.""" with tempfile.NamedTemporaryFile(suffix=".db") as tmp: config = SqliteConfig(connection_config={"database": tmp.name}) store = SqliteADKMemoryStore(config) - store.create_tables() + await store.create_tables() now = datetime.now(timezone.utc) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now) record2 = _build_record(session_id="s1", event_id="evt-2", content_text="latte", inserted_at=now) - inserted = store.insert_memory_entries([record1, record2]) + inserted = await store.insert_memory_entries([record1, record2]) assert inserted == 2 - results = store.search_entries(query="espresso", app_name="app", user_id="user") + results = await store.search_entries(query="espresso", app_name="app", user_id="user") assert len(results) == 1 assert results[0]["event_id"] == "evt-1" - deduped = store.insert_memory_entries([record1]) + deduped = await store.insert_memory_entries([record1]) assert deduped == 0 -def test_sqlite_memory_store_delete_by_session() -> None: +async def test_sqlite_memory_store_delete_by_session() -> None: """Delete memory entries by session id.""" with tempfile.NamedTemporaryFile(suffix=".db") as tmp: config = SqliteConfig(connection_config={"database": tmp.name}) store = SqliteADKMemoryStore(config) - store.create_tables() + await store.create_tables() now = datetime.now(timezone.utc) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now) record2 = _build_record(session_id="s2", event_id="evt-2", content_text="latte", inserted_at=now) - store.insert_memory_entries([record1, record2]) + await store.insert_memory_entries([record1, record2]) - deleted = store.delete_entries_by_session("s1") + deleted = await store.delete_entries_by_session("s1") assert deleted == 1 - remaining = store.search_entries(query="latte", app_name="app", user_id="user") + remaining = await store.search_entries(query="latte", app_name="app", user_id="user") assert len(remaining) == 1 assert remaining[0]["session_id"] == "s2" -def test_sqlite_memory_store_delete_older_than() -> None: +async def test_sqlite_memory_store_delete_older_than() -> None: """Delete memory entries older than a cutoff.""" with tempfile.NamedTemporaryFile(suffix=".db") as tmp: config = SqliteConfig(connection_config={"database": tmp.name}) store = SqliteADKMemoryStore(config) - store.create_tables() + await store.create_tables() now = datetime.now(timezone.utc) old = now - timedelta(days=40) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="old", inserted_at=old) record2 = _build_record(session_id="s1", event_id="evt-2", content_text="new", inserted_at=now) - store.insert_memory_entries([record1, record2]) + await store.insert_memory_entries([record1, record2]) - deleted = store.delete_entries_older_than(30) + deleted = await store.delete_entries_older_than(30) assert deleted == 1 - remaining = store.search_entries(query="new", app_name="app", user_id="user") + remaining = await store.search_entries(query="new", app_name="app", user_id="user") assert len(remaining) == 1 assert remaining[0]["event_id"] == "evt-2" diff --git a/tests/unit/extensions/test_adk/test_store_instantiation.py b/tests/unit/extensions/test_adk/test_store_instantiation.py index 56ebc59ef..b98662fee 100644 --- a/tests/unit/extensions/test_adk/test_store_instantiation.py +++ b/tests/unit/extensions/test_adk/test_store_instantiation.py @@ -83,5 +83,5 @@ def test_store_has_no_abstract_methods(class_path: str) -> None: except ImportError: pytest.skip(f"Module {module_path} not importable (missing optional dependency)") cls = getattr(module, class_name) - abstract = getattr(cls, "__abstractmethods__", set()) + abstract: set[str] = getattr(cls, "__abstractmethods__", set()) assert not abstract, f"{class_path} has unsatisfied abstract methods: {abstract}" From 50221d71c196d528391912e417b212782672eec5 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 21 Mar 2026 19:04:56 +0000 Subject: [PATCH 20/23] refactor: Remove detailed comment block explaining artifact storage URI configuration. --- sqlspec/config.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/sqlspec/config.py b/sqlspec/config.py index d59f51869..63d1ae2dd 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -805,22 +805,6 @@ class ADKConfig(TypedDict): specified language will fall back to 'simple' or 'english'. """ - """Base URI for artifact content storage. Default: None (store inline in database). - - When set, large artifact payloads are stored externally and only metadata - is kept in the database. The URI scheme determines the storage backend: - - ``file:///path/to/artifacts`` — local filesystem - - ``s3://bucket/prefix`` — AWS S3 or S3-compatible storage - - ``gs://bucket/prefix`` — Google Cloud Storage - - ``az://container/prefix`` — Azure Blob Storage - - When None, artifact content is stored inline in the database tables, - which is suitable for small payloads but may cause performance issues - with large binary artifacts. - - Integrates with the ``StorageRegistry`` for pluggable storage backends. - """ - schema_version: NotRequired[int] """Explicit schema version for ADK tables. Default: None (auto-detect). From 93690e454d5da7dba9e7a5e56b7dd01945477792 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 21 Mar 2026 19:38:16 +0000 Subject: [PATCH 21/23] feat: Enhance ADK store `_get_events` with `after_timestamp` and `limit` parameters across adapters, and refactor event insertion logic. --- sqlspec/adapters/adbc/adk/store.py | 27 +++-- .../adapters/cockroach_psycopg/adk/store.py | 42 ++++++- sqlspec/adapters/mysqlconnector/adk/store.py | 48 +++++++- sqlspec/adapters/oracledb/adk/store.py | 85 ++++++++++---- sqlspec/adapters/psycopg/adk/store.py | 52 ++++++++- sqlspec/adapters/pymysql/adk/store.py | 48 +++++++- sqlspec/adapters/spanner/adk/store.py | 28 ++++- .../extensions/adk/test_event_operations.py | 80 ++++++++++++- .../extensions/adk/test_oracle_specific.py | 2 +- .../test_oracledb/test_oracle_adk_store.py | 33 +++++- .../adapters/test_psycopg/test_adk_store.py | 109 ++++++++++++++++++ 11 files changed, 499 insertions(+), 55 deletions(-) create mode 100644 tests/unit/adapters/test_psycopg/test_adk_store.py diff --git a/sqlspec/adapters/adbc/adk/store.py b/sqlspec/adapters/adbc/adk/store.py index 3c7155b67..7999d8c2b 100644 --- a/sqlspec/adapters/adbc/adk/store.py +++ b/sqlspec/adapters/adbc/adk/store.py @@ -508,7 +508,7 @@ def _create_session( finally: cursor.close() # type: ignore[no-untyped-call] - return self.get_session(session_id) # type: ignore[return-value] + return self._get_session(session_id) async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -684,6 +684,7 @@ def _insert_event(self, event_record: "EventRecord") -> None: Args: event_record: Event record to store. """ + event_json = self._serialize_json_field(event_record["event_json"]) sql = f""" INSERT INTO {self._events_table} ( session_id, invocation_id, author, timestamp, event_json @@ -700,7 +701,7 @@ def _insert_event(self, event_record: "EventRecord") -> None: event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_record["event_json"], + event_json, ), ) conn.commit() @@ -733,6 +734,7 @@ def _append_event_and_update_state( WHERE id = ? """ state_json = self._serialize_state(state) + event_json = self._serialize_json_field(event_record["event_json"]) with self._config.provide_connection() as conn: cursor = conn.cursor() @@ -744,7 +746,7 @@ def _append_event_and_update_state( event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_record["event_json"], + event_json, ), ) cursor.execute(update_sql, (state_json, session_id)) @@ -769,6 +771,8 @@ def _get_events( Args: session_id: Session identifier. + after_timestamp: Only return events after this time. + limit: Maximum number of events to return. Returns: List of event records ordered by timestamp ASC. @@ -778,18 +782,27 @@ def _get_events( Returns the 5-column EventRecord (session_id, invocation_id, author, timestamp, event_json). """ + where_clauses = ["session_id = ?"] + params: list[Any] = [session_id] + + if after_timestamp is not None: + where_clauses.append("timestamp > ?") + params.append(after_timestamp) + + where_clause = " AND ".join(where_clauses) + limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} - WHERE session_id = ? - ORDER BY timestamp ASC + WHERE {where_clause} + ORDER BY timestamp ASC{limit_clause} """ try: with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute(sql, (session_id,)) + cursor.execute(sql, params) rows = cursor.fetchall() return [ @@ -818,7 +831,7 @@ async def get_events( def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" - self._append_event_and_update_state(event_record, event_record["session_id"], {}) + self._insert_event(event_record) async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session.""" diff --git a/sqlspec/adapters/cockroach_psycopg/adk/store.py b/sqlspec/adapters/cockroach_psycopg/adk/store.py index 535facd6c..d2906c7ea 100644 --- a/sqlspec/adapters/cockroach_psycopg/adk/store.py +++ b/sqlspec/adapters/cockroach_psycopg/adk/store.py @@ -582,19 +582,53 @@ async def append_event_and_update_state( """Atomically append an event and update the session's durable state.""" await async_(self._append_event_and_update_state)(event_record, session_id, state) + def _insert_event(self, event_record: EventRecord) -> None: + sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) + """ + + event_json_value = event_record["event_json"] + jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value + + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute( + sql.encode(), + ( + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + jsonb_value, + ), + ) + conn.commit() + def _get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": + where_clauses = ["session_id = %s"] + params: list[Any] = [session_id] + + if after_timestamp is not None: + where_clauses.append("timestamp > %s") + params.append(after_timestamp) + + where_clause = " AND ".join(where_clauses) + limit_clause = " LIMIT %s" if limit else "" sql = f""" SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} - WHERE session_id = %s - ORDER BY timestamp ASC + WHERE {where_clause} + ORDER BY timestamp ASC{limit_clause} """ + if limit: + params.append(limit) try: with self._config.provide_connection() as conn, conn.cursor() as cur: - cur.execute(sql.encode(), (session_id,)) + cur.execute(sql.encode(), tuple(params)) rows = cur.fetchall() return [ @@ -618,7 +652,7 @@ async def get_events( def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" - self._append_event_and_update_state(event_record, event_record["session_id"], {}) + self._insert_event(event_record) async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session.""" diff --git a/sqlspec/adapters/mysqlconnector/adk/store.py b/sqlspec/adapters/mysqlconnector/adk/store.py index 23c451040..1a25702f7 100644 --- a/sqlspec/adapters/mysqlconnector/adk/store.py +++ b/sqlspec/adapters/mysqlconnector/adk/store.py @@ -640,6 +640,33 @@ async def append_event_and_update_state( """Atomically append an event and update the session's durable state.""" await async_(self._append_event_and_update_state)(event_record, session_id, state) + def _insert_event(self, event_record: EventRecord) -> None: + event_json = event_record["event_json"] + event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + + sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute( + sql, + ( + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_json_str, + ), + ) + finally: + cursor.close() + conn.commit() + def _get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": @@ -647,22 +674,35 @@ def _get_events( Args: session_id: Session identifier. + after_timestamp: Only return events after this time. + limit: Maximum number of events to return. Returns: List of event records ordered by timestamp ASC. """ + where_clauses = ["session_id = %s"] + params: list[Any] = [session_id] + + if after_timestamp is not None: + where_clauses.append("timestamp > %s") + params.append(after_timestamp) + + where_clause = " AND ".join(where_clauses) + limit_clause = " LIMIT %s" if limit else "" sql = f""" SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} - WHERE session_id = %s - ORDER BY timestamp ASC + WHERE {where_clause} + ORDER BY timestamp ASC{limit_clause} """ + if limit: + params.append(limit) try: with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute(sql, (session_id,)) + cursor.execute(sql, tuple(params)) rows = cursor.fetchall() finally: cursor.close() @@ -690,7 +730,7 @@ async def get_events( def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" - self._append_event_and_update_state(event_record, event_record["session_id"], {}) + self._insert_event(event_record) async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session.""" diff --git a/sqlspec/adapters/oracledb/adk/store.py b/sqlspec/adapters/oracledb/adk/store.py index 1c4647470..a95d3b841 100644 --- a/sqlspec/adapters/oracledb/adk/store.py +++ b/sqlspec/adapters/oracledb/adk/store.py @@ -98,14 +98,14 @@ def _event_json_column_ddl(storage_type: JSONStorageType) -> str: """Return the DDL fragment for the event_json column. For JSON_NATIVE (Oracle 21c+) we use the native JSON type. - For older versions we use CLOB since event_json is a JSON text string. - BLOB_JSON gets a CHECK constraint; BLOB_PLAIN does not. + For older versions we use BLOB since Oracle recommends BLOB over CLOB for + JSON storage. BLOB_JSON gets a CHECK constraint; BLOB_PLAIN does not. """ if storage_type == JSONStorageType.JSON_NATIVE: return "event_json JSON NOT NULL" if storage_type == JSONStorageType.BLOB_JSON: - return "event_json CLOB CHECK (event_json IS JSON) NOT NULL" - return "event_json CLOB NOT NULL" + return "event_json BLOB CHECK (event_json IS JSON) NOT NULL" + return "event_json BLOB NOT NULL" class OracleAsyncADKStore(BaseAsyncADKStore["OracleAsyncConfig"]): @@ -125,7 +125,7 @@ class OracleAsyncADKStore(BaseAsyncADKStore["OracleAsyncConfig"]): Notes: - JSON storage type detected based on Oracle version (21c+, 12c+, legacy) - - event_json stored as JSON (21c+) or CLOB (older versions) + - event_json stored as JSON (21c+) or BLOB (older versions) - TIMESTAMP WITH TIME ZONE for timezone-aware timestamps - Named parameters using :param_name - State merging handled at application level @@ -182,8 +182,8 @@ async def _detect_json_storage_type(self) -> JSONStorageType: Notes: Queries product_component_version to determine Oracle version. - Oracle 21c+ with compatible >= 20: Native JSON type - - Oracle 12c+: CLOB with IS JSON constraint - - Oracle 11g and earlier: CLOB without constraint + - Oracle 12c+: BLOB with IS JSON constraint + - Oracle 11g and earlier: plain BLOB Result is cached in self._json_storage_type. """ @@ -254,6 +254,13 @@ async def _deserialize_state(self, data: Any) -> "dict[str, Any]": return from_json(str(data)) # type: ignore[no-any-return] + async def _serialize_event_json(self, event_json: Any) -> "str | bytes": + """Serialize event_json to the configured Oracle JSON storage format.""" + storage_type = await self._detect_json_storage_type() + if storage_type == JSONStorageType.JSON_NATIVE: + return to_json(event_json) + return to_json(event_json, as_bytes=True) + async def _read_event_json(self, data: Any) -> str: """Read event_json from database, handling LOB types. @@ -338,7 +345,7 @@ def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) - The events table uses the new 5-column contract: session_id, invocation_id, author, timestamp, and event_json. The event_json column stores the full - ADK Event as JSON (21c+) or CLOB (older versions). + ADK Event as JSON (21c+) or BLOB (older versions). Args: storage_type: JSON storage type to use. @@ -676,7 +683,7 @@ async def append_event(self, event_record: EventRecord) -> None: "invocation_id": event_record["invocation_id"], "author": event_record["author"], "timestamp": event_record["timestamp"], - "event_json": event_record["event_json"], + "event_json": await self._serialize_event_json(event_record["event_json"]), }, ) await conn.commit() @@ -720,7 +727,7 @@ async def append_event_and_update_state( "invocation_id": event_record["invocation_id"], "author": event_record["author"], "timestamp": event_record["timestamp"], - "event_json": event_record["event_json"], + "event_json": await self._serialize_event_json(event_record["event_json"]), }, ) await cursor.execute(update_sql, {"state": state_data, "id": session_id}) @@ -803,7 +810,7 @@ class OracleSyncADKStore(BaseAsyncADKStore["OracleSyncConfig"]): Notes: - JSON storage type detected based on Oracle version (21c+, 12c+, legacy) - - event_json stored as JSON (21c+) or CLOB (older versions) + - event_json stored as JSON (21c+) or BLOB (older versions) - TIMESTAMP WITH TIME ZONE for timezone-aware timestamps - Named parameters using :param_name - State merging handled at application level @@ -860,8 +867,8 @@ def _detect_json_storage_type(self) -> JSONStorageType: Notes: Queries product_component_version to determine Oracle version. - Oracle 21c+ with compatible >= 20: Native JSON type - - Oracle 12c+: CLOB with IS JSON constraint - - Oracle 11g and earlier: CLOB without constraint + - Oracle 12c+: BLOB with IS JSON constraint + - Oracle 11g and earlier: plain BLOB Result is cached in self._json_storage_type. """ @@ -930,6 +937,13 @@ def _deserialize_state(self, data: Any) -> "dict[str, Any]": return from_json(str(data)) # type: ignore[no-any-return] + def _serialize_event_json(self, event_json: Any) -> "str | bytes": + """Serialize event_json to the configured Oracle JSON storage format.""" + storage_type = self._detect_json_storage_type() + if storage_type == JSONStorageType.JSON_NATIVE: + return to_json(event_json) + return to_json(event_json, as_bytes=True) + def _read_event_json(self, data: Any) -> str: """Read event_json from database, handling LOB types. @@ -1012,7 +1026,7 @@ def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) - The events table uses the new 5-column contract: session_id, invocation_id, author, timestamp, and event_json. The event_json column stores the full - ADK Event as JSON (21c+) or CLOB (older versions). + ADK Event as JSON (21c+) or BLOB (older versions). Args: storage_type: JSON storage type to use. @@ -1182,7 +1196,7 @@ def _create_session( cursor.execute(sql, params) conn.commit() - return self.get_session(session_id) # type: ignore[return-value] + return self._get_session(session_id) async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -1392,7 +1406,7 @@ def _append_event_and_update_state( "invocation_id": event_record["invocation_id"], "author": event_record["author"], "timestamp": event_record["timestamp"], - "event_json": event_record["event_json"], + "event_json": self._serialize_event_json(event_record["event_json"]), }, ) cursor.execute(update_sql, {"state": state_data, "id": session_id}) @@ -1411,22 +1425,33 @@ def _get_events( Args: session_id: Session identifier. + after_timestamp: Only return events after this time. + limit: Maximum number of events to return. Returns: List of event records ordered by timestamp ASC. """ + where_clauses = ["session_id = :session_id"] + params: dict[str, Any] = {"session_id": session_id} + + if after_timestamp is not None: + where_clauses.append("timestamp > :after_timestamp") + params["after_timestamp"] = after_timestamp + + where_clause = " AND ".join(where_clauses) + limit_clause = f" FETCH FIRST {limit} ROWS ONLY" if limit else "" sql = f""" SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} - WHERE session_id = :session_id - ORDER BY timestamp ASC + WHERE {where_clause} + ORDER BY timestamp ASC{limit_clause} """ try: with self._config.provide_connection() as conn: cursor = conn.cursor() - cursor.execute(sql, {"session_id": session_id}) + cursor.execute(sql, params) rows = cursor.fetchall() results = [] @@ -1457,7 +1482,27 @@ async def get_events( def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" - self._append_event_and_update_state(event_record, event_record["session_id"], {}) + sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES ( + :session_id, :invocation_id, :author, :timestamp, :event_json + ) + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute( + sql, + { + "session_id": event_record["session_id"], + "invocation_id": event_record["invocation_id"], + "author": event_record["author"], + "timestamp": event_record["timestamp"], + "event_json": self._serialize_event_json(event_record["event_json"]), + }, + ) + conn.commit() async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session.""" diff --git a/sqlspec/adapters/psycopg/adk/store.py b/sqlspec/adapters/psycopg/adk/store.py index 5ce2c9c28..00d68aade 100644 --- a/sqlspec/adapters/psycopg/adk/store.py +++ b/sqlspec/adapters/psycopg/adk/store.py @@ -542,6 +542,29 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis """List sessions for an app.""" return await async_(self._list_sessions)(app_name, user_id) + def _insert_event(self, event_record: EventRecord) -> None: + insert_query = pg_sql.SQL(""" + INSERT INTO {table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) + """).format(table=pg_sql.Identifier(self._events_table)) + + event_json_value = event_record["event_json"] + jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value + + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute( + insert_query, + ( + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + jsonb_value, + ), + ) + conn.commit() + def _append_event_and_update_state( self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" ) -> None: @@ -583,16 +606,33 @@ async def append_event_and_update_state( def _get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": - query = pg_sql.SQL(""" + where_clauses = ["session_id = %s"] + params: list[Any] = [session_id] + + if after_timestamp is not None: + where_clauses.append("timestamp > %s") + params.append(after_timestamp) + + where_clause = " AND ".join(where_clauses) + if limit: + params.append(limit) + + query = pg_sql.SQL( + """ SELECT session_id, invocation_id, author, timestamp, event_json FROM {table} - WHERE session_id = %s - ORDER BY timestamp ASC - """).format(table=pg_sql.Identifier(self._events_table)) + WHERE {where_clause} + ORDER BY timestamp ASC{limit_clause} + """ + ).format( + table=pg_sql.Identifier(self._events_table), + where_clause=pg_sql.SQL(where_clause), # pyright: ignore[reportArgumentType] + limit_clause=pg_sql.SQL(" LIMIT %s" if limit else ""), # pyright: ignore[reportArgumentType] + ) try: with self._config.provide_connection() as conn, conn.cursor() as cur: - cur.execute(query, (session_id,)) + cur.execute(query, tuple(params)) rows = cur.fetchall() return [ @@ -616,7 +656,7 @@ async def get_events( def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" - self._append_event_and_update_state(event_record, event_record["session_id"], {}) + self._insert_event(event_record) async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session.""" diff --git a/sqlspec/adapters/pymysql/adk/store.py b/sqlspec/adapters/pymysql/adk/store.py index ec32e4765..5e7ea8513 100644 --- a/sqlspec/adapters/pymysql/adk/store.py +++ b/sqlspec/adapters/pymysql/adk/store.py @@ -326,6 +326,33 @@ async def append_event_and_update_state( """Atomically append an event and update the session's durable state.""" await async_(self._append_event_and_update_state)(event_record, session_id, state) + def _insert_event(self, event_record: EventRecord) -> None: + event_json = event_record["event_json"] + event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + + sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute( + sql, + ( + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_json_str, + ), + ) + finally: + cursor.close() + conn.commit() + def _get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": @@ -333,22 +360,35 @@ def _get_events( Args: session_id: Session identifier. + after_timestamp: Only return events after this time. + limit: Maximum number of events to return. Returns: List of event records ordered by timestamp ASC. """ + where_clauses = ["session_id = %s"] + params: list[Any] = [session_id] + + if after_timestamp is not None: + where_clauses.append("timestamp > %s") + params.append(after_timestamp) + + where_clause = " AND ".join(where_clauses) + limit_clause = " LIMIT %s" if limit else "" sql = f""" SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} - WHERE session_id = %s - ORDER BY timestamp ASC + WHERE {where_clause} + ORDER BY timestamp ASC{limit_clause} """ + if limit: + params.append(limit) try: with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute(sql, (session_id,)) + cursor.execute(sql, tuple(params)) rows = cursor.fetchall() finally: cursor.close() @@ -376,7 +416,7 @@ async def get_events( def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" - self._append_event_and_update_state(event_record, event_record["session_id"], {}) + self._insert_event(event_record) async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session.""" diff --git a/sqlspec/adapters/spanner/adk/store.py b/sqlspec/adapters/spanner/adk/store.py index 95462f90a..1e115d58d 100644 --- a/sqlspec/adapters/spanner/adk/store.py +++ b/sqlspec/adapters/spanner/adk/store.py @@ -278,6 +278,20 @@ async def append_event_and_update_state( """Atomically append an event and update the session's durable state.""" await async_(self._append_event_and_update_state)(event_record, session_id, state) + def _insert_event(self, event_record: "EventRecord") -> None: + event_params: dict[str, Any] = { + "session_id": event_record["session_id"], + "invocation_id": event_record["invocation_id"], + "author": event_record["author"], + "timestamp": event_record["timestamp"], + "event_json": event_record["event_json"], + } + insert_sql = f""" + INSERT INTO {self._events_table} (session_id, invocation_id, author, timestamp, event_json) + VALUES (@session_id, @invocation_id, @author, PENDING_COMMIT_TIMESTAMP(), @event_json) + """ + self._run_write([(insert_sql, event_params, self._event_param_types())]) + def _get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": @@ -288,9 +302,17 @@ def _get_events( """ if self._shard_count > 1: sql = f"{sql} AND shard_id = MOD(FARM_FINGERPRINT(@session_id), {self._shard_count})" + params: dict[str, Any] = {"session_id": session_id} + types: dict[str, Any] = {"session_id": SPANNER_PARAM_TYPES.STRING} + if after_timestamp is not None: + sql = f"{sql} AND timestamp > @after_timestamp" + params["after_timestamp"] = after_timestamp + types["after_timestamp"] = SPANNER_PARAM_TYPES.TIMESTAMP sql = f"{sql} ORDER BY timestamp ASC" - params = {"session_id": session_id} - types = {"session_id": SPANNER_PARAM_TYPES.STRING} + if limit is not None: + sql = f"{sql} LIMIT @limit" + params["limit"] = limit + types["limit"] = SPANNER_PARAM_TYPES.INT64 rows = self._run_read(sql, params, types) return [ { @@ -311,7 +333,7 @@ async def get_events( def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" - self._append_event_and_update_state(event_record, event_record["session_id"], {}) + self._insert_event(event_record) async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session.""" diff --git a/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py b/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py index d7ad7c2d2..8e54bf766 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py @@ -1,7 +1,8 @@ """Tests for ADBC ADK store event operations.""" +import asyncio import json -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Any @@ -208,8 +209,6 @@ async def test_event_json_fields(adbc_store: Any, session_fixture: Any) -> None: async def test_event_ordering(adbc_store: Any, session_fixture: Any) -> None: """Test that events are ordered by timestamp ASC.""" - import time - ev1: EventRecord = { "session_id": session_fixture["session_id"], "invocation_id": "", @@ -219,7 +218,7 @@ async def test_event_ordering(adbc_store: Any, session_fixture: Any) -> None: } await adbc_store.append_event(ev1) - time.sleep(0.01) + await asyncio.sleep(0.01) ev2: EventRecord = { "session_id": session_fixture["session_id"], @@ -230,7 +229,7 @@ async def test_event_ordering(adbc_store: Any, session_fixture: Any) -> None: } await adbc_store.append_event(ev2) - time.sleep(0.01) + await asyncio.sleep(0.01) ev3: EventRecord = { "session_id": session_fixture["session_id"], @@ -324,3 +323,74 @@ async def test_event_with_large_content(adbc_store: Any, session_fixture: Any) - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] ) assert event_data["content"] == large_content + + +async def test_append_event_preserves_existing_session_state(adbc_store: Any, session_fixture: Any) -> None: + """append_event must not overwrite the durable session state.""" + event_record: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "append-only", + "author": "user", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "append-only-event", + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + }, + } + + await adbc_store.append_event(event_record) + + session = await adbc_store.get_session(session_fixture["session_id"]) + assert session is not None + assert session["state"] == {"test": True} + + +async def test_get_events_applies_after_timestamp_and_limit(adbc_store: Any, session_fixture: Any) -> None: + """get_events must respect both after_timestamp and limit.""" + base_time = datetime(2026, 1, 1, tzinfo=timezone.utc) + event_records = [ + { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "user", + "timestamp": base_time, + "event_json": { + "id": "event-1", + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + }, + }, + { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "assistant", + "timestamp": base_time + timedelta(seconds=1), + "event_json": { + "id": "event-2", + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + }, + }, + { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "assistant", + "timestamp": base_time + timedelta(seconds=2), + "event_json": { + "id": "event-3", + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + }, + }, + ] + + for event_record in event_records: + await adbc_store.append_event(event_record) + + filtered_events = await adbc_store.get_events(session_fixture["session_id"], after_timestamp=base_time, limit=1) + + assert len(filtered_events) == 1 + filtered_event = filtered_events[0]["event_json"] + filtered_data = json.loads(filtered_event) if isinstance(filtered_event, str) else filtered_event + assert filtered_data["id"] == "event-2" diff --git a/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py b/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py index 1a05b040c..f9034eed5 100644 --- a/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py +++ b/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py @@ -416,7 +416,7 @@ async def test_json_storage_type_detection(oracle_async_store: "OracleAsyncADKSt detector = cast("Any", oracle_async_store) storage_type = await detector._detect_json_storage_type() - assert storage_type in ["json", "blob_json", "clob_json", "blob_plain"] + assert storage_type in ["json", "blob_json", "blob_plain"] async def test_json_fields_stored_and_retrieved(oracle_async_store: "OracleAsyncADKStore") -> None: diff --git a/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py b/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py index ec618e94e..b275006c5 100644 --- a/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py +++ b/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py @@ -2,7 +2,12 @@ from decimal import Decimal -from sqlspec.adapters.oracledb.adk.store import OracleAsyncADKStore, OracleSyncADKStore +from sqlspec.adapters.oracledb.adk.store import ( + JSONStorageType, + OracleAsyncADKStore, + OracleSyncADKStore, + _event_json_column_ddl, +) async def test_oracle_async_adk_store_deserialize_dict_coerces_decimal() -> None: @@ -43,3 +48,29 @@ def test_oracle_sync_adk_store_deserialize_state_dict_coerces_decimal() -> None: result = store._deserialize_state(payload) # type: ignore[attr-defined] assert result == {"state": 5.0} + + +def test_oracle_event_json_column_ddl_prefers_blob_over_clob() -> None: + assert _event_json_column_ddl(JSONStorageType.JSON_NATIVE) == "event_json JSON NOT NULL" + assert _event_json_column_ddl(JSONStorageType.BLOB_JSON) == "event_json BLOB CHECK (event_json IS JSON) NOT NULL" + assert _event_json_column_ddl(JSONStorageType.BLOB_PLAIN) == "event_json BLOB NOT NULL" + + +async def test_oracle_async_adk_store_serialize_event_json_uses_blob_for_non_native() -> None: + store = OracleAsyncADKStore.__new__(OracleAsyncADKStore) # type: ignore[call-arg] + store._json_storage_type = JSONStorageType.BLOB_JSON # type: ignore[attr-defined] + + result = await store._serialize_event_json({"value": 1}) # type: ignore[attr-defined] + + assert isinstance(result, bytes) + assert b'"value":1' in result + + +def test_oracle_sync_adk_store_serialize_event_json_uses_blob_for_non_native() -> None: + store = OracleSyncADKStore.__new__(OracleSyncADKStore) # type: ignore[call-arg] + store._json_storage_type = JSONStorageType.BLOB_JSON # type: ignore[attr-defined] + + result = store._serialize_event_json({"value": 1}) # type: ignore[attr-defined] + + assert isinstance(result, bytes) + assert b'"value":1' in result diff --git a/tests/unit/adapters/test_psycopg/test_adk_store.py b/tests/unit/adapters/test_psycopg/test_adk_store.py new file mode 100644 index 000000000..a0d3eb6c1 --- /dev/null +++ b/tests/unit/adapters/test_psycopg/test_adk_store.py @@ -0,0 +1,109 @@ +"""Unit tests for psycopg ADK store sync wrappers.""" + +from datetime import datetime, timezone +from typing import Any + +from psycopg.types.json import Jsonb +from typing_extensions import Self + +from sqlspec.adapters.psycopg.adk.store import PsycopgSyncADKStore + + +class _DummyCursor: + def __init__(self, rows: "list[dict[str, Any]] | None" = None) -> None: + self.execute_calls: list[tuple[Any, Any]] = [] + self._rows = rows or [] + + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + return None + + def execute(self, query: Any, params: Any) -> None: + self.execute_calls.append((query, params)) + + def fetchall(self) -> "list[dict[str, Any]]": + return self._rows + + +class _DummyConnection: + def __init__(self, cursor: _DummyCursor) -> None: + self._cursor = cursor + self.commit_called = False + + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + return None + + def cursor(self) -> _DummyCursor: + return self._cursor + + def commit(self) -> None: + self.commit_called = True + + +class _DummyConfig: + def __init__(self, connection: _DummyConnection) -> None: + self._connection = connection + + def provide_connection(self) -> _DummyConnection: + return self._connection + + +def _build_store( + rows: "list[dict[str, Any]] | None" = None, +) -> "tuple[PsycopgSyncADKStore, _DummyCursor, _DummyConnection]": + cursor = _DummyCursor(rows) + connection = _DummyConnection(cursor) + store = PsycopgSyncADKStore.__new__(PsycopgSyncADKStore) # type: ignore[call-arg] + store._config = _DummyConfig(connection) # type: ignore[attr-defined] + store._events_table = "test_events" # type: ignore[attr-defined] + store._session_table = "test_sessions" # type: ignore[attr-defined] + store._owner_id_column_ddl = None # type: ignore[attr-defined] + store._owner_id_column_name = None # type: ignore[attr-defined] + return store, cursor, connection + + +def test_sync_append_event_inserts_without_session_update() -> None: + """append_event must insert a single event without writing session state.""" + store, cursor, connection = _build_store() + event_record = { + "session_id": "session-1", + "invocation_id": "", + "author": "assistant", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "event-1"}, + } + + store._append_event(event_record) # type: ignore[arg-type] + + assert len(cursor.execute_calls) == 1 + _, params = cursor.execute_calls[0] + assert params[0] == "session-1" + assert isinstance(params[4], Jsonb) + assert connection.commit_called + + +def test_sync_get_events_passes_after_timestamp_and_limit() -> None: + """get_events must forward after_timestamp and limit to the sync query.""" + base_time = datetime(2026, 1, 1, tzinfo=timezone.utc) + rows = [ + { + "session_id": "session-1", + "invocation_id": "", + "author": "assistant", + "timestamp": base_time, + "event_json": {"id": "event-2"}, + } + ] + store, cursor, _ = _build_store(rows) + + result = store._get_events("session-1", after_timestamp=base_time, limit=1) + + assert len(cursor.execute_calls) == 1 + _, params = cursor.execute_calls[0] + assert params == ("session-1", base_time, 1) + assert result[0]["event_json"]["id"] == "event-2" From a88d4a9659714a6f4b21636432dcb08fbce2951c Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 21 Mar 2026 21:15:35 +0000 Subject: [PATCH 22/23] fix: Enhance ADK store robustness by improving Oracle NULL and JSON handling, refining Spanner JSON serialization, and adding session creation error checks. --- .github/workflows/ci.yml | 2 +- pyproject.toml | 4 +- sqlspec/adapters/adbc/adk/store.py | 6 +- sqlspec/adapters/oracledb/adk/store.py | 80 ++++++++++++------- sqlspec/adapters/spanner/adk/store.py | 4 +- sqlspec/core/splitter.py | 16 +++- tests/conftest.py | 21 +++++ .../extensions/adk/test_oracle_specific.py | 6 +- .../spanner/extensions/adk/test_adk_store.py | 80 ++++++++++++------- uv.lock | 11 ++- 10 files changed, 154 insertions(+), 76 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f04a450c4..d4381dd65 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -108,7 +108,7 @@ jobs: - name: Test env: PYTHONFAULTHANDLER: "1" - PYTEST_ADDOPTS: "--max-worker-restart=0 -s" + PYTEST_ADDOPTS: "--max-worker-restart=0" run: timeout 900s uv run pytest -n 2 --dist=loadgroup # test-linux-freethreaded: diff --git a/pyproject.toml b/pyproject.toml index cd9e7484a..b834e0203 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ attrs = ["attrs", "cattrs"] bigquery = ["google-cloud-bigquery", "google-cloud-storage"] cloud-sql = ["cloud-sql-python-connector"] cockroachdb = ["psycopg[binary,pool]", "asyncpg"] -duckdb = ["duckdb"] +duckdb = ["duckdb", "pytz"] fastapi = ["fastapi"] flask = ["flask"] fsspec = ["fsspec"] @@ -296,7 +296,7 @@ version = "{current_version}" """ [tool.codespell] -ignore-words-list = "te,ECT,SELCT,froms,ccompiler" +ignore-words-list = "te,ECT,SELCT,froms,ccompiler,BRIN" skip = 'uv.lock' [tool.coverage.run] diff --git a/sqlspec/adapters/adbc/adk/store.py b/sqlspec/adapters/adbc/adk/store.py index 7999d8c2b..50d6c72f4 100644 --- a/sqlspec/adapters/adbc/adk/store.py +++ b/sqlspec/adapters/adbc/adk/store.py @@ -508,7 +508,11 @@ def _create_session( finally: cursor.close() # type: ignore[no-untyped-call] - return self._get_session(session_id) + result = self._get_session(session_id) + if result is None: + msg = "Failed to fetch created session" + raise RuntimeError(msg) + return result async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None diff --git a/sqlspec/adapters/oracledb/adk/store.py b/sqlspec/adapters/oracledb/adk/store.py index a95d3b841..56248882d 100644 --- a/sqlspec/adapters/oracledb/adk/store.py +++ b/sqlspec/adapters/oracledb/adk/store.py @@ -108,6 +108,16 @@ def _event_json_column_ddl(storage_type: JSONStorageType) -> str: return "event_json BLOB NOT NULL" +def _oracle_text_value(value: Any) -> str: + """Normalize Oracle VARCHAR2 values back to Python strings. + + Oracle stores empty strings as ``NULL``. The ADK event contract allows + empty strings for fields like ``invocation_id``, so reads coerce ``NULL`` + back to ``""``. + """ + return "" if value is None else str(value) + + class OracleAsyncADKStore(BaseAsyncADKStore["OracleAsyncConfig"]): """Oracle async ADK store using oracledb async driver. @@ -254,6 +264,12 @@ async def _deserialize_state(self, data: Any) -> "dict[str, Any]": return from_json(str(data)) # type: ignore[no-any-return] + async def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": + """Deserialize JSON payloads from Oracle JSON/BLOB/LOB values.""" + if data is None: + return None + return await self._deserialize_state(data) + async def _serialize_event_json(self, event_json: Any) -> "str | bytes": """Serialize event_json to the configured Oracle JSON storage format.""" storage_type = await self._detect_json_storage_type() @@ -360,8 +376,8 @@ def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) - BEGIN EXECUTE IMMEDIATE 'CREATE TABLE {self._events_table} ( session_id VARCHAR2(128) NOT NULL, - invocation_id VARCHAR2(256) NOT NULL, - author VARCHAR2(256) NOT NULL, + invocation_id VARCHAR2(256), + author VARCHAR2(256), timestamp TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL, {event_json_col}, CONSTRAINT fk_{self._events_table}_session FOREIGN KEY (session_id) @@ -772,20 +788,16 @@ async def get_events( await cursor.execute(sql, params) rows = await cursor.fetchall() - results = [] - for row in rows: - event_json_str = await self._read_event_json(row[4]) - - results.append( - EventRecord( - session_id=row[0], - invocation_id=row[1], - author=row[2], - timestamp=row[3], - event_json=from_json(event_json_str) if isinstance(event_json_str, str) else event_json_str, - ) + return [ + EventRecord( + session_id=row[0], + invocation_id=_oracle_text_value(row[1]), + author=_oracle_text_value(row[2]), + timestamp=row[3], + event_json=await self._deserialize_json_field(row[4]) or {}, ) - return results + for row in rows + ] except oracledb.DatabaseError as e: error_obj = e.args[0] if e.args else None if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: @@ -937,6 +949,12 @@ def _deserialize_state(self, data: Any) -> "dict[str, Any]": return from_json(str(data)) # type: ignore[no-any-return] + def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": + """Deserialize JSON payloads from Oracle JSON/BLOB/LOB values.""" + if data is None: + return None + return self._deserialize_state(data) + def _serialize_event_json(self, event_json: Any) -> "str | bytes": """Serialize event_json to the configured Oracle JSON storage format.""" storage_type = self._detect_json_storage_type() @@ -1041,8 +1059,8 @@ def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) - BEGIN EXECUTE IMMEDIATE 'CREATE TABLE {self._events_table} ( session_id VARCHAR2(128) NOT NULL, - invocation_id VARCHAR2(256) NOT NULL, - author VARCHAR2(256) NOT NULL, + invocation_id VARCHAR2(256), + author VARCHAR2(256), timestamp TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL, {event_json_col}, CONSTRAINT fk_{self._events_table}_session FOREIGN KEY (session_id) @@ -1196,7 +1214,11 @@ def _create_session( cursor.execute(sql, params) conn.commit() - return self._get_session(session_id) + result = self._get_session(session_id) + if result is None: + msg = "Failed to fetch created session" + raise RuntimeError(msg) + return result async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -1454,20 +1476,16 @@ def _get_events( cursor.execute(sql, params) rows = cursor.fetchall() - results = [] - for row in rows: - event_json_str = self._read_event_json(row[4]) - - results.append( - EventRecord( - session_id=row[0], - invocation_id=row[1], - author=row[2], - timestamp=row[3], - event_json=from_json(event_json_str) if isinstance(event_json_str, str) else event_json_str, - ) + return [ + EventRecord( + session_id=row[0], + invocation_id=_oracle_text_value(row[1]), + author=_oracle_text_value(row[2]), + timestamp=row[3], + event_json=self._deserialize_json_field(row[4]) or {}, ) - return results + for row in rows + ] except oracledb.DatabaseError as e: error_obj = e.args[0] if e.args else None if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: diff --git a/sqlspec/adapters/spanner/adk/store.py b/sqlspec/adapters/spanner/adk/store.py index 1e115d58d..a180de7e5 100644 --- a/sqlspec/adapters/spanner/adk/store.py +++ b/sqlspec/adapters/spanner/adk/store.py @@ -250,7 +250,7 @@ def _append_event_and_update_state( "invocation_id": event_record["invocation_id"], "author": event_record["author"], "timestamp": event_record["timestamp"], - "event_json": event_record["event_json"], + "event_json": to_json(event_record["event_json"]), } insert_sql = f""" INSERT INTO {self._events_table} (session_id, invocation_id, author, timestamp, event_json) @@ -284,7 +284,7 @@ def _insert_event(self, event_record: "EventRecord") -> None: "invocation_id": event_record["invocation_id"], "author": event_record["author"], "timestamp": event_record["timestamp"], - "event_json": event_record["event_json"], + "event_json": to_json(event_record["event_json"]), } insert_sql = f""" INSERT INTO {self._events_table} (session_id, invocation_id, author, timestamp, event_json) diff --git a/sqlspec/core/splitter.py b/sqlspec/core/splitter.py index 45908cf27..0f5414f62 100644 --- a/sqlspec/core/splitter.py +++ b/sqlspec/core/splitter.py @@ -623,6 +623,8 @@ def statement_terminators(self) -> "set[str]": _pattern_cache: LRUCache | None = None _result_cache: LRUCache | None = None _cache_lock = threading.Lock() +_unknown_dialect_warning_lock = threading.Lock() +_warned_unknown_dialects: set[str] = set() def _get_pattern_cache() -> LRUCache: @@ -653,6 +655,16 @@ def _get_result_cache() -> LRUCache: return _result_cache +def _warn_unknown_dialect_once(dialect: "str | None") -> None: + """Emit the generic splitter fallback warning once per dialect.""" + key = "" if dialect is None else dialect.lower() + with _unknown_dialect_warning_lock: + if key in _warned_unknown_dialects: + return + _warned_unknown_dialects.add(key) + logger.warning("Unknown dialect '%s', using generic SQL splitter", dialect) + + @mypyc_attr(allow_interpreted_subclasses=False) class StatementSplitter: """SQL script splitter with caching and dialect support.""" @@ -933,7 +945,7 @@ def split_sql_script(script: str, dialect: str | None = None, strip_trailing_ter config = dialect_configs.get(dialect.lower()) if not config: - logger.warning("Unknown dialect '%s', using generic SQL splitter", dialect) + _warn_unknown_dialect_once(dialect) config = GenericDialectConfig() splitter = StatementSplitter(config, strip_trailing_semicolon=strip_trailing_terminator) @@ -949,6 +961,8 @@ def clear_splitter_caches() -> None: result_cache = _get_result_cache() pattern_cache.clear() result_cache.clear() + with _unknown_dialect_warning_lock: + _warned_unknown_dialects.clear() def get_splitter_cache_stats() -> "dict[str, Any]": diff --git a/tests/conftest.py b/tests/conftest.py index 0a98005ed..b3fabd251 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import logging import os import warnings @@ -63,6 +64,26 @@ def disable_spanner_builtin_metrics() -> "Generator[None, None, None]": yield +@pytest.fixture(scope="session", autouse=True) +def suppress_noisy_test_loggers() -> "Generator[None, None, None]": + """Lower especially noisy library loggers during test runs.""" + overrides = { + "httpx": logging.WARNING, + "httpcore": logging.WARNING, + "mysql.connector": logging.WARNING, + "asyncmy": logging.ERROR, + "sqlspec.migrations.tracker": logging.WARNING, + } + original_levels = {name: logging.getLogger(name).level for name in overrides} + for name, level in overrides.items(): + logging.getLogger(name).setLevel(level) + try: + yield + finally: + for name, level in original_levels.items(): + logging.getLogger(name).setLevel(level) + + @pytest.fixture(scope="session") def minio_client(minio_service: "MinioService", minio_default_bucket_name: str) -> Generator[Minio, None, None]: """Override pytest-databases minio_client to use new minio API with keyword arguments.""" diff --git a/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py b/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py index f9034eed5..1e6cb1f94 100644 --- a/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py +++ b/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py @@ -62,9 +62,9 @@ async def oracle_async_store(oracle_async_config: "OracleAsyncConfig") -> "Async await _cleanup_async_store(store, oracle_async_config) -@pytest.fixture(scope="module") +@pytest.fixture async def oracle_sync_store(oracle_sync_config: "OracleSyncConfig") -> "AsyncGenerator[OracleSyncADKStore, None]": - """Create a sync Oracle ADK store with tables created once per module.""" + """Create a sync Oracle ADK store with tables created per test.""" store = OracleSyncADKStore(oracle_sync_config) await store.create_tables() try: @@ -223,7 +223,7 @@ async def test_state_lob_deserialization(oracle_async_store: "OracleAsyncADKStor async def test_event_json_lob_deserialization(oracle_async_store: "OracleAsyncADKStore") -> None: - """Test event_json CLOB is correctly deserialized.""" + """Test event_json LOB data is correctly deserialized.""" session_id = _unique_session_id("event-lob") app_name = "test-app" user_id = "user-123" diff --git a/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py b/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py index 75b51dc08..b7cca39f2 100644 --- a/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py +++ b/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py @@ -1,67 +1,85 @@ -"""Integration tests for Spanner ADK store (sync).""" +"""Integration tests for Spanner ADK store.""" import json +from datetime import datetime, timezone from typing import Any import pytest +from sqlspec.extensions.adk import EventRecord + pytestmark = [pytest.mark.spanner, pytest.mark.integration] -def test_create_and_get_session(spanner_adk_store: Any) -> None: +async def test_create_and_get_session(spanner_adk_store: Any) -> None: session_id = "session-create" - spanner_adk_store.delete_session(session_id) - created = spanner_adk_store.create_session(session_id, "app", "user", {"a": 1}) + await spanner_adk_store.delete_session(session_id) + created = await spanner_adk_store.create_session(session_id, "app", "user", {"a": 1}) assert created["id"] == session_id - fetched = spanner_adk_store.get_session(session_id) + fetched = await spanner_adk_store.get_session(session_id) assert fetched is not None assert fetched["state"] == {"a": 1} -def test_update_session_state(spanner_adk_store: Any) -> None: +async def test_update_session_state(spanner_adk_store: Any) -> None: session_id = "session-update" - spanner_adk_store.delete_session(session_id) - spanner_adk_store.create_session(session_id, "app", "user", {"a": 1}) + await spanner_adk_store.delete_session(session_id) + await spanner_adk_store.create_session(session_id, "app", "user", {"a": 1}) - spanner_adk_store.update_session_state(session_id, {"a": 2, "b": True}) + await spanner_adk_store.update_session_state(session_id, {"a": 2, "b": True}) - fetched = spanner_adk_store.get_session(session_id) + fetched = await spanner_adk_store.get_session(session_id) assert fetched is not None assert fetched["state"] == {"a": 2, "b": True} -def test_list_sessions(spanner_adk_store: Any) -> None: - spanner_adk_store.delete_session("session-list-1") - spanner_adk_store.delete_session("session-list-2") - spanner_adk_store.delete_session("session-list-3") - spanner_adk_store.create_session("session-list-1", "app-list", "user1", {"v": 1}) - spanner_adk_store.create_session("session-list-2", "app-list", "user1", {"v": 2}) - spanner_adk_store.create_session("session-list-3", "app-list", "user2", {"v": 3}) +async def test_list_sessions(spanner_adk_store: Any) -> None: + await spanner_adk_store.delete_session("session-list-1") + await spanner_adk_store.delete_session("session-list-2") + await spanner_adk_store.delete_session("session-list-3") + await spanner_adk_store.create_session("session-list-1", "app-list", "user1", {"v": 1}) + await spanner_adk_store.create_session("session-list-2", "app-list", "user1", {"v": 2}) + await spanner_adk_store.create_session("session-list-3", "app-list", "user2", {"v": 3}) - sessions = spanner_adk_store.list_sessions("app-list", "user1") + sessions = await spanner_adk_store.list_sessions("app-list", "user1") session_ids = {s["id"] for s in sessions} assert session_ids == {"session-list-1", "session-list-2"} -def test_delete_session(spanner_adk_store: Any) -> None: +async def test_delete_session(spanner_adk_store: Any) -> None: session_id = "session-delete" - spanner_adk_store.delete_session(session_id) - spanner_adk_store.create_session(session_id, "app", "user", {"k": "v"}) - spanner_adk_store.delete_session(session_id) + await spanner_adk_store.delete_session(session_id) + await spanner_adk_store.create_session(session_id, "app", "user", {"k": "v"}) + await spanner_adk_store.delete_session(session_id) - assert spanner_adk_store.get_session(session_id) is None + assert await spanner_adk_store.get_session(session_id) is None -def test_create_and_list_events(spanner_adk_store: Any) -> None: +async def test_create_and_list_events(spanner_adk_store: Any) -> None: session_id = "session-events" - spanner_adk_store.delete_session(session_id) - spanner_adk_store.create_session(session_id, "app", "user", {"x": 1}) - - spanner_adk_store.create_event("event-1", session_id, "app", "user", author="user", content={"msg": "hi"}) - spanner_adk_store.create_event("event-2", session_id, "app", "user", author="assistant", content={"msg": "ok"}) - - events = spanner_adk_store.list_events(session_id) + await spanner_adk_store.delete_session(session_id) + await spanner_adk_store.create_session(session_id, "app", "user", {"x": 1}) + + event_one: EventRecord = { + "session_id": session_id, + "invocation_id": "event-1", + "author": "user", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "event-1", "content": {"msg": "hi"}, "app_name": "app", "user_id": "user"}, + } + event_two: EventRecord = { + "session_id": session_id, + "invocation_id": "event-2", + "author": "assistant", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "event-2", "content": {"msg": "ok"}, "app_name": "app", "user_id": "user"}, + } + + await spanner_adk_store.append_event(event_one) + await spanner_adk_store.append_event(event_two) + + events = await spanner_adk_store.get_events(session_id) assert len(events) == 2 assert events[0]["author"] == "user" assert events[1]["author"] == "assistant" diff --git a/uv.lock b/uv.lock index 4300db512..3a10e6a3b 100644 --- a/uv.lock +++ b/uv.lock @@ -1509,7 +1509,7 @@ name = "exceptiongroup" version = "1.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" } wheels = [ @@ -1963,7 +1963,7 @@ wheels = [ [[package]] name = "google-cloud-aiplatform" -version = "1.141.0" +version = "1.142.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "docstring-parser" }, @@ -1979,13 +1979,14 @@ dependencies = [ { name = "pydantic" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ac/dc/1209c7aab43bd7233cf631165a3b1b4284d22fc7fe7387c66228d07868ab/google_cloud_aiplatform-1.141.0.tar.gz", hash = "sha256:e3b1cdb28865dd862aac9c685dfc5ac076488705aba0a5354016efadcddd59c6", size = 10152688, upload-time = "2026-03-10T22:20:08.692Z" } +sdist = { url = "https://files.pythonhosted.org/packages/41/0d/3063a0512d60cf18854a279e00ccb796429545464345ef821cf77cb93d05/google_cloud_aiplatform-1.142.0.tar.gz", hash = "sha256:87b49e002703dc14885093e9b264587db84222bef5f70f5a442d03f41beecdd1", size = 10207993, upload-time = "2026-03-20T22:49:13.797Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/6a/fc/428af69a69ff2e477e7f5e12d227b31fe5790f1a8234aacd54297f49c836/google_cloud_aiplatform-1.141.0-py2.py3-none-any.whl", hash = "sha256:6bd25b4d514c40b8181ca703e1b313ad6d0454ab8006fc9907fb3e9f672f31d1", size = 8358409, upload-time = "2026-03-10T22:20:04.871Z" }, + { url = "https://files.pythonhosted.org/packages/59/8b/f29646d3fa940f0e38cfcc12137f4851856b50d7486a3c05103ebc78d82d/google_cloud_aiplatform-1.142.0-py2.py3-none-any.whl", hash = "sha256:17c91db9b613cbbafb2c36335b123686aeb2b4b8448be5134b565ae07165a39a", size = 8388991, upload-time = "2026-03-20T22:49:10.334Z" }, ] [package.optional-dependencies] agent-engines = [ + { name = "aiohttp" }, { name = "cloudpickle" }, { name = "google-cloud-iam" }, { name = "google-cloud-logging" }, @@ -7121,6 +7122,7 @@ cockroachdb = [ ] duckdb = [ { name = "duckdb" }, + { name = "pytz" }, ] fastapi = [ { name = "fastapi" }, @@ -7397,6 +7399,7 @@ requires-dist = [ { name = "pydantic-extra-types", marker = "extra == 'pydantic'" }, { name = "pymssql", marker = "extra == 'pymssql'" }, { name = "pymysql", marker = "extra == 'pymysql'" }, + { name = "pytz", marker = "extra == 'duckdb'" }, { name = "rich-click", specifier = ">=1.9.0" }, { name = "sqlglot", specifier = ">=30.0.0" }, { name = "sqlglot", extras = ["c"], marker = "extra == 'mypyc'", specifier = ">=30.0.0" }, From f9fd8a8b0433f06d07734a8e2e8bb158207a7b01 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 21 Mar 2026 22:28:31 +0000 Subject: [PATCH 23/23] feat: Streamline migration CLI commands, update ADK schema guidance, and enhance migration configuration documentation. --- docs/_tapes/migration_workflow.tape | 8 +++---- docs/extensions/adk/migrations.rst | 35 +++++++++++++++++++++++++++-- docs/extensions/adk/quickstart.rst | 11 +++++++-- docs/extensions/adk/schema.rst | 4 +++- docs/usage/cli.rst | 13 ++++++----- docs/usage/migrations.rst | 19 +++++++++++----- pyproject.toml | 4 ++-- uv.lock | 2 +- 8 files changed, 74 insertions(+), 22 deletions(-) diff --git a/docs/_tapes/migration_workflow.tape b/docs/_tapes/migration_workflow.tape index f7703a6f6..d08a78408 100644 --- a/docs/_tapes/migration_workflow.tape +++ b/docs/_tapes/migration_workflow.tape @@ -28,7 +28,7 @@ Type "# Initialize the migration environment" Enter Sleep 500ms -Type "sqlspec db init" +Type "sqlspec init" Enter Sleep 3s @@ -36,7 +36,7 @@ Type "# Create a new migration" Enter Sleep 500ms -Type 'sqlspec db create-migration -m "add users table"' +Type 'sqlspec create-migration -m "add users table"' Enter Sleep 3s @@ -44,7 +44,7 @@ Type "# Apply the migration" Enter Sleep 500ms -Type "sqlspec db upgrade" +Type "sqlspec upgrade" Enter Sleep 3s @@ -52,7 +52,7 @@ Type "# Check current revision" Enter Sleep 500ms -Type "sqlspec db show-current-revision" +Type "sqlspec show-current-revision" Enter Sleep 3s diff --git a/docs/extensions/adk/migrations.rst b/docs/extensions/adk/migrations.rst index 0bbb9c85c..9359fbcd3 100644 --- a/docs/extensions/adk/migrations.rst +++ b/docs/extensions/adk/migrations.rst @@ -8,7 +8,8 @@ used by your ADK backend, then run them with the SQLSpec migration CLI. Schema Bootstrapping ==================== -For development, use ``ensure_tables()`` to create tables on first use: +You can programmatically create ADK tables with ``create_tables()`` / +``ensure_tables()``: .. code-block:: python @@ -16,7 +17,37 @@ For development, use ``ensure_tables()`` to create tables on first use: await memory_store.ensure_tables() await artifact_store.ensure_table() -For production, run migrations ahead of deployment to avoid runtime DDL. +Alternatively, configure SQLSpec migrations on the database config and run the +migration CLI ahead of deployment: + +.. code-block:: python + + from sqlspec.adapters.asyncpg import AsyncpgConfig + + config = AsyncpgConfig( + connection_config={"dsn": "postgresql://localhost/app"}, + migration_config={"script_location": "migrations/postgres"}, + ) + +.. code-block:: console + + sqlspec upgrade + +Use the programmatic table-creation path when you want the store to bootstrap +its own schema. Use migrations when you want schema changes tracked and applied +through your deployment workflow. + +.. note:: + + The migration CLI resolves configuration from ``--config``, + ``SQLSPEC_CONFIG``, or ``[tool.sqlspec]`` in ``pyproject.toml``. + + When ``extension_config["adk"]`` is present, ADK extension migrations are + auto-included. Use ``migration_config={"exclude_extensions": ["adk"]}`` + to skip only ADK extension migrations, or + ``migration_config={"include_extensions": ["adk"]}`` to opt in explicitly + by extension name. Use ``migration_config={"enabled": False}`` to disable + migrations entirely for a given database config. Clean-Break Migration Notes ============================ diff --git a/docs/extensions/adk/quickstart.rst b/docs/extensions/adk/quickstart.rst index af14b00ca..69bc0e6b6 100644 --- a/docs/extensions/adk/quickstart.rst +++ b/docs/extensions/adk/quickstart.rst @@ -124,8 +124,8 @@ automatic versioning. Metadata lives in SQL; content lives in object storage. Schema Setup ============ -Stores create their tables on first use via ``ensure_tables()``. For -production, run table creation ahead of deployment: +You can programmatically create ADK tables ahead of first use with +``ensure_tables()`` / ``ensure_table()``: .. code-block:: python @@ -133,6 +133,13 @@ production, run table creation ahead of deployment: await memory_store.ensure_tables() await artifact_store.ensure_table() +Alternatively, configure SQLSpec migrations for your database and run the +migration CLI as part of deployment: + +.. code-block:: console + + sqlspec upgrade + Next Steps ========== diff --git a/docs/extensions/adk/schema.rst b/docs/extensions/adk/schema.rst index b49a4996d..82bdd5ef2 100644 --- a/docs/extensions/adk/schema.rst +++ b/docs/extensions/adk/schema.rst @@ -5,7 +5,9 @@ Schema ADK stores create tables for sessions, events, memory entries, and artifact metadata. Table names are configurable via ``extension_config["adk"]``. -Use ``create_tables()`` or ``ensure_tables()`` on a store to apply the schema. +You can programmatically create the schema with ``create_tables()`` or +``ensure_tables()`` on a store. For managed deployments, configure SQLSpec +migrations for the target database and run ``sqlspec upgrade`` instead. .. contents:: On this page :local: diff --git a/docs/usage/cli.rst b/docs/usage/cli.rst index f1ae3cef9..1bbdfe740 100644 --- a/docs/usage/cli.rst +++ b/docs/usage/cli.rst @@ -4,15 +4,18 @@ Command Line Interface SQLSpec includes a CLI for managing migrations and inspecting configuration. Use it when you want a fast, explicit workflow without additional tooling. +Configuration can come from ``--config``, ``SQLSPEC_CONFIG``, or +``[tool.sqlspec]`` in ``pyproject.toml``. + Core Commands ------------- .. code-block:: console - sqlspec db init - sqlspec db create-migration -m "add users" - sqlspec db upgrade - sqlspec db downgrade + sqlspec init + sqlspec create-migration -m "add users" + sqlspec upgrade + sqlspec downgrade Common Options -------------- @@ -28,7 +31,7 @@ Tips ---- - Run ``sqlspec --help`` to see global options. -- Run ``sqlspec db --help`` to see migration command details. +- Run ``sqlspec upgrade --help`` to see command-specific migration options. Related Guides -------------- diff --git a/docs/usage/migrations.rst b/docs/usage/migrations.rst index 7d4aa5826..2e5cd1b62 100644 --- a/docs/usage/migrations.rst +++ b/docs/usage/migrations.rst @@ -21,9 +21,9 @@ Common Commands .. code-block:: console - sqlspec db init - sqlspec db create-migration -m "add users" - sqlspec db upgrade + sqlspec init + sqlspec create-migration -m "add users" + sqlspec upgrade Configuration ------------- @@ -31,6 +31,9 @@ Configuration Set ``migration_config`` on your database configuration to customize script locations, version table names, and extension migration behavior. +The migration CLI resolves config from ``--config``, ``SQLSPEC_CONFIG``, or +``[tool.sqlspec]`` in ``pyproject.toml``. + .. code-block:: python from sqlspec.adapters.duckdb import DuckDBConfig @@ -38,7 +41,7 @@ locations, version table names, and extension migration behavior. config = DuckDBConfig( connection_config={"database": "/tmp/analytics.db"}, migration_config={ - "migration_dir": "migrations/duckdb", + "script_location": "migrations/duckdb", "version_table": "_schema_versions", }, ) @@ -60,11 +63,17 @@ For async configs, ``migrate_up()`` returns an awaitable: config = AsyncpgConfig( connection_config={"dsn": "postgresql://localhost/app"}, - migration_config={"migration_dir": "migrations/postgres"}, + migration_config={"script_location": "migrations/postgres"}, ) await config.migrate_up() +Extension migrations are auto-included when the corresponding entry exists in +``extension_config``. Use ``migration_config["exclude_extensions"]`` to skip a +specific extension, ``migration_config["include_extensions"]`` to opt in +explicitly by extension name, or ``migration_config["enabled"] = False`` to +disable migrations entirely for a database config. + Logging and Echo Controls ------------------------- diff --git a/pyproject.toml b/pyproject.toml index b834e0203..53240438b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ maintainers = [{ name = "Litestar Developers", email = "hello@litestar.dev" }] name = "sqlspec" readme = "README.md" requires-python = ">=3.10, <4.0" -version = "0.41.1" +version = "0.42.0" [project.urls] Discord = "https://discord.gg/litestar" @@ -254,7 +254,7 @@ opt_level = "3" # Maximum optimization (0-3) allow_dirty = true commit = false commit_args = "--no-verify" -current_version = "0.41.1" +current_version = "0.42.0" ignore_missing_files = false ignore_missing_version = false message = "chore(release): bump to v{new_version}" diff --git a/uv.lock b/uv.lock index 3a10e6a3b..b499bb1d7 100644 --- a/uv.lock +++ b/uv.lock @@ -7072,7 +7072,7 @@ wheels = [ [[package]] name = "sqlspec" -version = "0.41.1" +version = "0.42.0" source = { editable = "." } dependencies = [ { name = "mypy-extensions" },