diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f04a450c4..d4381dd65 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -108,7 +108,7 @@ jobs: - name: Test env: PYTHONFAULTHANDLER: "1" - PYTEST_ADDOPTS: "--max-worker-restart=0 -s" + PYTEST_ADDOPTS: "--max-worker-restart=0" run: timeout 900s uv run pytest -n 2 --dist=loadgroup # test-linux-freethreaded: diff --git a/.gitignore b/.gitignore index 20a641773..442abbaa4 100644 --- a/.gitignore +++ b/.gitignore @@ -68,3 +68,4 @@ uv.toml .geminiignore .beads/ tools/scripts/profiles/*.prof +.agents/ diff --git a/docs/_tapes/migration_workflow.tape b/docs/_tapes/migration_workflow.tape index f7703a6f6..d08a78408 100644 --- a/docs/_tapes/migration_workflow.tape +++ b/docs/_tapes/migration_workflow.tape @@ -28,7 +28,7 @@ Type "# Initialize the migration environment" Enter Sleep 500ms -Type "sqlspec db init" +Type "sqlspec init" Enter Sleep 3s @@ -36,7 +36,7 @@ Type "# Create a new migration" Enter Sleep 500ms -Type 'sqlspec db create-migration -m "add users table"' +Type 'sqlspec create-migration -m "add users table"' Enter Sleep 3s @@ -44,7 +44,7 @@ Type "# Apply the migration" Enter Sleep 500ms -Type "sqlspec db upgrade" +Type "sqlspec upgrade" Enter Sleep 3s @@ -52,7 +52,7 @@ Type "# Check current revision" Enter Sleep 500ms -Type "sqlspec db show-current-revision" +Type "sqlspec show-current-revision" Enter Sleep 3s diff --git a/docs/extensions/adk/adapters.rst b/docs/extensions/adk/adapters.rst index 5ff658122..72b0d06b7 100644 --- a/docs/extensions/adk/adapters.rst +++ b/docs/extensions/adk/adapters.rst @@ -10,11 +10,36 @@ Choosing an Adapter Use async adapters for best performance with ADK runners: -- **PostgreSQL**: ``asyncpg`` (recommended), ``psycopg`` (async mode) -- **SQLite**: ``aiosqlite`` -- **MySQL**: ``asyncmy`` +- **PostgreSQL** (recommended): ``asyncpg``, ``psycopg`` (async mode), ``psqlpy`` +- **CockroachDB**: ``cockroach_asyncpg``, ``cockroach_psycopg`` (full FTS support) +- **MySQL/MariaDB**: ``asyncmy`` +- **SQLite**: ``aiosqlite`` (development and single-process) +- **Oracle**: ``oracledb`` +- **DuckDB**: ``duckdb`` (analytics; reduced-scope for ADK) +- **ADBC**: ``adbc`` (Arrow-native, driver-agnostic) +- **Spanner**: ``spanner`` (Google Cloud, globally distributed) -Sync adapters work but require wrapping with ``anyio`` for async ADK runners. +Sync adapters (``psycopg`` sync mode, ``sqlite``, ``mysqlconnector``, ``pymysql``) +work but require wrapping with ``anyio`` for async ADK runners. + +Each Adapter Provides +===================== + +Every adapter with ADK support ships three store classes: + +- **Session store** (e.g., ``AsyncpgADKStore``) -- sessions and events. +- **Memory store** (e.g., ``AsyncpgADKMemoryStore``) -- long-term memory with FTS. +- **Artifact store** (e.g., ``AsyncpgADKArtifactStore``) -- artifact metadata. + +Import from the adapter's ``adk`` subpackage: + +.. code-block:: python + + from sqlspec.adapters.asyncpg.adk import ( + AsyncpgADKStore, + AsyncpgADKMemoryStore, + AsyncpgADKArtifactStore, + ) Example ======= @@ -30,6 +55,6 @@ Example See Also ======== -- :doc:`backends` for the full adapter support matrix. +- :doc:`backends` for the full support matrix and backend-specific notes. - :doc:`/usage/drivers_and_querying` for adapter configuration patterns. - :doc:`/reference/adapters` for the complete adapter API. diff --git a/docs/extensions/adk/api.rst b/docs/extensions/adk/api.rst index 95d2431d4..e95a28a31 100644 --- a/docs/extensions/adk/api.rst +++ b/docs/extensions/adk/api.rst @@ -19,8 +19,20 @@ Services :show-inheritance: :no-index: -Base Stores -=========== +.. autoclass:: sqlspec.extensions.adk.memory.SQLSpecSyncMemoryService + :members: + :undoc-members: + :show-inheritance: + :no-index: + +.. autoclass:: SQLSpecArtifactService + :members: + :undoc-members: + :show-inheritance: + :no-index: + +Session Stores +============== .. autoclass:: BaseAsyncADKStore :members: @@ -34,6 +46,9 @@ Base Stores :show-inheritance: :no-index: +Memory Stores +============= + .. autoclass:: BaseAsyncADKMemoryStore :members: :undoc-members: @@ -45,3 +60,57 @@ Base Stores :undoc-members: :show-inheritance: :no-index: + +Artifact Stores +=============== + +.. autoclass:: BaseAsyncADKArtifactStore + :members: + :undoc-members: + :show-inheritance: + :no-index: + +.. autoclass:: BaseSyncADKArtifactStore + :members: + :undoc-members: + :show-inheritance: + :no-index: + +Record Types +============ + +.. autoclass:: SessionRecord + :members: + :show-inheritance: + :no-index: + +.. autoclass:: EventRecord + :members: + :show-inheritance: + :no-index: + +.. autoclass:: MemoryRecord + :members: + :show-inheritance: + :no-index: + +.. autoclass:: ArtifactRecord + :members: + :show-inheritance: + :no-index: + +Configuration +============= + +.. autoclass:: ADKConfig + :members: + :show-inheritance: + :no-index: + +Converters +========== + +.. automodule:: sqlspec.extensions.adk.converters + :members: + :undoc-members: + :no-index: diff --git a/docs/extensions/adk/backends.rst b/docs/extensions/adk/backends.rst index e550d0024..1f770f68d 100644 --- a/docs/extensions/adk/backends.rst +++ b/docs/extensions/adk/backends.rst @@ -2,51 +2,249 @@ Backends ======== -ADK stores are implemented per adapter. Use the backend config helpers when -connecting to multiple databases or configuring advanced options. +ADK stores are implemented per adapter. Each backend has different capabilities +for session, event, memory, and artifact storage. Use the support matrix below +to select the right backend for your deployment. -Example -======= +.. _adk-support-matrix: -.. literalinclude:: /examples/extensions/adk/backend_config.py - :language: python - :caption: ``adk backend config`` - :start-after: # start-example - :end-before: # end-example - :dedent: 4 - :no-upgrade: +Support Matrix +============== -Supported Backends -================== +The table below classifies every backend by its ADK support level. .. list-table:: :header-rows: 1 + :widths: 20 15 15 15 15 20 * - Adapter - Status + - Session/Event + - Memory (FTS) + - Artifacts + - Notes * - asyncpg - - Production - * - psycopg - - Production + - Recommended + - Full + - Full + - Full + - Best async PostgreSQL driver. + * - psycopg (async) + - Recommended + - Full + - Full + - Full + - Supports both sync and async modes. * - psqlpy - - Production + - Supported + - Full + - Full + - Full + - Rust-backed PostgreSQL driver. + * - cockroach_asyncpg + - Supported + - Full + - Full + - Full + - CockroachDB with full FTS support. + * - cockroach_psycopg + - Supported + - Full + - Full + - Full + - CockroachDB with full FTS support. * - asyncmy - - Production - * - sqlite - - Production + - Supported + - Full + - Full + - Full + - MySQL/MariaDB async driver. + * - mysqlconnector + - Supported + - Full + - Full + - Full + - MySQL/MariaDB sync driver. + * - pymysql + - Supported + - Full + - Full + - Full + - MySQL/MariaDB sync driver. * - aiosqlite - - Production + - Supported + - Full + - Full + - Full + - SQLite async, ideal for development. + * - sqlite + - Supported + - Full + - Full + - Full + - SQLite sync with thread-local pools. * - oracledb - - Production + - Supported + - Full + - Full + - Full + - Oracle Database driver. * - duckdb - - Production (analytics) - * - bigquery - - Production + - Reduced-scope + - Full + - Limited + - Full + - Analytics-oriented; no concurrent writes. * - adbc - - Production + - Supported + - Full + - Full + - Full + - Arrow-native database connectivity. + * - spanner + - Supported + - Full + - Full + - Full + - Google Cloud Spanner (cloud-managed). + +Status Definitions +------------------ + +**Recommended** + Production-grade, fully tested, actively optimized. Start here unless you + have a specific reason not to. + +**Supported** + Fully implemented and tested. Works correctly for all ADK operations. + +**Reduced-scope** + Implemented with known limitations. Specific features may be absent or + behave differently. See backend-specific notes. + +**Removed** + Previously available but no longer supported. See the removal notice for + migration guidance. + +Removed Backends +---------------- + +**BigQuery** was removed from the ADK backend surface. BigQuery's batch-oriented +architecture is incompatible with the low-latency, transactional write patterns +that ADK session and event storage require. If you were using BigQuery for ADK +storage, migrate to PostgreSQL (asyncpg or psycopg) or any other supported +backend. + +Backend Details +=============== + +PostgreSQL Family +----------------- + +PostgreSQL backends (asyncpg, psycopg, psqlpy) provide the fullest feature set: + +- Native ``JSONB`` storage for session state and event JSON. +- Full-text search via ``tsvector`` for memory entries. +- ``UPSERT`` and ``RETURNING`` clauses for atomic operations. +- ``append_event_and_update_state()`` executes as a single transaction. + +**Recommended for production deployments.** + +CockroachDB +------------ + +CockroachDB backends (cockroach_asyncpg, cockroach_psycopg) provide full ADK +support including full-text search. CockroachDB is a distributed SQL database +compatible with the PostgreSQL wire protocol. + +- Full FTS support for memory search. +- Distributed transactions for session and event atomicity. +- Horizontal scalability for high-throughput agent deployments. + +MySQL Family +------------ + +MySQL backends (asyncmy, mysqlconnector, pymysql) provide full ADK support: + +- JSON column storage for session state and event records. +- Full-text search on ``InnoDB`` tables for memory entries. +- Transactional writes for ``append_event_and_update_state()``. + +SQLite +------ + +SQLite backends (aiosqlite, sqlite) are ideal for local development, testing, +and single-process deployments: + +- JSON1 extension for state and event storage. +- FTS5 virtual tables for memory full-text search. +- File-based or in-memory operation. + +.. note:: + + SQLite does not support concurrent writers. Use a server-backed database + for production multi-process deployments. + +Oracle +------ + +Oracle Database (oracledb) provides full ADK support: + +- Native JSON column support (Oracle 21c+). +- Oracle Text for full-text search on memory entries. +- Full transactional support for atomic operations. + +DuckDB +------ + +DuckDB provides session and event storage but has limitations: + +- Optimized for analytics, not OLTP workloads. +- Single-writer constraint limits concurrent access. +- Memory search capabilities are limited compared to server databases. + +**Best suited for analytics pipelines and offline agent evaluation.** + +ADBC +---- + +ADBC (Arrow Database Connectivity) provides a driver-agnostic interface: + +- Works with any ADBC-compatible driver (PostgreSQL, SQLite, DuckDB, etc.). +- Arrow-native data transfer for high-throughput event ingestion. +- Backend capabilities depend on the underlying database driver. + +Spanner +------- + +Google Cloud Spanner provides globally distributed ADK storage: + +- Cloud-managed, horizontally scalable. +- Full-text search support for memory entries. +- Strong consistency across regions. +- Suitable for multi-region agent deployments. + +Configuration +============= + +All backends are configured through ``extension_config["adk"]``: + +.. code-block:: python + + from sqlspec.adapters.asyncpg import AsyncpgConfig -Notes -===== + config = AsyncpgConfig( + connection_config={"dsn": "postgresql://localhost/mydb"}, + extension_config={ + "adk": { + "session_table": "adk_sessions", + "events_table": "adk_events", + "memory_table": "adk_memory_entries", + "memory_use_fts": True, + "artifact_table": "adk_artifact_versions", + "owner_id_column": "tenant_id INTEGER NOT NULL", + } + }, + ) -- Use async backends for ADK runners; sync backends can be wrapped with anyio. -- Backend stores expose ``create_tables`` to bootstrap schema. +See :doc:`adapters` for adapter-specific configuration patterns. diff --git a/docs/extensions/adk/index.rst b/docs/extensions/adk/index.rst index d83c786c1..de771ccc8 100644 --- a/docs/extensions/adk/index.rst +++ b/docs/extensions/adk/index.rst @@ -2,8 +2,23 @@ Google ADK Extension ==================== -SQLSpec provides an ADK extension for session, event, and memory storage with -SQL-backed persistence. +SQLSpec provides a full-featured backend for +`Google Agent Development Kit `_, +covering session, event, memory, and artifact storage with SQL-backed +persistence across 14 database adapters. + +Key capabilities: + +- **Session and event storage** with atomic ``append_event_and_update_state()`` + ensuring events and state are always consistent. +- **Full-event JSON storage** (EventRecord) that captures the entire ADK Event + in a single column, eliminating schema drift with upstream ADK releases. +- **Scoped state semantics** (``app:``, ``user:``, ``temp:``) for controlling + state visibility and persistence across sessions. +- **Memory service** with database-native full-text search (tsvector, FTS5, + InnoDB FT) for long-term agent context. +- **Artifact service** with append-only versioning, SQL metadata, and pluggable + object storage backends. Choose a guide ============== @@ -22,45 +37,45 @@ Choose a guide :link: quickstart :link-type: doc - Persist memory and sessions with minimal setup. + Persist sessions, memory, and artifacts with minimal setup. - .. grid-item-card:: API Reference - :link: api + .. grid-item-card:: Support Matrix + :link: backends :link-type: doc - Interfaces, stores, and configuration helpers. + See which backends are recommended, supported, or reduced-scope. .. grid-item-card:: Adapters :link: adapters :link-type: doc - Configure supported SQLSpec adapters. + Configure supported SQLSpec adapters for ADK. - .. grid-item-card:: Backends - :link: backends + .. grid-item-card:: Schema + :link: schema :link-type: doc - Storage backends and connection profiles. + Table layouts, EventRecord, scoped state, and artifact metadata. - .. grid-item-card:: Migrations - :link: migrations + .. grid-item-card:: API Reference + :link: api :link-type: doc - Apply schema changes safely over time. + Services, stores, and record types. - .. grid-item-card:: Schema - :link: schema + .. grid-item-card:: Migrations + :link: migrations :link-type: doc - Table layouts for sessions and memory records. + Apply schema changes safely over time. .. toctree:: :hidden: installation quickstart - api - adapters backends - migrations + adapters schema + api + migrations diff --git a/docs/extensions/adk/installation.rst b/docs/extensions/adk/installation.rst index 8304d72f7..347f8aa13 100644 --- a/docs/extensions/adk/installation.rst +++ b/docs/extensions/adk/installation.rst @@ -6,7 +6,7 @@ Install SQLSpec with a database adapter and the Google ADK SDK. .. tab-set:: - .. tab-item:: PostgreSQL + .. tab-item:: PostgreSQL (recommended) .. tab-set:: @@ -90,6 +90,34 @@ Install SQLSpec with a database adapter and the Google ADK SDK. pdm add "sqlspec[asyncmy,adk]" + .. tab-item:: CockroachDB + + .. tab-set:: + + .. tab-item:: uv + + .. code-block:: bash + + uv add "sqlspec[cockroach-asyncpg,adk]" + + .. tab-item:: pip + + .. code-block:: bash + + pip install "sqlspec[cockroach-asyncpg,adk]" + + .. tab-item:: Poetry + + .. code-block:: bash + + poetry add "sqlspec[cockroach-asyncpg,adk]" + + .. tab-item:: PDM + + .. code-block:: bash + + pdm add "sqlspec[cockroach-asyncpg,adk]" + .. tab-item:: DuckDB .. tab-set:: @@ -123,11 +151,17 @@ What This Provides The ``adk`` extra includes the Google ADK SDK (``google-genai``). SQLSpec provides: -- **Session Store** - Persist ADK agent sessions to your database. -- **Memory Store** - Store agent memory for context across conversations. -- **Event Store** - Log agent events for observability. +- **Session Service** -- Persist ADK agent sessions and events to your database + with atomic ``append_event_and_update_state()`` writes. +- **Memory Service** -- Store agent memory with database-native full-text search + for context retrieval across conversations. +- **Artifact Service** -- Version and store binary artifacts with SQL metadata + and pluggable object storage backends. +- **Event Storage** -- Full-event JSON storage (EventRecord) that captures the + entire ADK Event without schema drift. Next Steps ---------- -Proceed to :doc:`quickstart` to set up stores for your ADK agent. +Proceed to :doc:`quickstart` to set up stores for your ADK agent, or see +:doc:`backends` for the full support matrix. diff --git a/docs/extensions/adk/migrations.rst b/docs/extensions/adk/migrations.rst index 948969ae8..9359fbcd3 100644 --- a/docs/extensions/adk/migrations.rst +++ b/docs/extensions/adk/migrations.rst @@ -5,4 +5,63 @@ Migrations ADK stores use standard SQLSpec migrations. Generate migrations for the database used by your ADK backend, then run them with the SQLSpec migration CLI. +Schema Bootstrapping +==================== + +You can programmatically create ADK tables with ``create_tables()`` / +``ensure_tables()``: + +.. code-block:: python + + await session_store.ensure_tables() + await memory_store.ensure_tables() + await artifact_store.ensure_table() + +Alternatively, configure SQLSpec migrations on the database config and run the +migration CLI ahead of deployment: + +.. code-block:: python + + from sqlspec.adapters.asyncpg import AsyncpgConfig + + config = AsyncpgConfig( + connection_config={"dsn": "postgresql://localhost/app"}, + migration_config={"script_location": "migrations/postgres"}, + ) + +.. code-block:: console + + sqlspec upgrade + +Use the programmatic table-creation path when you want the store to bootstrap +its own schema. Use migrations when you want schema changes tracked and applied +through your deployment workflow. + +.. note:: + + The migration CLI resolves configuration from ``--config``, + ``SQLSPEC_CONFIG``, or ``[tool.sqlspec]`` in ``pyproject.toml``. + + When ``extension_config["adk"]`` is present, ADK extension migrations are + auto-included. Use ``migration_config={"exclude_extensions": ["adk"]}`` + to skip only ADK extension migrations, or + ``migration_config={"include_extensions": ["adk"]}`` to opt in explicitly + by extension name. Use ``migration_config={"enabled": False}`` to disable + migrations entirely for a given database config. + +Clean-Break Migration Notes +============================ + +If you are upgrading from a pre-clean-break version of the ADK extension, +note the following schema changes: + +- **Events table**: The column layout changed to full-event JSON storage. + The ``event_json`` column now stores the entire ADK Event as a JSON blob. + Individual event columns (``content``, ``actions``, ``branch``, etc.) have + been replaced by indexed scalar columns (``invocation_id``, ``author``, + ``timestamp``) plus ``event_json``. +- **Artifact table**: New table (``adk_artifact_versions``) for artifact + metadata. Create this table when enabling the artifact service. +- **BigQuery**: Removed. Migrate to PostgreSQL or any other supported backend. + See :doc:`/usage/migrations` for the full workflow and commands. diff --git a/docs/extensions/adk/quickstart.rst b/docs/extensions/adk/quickstart.rst index 93c829100..69bc0e6b6 100644 --- a/docs/extensions/adk/quickstart.rst +++ b/docs/extensions/adk/quickstart.rst @@ -2,56 +2,147 @@ Quickstart ========== -Wire SQLSpec stores into your ADK agent to persist sessions and memory across restarts. +Wire SQLSpec stores into your ADK agent to persist sessions, events, memory, +and artifacts across restarts. How It Works ============ -1. Create a SQLSpec database config. -2. Initialize ADK stores (session, memory, event) backed by that config. -3. Pass the stores to your ADK agent. +1. Create a SQLSpec database config with ADK extension settings. +2. Initialize the appropriate stores (session, memory, artifact). +3. Pass the service wrappers to your ADK agent. -Session Store -============= +Session Service +=============== -The session store persists agent state between conversations. When a user returns, -the agent can resume from where it left off. +The session service persists agent state and events between conversations. +When a user returns, the agent can resume from where it left off. -.. literalinclude:: /examples/extensions/adk/memory_store.py - :language: python - :caption: ``adk session store`` - :start-after: # start-example - :end-before: # end-example - :dedent: 4 - :no-upgrade: +.. code-block:: python + + from sqlspec.adapters.asyncpg import AsyncpgConfig + from sqlspec.adapters.asyncpg.adk import AsyncpgADKStore + from sqlspec.extensions.adk import SQLSpecSessionService + + config = AsyncpgConfig( + connection_config={"dsn": "postgresql://localhost/mydb"}, + extension_config={ + "adk": { + "session_table": "adk_sessions", + "events_table": "adk_events", + } + }, + ) + + store = AsyncpgADKStore(config) + await store.ensure_tables() + + session_service = SQLSpecSessionService(store) + + # Create a session with scoped state + session = await session_service.create_session( + app_name="my_agent", + user_id="user_123", + state={ + "app:model": "gemini-2.0", # shared across all sessions + "user:name": "Alice", # shared across user's sessions + "conversation_turn": 0, # session-local + "temp:scratch": "...", # runtime-only, never persisted + }, + ) + +Events are persisted automatically when you use the session service with an +ADK runner. Each call to ``append_event()`` atomically stores the event and +updates the session's durable state via ``append_event_and_update_state()``. + +Scoped State +------------ + +State keys use prefixes to control their scope and persistence: + +- ``app:`` -- shared across all sessions for the same application. +- ``user:`` -- shared across all sessions for the same user. +- ``temp:`` -- runtime-only, stripped before every write to storage. +- *(no prefix)* -- private to the current session. + +See :ref:`scoped-state` for full details. + +Memory Service +============== + +The memory service retains context that the agent can reference later. This +enables long-term memory across sessions with full-text search. + +.. code-block:: python + + from sqlspec.adapters.asyncpg.adk import AsyncpgADKMemoryStore + from sqlspec.extensions.adk import SQLSpecMemoryService + + memory_store = AsyncpgADKMemoryStore(config) + await memory_store.ensure_tables() + + memory_service = SQLSpecMemoryService(memory_store) + +Enable full-text search by setting ``memory_use_fts: True`` in the ADK config. +This creates database-native FTS indexes (tsvector, FTS5, InnoDB FT) for +efficient memory retrieval. + +Artifact Service +================ + +The artifact service stores binary artifacts (files, images, reports) with +automatic versioning. Metadata lives in SQL; content lives in object storage. -Memory Store Integration -======================== +.. code-block:: python + + from sqlspec.adapters.asyncpg.adk import AsyncpgADKArtifactStore + from sqlspec.extensions.adk import SQLSpecArtifactService + + artifact_store = AsyncpgADKArtifactStore(config) + await artifact_store.ensure_table() + + artifact_service = SQLSpecArtifactService( + store=artifact_store, + artifact_storage_uri="s3://my-bucket/adk-artifacts/", + ) -The memory store retains context that the agent can reference later. This enables -long-term memory across sessions. + # Save an artifact (returns version number starting from 0) + version = await artifact_service.save_artifact( + app_name="my_agent", + user_id="user_123", + filename="report.pdf", + artifact=part, + ) -.. literalinclude:: /examples/extensions/adk/tool_integration.py - :language: python - :caption: ``adk memory integration`` - :start-after: # start-example - :end-before: # end-example - :dedent: 4 - :no-upgrade: + # Load the latest version + loaded = await artifact_service.load_artifact( + app_name="my_agent", + user_id="user_123", + filename="report.pdf", + ) Schema Setup ============ -Stores create their tables automatically on first use. For production, run migrations -ahead of time: +You can programmatically create ADK tables ahead of first use with +``ensure_tables()`` / ``ensure_table()``: .. code-block:: python - await session_store.create_tables() - await memory_store.create_tables() + await session_store.ensure_tables() + await memory_store.ensure_tables() + await artifact_store.ensure_table() + +Alternatively, configure SQLSpec migrations for your database and run the +migration CLI as part of deployment: + +.. code-block:: console + + sqlspec upgrade Next Steps ========== -- :doc:`backends` for adapter-specific configuration. -- :doc:`schema` for table layouts and indexes. +- :doc:`backends` for the full support matrix and backend-specific details. +- :doc:`schema` for table layouts, EventRecord format, and scoped state semantics. +- :doc:`api` for the complete API reference. diff --git a/docs/extensions/adk/schema.rst b/docs/extensions/adk/schema.rst index 18a1482bd..82bdd5ef2 100644 --- a/docs/extensions/adk/schema.rst +++ b/docs/extensions/adk/schema.rst @@ -2,7 +2,308 @@ Schema ====== -ADK stores create tables for sessions, events, and memory entries. Table names -and schemas are configurable via the store config. +ADK stores create tables for sessions, events, memory entries, and artifact +metadata. Table names are configurable via ``extension_config["adk"]``. -Use ``create_tables()`` on a store to apply the schema. +You can programmatically create the schema with ``create_tables()`` or +``ensure_tables()`` on a store. For managed deployments, configure SQLSpec +migrations for the target database and run ``sqlspec upgrade`` instead. + +.. contents:: On this page + :local: + :depth: 2 + +Sessions Table +============== + +The sessions table stores agent session metadata and durable state. + +Default name: ``adk_sessions`` + +.. list-table:: + :header-rows: 1 + + * - Column + - Type + - Description + * - ``id`` + - ``VARCHAR`` / ``TEXT`` + - Primary key. UUID assigned by the service layer. + * - ``app_name`` + - ``VARCHAR`` / ``TEXT`` + - Application identifier. + * - ``user_id`` + - ``VARCHAR`` / ``TEXT`` + - User identifier. + * - ``state`` + - ``JSONB`` / ``JSON`` / ``TEXT`` + - Durable session state (see :ref:`scoped-state`). + * - ``create_time`` + - ``TIMESTAMP`` + - When the session was created (UTC). + * - ``update_time`` + - ``TIMESTAMP`` + - Last state update time (UTC). + +An optional ``owner_id`` column can be added via ``owner_id_column`` in the ADK +config for multi-tenant deployments. + +.. _event-record: + +Events Table (EventRecord) +========================== + +The events table uses **full-event JSON storage**: the entire ADK ``Event`` is +serialized into a single ``event_json`` column alongside a small set of indexed +scalar columns used for query filtering. + +This design eliminates column drift with upstream ADK releases. New ``Event`` +fields are automatically captured in ``event_json`` without schema changes. + +Default name: ``adk_events`` + +.. list-table:: + :header-rows: 1 + + * - Column + - Type + - Description + * - ``session_id`` + - ``VARCHAR`` / ``TEXT`` + - Foreign key to the sessions table. + * - ``invocation_id`` + - ``VARCHAR`` / ``TEXT`` + - ADK invocation identifier (indexed for filtering). + * - ``author`` + - ``VARCHAR`` / ``TEXT`` + - Event author: ``"user"``, ``"agent"``, or ``"system"``. + * - ``timestamp`` + - ``TIMESTAMP`` + - Event timestamp (UTC, indexed for range queries). + * - ``event_json`` + - ``JSONB`` / ``JSON`` / ``TEXT`` + - Full ADK Event serialized via ``Event.model_dump()``. + +**Serialization and reconstruction:** + +Events are converted to records via ``event_to_record()``, which calls +``event.model_dump(exclude_none=True, mode="json")`` to produce the JSON blob. +Reconstruction is lossless: ``record_to_event()`` restores the full ``Event`` +via ``Event.model_validate()``. + +.. code-block:: python + + from sqlspec.extensions.adk.converters import event_to_record, record_to_event + + # Serialize: Event -> EventRecord + record = event_to_record(event=adk_event, session_id="sess_123") + + # Reconstruct: EventRecord -> Event + restored_event = record_to_event(record) + +.. _scoped-state: + +Scoped State Semantics +====================== + +ADK uses key prefixes to scope state visibility across sessions. SQLSpec +respects these prefixes when persisting and loading state. + +.. list-table:: + :header-rows: 1 + + * - Prefix + - Scope + - Persisted + - Description + * - ``app:`` + - Application + - Yes + - Shared across all sessions for the same ``app_name``. + * - ``user:`` + - User + - Yes + - Shared across all sessions for the same ``app_name`` + ``user_id``. + * - ``temp:`` + - Runtime + - **No** + - Process-local state. Stripped before every write to storage. + * - *(no prefix)* + - Session + - Yes + - Private to a single session. + +**How scoped state is handled:** + +1. On ``create_session()``, the service strips ``temp:`` keys before the + initial ``INSERT``. + +2. On ``append_event()``, the service calls ``filter_temp_state()`` to produce + a durable state snapshot, then calls ``append_event_and_update_state()`` to + atomically persist the event and the state update. + +3. On ``get_session()``, state is loaded from the database. Since ``temp:`` + keys were never written, they are absent from the loaded state. + +.. code-block:: python + + from sqlspec.extensions.adk.converters import filter_temp_state, split_scoped_state + + state = { + "app:model_version": "v2", + "user:preferences": {"theme": "dark"}, + "temp:scratch_pad": "...", + "conversation_turn": 5, + } + + # Strip temp keys before persisting + durable = filter_temp_state(state) + # {"app:model_version": "v2", "user:preferences": {...}, "conversation_turn": 5} + + # Split into scoped buckets + app_state, user_state, session_state = split_scoped_state(durable) + # app_state: {"app:model_version": "v2"} + # user_state: {"user:preferences": {"theme": "dark"}} + # session_state: {"conversation_turn": 5} + +.. _append-event-contract: + +The ``append_event_and_update_state()`` Contract +================================================= + +This method is the **authoritative durable write boundary** for post-creation +session mutations. It atomically: + +1. Inserts the event record into the events table. +2. Updates the session's durable state in the sessions table. + +Both operations succeed together or fail together within a single database +transaction. + +.. code-block:: python + + # Called by SQLSpecSessionService.append_event() internally: + await store.append_event_and_update_state( + event_record=event_record, + session_id=session.id, + state=durable_state, # temp: keys already stripped + ) + +**Why this matters:** + +- Prevents state from advancing without the corresponding event being recorded. +- Prevents orphaned events that reference a stale session state. +- Ensures that on session reload, the state always reflects all persisted events. + +Every backend store implements this as a single transaction (or equivalent +atomic operation for the backend's concurrency model). + +Memory Table +============ + +The memory table stores long-term context entries that agents can search and +reference across sessions. + +Default name: ``adk_memory_entries`` + +.. list-table:: + :header-rows: 1 + + * - Column + - Type + - Description + * - ``id`` + - ``VARCHAR`` / ``TEXT`` + - Primary key. + * - ``session_id`` + - ``VARCHAR`` / ``TEXT`` + - Session that produced this memory. + * - ``app_name`` + - ``VARCHAR`` / ``TEXT`` + - Application identifier. + * - ``user_id`` + - ``VARCHAR`` / ``TEXT`` + - User identifier. + * - ``content_text`` + - ``TEXT`` + - Searchable text content (used by FTS). + * - ``content_json`` + - ``JSONB`` / ``JSON`` / ``TEXT`` + - Structured content. + * - ``inserted_at`` + - ``TIMESTAMP`` + - When the entry was created. + +When ``memory_use_fts`` is enabled in the ADK config, backends create +full-text search indexes on ``content_text`` using the database's native +FTS engine (tsvector, FTS5, InnoDB FT, etc.). + +.. _artifact-schema: + +Artifact Metadata Table +======================= + +The artifact table stores versioning metadata for binary artifacts. Content +bytes are stored separately in object storage; this table tracks ownership, +versioning, and canonical URIs. + +Default name: ``adk_artifact_versions`` + +.. list-table:: + :header-rows: 1 + + * - Column + - Type + - Description + * - ``app_name`` + - ``VARCHAR`` / ``TEXT`` + - Application identifier. + * - ``user_id`` + - ``VARCHAR`` / ``TEXT`` + - User identifier. + * - ``session_id`` + - ``VARCHAR`` / ``TEXT`` (nullable) + - Session identifier. NULL for user-scoped artifacts. + * - ``filename`` + - ``VARCHAR`` / ``TEXT`` + - Artifact filename. + * - ``version`` + - ``INTEGER`` + - Monotonically increasing version (starts at 0). + * - ``mime_type`` + - ``VARCHAR`` / ``TEXT`` (nullable) + - MIME type of the artifact content. + * - ``canonical_uri`` + - ``VARCHAR`` / ``TEXT`` + - URI pointing to content in object storage. + * - ``custom_metadata`` + - ``JSONB`` / ``JSON`` / ``TEXT`` (nullable) + - User-defined metadata. + * - ``created_at`` + - ``TIMESTAMP`` + - When this version was created. + +The composite key is ``(app_name, user_id, session_id, filename, version)``. + +Table Name Configuration +======================== + +All table names are configurable: + +.. code-block:: python + + config = AsyncpgConfig( + connection_config={"dsn": "postgresql://..."}, + extension_config={ + "adk": { + "session_table": "my_sessions", # default: "adk_sessions" + "events_table": "my_events", # default: "adk_events" + "memory_table": "my_memory", # default: "adk_memory_entries" + "artifact_table": "my_artifacts", # default: "adk_artifact_versions" + } + }, + ) + +Table names are validated on store initialization: they must start with a +letter or underscore, contain only alphanumeric characters and underscores, +and be at most 63 characters long. diff --git a/docs/reference/extensions/adk.rst b/docs/reference/extensions/adk.rst index aa0ca33a9..3c08b6ed7 100644 --- a/docs/reference/extensions/adk.rst +++ b/docs/reference/extensions/adk.rst @@ -2,7 +2,7 @@ Google ADK ========== -Session, event, and memory storage backends for +Session, event, memory, and artifact storage backends for `Google Agent Development Kit `_. Session Service @@ -23,6 +23,13 @@ Memory Services :members: :show-inheritance: +Artifact Service +================ + +.. autoclass:: sqlspec.extensions.adk.SQLSpecArtifactService + :members: + :show-inheritance: + Store Base Classes ================== @@ -45,6 +52,17 @@ Memory Store Base Classes :members: :show-inheritance: +Artifact Store Base Classes +=========================== + +.. autoclass:: sqlspec.extensions.adk.BaseAsyncADKArtifactStore + :members: + :show-inheritance: + +.. autoclass:: sqlspec.extensions.adk.BaseSyncADKArtifactStore + :members: + :show-inheritance: + Record Types ============ @@ -59,3 +77,14 @@ Record Types .. autoclass:: sqlspec.extensions.adk.MemoryRecord :members: :show-inheritance: + +.. autoclass:: sqlspec.extensions.adk.ArtifactRecord + :members: + :show-inheritance: + +Configuration +============= + +.. autoclass:: sqlspec.extensions.adk.ADKConfig + :members: + :show-inheritance: diff --git a/docs/usage/cli.rst b/docs/usage/cli.rst index f1ae3cef9..1bbdfe740 100644 --- a/docs/usage/cli.rst +++ b/docs/usage/cli.rst @@ -4,15 +4,18 @@ Command Line Interface SQLSpec includes a CLI for managing migrations and inspecting configuration. Use it when you want a fast, explicit workflow without additional tooling. +Configuration can come from ``--config``, ``SQLSPEC_CONFIG``, or +``[tool.sqlspec]`` in ``pyproject.toml``. + Core Commands ------------- .. code-block:: console - sqlspec db init - sqlspec db create-migration -m "add users" - sqlspec db upgrade - sqlspec db downgrade + sqlspec init + sqlspec create-migration -m "add users" + sqlspec upgrade + sqlspec downgrade Common Options -------------- @@ -28,7 +31,7 @@ Tips ---- - Run ``sqlspec --help`` to see global options. -- Run ``sqlspec db --help`` to see migration command details. +- Run ``sqlspec upgrade --help`` to see command-specific migration options. Related Guides -------------- diff --git a/docs/usage/migrations.rst b/docs/usage/migrations.rst index 7d4aa5826..2e5cd1b62 100644 --- a/docs/usage/migrations.rst +++ b/docs/usage/migrations.rst @@ -21,9 +21,9 @@ Common Commands .. code-block:: console - sqlspec db init - sqlspec db create-migration -m "add users" - sqlspec db upgrade + sqlspec init + sqlspec create-migration -m "add users" + sqlspec upgrade Configuration ------------- @@ -31,6 +31,9 @@ Configuration Set ``migration_config`` on your database configuration to customize script locations, version table names, and extension migration behavior. +The migration CLI resolves config from ``--config``, ``SQLSPEC_CONFIG``, or +``[tool.sqlspec]`` in ``pyproject.toml``. + .. code-block:: python from sqlspec.adapters.duckdb import DuckDBConfig @@ -38,7 +41,7 @@ locations, version table names, and extension migration behavior. config = DuckDBConfig( connection_config={"database": "/tmp/analytics.db"}, migration_config={ - "migration_dir": "migrations/duckdb", + "script_location": "migrations/duckdb", "version_table": "_schema_versions", }, ) @@ -60,11 +63,17 @@ For async configs, ``migrate_up()`` returns an awaitable: config = AsyncpgConfig( connection_config={"dsn": "postgresql://localhost/app"}, - migration_config={"migration_dir": "migrations/postgres"}, + migration_config={"script_location": "migrations/postgres"}, ) await config.migrate_up() +Extension migrations are auto-included when the corresponding entry exists in +``extension_config``. Use ``migration_config["exclude_extensions"]`` to skip a +specific extension, ``migration_config["include_extensions"]`` to opt in +explicitly by extension name, or ``migration_config["enabled"] = False`` to +disable migrations entirely for a database config. + Logging and Echo Controls ------------------------- diff --git a/pyproject.toml b/pyproject.toml index cd9e7484a..53240438b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ maintainers = [{ name = "Litestar Developers", email = "hello@litestar.dev" }] name = "sqlspec" readme = "README.md" requires-python = ">=3.10, <4.0" -version = "0.41.1" +version = "0.42.0" [project.urls] Discord = "https://discord.gg/litestar" @@ -43,7 +43,7 @@ attrs = ["attrs", "cattrs"] bigquery = ["google-cloud-bigquery", "google-cloud-storage"] cloud-sql = ["cloud-sql-python-connector"] cockroachdb = ["psycopg[binary,pool]", "asyncpg"] -duckdb = ["duckdb"] +duckdb = ["duckdb", "pytz"] fastapi = ["fastapi"] flask = ["flask"] fsspec = ["fsspec"] @@ -254,7 +254,7 @@ opt_level = "3" # Maximum optimization (0-3) allow_dirty = true commit = false commit_args = "--no-verify" -current_version = "0.41.1" +current_version = "0.42.0" ignore_missing_files = false ignore_missing_version = false message = "chore(release): bump to v{new_version}" @@ -296,7 +296,7 @@ version = "{current_version}" """ [tool.codespell] -ignore-words-list = "te,ECT,SELCT,froms,ccompiler" +ignore-words-list = "te,ECT,SELCT,froms,ccompiler,BRIN" skip = 'uv.lock' [tool.coverage.run] diff --git a/sqlspec/adapters/adbc/adk/store.py b/sqlspec/adapters/adbc/adk/store.py index 65e5c4975..50d6c72f4 100644 --- a/sqlspec/adapters/adbc/adk/store.py +++ b/sqlspec/adapters/adbc/adk/store.py @@ -1,12 +1,14 @@ """ADBC ADK store for Google Agent Development Kit session/event storage.""" +import contextlib from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Final -from sqlspec.extensions.adk import BaseSyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore +from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json +from sqlspec.utils.sync_tools import async_, run_ if TYPE_CHECKING: from sqlspec.adapters.adbc.config import AdbcConfig @@ -25,16 +27,23 @@ ADBC_TABLE_NOT_FOUND_PATTERNS: Final = ("no such table", "table or view does not exist", "relation does not exist") -class AdbcADKStore(BaseSyncADKStore["AdbcConfig"]): +class AdbcADKStore(BaseAsyncADKStore["AdbcConfig"]): """ADBC synchronous ADK store for Arrow Database Connectivity. Implements session and event storage for Google Agent Development Kit using ADBC. ADBC provides a vendor-neutral API with Arrow-native data transfer across multiple databases (PostgreSQL, SQLite, DuckDB, etc.). + Events use the new 5-column contract: session_id, invocation_id, author, + timestamp, and event_json. The full ADK Event payload is stored as a + single JSON blob in event_json using a dialect-appropriate column type + (JSONB for PostgreSQL, JSON for DuckDB, VARIANT for Snowflake, TEXT for + SQLite and generic fallback). + Provides: - - Session state management with JSON serialization (TEXT storage) - - Event history tracking with BLOB-serialized actions + - Session state management with JSON serialization + - Event history tracking via single event_json blob + - Atomic event insert + session state update - Timezone-aware timestamps - Foreign key constraints with cascade delete - Database-agnostic SQL (supports multiple backends) @@ -60,12 +69,9 @@ class AdbcADKStore(BaseSyncADKStore["AdbcConfig"]): store.ensure_tables() Notes: - - TEXT for JSON storage (compatible across all ADBC backends) - - BLOB for pre-serialized actions from Google ADK + - Dialect-appropriate JSON type for event_json storage - TIMESTAMP for timezone-aware timestamps (driver-dependent precision) - - INTEGER for booleans (0/1/NULL) - - Parameter style varies by backend (?, $1, :name, etc.) - - Uses dialect-agnostic SQL for maximum compatibility + - Parameter style: ``?`` universally across ADBC backends - State and JSON fields use to_json/from_json for serialization - ADBC drivers handle parameter binding automatically - Configuration is read from config.extension_config["adk"] @@ -171,7 +177,7 @@ def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": return None return from_json(str(data)) # type: ignore[no-any-return] - def _get_create_sessions_table_sql(self) -> str: + async def _get_create_sessions_table_sql(self) -> str: """Get CREATE TABLE SQL for sessions with dialect dispatch. Returns: @@ -277,7 +283,7 @@ def _get_sessions_ddl_generic(self) -> str: ) """ - def _get_create_events_table_sql(self) -> str: + async def _get_create_events_table_sql(self) -> str: """Get CREATE TABLE SQL for events with dialect dispatch. Returns: @@ -298,27 +304,17 @@ def _get_events_ddl_postgresql(self) -> str: Returns: SQL to create events table optimized for PostgreSQL. + + Notes: + Uses JSONB for event_json to enable indexing and query support. """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256), - author VARCHAR(256), - actions BYTEA, - long_running_tool_ids_json TEXT, - branch VARCHAR(256), + invocation_id VARCHAR(256) NOT NULL, + author VARCHAR(256) NOT NULL, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - content JSONB, - grounding_metadata JSONB, - custom_metadata JSONB, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ) """ @@ -328,27 +324,17 @@ def _get_events_ddl_sqlite(self) -> str: Returns: SQL to create events table optimized for SQLite. + + Notes: + Uses TEXT for event_json (SQLite has no native JSON column type). """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id TEXT PRIMARY KEY, session_id TEXT NOT NULL, - app_name TEXT NOT NULL, - user_id TEXT NOT NULL, - invocation_id TEXT, - author TEXT, - actions BLOB, - long_running_tool_ids_json TEXT, - branch TEXT, + invocation_id TEXT NOT NULL, + author TEXT NOT NULL, timestamp REAL NOT NULL, - content TEXT, - grounding_metadata TEXT, - custom_metadata TEXT, - partial INTEGER, - turn_complete INTEGER, - interrupted INTEGER, - error_code TEXT, - error_message TEXT, + event_json TEXT NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ) """ @@ -358,27 +344,17 @@ def _get_events_ddl_duckdb(self) -> str: Returns: SQL to create events table optimized for DuckDB. + + Notes: + Uses JSON for event_json (DuckDB native JSON type). """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256), - author VARCHAR(256), - actions BLOB, - long_running_tool_ids_json VARCHAR, - branch VARCHAR(256), + invocation_id VARCHAR(256) NOT NULL, + author VARCHAR(256) NOT NULL, timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - content JSON, - grounding_metadata JSON, - custom_metadata JSON, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ) """ @@ -388,27 +364,17 @@ def _get_events_ddl_snowflake(self) -> str: Returns: SQL to create events table optimized for Snowflake. + + Notes: + Uses VARIANT for event_json (Snowflake semi-structured type). """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR PRIMARY KEY, session_id VARCHAR NOT NULL, - app_name VARCHAR NOT NULL, - user_id VARCHAR NOT NULL, - invocation_id VARCHAR, - author VARCHAR, - actions BINARY, - long_running_tool_ids_json VARCHAR, - branch VARCHAR, + invocation_id VARCHAR NOT NULL, + author VARCHAR NOT NULL, timestamp TIMESTAMP_TZ NOT NULL DEFAULT CURRENT_TIMESTAMP(), - content VARIANT, - grounding_metadata VARIANT, - custom_metadata VARIANT, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR, - error_message VARCHAR, + event_json VARIANT NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ) """ @@ -418,27 +384,17 @@ def _get_events_ddl_generic(self) -> str: Returns: SQL to create events table using generic types. + + Notes: + Uses TEXT for event_json (maximum portability). """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256), - author VARCHAR(256), - actions BLOB, - long_running_tool_ids_json TEXT, - branch VARCHAR(256), + invocation_id VARCHAR(256) NOT NULL, + author VARCHAR(256) NOT NULL, timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - content TEXT, - grounding_metadata TEXT, - custom_metadata TEXT, - partial INTEGER, - turn_complete INTEGER, - interrupted INTEGER, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json TEXT NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ) """ @@ -455,14 +411,14 @@ def _get_drop_tables_sql(self) -> "list[str]": """ return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - def create_tables(self) -> None: + def _create_tables(self) -> None: """Create both sessions and events tables if they don't exist.""" with self._config.provide_connection() as conn: cursor = conn.cursor() try: self._enable_foreign_keys(cursor, conn) - cursor.execute(self._get_create_sessions_table_sql()) + cursor.execute(run_(self._get_create_sessions_table_sql)()) conn.commit() sessions_idx_app_user = ( @@ -479,7 +435,7 @@ def create_tables(self) -> None: cursor.execute(sessions_idx_update) conn.commit() - cursor.execute(self._get_create_events_table_sql()) + cursor.execute(run_(self._get_create_events_table_sql)()) conn.commit() events_idx = ( @@ -491,6 +447,10 @@ def create_tables(self) -> None: finally: cursor.close() # type: ignore[no-untyped-call] + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + def _enable_foreign_keys(self, cursor: Any, conn: Any) -> None: """Enable foreign key constraints for SQLite. @@ -508,7 +468,7 @@ def _enable_foreign_keys(self, cursor: Any, conn: Any) -> None: except Exception: logger.debug("Foreign key enforcement not supported or already enabled") - def create_session( + def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: """Create a new session. @@ -548,9 +508,19 @@ def create_session( finally: cursor.close() # type: ignore[no-untyped-call] - return self.get_session(session_id) # type: ignore[return-value] + result = self._get_session(session_id) + if result is None: + msg = "Failed to fetch created session" + raise RuntimeError(msg) + return result + + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - def get_session(self, session_id: str) -> "SessionRecord | None": + def _get_session(self, session_id: str) -> "SessionRecord | None": """Get session by ID. Args: @@ -594,7 +564,11 @@ def get_session(self, session_id: str) -> "SessionRecord | None": return None raise - def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def get_session(self, session_id: str) -> "SessionRecord | None": + """Get session by ID.""" + return await async_(self._get_session)(session_id) + + def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: """Update session state. Args: @@ -620,7 +594,11 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None finally: cursor.close() # type: ignore[no-untyped-call] - def delete_session(self, session_id: str) -> None: + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + await async_(self._update_session_state)(session_id, state) + + def _delete_session(self, session_id: str) -> None: """Delete session and all associated events (cascade). Args: @@ -640,7 +618,11 @@ def delete_session(self, session_id: str) -> None: finally: cursor.close() # type: ignore[no-untyped-call] - def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + async def delete_session(self, session_id: str) -> None: + """Delete session and associated events.""" + await async_(self._delete_session)(session_id) + + def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app, optionally filtered by user. Args: @@ -696,149 +678,144 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess return [] raise - def create_event( - self, - event_id: str, - session_id: str, - app_name: str, - user_id: str, - author: "str | None" = None, - actions: "bytes | None" = None, - content: "dict[str, Any] | None" = None, - **kwargs: Any, - ) -> "EventRecord": - """Create a new event. + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return await async_(self._list_sessions)(app_name, user_id) - Args: - event_id: Unique event identifier. - session_id: Session identifier. - app_name: Application name. - user_id: User identifier. - author: Event author (user/assistant/system). - actions: Pickled actions object. - content: Event content (JSON). - **kwargs: Additional optional fields. - - Returns: - Created event record. + def _insert_event(self, event_record: "EventRecord") -> None: + """Insert an event record into the events table. - Notes: - Uses CURRENT_TIMESTAMP for timestamp if not provided. - JSON fields are serialized to JSON strings. - Boolean fields are converted to INTEGER (0/1). + Args: + event_record: Event record to store. """ - content_json = self._serialize_json_field(content) - grounding_metadata_json = self._serialize_json_field(kwargs.get("grounding_metadata")) - custom_metadata_json = self._serialize_json_field(kwargs.get("custom_metadata")) - - partial_int = self._to_int_bool(kwargs.get("partial")) - turn_complete_int = self._to_int_bool(kwargs.get("turn_complete")) - interrupted_int = self._to_int_bool(kwargs.get("interrupted")) - + event_json = self._serialize_json_field(event_record["event_json"]) sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? - ) + session_id, invocation_id, author, timestamp, event_json + ) VALUES (?, ?, ?, ?, ?) """ - timestamp = kwargs.get("timestamp") - if timestamp is None: - timestamp = datetime.now(timezone.utc) - with self._config.provide_connection() as conn: cursor = conn.cursor() try: cursor.execute( sql, ( - event_id, - session_id, - app_name, - user_id, - kwargs.get("invocation_id"), - author, - actions, - kwargs.get("long_running_tool_ids_json"), - kwargs.get("branch"), - timestamp, - content_json, - grounding_metadata_json, - custom_metadata_json, - partial_int, - turn_complete_int, - interrupted_int, - kwargs.get("error_code"), - kwargs.get("error_message"), + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_json, ), ) conn.commit() finally: cursor.close() # type: ignore[no-untyped-call] - events = self.list_events(session_id) - for event in events: - if event["id"] == event_id: - return event + def _append_event_and_update_state( + self, event_record: "EventRecord", session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically insert an event and update the session's durable state. - msg = f"Failed to retrieve created event {event_id}" - raise RuntimeError(msg) + The event insert and state update are executed within a single + connection and committed together. If either statement fails the + transaction is rolled back so the two writes remain consistent. - def list_events(self, session_id: str) -> "list[EventRecord]": + Args: + event_record: Event record to store. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot (``temp:`` keys already + stripped by the service layer). + """ + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES (?, ?, ?, ?, ?) + """ + update_sql = f""" + UPDATE {self._session_table} + SET state = ?, update_time = CURRENT_TIMESTAMP + WHERE id = ? + """ + state_json = self._serialize_state(state) + event_json = self._serialize_json_field(event_record["event_json"]) + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute( + insert_sql, + ( + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_json, + ), + ) + cursor.execute(update_sql, (state_json, session_id)) + conn.commit() + except Exception: + with contextlib.suppress(Exception): + conn.rollback() + raise + finally: + cursor.close() # type: ignore[no-untyped-call] + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state.""" + await async_(self._append_event_and_update_state)(event_record, session_id, state) + + def _get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": """List events for a session ordered by timestamp. Args: session_id: Session identifier. + after_timestamp: Only return events after this time. + limit: Maximum number of events to return. Returns: List of event records ordered by timestamp ASC. Notes: Uses index on (session_id, timestamp ASC). - JSON fields deserialized from JSON strings. - Converts INTEGER booleans to Python bool. + Returns the 5-column EventRecord (session_id, invocation_id, + author, timestamp, event_json). """ + where_clauses = ["session_id = ?"] + params: list[Any] = [session_id] + + if after_timestamp is not None: + where_clauses.append("timestamp > ?") + params.append(after_timestamp) + + where_clause = " AND ".join(where_clauses) + limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} - WHERE session_id = ? - ORDER BY timestamp ASC + WHERE {where_clause} + ORDER BY timestamp ASC{limit_clause} """ try: with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute(sql, (session_id,)) + cursor.execute(sql, params) rows = cursor.fetchall() return [ EventRecord( - id=row[0], - session_id=row[1], - app_name=row[2], - user_id=row[3], - invocation_id=row[4], - author=row[5], - actions=bytes(row[6]) if row[6] is not None else b"", - long_running_tool_ids_json=row[7], - branch=row[8], - timestamp=row[9], - content=self._deserialize_json_field(row[10]), - grounding_metadata=self._deserialize_json_field(row[11]), - custom_metadata=self._deserialize_json_field(row[12]), - partial=self._from_int_bool(row[13]), - turn_complete=self._from_int_bool(row[14]), - interrupted=self._from_int_bool(row[15]), - error_code=row[16], - error_message=row[17], + session_id=row[0], + invocation_id=row[1], + author=row[2], + timestamp=row[3], + event_json=self._deserialize_json_field(row[4]) or {}, ) for row in rows ] @@ -850,36 +827,22 @@ def list_events(self, session_id: str) -> "list[EventRecord]": return [] raise - @staticmethod - def _to_int_bool(value: "bool | None") -> "int | None": - """Convert Python boolean to INTEGER (0/1). + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session.""" + return await async_(self._get_events)(session_id, after_timestamp, limit) - Args: - value: Python boolean value or None. - - Returns: - 1 for True, 0 for False, None for None. - """ - if value is None: - return None - return 1 if value else 0 + def _append_event(self, event_record: EventRecord) -> None: + """Synchronous implementation of append_event.""" + self._insert_event(event_record) - @staticmethod - def _from_int_bool(value: "int | None") -> "bool | None": - """Convert INTEGER to Python boolean. + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + await async_(self._append_event)(event_record) - Args: - value: INTEGER value (0, 1, or None). - Returns: - Python boolean or None. - """ - if value is None: - return None - return bool(value) - - -class AdbcADKMemoryStore(BaseSyncADKMemoryStore["AdbcConfig"]): +class AdbcADKMemoryStore(BaseAsyncADKMemoryStore["AdbcConfig"]): """ADBC synchronous ADK memory store for Arrow Database Connectivity.""" __slots__ = ("_dialect",) @@ -924,7 +887,7 @@ def _decode_timestamp(self, value: Any) -> datetime: return datetime.fromisoformat(value) return datetime.fromisoformat(str(value)) - def _get_create_memory_table_sql(self) -> str: + async def _get_create_memory_table_sql(self) -> str: if self._dialect == DIALECT_POSTGRESQL: return self._get_memory_ddl_postgresql() if self._dialect == DIALECT_SQLITE: @@ -1028,14 +991,14 @@ def _get_memory_ddl_generic(self) -> str: def _get_drop_memory_table_sql(self) -> "list[str]": return [f"DROP TABLE IF EXISTS {self._memory_table}"] - def create_tables(self) -> None: + def _create_tables(self) -> None: if not self._enabled: return with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute(self._get_create_memory_table_sql()) + cursor.execute(run_(self._get_create_memory_table_sql)()) conn.commit() idx_app_user = ( @@ -1053,7 +1016,11 @@ def create_tables(self) -> None: finally: cursor.close() # type: ignore[no-untyped-call] - def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -1159,7 +1126,11 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object return inserted_count - def search_entries( + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + + def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": if not self._enabled: @@ -1199,7 +1170,13 @@ def search_entries( return self._rows_to_records(rows) - def delete_entries_by_session(self, session_id: str) -> int: + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + + def _delete_entries_by_session(self, session_id: str) -> int: use_returning = self._dialect in {DIALECT_SQLITE, DIALECT_POSTGRESQL, DIALECT_DUCKDB} if use_returning: sql = f"DELETE FROM {self._memory_table} WHERE session_id = ? RETURNING 1" @@ -1218,7 +1195,11 @@ def delete_entries_by_session(self, session_id: str) -> int: finally: cursor.close() # type: ignore[no-untyped-call] - def delete_entries_older_than(self, days: int) -> int: + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + def _delete_entries_older_than(self, days: int) -> int: cutoff = self._encode_timestamp(datetime.now(timezone.utc) - timedelta(days=days)) use_returning = self._dialect in {DIALECT_SQLITE, DIALECT_POSTGRESQL, DIALECT_DUCKDB} if use_returning: @@ -1238,6 +1219,10 @@ def delete_entries_older_than(self, days: int) -> int: finally: cursor.close() # type: ignore[no-untyped-call] + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) + def _rows_to_records(self, rows: "list[Any]") -> "list[MemoryRecord]": records: list[MemoryRecord] = [] for row in rows: diff --git a/sqlspec/adapters/aiosqlite/adk/store.py b/sqlspec/adapters/aiosqlite/adk/store.py index 7a63095f1..f68805297 100644 --- a/sqlspec/adapters/aiosqlite/adk/store.py +++ b/sqlspec/adapters/aiosqlite/adk/store.py @@ -52,34 +52,6 @@ def _julian_to_datetime(julian: float) -> datetime: return datetime.fromtimestamp(timestamp, tz=timezone.utc) -def _to_sqlite_bool(value: "bool | None") -> "int | None": - """Convert Python bool to SQLite INTEGER. - - Args: - value: Boolean value or None. - - Returns: - 1 for True, 0 for False, None for None. - """ - if value is None: - return None - return 1 if value else 0 - - -def _from_sqlite_bool(value: "int | None") -> "bool | None": - """Convert SQLite INTEGER to Python bool. - - Args: - value: Integer value (0/1) or None. - - Returns: - True for 1, False for 0, None for None. - """ - if value is None: - return None - return bool(value) - - class AiosqliteADKStore(BaseAsyncADKStore["AiosqliteConfig"]): """Aiosqlite ADK store using asynchronous SQLite driver. @@ -88,10 +60,11 @@ class AiosqliteADKStore(BaseAsyncADKStore["AiosqliteConfig"]): Provides: - Session state management with JSON storage (as TEXT) - - Event history tracking with BLOB-serialized actions + - Event history tracking with full-event JSON storage - Julian Day timestamps (REAL) for efficient date operations - Foreign key constraints with cascade delete - - Efficient upserts using INSERT OR REPLACE + - Atomic event+state writes via append_event_and_update_state + - PRAGMA optimization profile for file-based databases Args: config: AiosqliteConfig with extension_config["adk"] settings. @@ -114,9 +87,8 @@ class AiosqliteADKStore(BaseAsyncADKStore["AiosqliteConfig"]): Notes: - JSON stored as TEXT with SQLSpec serializers (msgspec/orjson/stdlib) - - BOOLEAN as INTEGER (0/1, with None for NULL) - Timestamps as REAL (Julian day: julianday('now')) - - BLOB for pre-serialized actions from Google ADK + - Full event stored as JSON TEXT in event_data column - PRAGMA foreign_keys = ON (enable per connection) - Configuration is read from config.extension_config["adk"] """ @@ -136,6 +108,22 @@ def __init__(self, config: "AiosqliteConfig") -> None: """ super().__init__(config) + async def _apply_pragmas(self, connection: Any) -> None: + """Apply PRAGMA optimization profile for this connection. + + Args: + connection: Aiosqlite connection. + + Notes: + Enables foreign keys and applies performance PRAGMAs. + For file-based databases, adds cache_size, mmap_size, + and journal_size_limit optimizations. + """ + await connection.execute("PRAGMA foreign_keys = ON") + await connection.execute("PRAGMA cache_size = -64000") + await connection.execute("PRAGMA mmap_size = 30000000") + await connection.execute("PRAGMA journal_size_limit = 67108864") + async def _get_create_sessions_table_sql(self) -> str: """Get SQLite CREATE TABLE SQL for sessions. @@ -170,9 +158,8 @@ async def _get_create_events_table_sql(self) -> str: SQL statement to create adk_events table with indexes. Notes: - - TEXT for IDs, strings, and JSON content - - BLOB for pickled actions - - INTEGER for booleans (0/1/NULL) + - TEXT for IDs and indexed scalars + - TEXT for full event JSON (event_data) - REAL for Julian Day timestamps - Foreign key to sessions with CASCADE delete - Index on (session_id, timestamp ASC) @@ -181,22 +168,10 @@ async def _get_create_events_table_sql(self) -> str: CREATE TABLE IF NOT EXISTS {self._events_table} ( id TEXT PRIMARY KEY, session_id TEXT NOT NULL, - app_name TEXT NOT NULL, - user_id TEXT NOT NULL, - invocation_id TEXT NOT NULL, - author TEXT NOT NULL, - actions BLOB NOT NULL, - long_running_tool_ids_json TEXT, - branch TEXT, + invocation_id TEXT, + author TEXT, timestamp REAL NOT NULL, - content TEXT, - grounding_metadata TEXT, - custom_metadata TEXT, - partial INTEGER, - turn_complete INTEGER, - interrupted INTEGER, - error_code TEXT, - error_message TEXT, + event_data TEXT NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session @@ -215,21 +190,10 @@ def _get_drop_tables_sql(self) -> "list[str]": """ return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - async def _enable_foreign_keys(self, connection: Any) -> None: - """Enable foreign key constraints for this connection. - - Args: - connection: Aiosqlite connection. - - Notes: - SQLite requires PRAGMA foreign_keys = ON per connection. - """ - await connection.execute("PRAGMA foreign_keys = ON") - async def create_tables(self) -> None: """Create both sessions and events tables if they don't exist.""" async with self._config.provide_session() as driver: - await self._enable_foreign_keys(driver.connection) + await self._apply_pragmas(driver.connection) await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) @@ -250,11 +214,11 @@ async def create_session( Notes: Uses Julian Day for create_time and update_time. - State is JSON-serialized before insertion. + State is always JSON-serialized (empty dict becomes '{}', never NULL). """ now = datetime.now(timezone.utc) now_julian = _datetime_to_julian(now) - state_json = to_json(state) if state else None + state_json = to_json(state) params: tuple[Any, ...] if self._owner_id_column_name: @@ -272,7 +236,7 @@ async def create_session( params = (session_id, app_name, user_id, state_json, now_julian, now_julian) async with self._config.provide_connection() as conn: - await self._enable_foreign_keys(conn) + await self._apply_pragmas(conn) await conn.execute(sql, params) await conn.commit() @@ -300,7 +264,7 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": """ async with self._config.provide_connection() as conn: - await self._enable_foreign_keys(conn) + await self._apply_pragmas(conn) cursor = await conn.execute(sql, (session_id,)) row = await cursor.fetchone() @@ -326,9 +290,10 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") - Notes: This replaces the entire state dictionary. Updates update_time to current Julian Day. + Empty dict is serialized as '{}', never NULL. """ now_julian = _datetime_to_julian(datetime.now(timezone.utc)) - state_json = to_json(state) if state else None + state_json = to_json(state) sql = f""" UPDATE {self._session_table} @@ -337,7 +302,7 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") - """ async with self._config.provide_connection() as conn: - await self._enable_foreign_keys(conn) + await self._apply_pragmas(conn) await conn.execute(sql, (state_json, now_julian, session_id)) await conn.commit() @@ -372,7 +337,7 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis params = (app_name, user_id) async with self._config.provide_connection() as conn: - await self._enable_foreign_keys(conn) + await self._apply_pragmas(conn) cursor = await conn.execute(sql, params) rows = await cursor.fetchall() @@ -400,7 +365,7 @@ async def delete_session(self, session_id: str) -> None: sql = f"DELETE FROM {self._session_table} WHERE id = ?" async with self._config.provide_connection() as conn: - await self._enable_foreign_keys(conn) + await self._apply_pragmas(conn) await conn.execute(sql, (session_id,)) await conn.commit() @@ -408,63 +373,88 @@ async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session. Args: - event_record: Event record to store. + event_record: Event record with 5 keys: session_id, invocation_id, + author, timestamp, event_json. Notes: Uses Julian Day for timestamp. - JSON fields are serialized to TEXT. - Boolean fields converted to INTEGER (0/1/NULL). + event_json dict is serialized to TEXT as event_data column. """ - timestamp_julian = _datetime_to_julian(event_record["timestamp"]) + import uuid - content_json = to_json(event_record.get("content")) if event_record.get("content") else None - grounding_metadata_json = ( - to_json(event_record.get("grounding_metadata")) if event_record.get("grounding_metadata") else None - ) - custom_metadata_json = ( - to_json(event_record.get("custom_metadata")) if event_record.get("custom_metadata") else None - ) - - partial_int = _to_sqlite_bool(event_record.get("partial")) - turn_complete_int = _to_sqlite_bool(event_record.get("turn_complete")) - interrupted_int = _to_sqlite_bool(event_record.get("interrupted")) + timestamp_julian = _datetime_to_julian(event_record["timestamp"]) + event_data_json = to_json(event_record["event_json"]) + event_id = str(uuid.uuid4()) sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? - ) + id, session_id, invocation_id, author, timestamp, event_data + ) VALUES (?, ?, ?, ?, ?, ?) """ async with self._config.provide_connection() as conn: - await self._enable_foreign_keys(conn) + await self._apply_pragmas(conn) await conn.execute( sql, ( - event_record["id"], + event_id, + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + timestamp_julian, + event_data_json, + ), + ) + await conn.commit() + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state. + + Inserts the event and updates the session state + update_time in a + single transaction. Both operations succeed or fail together. + + Args: + event_record: Event record to store. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot (temp: keys already + stripped by the service layer). + """ + import uuid + + timestamp_julian = _datetime_to_julian(event_record["timestamp"]) + event_data_json = to_json(event_record["event_json"]) + now_julian = _datetime_to_julian(datetime.now(timezone.utc)) + state_json = to_json(state) + event_id = str(uuid.uuid4()) + + insert_sql = f""" + INSERT INTO {self._events_table} ( + id, session_id, invocation_id, author, timestamp, event_data + ) VALUES (?, ?, ?, ?, ?, ?) + """ + + update_sql = f""" + UPDATE {self._session_table} + SET state = ?, update_time = ? + WHERE id = ? + """ + + async with self._config.provide_connection() as conn: + await self._apply_pragmas(conn) + await conn.execute( + insert_sql, + ( + event_id, event_record["session_id"], - event_record["app_name"], - event_record["user_id"], event_record["invocation_id"], event_record["author"], - event_record["actions"], - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), timestamp_julian, - content_json, - grounding_metadata_json, - custom_metadata_json, - partial_int, - turn_complete_int, - interrupted_int, - event_record.get("error_code"), - event_record.get("error_message"), + event_data_json, ), ) + await conn.execute(update_sql, (state_json, now_julian, session_id)) await conn.commit() async def get_events( @@ -482,8 +472,7 @@ async def get_events( Notes: Uses index on (session_id, timestamp ASC). - Parses JSON fields and converts BLOB actions to bytes. - Converts INTEGER booleans back to bool/None. + Parses event_data TEXT back to dict for event_json field. """ where_clauses = ["session_id = ?"] params: list[Any] = [session_id] @@ -496,40 +485,24 @@ async def get_events( limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT id, session_id, invocation_id, author, timestamp, event_data FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} """ async with self._config.provide_connection() as conn: - await self._enable_foreign_keys(conn) + await self._apply_pragmas(conn) cursor = await conn.execute(sql, params) rows = await cursor.fetchall() return [ EventRecord( - id=row[0], session_id=row[1], - app_name=row[2], - user_id=row[3], - invocation_id=row[4], - author=row[5], - actions=bytes(row[6]), - long_running_tool_ids_json=row[7], - branch=row[8], - timestamp=_julian_to_datetime(row[9]), - content=from_json(row[10]) if row[10] else None, - grounding_metadata=from_json(row[11]) if row[11] else None, - custom_metadata=from_json(row[12]) if row[12] else None, - partial=_from_sqlite_bool(row[13]), - turn_complete=_from_sqlite_bool(row[14]), - interrupted=_from_sqlite_bool(row[15]), - error_code=row[16], - error_message=row[17], + invocation_id=row[2], + author=row[3], + timestamp=_julian_to_datetime(row[4]), + event_json=from_json(row[5]) if row[5] else {}, ) for row in rows ] diff --git a/sqlspec/adapters/asyncmy/adk/store.py b/sqlspec/adapters/asyncmy/adk/store.py index ff74e6851..446defc78 100644 --- a/sqlspec/adapters/asyncmy/adk/store.py +++ b/sqlspec/adapters/asyncmy/adk/store.py @@ -27,36 +27,16 @@ class AsyncmyADKStore(BaseAsyncADKStore["AsyncmyConfig"]): Implements session and event storage for Google Agent Development Kit using MySQL/MariaDB via the AsyncMy driver. Provides: - Session state management with JSON storage - - Event history tracking with BLOB-serialized actions + - Full-event JSON storage (single ``event_json`` column) + - Atomic event-append + state-update in one transaction - Microsecond-precision timestamps - Foreign key constraints with cascade delete - Efficient upserts using ON DUPLICATE KEY UPDATE - Args: - config: AsyncmyConfig with extension_config["adk"] settings. - - Example: - from sqlspec.adapters.asyncmy import AsyncmyConfig - from sqlspec.adapters.asyncmy.adk import AsyncmyADKStore - - config = AsyncmyConfig( - connection_config={"host": "localhost", ...}, - extension_config={ - "adk": { - "session_table": "my_sessions", - "events_table": "my_events", - "owner_id_column": "tenant_id BIGINT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE" - } - } - ) - store = AsyncmyADKStore(config) - await store.ensure_tables() - Notes: - MySQL JSON type used (not JSONB) - requires MySQL 5.7.8+ - TIMESTAMP(6) provides microsecond precision - InnoDB engine required for foreign key support - - State merging handled at application level - Configuration is read from config.extension_config["adk"] """ @@ -67,12 +47,6 @@ def __init__(self, config: "AsyncmyConfig") -> None: Args: config: AsyncmyConfig instance. - - Notes: - Configuration is read from config.extension_config["adk"]: - - session_table: Sessions table name (default: "adk_sessions") - - events_table: Events table name (default: "adk_events") - - owner_id_column: Optional owner FK column DDL (default: None) """ super().__init__(config) @@ -88,10 +62,6 @@ def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]" Returns: Tuple of (column_definition, foreign_key_constraint) - - Example: - Input: "tenant_id BIGINT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE" - Output: ("tenant_id BIGINT NOT NULL", "FOREIGN KEY (tenant_id) REFERENCES tenants(id) ON DELETE CASCADE") """ references_match = re.search(r"\s+REFERENCES\s+(.+)", column_ddl, re.IGNORECASE) @@ -110,16 +80,6 @@ async def _get_create_sessions_table_sql(self) -> str: Returns: SQL statement to create adk_sessions table with indexes. - - Notes: - - VARCHAR(128) for IDs and names (sufficient for UUIDs and app names) - - JSON type for state storage (MySQL 5.7.8+) - - TIMESTAMP(6) with microsecond precision - - AUTO-UPDATE on update_time - - Composite index on (app_name, user_id) for listing - - Index on update_time DESC for recent session queries - - Optional owner ID column for multi-tenancy - - MySQL requires explicit FOREIGN KEY syntax (inline REFERENCES is ignored) """ owner_id_col = "" fk_constraint = "" @@ -151,34 +111,18 @@ async def _get_create_events_table_sql(self) -> str: SQL statement to create adk_events table with indexes. Notes: - - VARCHAR sizes: id(128), session_id(128), invocation_id(256), author(256), - branch(256), error_code(256), error_message(1024) - - BLOB for pickled actions (up to 64KB) - - JSON for content, grounding_metadata, custom_metadata, long_running_tool_ids_json - - BOOLEAN for partial, turn_complete, interrupted - - Foreign key to sessions with CASCADE delete - - Index on (session_id, timestamp ASC) for ordered event retrieval + Post clean-break schema: 5 columns only. + - session_id, invocation_id, author: indexed scalars + - timestamp: microsecond-precision TIMESTAMP + - event_json: full Event as native JSON """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, - actions BLOB NOT NULL, - long_running_tool_ids_json JSON, - branch VARCHAR(256), + author VARCHAR(128) NOT NULL, timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - content JSON, - grounding_metadata JSON, - custom_metadata JSON, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE, INDEX idx_{self._events_table}_session (session_id, timestamp ASC) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci @@ -189,10 +133,6 @@ def _get_drop_tables_sql(self) -> "list[str]": Returns: List of SQL statements to drop tables and indexes. - - Notes: - Order matters: drop events table (child) before sessions (parent). - MySQL automatically drops indexes when dropping tables. """ return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] @@ -216,11 +156,6 @@ async def create_session( Returns: Created session record. - - Notes: - Uses INSERT with UTC_TIMESTAMP(6) for create_time and update_time. - State is JSON-serialized before insertion. - If owner_id_column is configured, owner_id must be provided. """ state_json = to_json(state) @@ -252,10 +187,6 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": Returns: Session record or None if not found. - - Notes: - MySQL returns datetime objects for TIMESTAMP columns. - JSON is parsed from database storage. """ sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -292,10 +223,6 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") - Args: session_id: Session identifier. state: New state dictionary (replaces existing state). - - Notes: - This replaces the entire state dictionary. - Uses update_time auto-update trigger. """ state_json = to_json(state) @@ -314,9 +241,6 @@ async def delete_session(self, session_id: str) -> None: Args: session_id: Session identifier. - - Notes: - Foreign key constraint ensures events are cascade-deleted. """ sql = f"DELETE FROM {self._session_table} WHERE id = %s" @@ -333,9 +257,6 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis Returns: List of session records ordered by update_time DESC. - - Notes: - Uses composite index on (app_name, user_id) when user_id is provided. """ if user_id is None: sql = f""" @@ -379,55 +300,72 @@ async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session. Args: - event_record: Event record to store. - - Notes: - Uses UTC_TIMESTAMP(6) for timestamp if not provided. - JSON fields are serialized before insertion. + event_record: Event record with 5 keys (session_id, invocation_id, + author, timestamp, event_json). """ - content_json = to_json(event_record.get("content")) if event_record.get("content") else None - grounding_metadata_json = ( - to_json(event_record.get("grounding_metadata")) if event_record.get("grounding_metadata") else None - ) - custom_metadata_json = ( - to_json(event_record.get("custom_metadata")) if event_record.get("custom_metadata") else None - ) + event_json = event_record["event_json"] + event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s - ) + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) """ async with self._config.provide_connection() as conn, conn.cursor() as cursor: await cursor.execute( sql, ( - event_record["id"], event_record["session_id"], - event_record["app_name"], - event_record["user_id"], event_record["invocation_id"], event_record["author"], - event_record["actions"], - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), event_record["timestamp"], - content_json, - grounding_metadata_json, - custom_metadata_json, - event_record.get("partial"), - event_record.get("turn_complete"), - event_record.get("interrupted"), - event_record.get("error_code"), - event_record.get("error_message"), + event_json_str, + ), + ) + await conn.commit() + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state. + + The event insert and state update succeed together or fail together + within a single transaction. + + Args: + event_record: Event record to store. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot. + """ + event_json = event_record["event_json"] + event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + state_json = to_json(state) + + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) + """ + + update_sql = f""" + UPDATE {self._session_table} + SET state = %s + WHERE id = %s + """ + + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute( + insert_sql, + ( + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_json_str, ), ) + await cursor.execute(update_sql, (state_json, session_id)) await conn.commit() async def get_events( @@ -442,10 +380,6 @@ async def get_events( Returns: List of event records ordered by timestamp ASC. - - Notes: - Uses index on (session_id, timestamp ASC). - Parses JSON fields and converts BLOB actions to bytes. """ where_clauses = ["session_id = %s"] params: list[Any] = [session_id] @@ -458,10 +392,7 @@ async def get_events( limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -474,24 +405,11 @@ async def get_events( return [ EventRecord( - id=row[0], - session_id=row[1], - app_name=row[2], - user_id=row[3], - invocation_id=row[4], - author=row[5], - actions=bytes(row[6]), - long_running_tool_ids_json=row[7], - branch=row[8], - timestamp=row[9], - content=from_json(row[10]) if row[10] and isinstance(row[10], str) else row[10], - grounding_metadata=from_json(row[11]) if row[11] and isinstance(row[11], str) else row[11], - custom_metadata=from_json(row[12]) if row[12] and isinstance(row[12], str) else row[12], - partial=row[13], - turn_complete=row[14], - interrupted=row[15], - error_code=row[16], - error_message=row[17], + session_id=row[0], + invocation_id=row[1], + author=row[2], + timestamp=row[3], + event_json=from_json(row[4]) if isinstance(row[4], str) else row[4], ) for row in rows ] diff --git a/sqlspec/adapters/asyncpg/adk/store.py b/sqlspec/adapters/asyncpg/adk/store.py index 4c4624a87..76cffccab 100644 --- a/sqlspec/adapters/asyncpg/adk/store.py +++ b/sqlspec/adapters/asyncpg/adk/store.py @@ -1,6 +1,6 @@ """AsyncPG ADK store for Google Agent Development Kit session/event storage.""" -from typing import TYPE_CHECKING, Any, Final, cast +from typing import TYPE_CHECKING, Any, Final import asyncpg @@ -21,87 +21,32 @@ class AsyncpgADKStore(BaseAsyncADKStore[AsyncConfigT]): - """PostgreSQL ADK store base class for all PostgreSQL drivers. + """PostgreSQL ADK store using asyncpg driver. Implements session and event storage for Google Agent Development Kit - using PostgreSQL via any PostgreSQL driver (AsyncPG, Psycopg, Psqlpy). - All drivers share the same SQL dialect and parameter style ($1, $2, etc). + using PostgreSQL via asyncpg. Events are stored as a single JSONB blob + (``event_json``) alongside indexed scalar columns for efficient querying. Provides: - - Session state management with JSONB storage and merge operations - - Event history tracking with BYTEA-serialized actions + - Session state management with JSONB storage + - Full-fidelity event storage via ``event_json`` JSONB column + - Atomic ``append_event_and_update_state`` for durable session mutations - Microsecond-precision timestamps with TIMESTAMPTZ - Foreign key constraints with cascade delete - - Efficient upserts using ON CONFLICT - GIN indexes for JSONB queries - HOT updates with FILLFACTOR 80 - - Optional user FK column for multi-tenancy + - Optional owner ID column for multi-tenancy Args: config: PostgreSQL database config with extension_config["adk"] settings. - - Example: - from sqlspec.adapters.asyncpg import AsyncpgConfig - from sqlspec.adapters.asyncpg.adk import AsyncpgADKStore - - config = AsyncpgConfig( - connection_config={"dsn": "postgresql://..."}, - extension_config={ - "adk": { - "session_table": "my_sessions", - "events_table": "my_events", - "owner_id_column": "tenant_id INTEGER NOT NULL REFERENCES tenants(id) ON DELETE CASCADE" - } - } - ) - store = AsyncpgADKStore(config) - await store.ensure_tables() - - Notes: - - PostgreSQL JSONB type used for state (more efficient than JSON) - - AsyncPG automatically converts Python dicts to/from JSONB (no manual serialization) - - TIMESTAMPTZ provides timezone-aware microsecond precision - - State merging uses `state || $1::jsonb` operator for efficiency - - BYTEA for pre-serialized actions from Google ADK (not pickled here) - - GIN index on state for JSONB queries (partial index) - - FILLFACTOR 80 leaves space for HOT updates - - Generic over PostgresConfigT to support all PostgreSQL drivers - - Owner ID column enables multi-tenant isolation with referential integrity - - Configuration is read from config.extension_config["adk"] """ __slots__ = () def __init__(self, config: AsyncConfigT) -> None: - """Initialize AsyncPG ADK store. - - Args: - config: PostgreSQL database config. - - Notes: - Configuration is read from config.extension_config["adk"]: - - session_table: Sessions table name (default: "adk_sessions") - - events_table: Events table name (default: "adk_events") - - owner_id_column: Optional owner FK column DDL (default: None) - """ super().__init__(config) async def _get_create_sessions_table_sql(self) -> str: - """Get PostgreSQL CREATE TABLE SQL for sessions. - - Returns: - SQL statement to create adk_sessions table with indexes. - - Notes: - - VARCHAR(128) for IDs and names (sufficient for UUIDs and app names) - - JSONB type for state storage with default empty object - - TIMESTAMPTZ with microsecond precision - - FILLFACTOR 80 for HOT updates (reduces table bloat) - - Composite index on (app_name, user_id) for listing - - Index on update_time DESC for recent session queries - - Partial GIN index on state for JSONB queries (only non-empty) - - Optional owner ID column for multi-tenancy or owner references - """ owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" @@ -128,61 +73,24 @@ async def _get_create_sessions_table_sql(self) -> str: """ async def _get_create_events_table_sql(self) -> str: - """Get PostgreSQL CREATE TABLE SQL for events. - - Returns: - SQL statement to create adk_events table with indexes. - - Notes: - - VARCHAR sizes: id(128), session_id(128), invocation_id(256), author(256), - branch(256), error_code(256), error_message(1024) - - BYTEA for pickled actions (no size limit) - - JSONB for content, grounding_metadata, custom_metadata, long_running_tool_ids_json - - BOOLEAN for partial, turn_complete, interrupted - - Foreign key to sessions with CASCADE delete - - Index on (session_id, timestamp ASC) for ordered event retrieval - """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256), - author VARCHAR(256), - actions BYTEA, - long_running_tool_ids_json JSONB, - branch VARCHAR(256), + invocation_id VARCHAR(256) NOT NULL, + author VARCHAR(256) NOT NULL, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - content JSONB, - grounding_metadata JSONB, - custom_metadata JSONB, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE - ); + ) WITH (fillfactor = 80); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); """ def _get_drop_tables_sql(self) -> "list[str]": - """Get PostgreSQL DROP TABLE SQL statements. - - Returns: - List of SQL statements to drop tables and indexes. - - Notes: - Order matters: drop events table (child) before sessions (parent). - PostgreSQL automatically drops indexes when dropping tables. - """ return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] async def create_tables(self) -> None: - """Create both sessions and events tables if they don't exist.""" async with self.config.provide_session() as driver: await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) @@ -190,23 +98,6 @@ async def create_tables(self) -> None: async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: - """Create a new session. - - Args: - session_id: Unique session identifier. - app_name: Application name. - user_id: User identifier. - state: Initial session state. - owner_id: Optional owner ID value for owner_id_column (if configured). - - Returns: - Created session record. - - Notes: - Uses CURRENT_TIMESTAMP for create_time and update_time. - State is passed as dict and asyncpg converts to JSONB automatically. - If owner_id_column is configured, owner_id value must be provided. - """ async with self.config.provide_connection() as conn: if self._owner_id_column_name: sql = f""" @@ -225,18 +116,6 @@ async def create_session( return await self.get_session(session_id) # type: ignore[return-value] async def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID. - - Args: - session_id: Session identifier. - - Returns: - Session record or None if not found. - - Notes: - PostgreSQL returns datetime objects for TIMESTAMPTZ columns. - JSONB is automatically parsed by asyncpg. - """ sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} @@ -262,16 +141,6 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": return None async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. - - Args: - session_id: Session identifier. - state: New state dictionary (replaces existing state). - - Notes: - This replaces the entire state dictionary. - Uses CURRENT_TIMESTAMP for update_time. - """ sql = f""" UPDATE {self._session_table} SET state = $1, update_time = CURRENT_TIMESTAMP @@ -282,32 +151,12 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") - await conn.execute(sql, state, session_id) async def delete_session(self, session_id: str) -> None: - """Delete session and all associated events (cascade). - - Args: - session_id: Session identifier. - - Notes: - Foreign key constraint ensures events are cascade-deleted. - """ sql = f"DELETE FROM {self._session_table} WHERE id = $1" async with self.config.provide_connection() as conn: await conn.execute(sql, session_id) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app, optionally filtered by user. - - Args: - app_name: Application name. - user_id: User identifier. If None, lists all sessions for the app. - - Returns: - List of session records ordered by update_time DESC. - - Notes: - Uses composite index on (app_name, user_id) when user_id is provided. - """ if user_id is None: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -344,70 +193,50 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis return [] async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session. - - Args: - event_record: Event record to store. - - Notes: - Uses CURRENT_TIMESTAMP for timestamp if not provided. - JSONB fields are passed as dicts and asyncpg converts automatically. - """ - content_json = event_record.get("content") - grounding_metadata_json = event_record.get("grounding_metadata") - custom_metadata_json = event_record.get("custom_metadata") - sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18 - ) + session_id, invocation_id, author, timestamp, event_json + ) VALUES ($1, $2, $3, $4, $5) """ async with self.config.provide_connection() as conn: await conn.execute( sql, - event_record["id"], event_record["session_id"], - event_record["app_name"], - event_record["user_id"], - event_record.get("invocation_id"), - event_record.get("author"), - event_record.get("actions"), - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_record["event_json"], + ) + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES ($1, $2, $3, $4, $5) + """ + update_sql = f""" + UPDATE {self._session_table} + SET state = $1, update_time = CURRENT_TIMESTAMP + WHERE id = $2 + """ + + async with self.config.provide_connection() as conn, conn.transaction(): + await conn.execute( + insert_sql, + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], event_record["timestamp"], - content_json, - grounding_metadata_json, - custom_metadata_json, - event_record.get("partial"), - event_record.get("turn_complete"), - event_record.get("interrupted"), - event_record.get("error_code"), - event_record.get("error_message"), + event_record["event_json"], ) + await conn.execute(update_sql, state, session_id) async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": - """Get events for a session. - - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. - - Returns: - List of event records ordered by timestamp ASC. - - Notes: - Uses index on (session_id, timestamp ASC). - Parses JSONB fields and converts BYTEA actions to bytes. - """ where_clauses = ["session_id = $1"] params: list[Any] = [session_id] @@ -421,10 +250,7 @@ async def get_events( params.append(limit) sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -436,24 +262,11 @@ async def get_events( return [ EventRecord( - id=row["id"], session_id=row["session_id"], - app_name=row["app_name"], - user_id=row["user_id"], invocation_id=row["invocation_id"], author=row["author"], - actions=bytes(row["actions"]) if row["actions"] else b"", - long_running_tool_ids_json=row["long_running_tool_ids_json"], - branch=row["branch"], timestamp=row["timestamp"], - content=row["content"], - grounding_metadata=row["grounding_metadata"], - custom_metadata=row["custom_metadata"], - partial=row["partial"], - turn_complete=row["turn_complete"], - interrupted=row["interrupted"], - error_code=row["error_code"], - error_message=row["error_message"], + event_json=row["event_json"], ) for row in rows ] @@ -475,66 +288,14 @@ class AsyncpgADKMemoryStore(BaseAsyncADKMemoryStore["AsyncpgConfig"]): Args: config: AsyncpgConfig with extension_config["adk"] settings. - - Example: - from sqlspec.adapters.asyncpg import AsyncpgConfig - from sqlspec.adapters.asyncpg.adk.store import AsyncpgADKMemoryStore - - config = AsyncpgConfig( - connection_config={"dsn": "postgresql://..."}, - extension_config={ - "adk": { - "memory_table": "adk_memory_entries", - "memory_use_fts": True, - "memory_max_results": 20, - } - } - ) - store = AsyncpgADKMemoryStore(config) - await store.ensure_tables() - - Notes: - - JSONB type for content_json and metadata_json - - TIMESTAMPTZ with microsecond precision - - GIN index on content_text tsvector for FTS queries - - Composite index on (app_name, user_id) for filtering - - event_id UNIQUE constraint for deduplication - - Configuration is read from config.extension_config["adk"] """ __slots__ = () def __init__(self, config: "AsyncpgConfig") -> None: - """Initialize AsyncPG ADK memory store. - - Args: - config: AsyncpgConfig instance. - - Notes: - Configuration is read from config.extension_config["adk"]: - - memory_table: Memory table name (default: "adk_memory_entries") - - memory_use_fts: Enable full-text search when supported (default: False) - - memory_max_results: Max search results (default: 20) - - owner_id_column: Optional owner FK column DDL (default: None) - - enable_memory: Whether memory is enabled (default: True) - """ super().__init__(config) async def _get_create_memory_table_sql(self) -> str: - """Get PostgreSQL CREATE TABLE SQL for memory entries. - - Returns: - SQL statement to create memory table with indexes. - - Notes: - - VARCHAR(128) for IDs and names - - JSONB for content and metadata storage - - TIMESTAMPTZ with microsecond precision - - UNIQUE constraint on event_id for deduplication - - Composite index on (app_name, user_id, timestamp DESC) - - GIN index on content_text tsvector for FTS - - Optional owner ID column for multi-tenancy - """ owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" @@ -570,21 +331,9 @@ async def _get_create_memory_table_sql(self) -> str: """ def _get_drop_memory_table_sql(self) -> "list[str]": - """Get PostgreSQL DROP TABLE SQL statements. - - Returns: - List of SQL statements to drop the memory table. - - Notes: - PostgreSQL automatically drops indexes when dropping tables. - """ return [f"DROP TABLE IF EXISTS {self._memory_table}"] async def create_tables(self) -> None: - """Create the memory table and indexes if they don't exist. - - Skips table creation if memory store is disabled. - """ if not self._enabled: return @@ -592,21 +341,6 @@ async def create_tables(self) -> None: await driver.execute_script(await self._get_create_memory_table_sql()) async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication. - - Uses UPSERT pattern (ON CONFLICT DO NOTHING) to skip duplicates - based on event_id unique constraint. - - Args: - entries: List of memory records to insert. - owner_id: Optional owner ID value for owner_id_column (if configured). - - Returns: - Number of entries actually inserted (excludes duplicates). - - Raises: - RuntimeError: If memory store is disabled. - """ if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -673,19 +407,6 @@ async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: " async def search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": - """Search memory entries by text query. - - Uses the configured search strategy (simple ILIKE or FTS). - - Args: - query: Text query to search for. - app_name: Application name to filter by. - user_id: User ID to filter by. - limit: Maximum number of results (defaults to max_results config). - - Returns: - List of memory records. - """ if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -693,6 +414,8 @@ async def search_entries( if not query: return [] + from typing import cast + limit_value = limit or self._max_results if self._use_fts: sql = f""" @@ -717,7 +440,6 @@ async def search_entries( return [cast("MemoryRecord", dict(row)) for row in rows] async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -731,7 +453,6 @@ async def delete_entries_by_session(self, session_id: str) -> int: return 0 async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) diff --git a/sqlspec/adapters/bigquery/adk/__init__.py b/sqlspec/adapters/bigquery/adk/__init__.py deleted file mode 100644 index 6d11c84b8..000000000 --- a/sqlspec/adapters/bigquery/adk/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""BigQuery ADK store for Google Agent Development Kit session/event storage.""" - -from sqlspec.adapters.bigquery.adk.store import BigQueryADKMemoryStore, BigQueryADKStore - -__all__ = ("BigQueryADKMemoryStore", "BigQueryADKStore") diff --git a/sqlspec/adapters/bigquery/adk/store.py b/sqlspec/adapters/bigquery/adk/store.py deleted file mode 100644 index 61339a7a4..000000000 --- a/sqlspec/adapters/bigquery/adk/store.py +++ /dev/null @@ -1,827 +0,0 @@ -"""BigQuery ADK store for Google Agent Development Kit session/event storage.""" - -from collections.abc import Mapping -from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, cast - -from google.api_core.exceptions import NotFound -from google.cloud.bigquery import QueryJobConfig, ScalarQueryParameter - -from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore -from sqlspec.utils.serializers import from_json, to_json -from sqlspec.utils.sync_tools import async_, run_ - -if TYPE_CHECKING: - from sqlspec.adapters.bigquery.config import BigQueryConfig - from sqlspec.extensions.adk import MemoryRecord - - -__all__ = ("BigQueryADKMemoryStore", "BigQueryADKStore") - - -class BigQueryADKStore(BaseAsyncADKStore["BigQueryConfig"]): - """BigQuery ADK store using synchronous BigQuery client with async wrapper. - - Implements session and event storage for Google Agent Development Kit - using Google Cloud BigQuery. Uses BigQuery's native JSON type for state/metadata - storage and async_() wrapper to provide async interface. - - Provides: - - Serverless, scalable session state management with JSON storage - - Event history tracking optimized for analytics - - Microsecond-precision timestamps with TIMESTAMP type - - Cost-optimized queries with partitioning and clustering - - Efficient JSON handling with BigQuery's JSON type - - Manual cascade delete pattern (no foreign key support) - - Args: - config: BigQueryConfig with extension_config["adk"] settings. - - Example: - from sqlspec.adapters.bigquery import BigQueryConfig - from sqlspec.adapters.bigquery.adk import BigQueryADKStore - - config = BigQueryConfig( - connection_config={ - "project": "my-project", - "dataset_id": "my_dataset", - }, - extension_config={ - "adk": { - "session_table": "my_sessions", - "events_table": "my_events", - "owner_id_column": "tenant_id INT64 NOT NULL" - } - } - ) - store = BigQueryADKStore(config) - await store.ensure_tables() - - Notes: - - JSON type for state, content, and metadata (native BigQuery JSON) - - BYTES for pre-serialized actions from Google ADK - - TIMESTAMP for timezone-aware microsecond precision - - Partitioned by DATE(create_time) for cost optimization - - Clustered by app_name, user_id for query performance - - Uses to_json/from_json for serialization to JSON columns - - BigQuery has eventual consistency - handle appropriately - - No true foreign keys but implements cascade delete pattern - - Configuration is read from config.extension_config["adk"] - """ - - __slots__ = ("_dataset_id",) - - def __init__(self, config: "BigQueryConfig") -> None: - """Initialize BigQuery ADK store. - - Args: - config: BigQueryConfig instance. - - Notes: - Configuration is read from config.extension_config["adk"]: - - session_table: Sessions table name (default: "adk_sessions") - - events_table: Events table name (default: "adk_events") - - owner_id_column: Optional owner FK column DDL (default: None) - """ - super().__init__(config) - self._dataset_id = config.connection_config.get("dataset_id") - - def _get_full_table_name(self, table_name: str) -> str: - """Get fully qualified table name for BigQuery. - - Args: - table_name: Base table name. - - Returns: - Fully qualified table name with backticks. - - Notes: - BigQuery requires backtick-quoted identifiers for table names. - Format: `project.dataset.table` or `dataset.table` - """ - if self._dataset_id: - return f"`{self._dataset_id}.{table_name}`" - return f"`{table_name}`" - - async def _get_create_sessions_table_sql(self) -> str: - """Get BigQuery CREATE TABLE SQL for sessions. - - Returns: - SQL statement to create adk_sessions table. - - Notes: - - STRING for IDs and names - - JSON type for state storage (native BigQuery JSON) - - TIMESTAMP for timezone-aware microsecond precision - - Partitioned by DATE(create_time) for cost optimization - - Clustered by app_name, user_id for query performance - - No indexes needed (BigQuery auto-optimizes) - - Optional owner ID column for multi-tenant scenarios - - Note: BigQuery doesn't enforce FK constraints - """ - owner_id_line = "" - if self._owner_id_column_ddl: - owner_id_line = f",\n {self._owner_id_column_ddl}" - - table_name = self._get_full_table_name(self._session_table) - return f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id STRING NOT NULL, - app_name STRING NOT NULL, - user_id STRING NOT NULL{owner_id_line}, - state JSON NOT NULL, - create_time TIMESTAMP NOT NULL, - update_time TIMESTAMP NOT NULL - ) - PARTITION BY DATE(create_time) - CLUSTER BY app_name, user_id - """ - - async def _get_create_events_table_sql(self) -> str: - """Get BigQuery CREATE TABLE SQL for events. - - Returns: - SQL statement to create adk_events table. - - Notes: - - STRING for IDs and text fields - - BYTES for pickled actions - - JSON for content, grounding_metadata, custom_metadata, long_running_tool_ids_json - - BOOL for boolean flags - - TIMESTAMP for timezone-aware timestamps - - Partitioned by DATE(timestamp) for cost optimization - - Clustered by session_id, timestamp for ordered retrieval - """ - table_name = self._get_full_table_name(self._events_table) - return f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id STRING NOT NULL, - session_id STRING NOT NULL, - app_name STRING NOT NULL, - user_id STRING NOT NULL, - invocation_id STRING, - author STRING, - actions BYTES, - long_running_tool_ids_json JSON, - branch STRING, - timestamp TIMESTAMP NOT NULL, - content JSON, - grounding_metadata JSON, - custom_metadata JSON, - partial BOOL, - turn_complete BOOL, - interrupted BOOL, - error_code STRING, - error_message STRING - ) - PARTITION BY DATE(timestamp) - CLUSTER BY session_id, timestamp - """ - - def _get_drop_tables_sql(self) -> "list[str]": - """Get BigQuery DROP TABLE SQL statements. - - Returns: - List of SQL statements to drop tables. - - Notes: - Order matters: drop events table before sessions table. - BigQuery uses IF EXISTS for idempotent drops. - """ - events_table = self._get_full_table_name(self._events_table) - sessions_table = self._get_full_table_name(self._session_table) - return [f"DROP TABLE IF EXISTS {events_table}", f"DROP TABLE IF EXISTS {sessions_table}"] - - def _create_tables(self) -> None: - """Synchronous implementation of create_tables.""" - with self._config.provide_session() as driver: - driver.execute_script(run_(self._get_create_sessions_table_sql)()) - driver.execute_script(run_(self._get_create_events_table_sql)()) - - async def create_tables(self) -> None: - """Create both sessions and events tables if they don't exist.""" - await async_(self._create_tables)() - - def _create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Synchronous implementation of create_session.""" - now = datetime.now(timezone.utc) - state_json = to_json(state) if state else "{}" - - table_name = self._get_full_table_name(self._session_table) - - if self._owner_id_column_name: - sql = f""" - INSERT INTO {table_name} (id, app_name, user_id, {self._owner_id_column_name}, state, create_time, update_time) - VALUES (@id, @app_name, @user_id, @owner_id, JSON(@state), @create_time, @update_time) - """ - - params = [ - ScalarQueryParameter("id", "STRING", session_id), - ScalarQueryParameter("app_name", "STRING", app_name), - ScalarQueryParameter("user_id", "STRING", user_id), - ScalarQueryParameter("owner_id", "STRING", str(owner_id) if owner_id is not None else None), - ScalarQueryParameter("state", "STRING", state_json), - ScalarQueryParameter("create_time", "TIMESTAMP", now), - ScalarQueryParameter("update_time", "TIMESTAMP", now), - ] - else: - sql = f""" - INSERT INTO {table_name} (id, app_name, user_id, state, create_time, update_time) - VALUES (@id, @app_name, @user_id, JSON(@state), @create_time, @update_time) - """ - - params = [ - ScalarQueryParameter("id", "STRING", session_id), - ScalarQueryParameter("app_name", "STRING", app_name), - ScalarQueryParameter("user_id", "STRING", user_id), - ScalarQueryParameter("state", "STRING", state_json), - ScalarQueryParameter("create_time", "TIMESTAMP", now), - ScalarQueryParameter("update_time", "TIMESTAMP", now), - ] - - with self._config.provide_connection() as conn: - job_config = QueryJobConfig(query_parameters=params) - conn.query(sql, job_config=job_config).result() - - return SessionRecord( - id=session_id, app_name=app_name, user_id=user_id, state=state, create_time=now, update_time=now - ) - - async def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Create a new session. - - Args: - session_id: Unique session identifier. - app_name: Application name. - user_id: User identifier. - state: Initial session state. - owner_id: Optional owner ID value for owner_id_column (if configured). - - Returns: - Created session record. - - Notes: - Uses CURRENT_TIMESTAMP() for timestamps. - State is JSON-serialized then stored in JSON column. - If owner_id_column is configured, owner_id value must be provided. - BigQuery doesn't enforce FK constraints, but column is useful for JOINs. - """ - return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - - def _get_session(self, session_id: str) -> "SessionRecord | None": - """Synchronous implementation of get_session.""" - table_name = self._get_full_table_name(self._session_table) - sql = f""" - SELECT id, app_name, user_id, JSON_VALUE(state) as state, create_time, update_time - FROM {table_name} - WHERE id = @session_id - """ - - params = [ScalarQueryParameter("session_id", "STRING", session_id)] - - with self._config.provide_connection() as conn: - job_config = QueryJobConfig(query_parameters=params) - query_job = conn.query(sql, job_config=job_config) - results = list(query_job.result()) - - if not results: - return None - - row = results[0] - return SessionRecord( - id=row.id, - app_name=row.app_name, - user_id=row.user_id, - state=from_json(row.state) if row.state else {}, - create_time=row.create_time, - update_time=row.update_time, - ) - - async def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID. - - Args: - session_id: Session identifier. - - Returns: - Session record or None if not found. - - Notes: - BigQuery returns datetime objects for TIMESTAMP columns. - JSON_VALUE extracts string representation for parsing. - """ - return await async_(self._get_session)(session_id) - - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Synchronous implementation of update_session_state.""" - now = datetime.now(timezone.utc) - state_json = to_json(state) if state else "{}" - - table_name = self._get_full_table_name(self._session_table) - sql = f""" - UPDATE {table_name} - SET state = JSON(@state), update_time = @update_time - WHERE id = @session_id - """ - - params = [ - ScalarQueryParameter("state", "STRING", state_json), - ScalarQueryParameter("update_time", "TIMESTAMP", now), - ScalarQueryParameter("session_id", "STRING", session_id), - ] - - with self._config.provide_connection() as conn: - job_config = QueryJobConfig(query_parameters=params) - conn.query(sql, job_config=job_config).result() - - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. - - Args: - session_id: Session identifier. - state: New state dictionary (replaces existing state). - - Notes: - Replaces entire state dictionary. - Updates update_time to CURRENT_TIMESTAMP(). - """ - await async_(self._update_session_state)(session_id, state) - - def _list_sessions(self, app_name: str, user_id: "str | None") -> "list[SessionRecord]": - """Synchronous implementation of list_sessions.""" - table_name = self._get_full_table_name(self._session_table) - - if user_id is None: - sql = f""" - SELECT id, app_name, user_id, JSON_VALUE(state) as state, create_time, update_time - FROM {table_name} - WHERE app_name = @app_name - ORDER BY update_time DESC - """ - params = [ScalarQueryParameter("app_name", "STRING", app_name)] - else: - sql = f""" - SELECT id, app_name, user_id, JSON_VALUE(state) as state, create_time, update_time - FROM {table_name} - WHERE app_name = @app_name AND user_id = @user_id - ORDER BY update_time DESC - """ - params = [ - ScalarQueryParameter("app_name", "STRING", app_name), - ScalarQueryParameter("user_id", "STRING", user_id), - ] - - with self._config.provide_connection() as conn: - job_config = QueryJobConfig(query_parameters=params) - query_job = conn.query(sql, job_config=job_config) - results = list(query_job.result()) - - return [ - SessionRecord( - id=row.id, - app_name=row.app_name, - user_id=row.user_id, - state=from_json(row.state) if row.state else {}, - create_time=row.create_time, - update_time=row.update_time, - ) - for row in results - ] - - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app, optionally filtered by user. - - Args: - app_name: Application name. - user_id: User identifier. If None, lists all sessions for the app. - - Returns: - List of session records ordered by update_time DESC. - - Notes: - Uses clustering on (app_name, user_id) when user_id is provided for efficiency. - """ - return await async_(self._list_sessions)(app_name, user_id) - - def _delete_session(self, session_id: str) -> None: - """Synchronous implementation of delete_session.""" - events_table = self._get_full_table_name(self._events_table) - sessions_table = self._get_full_table_name(self._session_table) - - params = [ScalarQueryParameter("session_id", "STRING", session_id)] - - with self._config.provide_connection() as conn: - job_config = QueryJobConfig(query_parameters=params) - conn.query(f"DELETE FROM {events_table} WHERE session_id = @session_id", job_config=job_config).result() - conn.query(f"DELETE FROM {sessions_table} WHERE id = @session_id", job_config=job_config).result() - - async def delete_session(self, session_id: str) -> None: - """Delete session and all associated events. - - Args: - session_id: Session identifier. - - Notes: - BigQuery doesn't support foreign keys, so we manually delete events first. - Uses two separate DELETE statements in sequence. - """ - await async_(self._delete_session)(session_id) - - def _append_event(self, event_record: EventRecord) -> None: - """Synchronous implementation of append_event.""" - table_name = self._get_full_table_name(self._events_table) - - content_json = to_json(event_record.get("content")) if event_record.get("content") else None - grounding_metadata_json = ( - to_json(event_record.get("grounding_metadata")) if event_record.get("grounding_metadata") else None - ) - custom_metadata_json = ( - to_json(event_record.get("custom_metadata")) if event_record.get("custom_metadata") else None - ) - - sql = f""" - INSERT INTO {table_name} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - @id, @session_id, @app_name, @user_id, @invocation_id, @author, @actions, - @long_running_tool_ids_json, @branch, @timestamp, - {"JSON(@content)" if content_json else "NULL"}, - {"JSON(@grounding_metadata)" if grounding_metadata_json else "NULL"}, - {"JSON(@custom_metadata)" if custom_metadata_json else "NULL"}, - @partial, @turn_complete, @interrupted, @error_code, @error_message - ) - """ - - actions_value = event_record.get("actions") - params = [ - ScalarQueryParameter("id", "STRING", event_record["id"]), - ScalarQueryParameter("session_id", "STRING", event_record["session_id"]), - ScalarQueryParameter("app_name", "STRING", event_record["app_name"]), - ScalarQueryParameter("user_id", "STRING", event_record["user_id"]), - ScalarQueryParameter("invocation_id", "STRING", event_record.get("invocation_id")), - ScalarQueryParameter("author", "STRING", event_record.get("author")), - ScalarQueryParameter( - "actions", - "BYTES", - actions_value.decode("latin1") if isinstance(actions_value, bytes) else actions_value, - ), - ScalarQueryParameter( - "long_running_tool_ids_json", "STRING", event_record.get("long_running_tool_ids_json") - ), - ScalarQueryParameter("branch", "STRING", event_record.get("branch")), - ScalarQueryParameter("timestamp", "TIMESTAMP", event_record["timestamp"]), - ScalarQueryParameter("partial", "BOOL", event_record.get("partial")), - ScalarQueryParameter("turn_complete", "BOOL", event_record.get("turn_complete")), - ScalarQueryParameter("interrupted", "BOOL", event_record.get("interrupted")), - ScalarQueryParameter("error_code", "STRING", event_record.get("error_code")), - ScalarQueryParameter("error_message", "STRING", event_record.get("error_message")), - ] - - if content_json: - params.append(ScalarQueryParameter("content", "STRING", content_json)) - if grounding_metadata_json: - params.append(ScalarQueryParameter("grounding_metadata", "STRING", grounding_metadata_json)) - if custom_metadata_json: - params.append(ScalarQueryParameter("custom_metadata", "STRING", custom_metadata_json)) - - with self._config.provide_connection() as conn: - job_config = QueryJobConfig(query_parameters=params) - conn.query(sql, job_config=job_config).result() - - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session. - - Args: - event_record: Event record to store. - - Notes: - Uses BigQuery TIMESTAMP for timezone-aware timestamps. - JSON fields are serialized to STRING then cast to JSON. - Boolean fields stored natively as BOOL. - """ - await async_(self._append_event)(event_record) - - def _get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Synchronous implementation of get_events.""" - table_name = self._get_full_table_name(self._events_table) - - where_clauses = ["session_id = @session_id"] - params: list[ScalarQueryParameter] = [ScalarQueryParameter("session_id", "STRING", session_id)] - - if after_timestamp is not None: - where_clauses.append("timestamp > @after_timestamp") - params.append(ScalarQueryParameter("after_timestamp", "TIMESTAMP", after_timestamp)) - - where_clause = " AND ".join(where_clauses) - limit_clause = f" LIMIT {limit}" if limit else "" - - sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, - JSON_VALUE(content) as content, - JSON_VALUE(grounding_metadata) as grounding_metadata, - JSON_VALUE(custom_metadata) as custom_metadata, - partial, turn_complete, interrupted, error_code, error_message - FROM {table_name} - WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} - """ - - with self._config.provide_connection() as conn: - job_config = QueryJobConfig(query_parameters=params) - query_job = conn.query(sql, job_config=job_config) - results = list(query_job.result()) - - return [ - EventRecord( - id=row.id, - session_id=row.session_id, - app_name=row.app_name, - user_id=row.user_id, - invocation_id=row.invocation_id, - author=row.author, - actions=bytes(row.actions) if row.actions else b"", - long_running_tool_ids_json=row.long_running_tool_ids_json, - branch=row.branch, - timestamp=row.timestamp, - content=from_json(row.content) if row.content else None, - grounding_metadata=from_json(row.grounding_metadata) if row.grounding_metadata else None, - custom_metadata=from_json(row.custom_metadata) if row.custom_metadata else None, - partial=row.partial, - turn_complete=row.turn_complete, - interrupted=row.interrupted, - error_code=row.error_code, - error_message=row.error_message, - ) - for row in results - ] - - async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Get events for a session. - - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. - - Returns: - List of event records ordered by timestamp ASC. - - Notes: - Uses clustering on (session_id, timestamp) for efficient retrieval. - Parses JSON fields and converts BYTES actions to bytes. - """ - return await async_(self._get_events)(session_id, after_timestamp, limit) - - -class BigQueryADKMemoryStore(BaseAsyncADKMemoryStore["BigQueryConfig"]): - """BigQuery ADK memory store using synchronous BigQuery client with async wrapper.""" - - __slots__ = ("_dataset_id",) - - def __init__(self, config: "BigQueryConfig") -> None: - """Initialize BigQuery ADK memory store.""" - super().__init__(config) - self._dataset_id = config.connection_config.get("dataset_id") - - def _get_full_table_name(self, table_name: str) -> str: - """Get fully qualified table name for BigQuery.""" - if self._dataset_id: - return f"`{self._dataset_id}.{table_name}`" - return f"`{table_name}`" - - async def _get_create_memory_table_sql(self) -> str: - """Get BigQuery CREATE TABLE SQL for memory entries.""" - owner_id_line = "" - if self._owner_id_column_ddl: - owner_id_line = f",\n {self._owner_id_column_ddl}" - - table_name = self._get_full_table_name(self._memory_table) - fts_index = "" - if self._use_fts: - fts_index = f""" - CREATE SEARCH INDEX idx_{self._memory_table}_fts - ON {table_name}(content_text) - """ - - return f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id STRING NOT NULL, - session_id STRING NOT NULL, - app_name STRING NOT NULL, - user_id STRING NOT NULL, - event_id STRING NOT NULL, - author STRING{owner_id_line}, - timestamp TIMESTAMP NOT NULL, - content_json JSON NOT NULL, - content_text STRING NOT NULL, - metadata_json JSON, - inserted_at TIMESTAMP NOT NULL - ) - PARTITION BY DATE(timestamp) - CLUSTER BY app_name, user_id; - {fts_index} - """ - - def _get_drop_memory_table_sql(self) -> "list[str]": - """Get BigQuery DROP TABLE SQL statements.""" - table_name = self._get_full_table_name(self._memory_table) - return [f"DROP TABLE IF EXISTS {table_name}"] - - def _create_tables(self) -> None: - """Synchronous implementation of create_tables.""" - with self._config.provide_session() as driver: - driver.execute_script(run_(self._get_create_memory_table_sql)()) - - async def create_tables(self) -> None: - """Create the memory table if it doesn't exist.""" - if not self._enabled: - return - await async_(self._create_tables)() - - def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Synchronous implementation of insert_memory_entries.""" - table_name = self._get_full_table_name(self._memory_table) - inserted_count = 0 - - with self._config.provide_connection() as conn: - for entry in entries: - content_json = to_json(entry["content_json"]) - metadata_json = to_json(entry["metadata_json"]) if entry["metadata_json"] is not None else None - metadata_expr = "JSON(@metadata_json)" if metadata_json is not None else "NULL" - - owner_column = f", {self._owner_id_column_name}" if self._owner_id_column_name else "" - owner_value = ", @owner_id" if self._owner_id_column_name else "" - - sql = f""" - MERGE {table_name} T - USING (SELECT @event_id AS event_id) S - ON T.event_id = S.event_id - WHEN NOT MATCHED THEN - INSERT (id, session_id, app_name, user_id, event_id, author{owner_column}, - timestamp, content_json, content_text, metadata_json, inserted_at) - VALUES (@id, @session_id, @app_name, @user_id, @event_id, @author{owner_value}, - @timestamp, JSON(@content_json), @content_text, {metadata_expr}, @inserted_at) - """ - - params = [ - ScalarQueryParameter("id", "STRING", entry["id"]), - ScalarQueryParameter("session_id", "STRING", entry["session_id"]), - ScalarQueryParameter("app_name", "STRING", entry["app_name"]), - ScalarQueryParameter("user_id", "STRING", entry["user_id"]), - ScalarQueryParameter("event_id", "STRING", entry["event_id"]), - ScalarQueryParameter("author", "STRING", entry["author"]), - ScalarQueryParameter("timestamp", "TIMESTAMP", entry["timestamp"]), - ScalarQueryParameter("content_json", "STRING", content_json), - ScalarQueryParameter("content_text", "STRING", entry["content_text"]), - ScalarQueryParameter("inserted_at", "TIMESTAMP", entry["inserted_at"]), - ] - - if self._owner_id_column_name: - params.append(ScalarQueryParameter("owner_id", "STRING", str(owner_id) if owner_id else None)) - if metadata_json is not None: - params.append(ScalarQueryParameter("metadata_json", "STRING", metadata_json)) - - job_config = QueryJobConfig(query_parameters=params) - job = conn.query(sql, job_config=job_config) - job.result() - if job.num_dml_affected_rows: - inserted_count += int(job.num_dml_affected_rows) - - return inserted_count - - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication.""" - if not self._enabled: - msg = "Memory store is disabled" - raise RuntimeError(msg) - - if not entries: - return 0 - - return await async_(self._insert_memory_entries)(entries, owner_id) - - def _search_entries(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": - """Synchronous implementation of search_entries.""" - table_name = self._get_full_table_name(self._memory_table) - base_params = [ - ScalarQueryParameter("app_name", "STRING", app_name), - ScalarQueryParameter("user_id", "STRING", user_id), - ScalarQueryParameter("limit", "INT64", limit), - ] - - if self._use_fts: - sql = f""" - SELECT id, session_id, app_name, user_id, event_id, author, - timestamp, content_json, content_text, metadata_json, inserted_at - FROM {table_name} - WHERE app_name = @app_name - AND user_id = @user_id - AND SEARCH(content_text, @query) - ORDER BY timestamp DESC - LIMIT @limit - """ - params = [*base_params, ScalarQueryParameter("query", "STRING", query)] - else: - sql = f""" - SELECT id, session_id, app_name, user_id, event_id, author, - timestamp, content_json, content_text, metadata_json, inserted_at - FROM {table_name} - WHERE app_name = @app_name - AND user_id = @user_id - AND LOWER(content_text) LIKE LOWER(@pattern) - ORDER BY timestamp DESC - LIMIT @limit - """ - pattern = f"%{query}%" - params = [*base_params, ScalarQueryParameter("pattern", "STRING", pattern)] - - with self._config.provide_connection() as conn: - job_config = QueryJobConfig(query_parameters=params) - rows = conn.query(sql, job_config=job_config).result() - return _rows_to_records(rows) - - async def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - """Search memory entries by text query.""" - if not self._enabled: - msg = "Memory store is disabled" - raise RuntimeError(msg) - - effective_limit = limit if limit is not None else self._max_results - - try: - return await async_(self._search_entries)(query, app_name, user_id, effective_limit) - except NotFound: - return [] - - def _delete_entries_by_session(self, session_id: str) -> int: - table_name = self._get_full_table_name(self._memory_table) - sql = f"DELETE FROM {table_name} WHERE session_id = @session_id" - params = [ScalarQueryParameter("session_id", "STRING", session_id)] - with self._config.provide_connection() as conn: - job_config = QueryJobConfig(query_parameters=params) - job = conn.query(sql, job_config=job_config) - job.result() - return int(job.num_dml_affected_rows or 0) - - async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" - return await async_(self._delete_entries_by_session)(session_id) - - def _delete_entries_older_than(self, days: int) -> int: - table_name = self._get_full_table_name(self._memory_table) - sql = f""" - DELETE FROM {table_name} - WHERE inserted_at < TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL {days} DAY) - """ - with self._config.provide_connection() as conn: - job = conn.query(sql) - job.result() - return int(job.num_dml_affected_rows or 0) - - async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" - return await async_(self._delete_entries_older_than)(days) - - -def _decode_json_field(value: Any) -> "dict[str, Any] | None": - if value is None: - return None - if isinstance(value, str): - return cast("dict[str, Any]", from_json(value)) - if isinstance(value, Mapping): - return dict(value) - return None - - -def _rows_to_records(rows: Any) -> "list[MemoryRecord]": - return [ - { - "id": row["id"], - "session_id": row["session_id"], - "app_name": row["app_name"], - "user_id": row["user_id"], - "event_id": row["event_id"], - "author": row["author"], - "timestamp": row["timestamp"], - "content_json": _decode_json_field(row["content_json"]) or {}, - "content_text": row["content_text"], - "metadata_json": _decode_json_field(row["metadata_json"]), - "inserted_at": row["inserted_at"], - } - for row in rows - ] diff --git a/sqlspec/adapters/cockroach_asyncpg/adk/store.py b/sqlspec/adapters/cockroach_asyncpg/adk/store.py index c81e5f6cc..bedaaa788 100644 --- a/sqlspec/adapters/cockroach_asyncpg/adk/store.py +++ b/sqlspec/adapters/cockroach_asyncpg/adk/store.py @@ -7,6 +7,8 @@ from sqlspec.utils.logging import get_logger if TYPE_CHECKING: + from datetime import datetime + from sqlspec.adapters.cockroach_asyncpg.config import CockroachAsyncpgConfig from sqlspec.extensions.adk import MemoryRecord @@ -17,7 +19,19 @@ class CockroachAsyncpgADKStore(BaseAsyncADKStore["CockroachAsyncpgConfig"]): - """CockroachDB ADK store using asyncpg driver.""" + """CockroachDB ADK store using asyncpg driver. + + Implements session and event storage for Google Agent Development Kit + using CockroachDB via asyncpg in PostgreSQL compatibility mode. + Events are stored as a single JSONB blob (``event_json``) alongside + indexed scalar columns for efficient querying. + + CockroachDB-specific differences from native PostgreSQL: + - No FILLFACTOR (CockroachDB uses different storage engine) + - No BRIN indexes (different physical storage layout) + - GIN/Inverted indexes on JSONB are fully supported (v23.1+) + - Native tsvector/tsquery FTS with GIN is supported (v23.1+) + """ __slots__ = () @@ -44,34 +58,28 @@ async def _get_create_sessions_table_sql(self) -> str: CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time ON {self._session_table}(update_time DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state + ON {self._session_table} USING GIN (state) + WHERE state != '{{}}'::jsonb; """ async def _get_create_events_table_sql(self) -> str: return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, invocation_id VARCHAR(256) NOT NULL, author VARCHAR(256) NOT NULL, - actions BYTEA NOT NULL, - long_running_tool_ids_json JSONB, - branch VARCHAR(256), timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - content JSONB, - grounding_metadata JSONB, - custom_metadata JSONB, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); + + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_json + ON {self._events_table} USING GIN (event_json); """ def _get_drop_tables_sql(self) -> "list[str]": @@ -181,72 +189,77 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis async def append_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18 - ) + session_id, invocation_id, author, timestamp, event_json + ) VALUES ($1, $2, $3, $4, $5) """ async with self._config.provide_connection() as conn: await conn.execute( sql, - event_record["id"], event_record["session_id"], - event_record["app_name"], - event_record["user_id"], event_record["invocation_id"], event_record["author"], - event_record["actions"], - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), event_record["timestamp"], - event_record.get("content"), - event_record.get("grounding_metadata"), - event_record.get("custom_metadata"), - event_record.get("partial"), - event_record.get("turn_complete"), - event_record.get("interrupted"), - event_record.get("error_code"), - event_record.get("error_message"), + event_record["event_json"], + ) + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES ($1, $2, $3, $4, $5) + """ + update_sql = f""" + UPDATE {self._session_table} + SET state = $1, update_time = CURRENT_TIMESTAMP + WHERE id = $2 + """ + + async with self._config.provide_connection() as conn, conn.transaction(): + await conn.execute( + insert_sql, + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_record["event_json"], ) + await conn.execute(update_sql, state, session_id) + + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + where_clauses = ["session_id = $1"] + params: list[Any] = [session_id] + + if after_timestamp is not None: + where_clauses.append(f"timestamp > ${len(params) + 1}") + params.append(after_timestamp) + + where_clause = " AND ".join(where_clauses) + limit_clause = f" LIMIT ${len(params) + 1}" if limit else "" + if limit: + params.append(limit) - async def list_events(self, session_id: str) -> "list[EventRecord]": sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} - WHERE session_id = $1 - ORDER BY timestamp ASC + WHERE {where_clause} + ORDER BY timestamp ASC{limit_clause} """ async with self._config.provide_connection() as conn: - rows = await conn.fetch(sql, session_id) + rows = await conn.fetch(sql, *params) return [ EventRecord( - id=row["id"], session_id=row["session_id"], - app_name=row["app_name"], - user_id=row["user_id"], invocation_id=row["invocation_id"], author=row["author"], - actions=bytes(row["actions"]), - long_running_tool_ids_json=row["long_running_tool_ids_json"], - branch=row["branch"], timestamp=row["timestamp"], - content=row["content"], - grounding_metadata=row["grounding_metadata"], - custom_metadata=row["custom_metadata"], - partial=row["partial"], - turn_complete=row["turn_complete"], - interrupted=row["interrupted"], - error_code=row["error_code"], - error_message=row["error_message"], + event_json=row["event_json"], ) for row in rows ] @@ -265,6 +278,13 @@ async def _get_create_memory_table_sql(self) -> str: if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" + fts_index = "" + if self._use_fts: + fts_index = f""" + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts + ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); + """ + return f""" CREATE TABLE IF NOT EXISTS {self._memory_table} ( id VARCHAR(128) PRIMARY KEY, @@ -285,6 +305,7 @@ async def _get_create_memory_table_sql(self) -> str: CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session ON {self._memory_table}(session_id); + {fts_index} """ def _get_drop_memory_table_sql(self) -> "list[str]": @@ -371,18 +392,27 @@ async def search_entries( return [] effective_limit = limit if limit is not None else self._max_results - if self._use_fts: - logger.debug("CockroachDB full-text search not supported; using simple search") - sql = f""" - SELECT * FROM {self._memory_table} - WHERE app_name = $1 AND user_id = $2 AND content_text ILIKE $3 - ORDER BY timestamp DESC - LIMIT $4 - """ + if self._use_fts: + sql = f""" + SELECT * FROM {self._memory_table} + WHERE app_name = $1 AND user_id = $2 + AND to_tsvector('english', content_text) @@ plainto_tsquery('english', $3) + ORDER BY timestamp DESC + LIMIT $4 + """ + params: tuple[Any, ...] = (app_name, user_id, query, effective_limit) + else: + sql = f""" + SELECT * FROM {self._memory_table} + WHERE app_name = $1 AND user_id = $2 AND content_text ILIKE $3 + ORDER BY timestamp DESC + LIMIT $4 + """ + params = (app_name, user_id, f"%{query}%", effective_limit) async with self._config.provide_connection() as conn: - rows = await conn.fetch(sql, app_name, user_id, f"%{query}%", effective_limit) + rows = await conn.fetch(sql, *params) return [cast("MemoryRecord", dict(row)) for row in rows] @@ -403,8 +433,8 @@ async def delete_entries_older_than(self, days: int) -> int: sql = f""" DELETE FROM {self._memory_table} - WHERE inserted_at < (CURRENT_TIMESTAMP - INTERVAL $1 DAY) + WHERE inserted_at < (CURRENT_TIMESTAMP - INTERVAL '{days} days') """ async with self._config.provide_connection() as conn: - result = await conn.execute(sql, days) + result = await conn.execute(sql) return int(result.split()[-1]) if result else 0 diff --git a/sqlspec/adapters/cockroach_psycopg/adk/store.py b/sqlspec/adapters/cockroach_psycopg/adk/store.py index 97ee76bc6..d2906c7ea 100644 --- a/sqlspec/adapters/cockroach_psycopg/adk/store.py +++ b/sqlspec/adapters/cockroach_psycopg/adk/store.py @@ -6,11 +6,14 @@ from psycopg import sql as pg_sql from psycopg.types.json import Jsonb -from sqlspec.extensions.adk import BaseAsyncADKStore, BaseSyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore +from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.utils.logging import get_logger +from sqlspec.utils.sync_tools import async_, run_ if TYPE_CHECKING: + from datetime import datetime + from sqlspec.adapters.cockroach_psycopg.config import CockroachPsycopgAsyncConfig, CockroachPsycopgSyncConfig from sqlspec.extensions.adk import MemoryRecord @@ -59,7 +62,19 @@ def _build_insert_params_with_owner(entry: "MemoryRecord", owner_id: "object | N class CockroachPsycopgAsyncADKStore(BaseAsyncADKStore["CockroachPsycopgAsyncConfig"]): - """CockroachDB ADK store using psycopg async driver.""" + """CockroachDB ADK store using psycopg async driver. + + Implements session and event storage for Google Agent Development Kit + using CockroachDB via psycopg in PostgreSQL compatibility mode. + Events are stored as a single JSONB blob (``event_json``) alongside + indexed scalar columns for efficient querying. + + CockroachDB-specific differences from native PostgreSQL: + - No FILLFACTOR (CockroachDB uses different storage engine) + - SQL strings require ``.encode()`` for cockroach-psycopg driver + - GIN/Inverted indexes on JSONB are fully supported (v23.1+) + - Native tsvector/tsquery FTS with GIN is supported (v23.1+) + """ __slots__ = () @@ -67,7 +82,6 @@ def __init__(self, config: "CockroachPsycopgAsyncConfig") -> None: super().__init__(config) async def _get_create_sessions_table_sql(self) -> str: - """Get CockroachDB CREATE TABLE SQL for sessions.""" owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" @@ -87,35 +101,28 @@ async def _get_create_sessions_table_sql(self) -> str: CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time ON {self._session_table}(update_time DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state + ON {self._session_table} USING GIN (state) + WHERE state != '{{}}'::jsonb; """ async def _get_create_events_table_sql(self) -> str: - """Get CockroachDB CREATE TABLE SQL for events.""" return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, invocation_id VARCHAR(256) NOT NULL, author VARCHAR(256) NOT NULL, - actions BYTEA NOT NULL, - long_running_tool_ids_json JSONB, - branch VARCHAR(256), timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - content JSONB, - grounding_metadata JSONB, - custom_metadata JSONB, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); + + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_json + ON {self._events_table} USING GIN (event_json); """ def _get_drop_tables_sql(self) -> "list[str]": @@ -239,77 +246,91 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis async def append_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s - ) + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) """ + event_json_value = event_record["event_json"] + jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value + async with self._config.provide_connection() as conn, conn.cursor() as cur: await cur.execute( sql.encode(), ( - event_record["id"], event_record["session_id"], - event_record["app_name"], - event_record["user_id"], event_record["invocation_id"], event_record["author"], - event_record["actions"], - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), event_record["timestamp"], - event_record.get("content"), - event_record.get("grounding_metadata"), - event_record.get("custom_metadata"), - event_record.get("partial"), - event_record.get("turn_complete"), - event_record.get("interrupted"), - event_record.get("error_code"), - event_record.get("error_message"), + jsonb_value, + ), + ) + await conn.commit() + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) + """ + update_sql = f""" + UPDATE {self._session_table} + SET state = %s, update_time = CURRENT_TIMESTAMP + WHERE id = %s + """ + + event_json_value = event_record["event_json"] + jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value + + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute( + insert_sql.encode(), + ( + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + jsonb_value, ), ) + await cur.execute(update_sql.encode(), (Jsonb(state), session_id)) await conn.commit() - async def list_events(self, session_id: str) -> "list[EventRecord]": + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + where_clauses = ["session_id = %s"] + params: list[Any] = [session_id] + + if after_timestamp is not None: + where_clauses.append("timestamp > %s") + params.append(after_timestamp) + + where_clause = " AND ".join(where_clauses) + limit_clause = " LIMIT %s" if limit else "" + if limit: + params.append(limit) + sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} - WHERE session_id = %s - ORDER BY timestamp ASC + WHERE {where_clause} + ORDER BY timestamp ASC{limit_clause} """ try: async with self._config.provide_connection() as conn, conn.cursor() as cur: - await cur.execute(sql.encode(), (session_id,)) + await cur.execute(sql.encode(), tuple(params)) rows = await cur.fetchall() return [ EventRecord( - id=row["id"], session_id=row["session_id"], - app_name=row["app_name"], - user_id=row["user_id"], invocation_id=row["invocation_id"], author=row["author"], - actions=bytes(row["actions"]), - long_running_tool_ids_json=row["long_running_tool_ids_json"], - branch=row["branch"], timestamp=row["timestamp"], - content=row["content"], - grounding_metadata=row["grounding_metadata"], - custom_metadata=row["custom_metadata"], - partial=row["partial"], - turn_complete=row["turn_complete"], - interrupted=row["interrupted"], - error_code=row["error_code"], - error_message=row["error_message"], + event_json=row["event_json"], ) for row in rows ] @@ -317,15 +338,26 @@ async def list_events(self, session_id: str) -> "list[EventRecord]": return [] -class CockroachPsycopgSyncADKStore(BaseSyncADKStore["CockroachPsycopgSyncConfig"]): - """CockroachDB ADK store using psycopg sync driver.""" +class CockroachPsycopgSyncADKStore(BaseAsyncADKStore["CockroachPsycopgSyncConfig"]): + """CockroachDB ADK store using psycopg sync driver. + + Implements session and event storage for Google Agent Development Kit + using CockroachDB via psycopg in PostgreSQL compatibility mode (sync). + Events are stored as a single JSONB blob (``event_json``) alongside + indexed scalar columns for efficient querying. + + CockroachDB-specific differences from native PostgreSQL: + - No FILLFACTOR (CockroachDB uses different storage engine) + - SQL strings require ``.encode()`` for cockroach-psycopg driver + - GIN/Inverted indexes on JSONB are fully supported (v23.1+) + """ __slots__ = () def __init__(self, config: "CockroachPsycopgSyncConfig") -> None: super().__init__(config) - def _get_create_sessions_table_sql(self) -> str: + async def _get_create_sessions_table_sql(self) -> str: owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" @@ -345,45 +377,43 @@ def _get_create_sessions_table_sql(self) -> str: CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time ON {self._session_table}(update_time DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state + ON {self._session_table} USING GIN (state) + WHERE state != '{{}}'::jsonb; """ - def _get_create_events_table_sql(self) -> str: + async def _get_create_events_table_sql(self) -> str: return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, invocation_id VARCHAR(256) NOT NULL, author VARCHAR(256) NOT NULL, - actions BYTEA NOT NULL, - long_running_tool_ids_json JSONB, - branch VARCHAR(256), timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - content JSONB, - grounding_metadata JSONB, - custom_metadata JSONB, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); + + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_json + ON {self._events_table} USING GIN (event_json); """ def _get_drop_tables_sql(self) -> "list[str]": return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - def create_tables(self) -> None: + def _create_tables(self) -> None: with self._config.provide_session() as driver: - driver.execute_script(self._get_create_sessions_table_sql()) - driver.execute_script(self._get_create_events_table_sql()) + driver.execute_script(run_(self._get_create_sessions_table_sql)()) + driver.execute_script(run_(self._get_create_events_table_sql)()) - def create_session( + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: state_json = Jsonb(state) @@ -406,13 +436,19 @@ def create_session( cur.execute(sql.encode(), params) conn.commit() - result = self.get_session(session_id) + result = self._get_session(session_id) if result is None: msg = "Session creation failed" raise RuntimeError(msg) return result - def get_session(self, session_id: str) -> "SessionRecord | None": + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + def _get_session(self, session_id: str) -> "SessionRecord | None": sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} @@ -438,7 +474,11 @@ def get_session(self, session_id: str) -> "SessionRecord | None": except errors.UndefinedTable: return None - def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def get_session(self, session_id: str) -> "SessionRecord | None": + """Get session by ID.""" + return await async_(self._get_session)(session_id) + + def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: sql = f""" UPDATE {self._session_table} SET state = %s, update_time = CURRENT_TIMESTAMP @@ -449,14 +489,22 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None cur.execute(sql.encode(), (Jsonb(state), session_id)) conn.commit() - def delete_session(self, session_id: str) -> None: + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + await async_(self._update_session_state)(session_id, state) + + def _delete_session(self, session_id: str) -> None: sql = f"DELETE FROM {self._session_table} WHERE id = %s" with self._config.provide_connection() as conn, conn.cursor() as cur: cur.execute(sql.encode(), (session_id,)) conn.commit() - def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + async def delete_session(self, session_id: str) -> None: + """Delete session and associated events.""" + await async_(self._delete_session)(session_id) + + def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -493,86 +541,123 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess except errors.UndefinedTable: return [] - def append_event(self, event_record: EventRecord) -> None: + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return await async_(self._list_sessions)(app_name, user_id) + + def _append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) + """ + update_sql = f""" + UPDATE {self._session_table} + SET state = %s, update_time = CURRENT_TIMESTAMP + WHERE id = %s + """ + + event_json_value = event_record["event_json"] + jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value + + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute( + insert_sql.encode(), + ( + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + jsonb_value, + ), + ) + cur.execute(update_sql.encode(), (Jsonb(state), session_id)) + conn.commit() + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state.""" + await async_(self._append_event_and_update_state)(event_record, session_id, state) + + def _insert_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s - ) + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) """ + event_json_value = event_record["event_json"] + jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value + with self._config.provide_connection() as conn, conn.cursor() as cur: cur.execute( sql.encode(), ( - event_record["id"], event_record["session_id"], - event_record["app_name"], - event_record["user_id"], event_record["invocation_id"], event_record["author"], - event_record["actions"], - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), event_record["timestamp"], - event_record.get("content"), - event_record.get("grounding_metadata"), - event_record.get("custom_metadata"), - event_record.get("partial"), - event_record.get("turn_complete"), - event_record.get("interrupted"), - event_record.get("error_code"), - event_record.get("error_message"), + jsonb_value, ), ) conn.commit() - def list_events(self, session_id: str) -> "list[EventRecord]": + def _get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + where_clauses = ["session_id = %s"] + params: list[Any] = [session_id] + + if after_timestamp is not None: + where_clauses.append("timestamp > %s") + params.append(after_timestamp) + + where_clause = " AND ".join(where_clauses) + limit_clause = " LIMIT %s" if limit else "" sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} - WHERE session_id = %s - ORDER BY timestamp ASC + WHERE {where_clause} + ORDER BY timestamp ASC{limit_clause} """ + if limit: + params.append(limit) try: with self._config.provide_connection() as conn, conn.cursor() as cur: - cur.execute(sql.encode(), (session_id,)) + cur.execute(sql.encode(), tuple(params)) rows = cur.fetchall() return [ EventRecord( - id=row["id"], session_id=row["session_id"], - app_name=row["app_name"], - user_id=row["user_id"], invocation_id=row["invocation_id"], author=row["author"], - actions=bytes(row["actions"]), - long_running_tool_ids_json=row["long_running_tool_ids_json"], - branch=row["branch"], timestamp=row["timestamp"], - content=row["content"], - grounding_metadata=row["grounding_metadata"], - custom_metadata=row["custom_metadata"], - partial=row["partial"], - turn_complete=row["turn_complete"], - interrupted=row["interrupted"], - error_code=row["error_code"], - error_message=row["error_message"], + event_json=row["event_json"], ) for row in rows ] except errors.UndefinedTable: return [] + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session.""" + return await async_(self._get_events)(session_id, after_timestamp, limit) + + def _append_event(self, event_record: EventRecord) -> None: + """Synchronous implementation of append_event.""" + self._insert_event(event_record) + + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + await async_(self._append_event)(event_record) + class CockroachPsycopgAsyncADKMemoryStore(BaseAsyncADKMemoryStore["CockroachPsycopgAsyncConfig"]): """CockroachDB ADK memory store using psycopg async driver.""" @@ -587,6 +672,13 @@ async def _get_create_memory_table_sql(self) -> str: if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" + fts_index = "" + if self._use_fts: + fts_index = f""" + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts + ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); + """ + return f""" CREATE TABLE IF NOT EXISTS {self._memory_table} ( id VARCHAR(128) PRIMARY KEY, @@ -607,6 +699,7 @@ async def _get_create_memory_table_sql(self) -> str: CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session ON {self._memory_table}(session_id); + {fts_index} """ def _get_drop_memory_table_sql(self) -> "list[str]": @@ -675,16 +768,25 @@ async def search_entries( return [] effective_limit = limit if limit is not None else self._max_results + if self._use_fts: - logger.debug("CockroachDB full-text search not supported; using simple search") + sql = f""" + SELECT * FROM {self._memory_table} + WHERE app_name = %s AND user_id = %s + AND to_tsvector('english', content_text) @@ plainto_tsquery('english', %s) + ORDER BY timestamp DESC + LIMIT %s + """ + else: + sql = f""" + SELECT * FROM {self._memory_table} + WHERE app_name = %s AND user_id = %s AND content_text ILIKE %s + ORDER BY timestamp DESC + LIMIT %s + """ - sql = f""" - SELECT * FROM {self._memory_table} - WHERE app_name = %s AND user_id = %s AND content_text ILIKE %s - ORDER BY timestamp DESC - LIMIT %s - """ - params = (app_name, user_id, f"%{query}%", effective_limit) + search_param = query if self._use_fts else f"%{query}%" + params = (app_name, user_id, search_param, effective_limit) try: async with self._config.provide_connection() as conn, conn.cursor() as cur: @@ -714,15 +816,15 @@ async def delete_entries_older_than(self, days: int) -> int: sql = f""" DELETE FROM {self._memory_table} - WHERE inserted_at < (CURRENT_TIMESTAMP - INTERVAL %s DAY) + WHERE inserted_at < CURRENT_TIMESTAMP - INTERVAL '{days} days' """ async with self._config.provide_connection() as conn, conn.cursor() as cur: - await cur.execute(sql.encode(), (days,)) + await cur.execute(sql.encode()) await conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 -class CockroachPsycopgSyncADKMemoryStore(BaseSyncADKMemoryStore["CockroachPsycopgSyncConfig"]): +class CockroachPsycopgSyncADKMemoryStore(BaseAsyncADKMemoryStore["CockroachPsycopgSyncConfig"]): """CockroachDB ADK memory store using psycopg sync driver.""" __slots__ = () @@ -730,11 +832,18 @@ class CockroachPsycopgSyncADKMemoryStore(BaseSyncADKMemoryStore["CockroachPsycop def __init__(self, config: "CockroachPsycopgSyncConfig") -> None: super().__init__(config) - def _get_create_memory_table_sql(self) -> str: + async def _get_create_memory_table_sql(self) -> str: owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" + fts_index = "" + if self._use_fts: + fts_index = f""" + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts + ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); + """ + return f""" CREATE TABLE IF NOT EXISTS {self._memory_table} ( id VARCHAR(128) PRIMARY KEY, @@ -755,19 +864,24 @@ def _get_create_memory_table_sql(self) -> str: CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session ON {self._memory_table}(session_id); + {fts_index} """ def _get_drop_memory_table_sql(self) -> "list[str]": return [f"DROP TABLE IF EXISTS {self._memory_table}"] - def create_tables(self) -> None: + def _create_tables(self) -> None: if not self._enabled: return with self._config.provide_session() as driver: - driver.execute_script(self._get_create_memory_table_sql()) + driver.execute_script(run_(self._get_create_memory_table_sql)()) + + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() - def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -810,7 +924,11 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object inserted_count += cur.rowcount return inserted_count - def search_entries( + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + + def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": if not self._enabled: @@ -821,16 +939,25 @@ def search_entries( return [] effective_limit = limit if limit is not None else self._max_results + if self._use_fts: - logger.debug("CockroachDB full-text search not supported; using simple search") + sql = f""" + SELECT * FROM {self._memory_table} + WHERE app_name = %s AND user_id = %s + AND to_tsvector('english', content_text) @@ plainto_tsquery('english', %s) + ORDER BY timestamp DESC + LIMIT %s + """ + else: + sql = f""" + SELECT * FROM {self._memory_table} + WHERE app_name = %s AND user_id = %s AND content_text ILIKE %s + ORDER BY timestamp DESC + LIMIT %s + """ - sql = f""" - SELECT * FROM {self._memory_table} - WHERE app_name = %s AND user_id = %s AND content_text ILIKE %s - ORDER BY timestamp DESC - LIMIT %s - """ - params = (app_name, user_id, f"%{query}%", effective_limit) + search_param = query if self._use_fts else f"%{query}%" + params = (app_name, user_id, search_param, effective_limit) try: with self._config.provide_connection() as conn, conn.cursor() as cur: @@ -842,7 +969,13 @@ def search_entries( return [cast("MemoryRecord", dict(zip(columns, row, strict=False))) for row in rows] - def delete_entries_by_session(self, session_id: str) -> int: + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + + def _delete_entries_by_session(self, session_id: str) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -853,16 +986,24 @@ def delete_entries_by_session(self, session_id: str) -> int: conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 - def delete_entries_older_than(self, days: int) -> int: + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + def _delete_entries_older_than(self, days: int) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) sql = f""" DELETE FROM {self._memory_table} - WHERE inserted_at < (CURRENT_TIMESTAMP - INTERVAL %s DAY) + WHERE inserted_at < CURRENT_TIMESTAMP - INTERVAL '{days} days' """ with self._config.provide_connection() as conn, conn.cursor() as cur: - cur.execute(sql.encode(), (days,)) + cur.execute(sql.encode()) conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) diff --git a/sqlspec/adapters/duckdb/adk/store.py b/sqlspec/adapters/duckdb/adk/store.py index 4db753b58..9c8d00715 100644 --- a/sqlspec/adapters/duckdb/adk/store.py +++ b/sqlspec/adapters/duckdb/adk/store.py @@ -16,10 +16,11 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Final, cast -from sqlspec.extensions.adk import BaseSyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore +from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json +from sqlspec.utils.sync_tools import async_ if TYPE_CHECKING: from sqlspec.adapters.duckdb.config import DuckDBConfig @@ -33,15 +34,16 @@ DUCKDB_TABLE_NOT_FOUND_ERROR: Final = "does not exist" -class DuckdbADKStore(BaseSyncADKStore["DuckDBConfig"]): +class DuckdbADKStore(BaseAsyncADKStore["DuckDBConfig"]): """DuckDB ADK store for Google Agent Development Kit. Implements session and event storage for Google Agent Development Kit - using DuckDB's synchronous driver. Provides: + using DuckDB's synchronous driver with async wrappers via ``async_()``. + Provides: - Session state management with native JSON type - - Event history tracking with BLOB-serialized actions - - Native TIMESTAMP type support - - Foreign key constraints (manual cascade in delete_session) + - Event history with single JSON blob (event_json) plus indexed scalars + - Native TIMESTAMPTZ type support + - Manual cascade delete (DuckDB has no FK CASCADE) - Columnar storage for analytical queries Args: @@ -62,20 +64,12 @@ class DuckdbADKStore(BaseSyncADKStore["DuckDBConfig"]): } ) store = DuckdbADKStore(config) - store.ensure_tables() - - session = store.create_session( - session_id="session-123", - app_name="my-app", - user_id="user-456", - state={"context": "conversation"} - ) + await store.ensure_tables() Notes: - - Uses DuckDB native JSON type (not JSONB) - - TIMESTAMP for date/time storage with microsecond precision - - BLOB for binary actions data - - BOOLEAN native type support + - Uses DuckDB native JSON type for event_json and state + - TIMESTAMPTZ for date/time storage with microsecond precision + - event_json stores the full ADK Event as a single JSON blob - Columnar storage provides excellent analytical query performance - DuckDB doesn't support CASCADE in foreign keys (manual cascade required) - Optimized for OLAP workloads; for high-concurrency writes use PostgreSQL @@ -98,7 +92,7 @@ def __init__(self, config: "DuckDBConfig") -> None: """ super().__init__(config) - def _get_create_sessions_table_sql(self) -> str: + async def _get_create_sessions_table_sql(self) -> str: """Get DuckDB CREATE TABLE SQL for sessions. Returns: @@ -107,7 +101,7 @@ def _get_create_sessions_table_sql(self) -> str: Notes: - VARCHAR for IDs and names - JSON type for state storage (DuckDB native) - - TIMESTAMP for create_time and update_time + - TIMESTAMPTZ for create_time and update_time - CURRENT_TIMESTAMP for defaults - Optional owner ID column for multi-tenant scenarios - Composite index on (app_name, user_id) for listing @@ -123,48 +117,34 @@ def _get_create_sessions_table_sql(self) -> str: app_name VARCHAR NOT NULL, user_id VARCHAR NOT NULL{owner_id_line}, state JSON NOT NULL, - create_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP ); CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user ON {self._session_table}(app_name, user_id); CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time ON {self._session_table}(update_time DESC); """ - def _get_create_events_table_sql(self) -> str: + async def _get_create_events_table_sql(self) -> str: """Get DuckDB CREATE TABLE SQL for events. Returns: SQL statement to create adk_events table with indexes. Notes: - - VARCHAR for string fields - - BLOB for pickled actions - - JSON for content, grounding_metadata, custom_metadata, long_running_tool_ids_json - - BOOLEAN for flags + - 5-column schema: session_id, invocation_id, author, timestamp, event_json + - event_json stores the full ADK Event as a single JSON blob + - No decomposed columns -- eliminates column drift with upstream ADK - Foreign key constraint (DuckDB doesn't support CASCADE) - Index on (session_id, timestamp ASC) for ordered event retrieval - Manual cascade delete required in delete_session method """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR PRIMARY KEY, session_id VARCHAR NOT NULL, - app_name VARCHAR NOT NULL, - user_id VARCHAR NOT NULL, - invocation_id VARCHAR, - author VARCHAR, - actions BLOB, - long_running_tool_ids_json JSON, - branch VARCHAR, - timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - content JSON, - grounding_metadata JSON, - custom_metadata JSON, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR, - error_message VARCHAR, + invocation_id VARCHAR NOT NULL, + author VARCHAR NOT NULL, + timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + event_json JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); @@ -182,31 +162,53 @@ def _get_drop_tables_sql(self) -> "list[str]": """ return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - def create_tables(self) -> None: - """Create both sessions and events tables if they don't exist.""" + def _create_tables(self) -> None: + """Synchronous implementation of create_tables.""" with self._config.provide_connection() as conn: - conn.execute(self._get_create_sessions_table_sql()) - conn.execute(self._get_create_events_table_sql()) - - def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Create a new session. + conn.execute(self.__get_create_sessions_table_sql_sync()) + conn.execute(self.__get_create_events_table_sql_sync()) - Args: - session_id: Unique session identifier. - app_name: Application name. - user_id: User identifier. - state: Initial session state. - owner_id: Optional owner ID value for owner_id_column (if configured). + def __get_create_sessions_table_sql_sync(self) -> str: + """Synchronous version of DDL generation for use in _create_tables.""" + owner_id_line = "" + if self._owner_id_column_ddl: + owner_id_line = f",\n {self._owner_id_column_ddl}" - Returns: - Created session record. + return f""" + CREATE TABLE IF NOT EXISTS {self._session_table} ( + id VARCHAR PRIMARY KEY, + app_name VARCHAR NOT NULL, + user_id VARCHAR NOT NULL{owner_id_line}, + state JSON NOT NULL, + create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user ON {self._session_table}(app_name, user_id); + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time ON {self._session_table}(update_time DESC); + """ - Notes: - Uses current UTC timestamp for create_time and update_time. - State is JSON-serialized using SQLSpec serializers. + def __get_create_events_table_sql_sync(self) -> str: + """Synchronous version of DDL generation for use in _create_tables.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._events_table} ( + session_id VARCHAR NOT NULL, + invocation_id VARCHAR NOT NULL, + author VARCHAR NOT NULL, + timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + event_json JSON NOT NULL, + FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) + ); + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); """ + + async def create_tables(self) -> None: + """Create both sessions and events tables if they don't exist.""" + await async_(self._create_tables)() + + def _create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Synchronous implementation of create_session.""" now = datetime.now(timezone.utc) state_json = to_json(state) @@ -233,19 +235,29 @@ def create_session( id=session_id, app_name=app_name, user_id=user_id, state=state, create_time=now, update_time=now ) - def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID. + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session. Args: - session_id: Session identifier. + session_id: Unique session identifier. + app_name: Application name. + user_id: User identifier. + state: Initial session state. + owner_id: Optional owner ID value for owner_id_column (if configured). Returns: - Session record or None if not found. + Created session record. Notes: - DuckDB returns datetime objects for TIMESTAMP columns. - JSON is parsed from database storage. + Uses current UTC timestamp for create_time and update_time. + State is JSON-serialized using SQLSpec serializers. """ + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + def _get_session(self, session_id: str) -> "SessionRecord | None": + """Synchronous implementation of get_session.""" sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} @@ -277,17 +289,23 @@ def get_session(self, session_id: str) -> "SessionRecord | None": return None raise - def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. + async def get_session(self, session_id: str) -> "SessionRecord | None": + """Get session by ID. Args: session_id: Session identifier. - state: New state dictionary (replaces existing state). + + Returns: + Session record or None if not found. Notes: - This replaces the entire state dictionary. - Update time is automatically set to current UTC timestamp. + DuckDB returns datetime objects for TIMESTAMPTZ columns. + JSON is parsed from database storage. """ + return await async_(self._get_session)(session_id) + + def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Synchronous implementation of update_session_state.""" now = datetime.now(timezone.utc) state_json = to_json(state) @@ -301,15 +319,21 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None conn.execute(sql, (state_json, now, session_id)) conn.commit() - def delete_session(self, session_id: str) -> None: - """Delete session and all associated events. + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state. Args: session_id: Session identifier. + state: New state dictionary (replaces existing state). Notes: - DuckDB doesn't support CASCADE in foreign keys, so we manually delete events first. + This replaces the entire state dictionary. + Update time is automatically set to current UTC timestamp. """ + await async_(self._update_session_state)(session_id, state) + + def _delete_session(self, session_id: str) -> None: + """Synchronous implementation of delete_session.""" delete_events_sql = f"DELETE FROM {self._events_table} WHERE session_id = ?" delete_session_sql = f"DELETE FROM {self._session_table} WHERE id = ?" @@ -318,19 +342,19 @@ def delete_session(self, session_id: str) -> None: conn.execute(delete_session_sql, (session_id,)) conn.commit() - def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app, optionally filtered by user. + async def delete_session(self, session_id: str) -> None: + """Delete session and all associated events. Args: - app_name: Application name. - user_id: User identifier. If None, lists all sessions for the app. - - Returns: - List of session records ordered by update_time DESC. + session_id: Session identifier. Notes: - Uses composite index on (app_name, user_id) when user_id is provided. + DuckDB doesn't support CASCADE in foreign keys, so we manually delete events first. """ + await async_(self._delete_session)(session_id) + + def _list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]": + """Synchronous implementation of list_sessions.""" if user_id is None: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -369,194 +393,136 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess return [] raise - def create_event( - self, - event_id: str, - session_id: str, - app_name: str, - user_id: str, - author: "str | None" = None, - actions: "bytes | None" = None, - content: "dict[str, Any] | None" = None, - **kwargs: Any, - ) -> EventRecord: - """Create a new event. + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app, optionally filtered by user. Args: - event_id: Unique event identifier. - session_id: Session identifier. app_name: Application name. - user_id: User identifier. - author: Event author (user/assistant/system). - actions: Pickled actions object. - content: Event content (JSON). - **kwargs: Additional optional fields. + user_id: User identifier. If None, lists all sessions for the app. Returns: - Created event record. + List of session records ordered by update_time DESC. Notes: - Uses current UTC timestamp if not provided in kwargs. - JSON fields are serialized using SQLSpec serializers. + Uses composite index on (app_name, user_id) when user_id is provided. """ - timestamp = kwargs.get("timestamp", datetime.now(timezone.utc)) - content_json = to_json(content) if content else None - grounding_metadata = kwargs.get("grounding_metadata") - grounding_metadata_json = to_json(grounding_metadata) if grounding_metadata else None - custom_metadata = kwargs.get("custom_metadata") - custom_metadata_json = to_json(custom_metadata) if custom_metadata else None + return await async_(self._list_sessions)(app_name, user_id) + + def _append_event(self, event_record: EventRecord) -> None: + """Synchronous implementation of append_event.""" + event_json_str = to_json(event_record["event_json"]) sql = f""" - INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO {self._events_table} + (session_id, invocation_id, author, timestamp, event_json) + VALUES (?, ?, ?, ?, ?) """ with self._config.provide_connection() as conn: conn.execute( sql, ( - event_id, - session_id, - app_name, - user_id, - kwargs.get("invocation_id"), - author, - actions, - kwargs.get("long_running_tool_ids_json"), - kwargs.get("branch"), - timestamp, - content_json, - grounding_metadata_json, - custom_metadata_json, - kwargs.get("partial"), - kwargs.get("turn_complete"), - kwargs.get("interrupted"), - kwargs.get("error_code"), - kwargs.get("error_message"), + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_json_str, ), ) conn.commit() - return EventRecord( - id=event_id, - session_id=session_id, - app_name=app_name, - user_id=user_id, - invocation_id=kwargs.get("invocation_id", ""), - author=author or "", - actions=actions or b"", - long_running_tool_ids_json=kwargs.get("long_running_tool_ids_json"), - branch=kwargs.get("branch"), - timestamp=timestamp, - content=content, - grounding_metadata=grounding_metadata, - custom_metadata=custom_metadata, - partial=kwargs.get("partial"), - turn_complete=kwargs.get("turn_complete"), - interrupted=kwargs.get("interrupted"), - error_code=kwargs.get("error_code"), - error_message=kwargs.get("error_message"), - ) - - def get_event(self, event_id: str) -> "EventRecord | None": - """Get event by ID. + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session. Args: - event_id: Event identifier. + event_record: Event record with 5 keys (session_id, invocation_id, + author, timestamp, event_json). + """ + await async_(self._append_event)(event_record) - Returns: - Event record or None if not found. + def _append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Synchronous implementation of append_event_and_update_state.""" + now = datetime.now(timezone.utc) + state_json = to_json(state) + event_json_str = to_json(event_record["event_json"]) + + insert_sql = f""" + INSERT INTO {self._events_table} + (session_id, invocation_id, author, timestamp, event_json) + VALUES (?, ?, ?, ?, ?) """ - sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - FROM {self._events_table} + + update_sql = f""" + UPDATE {self._session_table} + SET state = ?, update_time = ? WHERE id = ? """ - try: - with self._config.provide_connection() as conn: - cursor = conn.execute(sql, (event_id,)) - row = cursor.fetchone() - - if row is None: - return None + with self._config.provide_connection() as conn: + conn.execute( + insert_sql, + ( + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_json_str, + ), + ) + conn.execute(update_sql, (state_json, now, session_id)) + conn.commit() - return EventRecord( - id=row[0], - session_id=row[1], - app_name=row[2], - user_id=row[3], - invocation_id=row[4], - author=row[5], - actions=bytes(row[6]) if row[6] else b"", - long_running_tool_ids_json=row[7], - branch=row[8], - timestamp=row[9], - content=from_json(row[10]) if row[10] else None, - grounding_metadata=from_json(row[11]) if row[11] else None, - custom_metadata=from_json(row[12]) if row[12] else None, - partial=row[13], - turn_complete=row[14], - interrupted=row[15], - error_code=row[16], - error_message=row[17], - ) - except Exception as e: - if DUCKDB_TABLE_NOT_FOUND_ERROR in str(e): - return None - raise + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state. - def list_events(self, session_id: str) -> "list[EventRecord]": - """List events for a session ordered by timestamp. + The event insert and state update succeed together or fail together + within a single DuckDB transaction. Args: - session_id: Session identifier. - - Returns: - List of event records ordered by timestamp ASC. + event_record: Event record to store (5-key shape). + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot (``temp:`` keys already + stripped by the service layer). """ + await async_(self._append_event_and_update_state)(event_record, session_id, state) + + def _get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Synchronous implementation of get_events.""" + where_clauses = ["session_id = ?"] + params: list[Any] = [session_id] + + if after_timestamp is not None: + where_clauses.append("timestamp > ?") + params.append(after_timestamp) + + where_clause = " AND ".join(where_clauses) + limit_clause = f" LIMIT {limit}" if limit else "" + sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} - WHERE session_id = ? - ORDER BY timestamp ASC + WHERE {where_clause} + ORDER BY timestamp ASC{limit_clause} """ try: with self._config.provide_connection() as conn: - cursor = conn.execute(sql, (session_id,)) + cursor = conn.execute(sql, params) rows = cursor.fetchall() return [ EventRecord( - id=row[0], - session_id=row[1], - app_name=row[2], - user_id=row[3], - invocation_id=row[4], - author=row[5], - actions=bytes(row[6]) if row[6] else b"", - long_running_tool_ids_json=row[7], - branch=row[8], - timestamp=row[9], - content=from_json(row[10]) if row[10] else None, - grounding_metadata=from_json(row[11]) if row[11] else None, - custom_metadata=from_json(row[12]) if row[12] else None, - partial=row[13], - turn_complete=row[14], - interrupted=row[15], - error_code=row[16], - error_message=row[17], + session_id=row[0], + invocation_id=row[1], + author=row[2], + timestamp=row[3], + event_json=from_json(row[4]) if isinstance(row[4], str) else row[4], ) for row in rows ] @@ -565,14 +531,30 @@ def list_events(self, session_id: str) -> "list[EventRecord]": return [] raise + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session. -class DuckdbADKMemoryStore(BaseSyncADKMemoryStore["DuckDBConfig"]): - """DuckDB ADK memory store using synchronous DuckDB driver. + Args: + session_id: Session identifier. + after_timestamp: Only return events after this time. + limit: Maximum number of events to return. + + Returns: + List of event records ordered by timestamp ASC. + """ + return await async_(self._get_events)(session_id, after_timestamp, limit) + + +class DuckdbADKMemoryStore(BaseAsyncADKMemoryStore["DuckDBConfig"]): + """DuckDB ADK memory store using synchronous DuckDB driver with async wrappers. Implements memory entry storage for Google Agent Development Kit - using DuckDB's synchronous driver. Provides: + using DuckDB's synchronous driver with async wrappers via ``async_()``. + Provides: - Session memory storage with native JSON type - - Simple ILIKE search + - Simple ILIKE search or BM25 full-text search via FTS extension - Native TIMESTAMP type support - Deduplication via event_id unique constraint - Efficient upserts using INSERT OR IGNORE @@ -595,13 +577,15 @@ class DuckdbADKMemoryStore(BaseSyncADKMemoryStore["DuckDBConfig"]): } ) store = DuckdbADKMemoryStore(config) - store.ensure_tables() + await store.ensure_tables() Notes: - Uses DuckDB native JSON type (not JSONB) - TIMESTAMP for date/time storage with microsecond precision - event_id UNIQUE constraint for deduplication - Composite index on (app_name, user_id, timestamp DESC) + - FTS uses match_bm25() for BM25-ranked results (not @@ operator) + - FTS index is refreshed after inserts, not on every search - Columnar storage provides excellent analytical query performance - Optimized for OLAP workloads; for high-concurrency writes use PostgreSQL - Configuration is read from config.extension_config["adk"] @@ -644,12 +628,19 @@ def _create_fts_index(self, conn: Any) -> None: return try: - conn.execute(f"PRAGMA create_fts_index('{self._memory_table}', 'id', 'content_text')") + conn.execute( + f"PRAGMA create_fts_index('{self._memory_table}', 'id', 'content_text', " + f"stemmer='porter', stopwords='english', strip_accents=1, lower=1)" + ) except Exception as exc: logger.debug("Failed to create DuckDB FTS index: %s", exc) def _refresh_fts_index(self, conn: Any) -> None: - """Rebuild the FTS index to reflect recent changes.""" + """Rebuild the FTS index to reflect recent inserts. + + DuckDB FTS indexes do not auto-update. This must be called after + insert/update/delete operations, NOT on every search. + """ if not self._ensure_fts_extension(conn): return @@ -657,11 +648,14 @@ def _refresh_fts_index(self, conn: Any) -> None: conn.execute(f"PRAGMA drop_fts_index('{self._memory_table}')") try: - conn.execute(f"PRAGMA create_fts_index('{self._memory_table}', 'id', 'content_text')") + conn.execute( + f"PRAGMA create_fts_index('{self._memory_table}', 'id', 'content_text', " + f"overwrite=1, stemmer='porter', stopwords='english', strip_accents=1, lower=1)" + ) except Exception as exc: logger.debug("Failed to refresh DuckDB FTS index: %s", exc) - def _get_create_memory_table_sql(self) -> str: + async def _get_create_memory_table_sql(self) -> str: """Get DuckDB CREATE TABLE SQL for memory entries. Returns: @@ -697,18 +691,51 @@ def _get_drop_memory_table_sql(self) -> "list[str]": """Get DuckDB DROP TABLE SQL statements.""" return [f"DROP TABLE IF EXISTS {self._memory_table}"] - def create_tables(self) -> None: - """Create the memory table and indexes if they don't exist.""" + def _create_tables(self) -> None: + """Synchronous implementation of create_tables.""" if not self._enabled: return + ddl = self.__get_create_memory_table_sql_sync() with self._config.provide_connection() as conn: - conn.execute(self._get_create_memory_table_sql()) + conn.execute(ddl) if self._use_fts: self._create_fts_index(conn) - def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication.""" + def __get_create_memory_table_sql_sync(self) -> str: + """Synchronous version of DDL generation for use in _create_tables.""" + owner_id_line = "" + if self._owner_id_column_ddl: + owner_id_line = f",\n {self._owner_id_column_ddl}" + + return f""" + CREATE TABLE IF NOT EXISTS {self._memory_table} ( + id VARCHAR(128) PRIMARY KEY, + session_id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + event_id VARCHAR(128) NOT NULL UNIQUE, + author VARCHAR(256){owner_id_line}, + timestamp TIMESTAMP NOT NULL, + content_json JSON NOT NULL, + content_text TEXT NOT NULL, + metadata_json JSON, + inserted_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time + ON {self._memory_table}(app_name, user_id, timestamp DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session + ON {self._memory_table}(session_id); + """ + + async def create_tables(self) -> None: + """Create the memory table and indexes if they don't exist.""" + await async_(self._create_tables)() + + def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Synchronous implementation of insert_memory_entries.""" if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -770,12 +797,24 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object result = conn.execute(sql, params) inserted_count += len(result.fetchall()) conn.commit() + + # Refresh FTS index after inserts, not on search + if self._use_fts and inserted_count > 0: + self._refresh_fts_index(conn) + return inserted_count - def search_entries( + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication. + + After successful inserts, refreshes the FTS index if FTS is enabled. + """ + return await async_(self._insert_memory_entries)(entries, owner_id) + + def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": - """Search memory entries by text query.""" + """Synchronous implementation of search_entries.""" if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -785,13 +824,19 @@ def search_entries( limit_value = limit or self._max_results if self._use_fts: + # Use match_bm25() -- the correct DuckDB FTS syntax sql = f""" - SELECT * FROM {self._memory_table} - WHERE app_name = ? AND user_id = ? AND content_text @@ ? - ORDER BY timestamp DESC + SELECT m.* + FROM {self._memory_table} m + JOIN ( + SELECT id, fts_main_{self._memory_table}.match_bm25(id, ?, fields := 'content_text') AS score + FROM {self._memory_table} + ) fts ON m.id = fts.id + WHERE m.app_name = ? AND m.user_id = ? AND fts.score IS NOT NULL + ORDER BY fts.score DESC LIMIT ? """ - params = (app_name, user_id, query, limit_value) + params = (query, app_name, user_id, limit_value) else: sql = f""" SELECT * FROM {self._memory_table} @@ -814,13 +859,20 @@ def search_entries( if isinstance(metadata_value, (str, bytes)): record["metadata_json"] = from_json(metadata_value) records.append(record) - if self._use_fts: - with self._config.provide_connection() as conn: - self._refresh_fts_index(conn) return records - def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query. + + When FTS is enabled, uses ``match_bm25()`` for BM25-ranked results. + Falls back to ILIKE for simple substring matching. + """ + return await async_(self._search_entries)(query, app_name, user_id, limit) + + def _delete_entries_by_session(self, session_id: str) -> int: + """Synchronous implementation of delete_entries_by_session.""" if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -830,10 +882,16 @@ def delete_entries_by_session(self, session_id: str) -> int: result = conn.execute(sql, (session_id,)) deleted_count = len(result.fetchall()) conn.commit() + if self._use_fts and deleted_count > 0: + self._refresh_fts_index(conn) return deleted_count - def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + def _delete_entries_older_than(self, days: int) -> int: + """Synchronous implementation of delete_entries_older_than.""" if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -847,4 +905,10 @@ def delete_entries_older_than(self, days: int) -> int: result = conn.execute(sql) deleted_count = len(result.fetchall()) conn.commit() + if self._use_fts and deleted_count > 0: + self._refresh_fts_index(conn) return deleted_count + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) diff --git a/sqlspec/adapters/mysqlconnector/adk/store.py b/sqlspec/adapters/mysqlconnector/adk/store.py index 3aed6258e..1a25702f7 100644 --- a/sqlspec/adapters/mysqlconnector/adk/store.py +++ b/sqlspec/adapters/mysqlconnector/adk/store.py @@ -5,9 +5,10 @@ import mysql.connector -from sqlspec.extensions.adk import BaseAsyncADKStore, BaseSyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore +from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.utils.serializers import from_json, to_json +from sqlspec.utils.sync_tools import async_, run_ if TYPE_CHECKING: from datetime import datetime @@ -38,8 +39,61 @@ def _parse_owner_id_column_for_mysql(column_ddl: str) -> "tuple[str, str]": return (col_def, fk_constraint) +def _mysql_sessions_ddl(session_table: str, owner_id_column_ddl: "str | None") -> str: + """Generate shared MySQL sessions CREATE TABLE DDL.""" + owner_id_col = "" + fk_constraint = "" + + if owner_id_column_ddl: + col_def, fk_def = _parse_owner_id_column_for_mysql(owner_id_column_ddl) + owner_id_col = f"{col_def}," + if fk_def: + fk_constraint = f",\n {fk_def}" + + return f""" + CREATE TABLE IF NOT EXISTS {session_table} ( + id VARCHAR(128) PRIMARY KEY, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + {owner_id_col} + state JSON NOT NULL, + create_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), + INDEX idx_{session_table}_app_user (app_name, user_id), + INDEX idx_{session_table}_update_time (update_time DESC){fk_constraint} + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_events_ddl(events_table: str, session_table: str) -> str: + """Generate shared MySQL events CREATE TABLE DDL (post clean-break, 5 columns).""" + return f""" + CREATE TABLE IF NOT EXISTS {events_table} ( + session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256) NOT NULL, + author VARCHAR(128) NOT NULL, + timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + event_json JSON NOT NULL, + FOREIGN KEY (session_id) REFERENCES {session_table}(id) ON DELETE CASCADE, + INDEX idx_{events_table}_session (session_id, timestamp ASC) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + class MysqlConnectorAsyncADKStore(BaseAsyncADKStore["MysqlConnectorAsyncConfig"]): - """MySQL/MariaDB ADK store using mysql-connector async driver.""" + """MySQL/MariaDB ADK store using mysql-connector async driver. + + Provides: + - Session state management with JSON storage + - Full-event JSON storage (single ``event_json`` column) + - Atomic event-append + state-update in one transaction + - Microsecond-precision timestamps + - Foreign key constraints with cascade delete + + Notes: + - Uses ``cast()`` extensively because mysql-connector returns ``Any`` types + - Configuration is read from config.extension_config["adk"] + """ __slots__ = () @@ -50,54 +104,10 @@ def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]" return _parse_owner_id_column_for_mysql(column_ddl) async def _get_create_sessions_table_sql(self) -> str: - owner_id_col = "" - fk_constraint = "" - - if self._owner_id_column_ddl: - col_def, fk_def = self._parse_owner_id_column_for_mysql(self._owner_id_column_ddl) - owner_id_col = f"{col_def}," - if fk_def: - fk_constraint = f",\n {fk_def}" - - return f""" - CREATE TABLE IF NOT EXISTS {self._session_table} ( - id VARCHAR(128) PRIMARY KEY, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - {owner_id_col} - state JSON NOT NULL, - create_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), - INDEX idx_{self._session_table}_app_user (app_name, user_id), - INDEX idx_{self._session_table}_update_time (update_time DESC){fk_constraint} - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ + return _mysql_sessions_ddl(self._session_table, self._owner_id_column_ddl) async def _get_create_events_table_sql(self) -> str: - return f""" - CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, - session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, - actions BLOB NOT NULL, - long_running_tool_ids_json JSON, - branch VARCHAR(256), - timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - content JSON, - grounding_metadata JSON, - custom_metadata JSON, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), - FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE, - INDEX idx_{self._events_table}_session (session_id, timestamp ASC) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ + return _mysql_events_ddl(self._events_table, self._session_table) def _get_drop_tables_sql(self) -> "list[str]": return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] @@ -242,23 +252,19 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis raise async def append_event(self, event_record: EventRecord) -> None: - content_json = to_json(event_record.get("content")) if event_record.get("content") else None - grounding_metadata_json = ( - to_json(event_record.get("grounding_metadata")) if event_record.get("grounding_metadata") else None - ) - custom_metadata_json = ( - to_json(event_record.get("custom_metadata")) if event_record.get("custom_metadata") else None - ) + """Append an event to a session. + + Args: + event_record: Event record with 5 keys (session_id, invocation_id, + author, timestamp, event_json). + """ + event_json = event_record["event_json"] + event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s - ) + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) """ async with self._config.provide_connection() as conn: @@ -267,26 +273,60 @@ async def append_event(self, event_record: EventRecord) -> None: await cursor.execute( sql, ( - event_record["id"], event_record["session_id"], - event_record["app_name"], - event_record["user_id"], event_record["invocation_id"], event_record["author"], - event_record["actions"], - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), event_record["timestamp"], - content_json, - grounding_metadata_json, - custom_metadata_json, - event_record.get("partial"), - event_record.get("turn_complete"), - event_record.get("interrupted"), - event_record.get("error_code"), - event_record.get("error_message"), + event_json_str, + ), + ) + finally: + await cursor.close() + await conn.commit() + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state. + + The event insert and state update succeed together or fail together + within a single transaction. + + Args: + event_record: Event record to store. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot. + """ + event_json = event_record["event_json"] + event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + state_json = to_json(state) + + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) + """ + + update_sql = f""" + UPDATE {self._session_table} + SET state = %s + WHERE id = %s + """ + + async with self._config.provide_connection() as conn: + cursor = await conn.cursor() + try: + await cursor.execute( + insert_sql, + ( + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_json_str, ), ) + await cursor.execute(update_sql, (state_json, session_id)) finally: await cursor.close() await conn.commit() @@ -294,6 +334,16 @@ async def append_event(self, event_record: EventRecord) -> None: async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": + """Get events for a session. + + Args: + session_id: Session identifier. + after_timestamp: Only return events after this time. + limit: Maximum number of events to return. + + Returns: + List of event records ordered by timestamp ASC. + """ where_clauses = ["session_id = %s"] params: list[Any] = [session_id] @@ -305,10 +355,7 @@ async def get_events( limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -325,30 +372,11 @@ async def get_events( return [ EventRecord( - id=cast("str", row[0]), - session_id=cast("str", row[1]), - app_name=cast("str", row[2]), - user_id=cast("str", row[3]), - invocation_id=cast("str", row[4]), - author=cast("str", row[5]), - actions=bytes(cast("bytes", row[6])), - long_running_tool_ids_json=cast("str | None", row[7]), - branch=cast("str | None", row[8]), - timestamp=cast("datetime", row[9]), - content=from_json(row[10]) - if row[10] and isinstance(row[10], str) - else cast("dict[str, Any] | None", row[10]), - grounding_metadata=from_json(row[11]) - if row[11] and isinstance(row[11], str) - else cast("dict[str, Any] | None", row[11]), - custom_metadata=from_json(row[12]) - if row[12] and isinstance(row[12], str) - else cast("dict[str, Any] | None", row[12]), - partial=cast("bool | None", row[13]), - turn_complete=cast("bool | None", row[14]), - interrupted=cast("bool | None", row[15]), - error_code=cast("str | None", row[16]), - error_message=cast("str | None", row[17]), + session_id=cast("str", row[0]), + invocation_id=cast("str", row[1]), + author=cast("str", row[2]), + timestamp=cast("datetime", row[3]), + event_json=from_json(row[4]) if isinstance(row[4], str) else cast("dict[str, Any]", row[4]), ) for row in rows ] @@ -358,8 +386,20 @@ async def get_events( raise -class MysqlConnectorSyncADKStore(BaseSyncADKStore["MysqlConnectorSyncConfig"]): - """MySQL/MariaDB ADK store using mysql-connector sync driver.""" +class MysqlConnectorSyncADKStore(BaseAsyncADKStore["MysqlConnectorSyncConfig"]): + """MySQL/MariaDB ADK store using mysql-connector sync driver. + + Provides: + - Session state management with JSON storage + - Full-event JSON storage (single ``event_json`` column) + - Atomic event-create + state-update in one transaction + - Microsecond-precision timestamps + - Foreign key constraints with cascade delete + + Notes: + - Uses ``cast()`` extensively because mysql-connector returns ``Any`` types + - Configuration is read from config.extension_config["adk"] + """ __slots__ = () @@ -369,65 +409,25 @@ def __init__(self, config: "MysqlConnectorSyncConfig") -> None: def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]": return _parse_owner_id_column_for_mysql(column_ddl) - def _get_create_sessions_table_sql(self) -> str: - owner_id_col = "" - fk_constraint = "" - - if self._owner_id_column_ddl: - col_def, fk_def = self._parse_owner_id_column_for_mysql(self._owner_id_column_ddl) - owner_id_col = f"{col_def}," - if fk_def: - fk_constraint = f",\n {fk_def}" - - return f""" - CREATE TABLE IF NOT EXISTS {self._session_table} ( - id VARCHAR(128) PRIMARY KEY, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - {owner_id_col} - state JSON NOT NULL, - create_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), - INDEX idx_{self._session_table}_app_user (app_name, user_id), - INDEX idx_{self._session_table}_update_time (update_time DESC){fk_constraint} - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ + async def _get_create_sessions_table_sql(self) -> str: + return _mysql_sessions_ddl(self._session_table, self._owner_id_column_ddl) - def _get_create_events_table_sql(self) -> str: - return f""" - CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, - session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, - actions BLOB NOT NULL, - long_running_tool_ids_json JSON, - branch VARCHAR(256), - timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - content JSON, - grounding_metadata JSON, - custom_metadata JSON, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), - FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE, - INDEX idx_{self._events_table}_session (session_id, timestamp ASC) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ + async def _get_create_events_table_sql(self) -> str: + return _mysql_events_ddl(self._events_table, self._session_table) def _get_drop_tables_sql(self) -> "list[str]": return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - def create_tables(self) -> None: + def _create_tables(self) -> None: with self._config.provide_session() as driver: - driver.execute_script(self._get_create_sessions_table_sql()) - driver.execute_script(self._get_create_events_table_sql()) + driver.execute_script(run_(self._get_create_sessions_table_sql)()) + driver.execute_script(run_(self._get_create_events_table_sql)()) + + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() - def create_session( + def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: state_json = to_json(state) @@ -454,13 +454,19 @@ def create_session( cursor.close() conn.commit() - result = self.get_session(session_id) + result = self._get_session(session_id) if result is None: msg = "Failed to fetch created session" raise RuntimeError(msg) return result - def get_session(self, session_id: str) -> "SessionRecord | None": + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + def _get_session(self, session_id: str) -> "SessionRecord | None": sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} @@ -494,7 +500,11 @@ def get_session(self, session_id: str) -> "SessionRecord | None": return None raise - def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def get_session(self, session_id: str) -> "SessionRecord | None": + """Get session by ID.""" + return await async_(self._get_session)(session_id) + + def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: state_json = to_json(state) sql = f""" @@ -511,7 +521,11 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None cursor.close() conn.commit() - def delete_session(self, session_id: str) -> None: + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + await async_(self._update_session_state)(session_id, state) + + def _delete_session(self, session_id: str) -> None: sql = f"DELETE FROM {self._session_table} WHERE id = %s" with self._config.provide_connection() as conn: @@ -522,7 +536,11 @@ def delete_session(self, session_id: str) -> None: cursor.close() conn.commit() - def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + async def delete_session(self, session_id: str) -> None: + """Delete session and associated events.""" + await async_(self._delete_session)(session_id) + + def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -565,24 +583,71 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess return [] raise - def append_event(self, event_record: EventRecord) -> None: - content_json = to_json(event_record.get("content")) if event_record.get("content") else None - grounding_metadata_json = ( - to_json(event_record.get("grounding_metadata")) if event_record.get("grounding_metadata") else None - ) - custom_metadata_json = ( - to_json(event_record.get("custom_metadata")) if event_record.get("custom_metadata") else None - ) + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return await async_(self._list_sessions)(app_name, user_id) + + def _append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically create an event and update the session's durable state. + + The event insert and state update succeed together or fail together + within a single transaction. + + Args: + event_record: Event record to store. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot. + """ + event_json = event_record["event_json"] + event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + state_json = to_json(state) + + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) + """ + + update_sql = f""" + UPDATE {self._session_table} + SET state = %s + WHERE id = %s + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute( + insert_sql, + ( + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_json_str, + ), + ) + cursor.execute(update_sql, (state_json, session_id)) + finally: + cursor.close() + conn.commit() + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state.""" + await async_(self._append_event_and_update_state)(event_record, session_id, state) + + def _insert_event(self, event_record: EventRecord) -> None: + event_json = event_record["event_json"] + event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s - ) + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) """ with self._config.provide_connection() as conn: @@ -591,33 +656,30 @@ def append_event(self, event_record: EventRecord) -> None: cursor.execute( sql, ( - event_record["id"], event_record["session_id"], - event_record["app_name"], - event_record["user_id"], event_record["invocation_id"], event_record["author"], - event_record["actions"], - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), event_record["timestamp"], - content_json, - grounding_metadata_json, - custom_metadata_json, - event_record.get("partial"), - event_record.get("turn_complete"), - event_record.get("interrupted"), - event_record.get("error_code"), - event_record.get("error_message"), + event_json_str, ), ) finally: cursor.close() conn.commit() - def get_events( + def _get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": + """List events for a session ordered by timestamp. + + Args: + session_id: Session identifier. + after_timestamp: Only return events after this time. + limit: Maximum number of events to return. + + Returns: + List of event records ordered by timestamp ASC. + """ where_clauses = ["session_id = %s"] params: list[Any] = [session_id] @@ -626,53 +688,32 @@ def get_events( params.append(after_timestamp) where_clause = " AND ".join(where_clauses) - limit_clause = f" LIMIT {limit}" if limit else "" - + limit_clause = " LIMIT %s" if limit else "" sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} """ + if limit: + params.append(limit) try: with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute(sql, params) + cursor.execute(sql, tuple(params)) rows = cursor.fetchall() finally: cursor.close() return [ EventRecord( - id=cast("str", row[0]), - session_id=cast("str", row[1]), - app_name=cast("str", row[2]), - user_id=cast("str", row[3]), - invocation_id=cast("str", row[4]), - author=cast("str", row[5]), - actions=bytes(cast("bytes", row[6])), - long_running_tool_ids_json=cast("str | None", row[7]), - branch=cast("str | None", row[8]), - timestamp=cast("datetime", row[9]), - content=from_json(row[10]) - if row[10] and isinstance(row[10], str) - else cast("dict[str, Any] | None", row[10]), - grounding_metadata=from_json(row[11]) - if row[11] and isinstance(row[11], str) - else cast("dict[str, Any] | None", row[11]), - custom_metadata=from_json(row[12]) - if row[12] and isinstance(row[12], str) - else cast("dict[str, Any] | None", row[12]), - partial=cast("bool | None", row[13]), - turn_complete=cast("bool | None", row[14]), - interrupted=cast("bool | None", row[15]), - error_code=cast("str | None", row[16]), - error_message=cast("str | None", row[17]), + session_id=cast("str", row[0]), + invocation_id=cast("str", row[1]), + author=cast("str", row[2]), + timestamp=cast("datetime", row[3]), + event_json=from_json(row[4]) if isinstance(row[4], str) else cast("dict[str, Any]", row[4]), ) for row in rows ] @@ -681,6 +722,20 @@ def get_events( return [] raise + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session.""" + return await async_(self._get_events)(session_id, after_timestamp, limit) + + def _append_event(self, event_record: EventRecord) -> None: + """Synchronous implementation of append_event.""" + self._insert_event(event_record) + + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + await async_(self._append_event)(event_record) + class MysqlConnectorAsyncADKMemoryStore(BaseAsyncADKMemoryStore["MysqlConnectorAsyncConfig"]): """MySQL/MariaDB ADK memory store using mysql-connector async driver.""" @@ -871,7 +926,7 @@ async def delete_entries_older_than(self, days: int) -> int: await cursor.close() -class MysqlConnectorSyncADKMemoryStore(BaseSyncADKMemoryStore["MysqlConnectorSyncConfig"]): +class MysqlConnectorSyncADKMemoryStore(BaseAsyncADKMemoryStore["MysqlConnectorSyncConfig"]): """MySQL/MariaDB ADK memory store using mysql-connector sync driver.""" __slots__ = () @@ -879,7 +934,7 @@ class MysqlConnectorSyncADKMemoryStore(BaseSyncADKMemoryStore["MysqlConnectorSyn def __init__(self, config: "MysqlConnectorSyncConfig") -> None: super().__init__(config) - def _get_create_memory_table_sql(self) -> str: + async def _get_create_memory_table_sql(self) -> str: owner_id_line = "" fk_constraint = "" if self._owner_id_column_ddl: @@ -913,14 +968,18 @@ def _get_create_memory_table_sql(self) -> str: def _get_drop_memory_table_sql(self) -> "list[str]": return [f"DROP TABLE IF EXISTS {self._memory_table}"] - def create_tables(self) -> None: + def _create_tables(self) -> None: if not self._enabled: return with self._config.provide_session() as driver: - driver.execute_script(self._get_create_memory_table_sql()) + driver.execute_script(run_(self._get_create_memory_table_sql)()) - def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -986,7 +1045,11 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object conn.commit() return inserted_count - def search_entries( + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + + def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": if not self._enabled: @@ -1026,7 +1089,13 @@ def search_entries( return [cast("MemoryRecord", dict(zip(columns, row, strict=False))) for row in rows] - def delete_entries_by_session(self, session_id: str) -> int: + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + + def _delete_entries_by_session(self, session_id: str) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -1041,7 +1110,11 @@ def delete_entries_by_session(self, session_id: str) -> int: finally: cursor.close() - def delete_entries_older_than(self, days: int) -> int: + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + def _delete_entries_older_than(self, days: int) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -1058,3 +1131,7 @@ def delete_entries_older_than(self, days: int) -> int: return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 finally: cursor.close() + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) diff --git a/sqlspec/adapters/oracledb/adk/store.py b/sqlspec/adapters/oracledb/adk/store.py index 4aa9b2811..56248882d 100644 --- a/sqlspec/adapters/oracledb/adk/store.py +++ b/sqlspec/adapters/oracledb/adk/store.py @@ -12,10 +12,11 @@ OracledbSyncDataDictionary, OracleVersionInfo, ) -from sqlspec.extensions.adk import BaseAsyncADKStore, BaseSyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore +from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json +from sqlspec.utils.sync_tools import async_, run_ from sqlspec.utils.type_guards import is_async_readable, is_readable if TYPE_CHECKING: @@ -93,43 +94,28 @@ def storage_type_from_version(version_info: "OracleVersionInfo | None") -> JSONS return _storage_type_from_version(version_info) -def _to_oracle_bool(value: "bool | None") -> "int | None": - """Convert Python boolean to Oracle NUMBER(1). +def _event_json_column_ddl(storage_type: JSONStorageType) -> str: + """Return the DDL fragment for the event_json column. - Args: - value: Python boolean value or None. - - Returns: - 1 for True, 0 for False, None for None. + For JSON_NATIVE (Oracle 21c+) we use the native JSON type. + For older versions we use BLOB since Oracle recommends BLOB over CLOB for + JSON storage. BLOB_JSON gets a CHECK constraint; BLOB_PLAIN does not. """ - if value is None: - return None - return 1 if value else 0 + if storage_type == JSONStorageType.JSON_NATIVE: + return "event_json JSON NOT NULL" + if storage_type == JSONStorageType.BLOB_JSON: + return "event_json BLOB CHECK (event_json IS JSON) NOT NULL" + return "event_json BLOB NOT NULL" -def _from_oracle_bool(value: "int | None") -> "bool | None": - """Convert Oracle NUMBER(1) to Python boolean. - - Args: - value: Oracle NUMBER value (0, 1, or None). +def _oracle_text_value(value: Any) -> str: + """Normalize Oracle VARCHAR2 values back to Python strings. - Returns: - Python boolean or None. + Oracle stores empty strings as ``NULL``. The ADK event contract allows + empty strings for fields like ``invocation_id``, so reads coerce ``NULL`` + back to ``""``. """ - if value is None: - return None - return bool(value) - - -def _coerce_bytes_payload(value: Any) -> bytes: - """Coerce a LOB payload into bytes.""" - if value is None: - return b"" - if isinstance(value, bytes): - return value - if isinstance(value, str): - return value.encode("utf-8") - return str(value).encode("utf-8") + return "" if value is None else str(value) class OracleAsyncADKStore(BaseAsyncADKStore["OracleAsyncConfig"]): @@ -138,7 +124,8 @@ class OracleAsyncADKStore(BaseAsyncADKStore["OracleAsyncConfig"]): Implements session and event storage for Google Agent Development Kit using Oracle Database via the python-oracledb async driver. Provides: - Session state management with version-specific JSON storage - - Event history tracking with BLOB-serialized actions + - Full-fidelity event storage via ``event_json`` column + - Atomic ``append_event_and_update_state`` for durable session mutations - TIMESTAMP WITH TIME ZONE for timezone-aware timestamps - Foreign key constraints with cascade delete - Efficient upserts using MERGE statement @@ -146,28 +133,10 @@ class OracleAsyncADKStore(BaseAsyncADKStore["OracleAsyncConfig"]): Args: config: OracleAsyncConfig with extension_config["adk"] settings. - Example: - from sqlspec.adapters.oracledb import OracleAsyncConfig - from sqlspec.adapters.oracledb.adk import OracleAsyncADKStore - - config = OracleAsyncConfig( - connection_config={"dsn": "oracle://..."}, - extension_config={ - "adk": { - "session_table": "my_sessions", - "events_table": "my_events", - "owner_id_column": "tenant_id NUMBER(10) REFERENCES tenants(id)" - } - } - ) - store = OracleAsyncADKStore(config) - await store.ensure_tables() - Notes: - JSON storage type detected based on Oracle version (21c+, 12c+, legacy) - - BLOB for pre-serialized actions from Google ADK + - event_json stored as JSON (21c+) or BLOB (older versions) - TIMESTAMP WITH TIME ZONE for timezone-aware timestamps - - NUMBER(1) for booleans (0/1/NULL) - Named parameters using :param_name - State merging handled at application level - owner_id_column supports NUMBER, VARCHAR2, RAW for Oracle FK types @@ -223,10 +192,9 @@ async def _detect_json_storage_type(self) -> JSONStorageType: Notes: Queries product_component_version to determine Oracle version. - Oracle 21c+ with compatible >= 20: Native JSON type - - Oracle 12c+: BLOB with IS JSON constraint (preferred) - - Oracle 11g and earlier: BLOB without constraint + - Oracle 12c+: BLOB with IS JSON constraint + - Oracle 11g and earlier: plain BLOB - BLOB is preferred over CLOB for 12c+ as per Oracle recommendations. Result is cached in self._json_storage_type. """ if self._json_storage_type is not None: @@ -296,55 +264,40 @@ async def _deserialize_state(self, data: Any) -> "dict[str, Any]": return from_json(str(data)) # type: ignore[no-any-return] - async def _serialize_json_field(self, value: Any) -> "str | bytes | None": - """Serialize optional JSON field for event storage. - - Args: - value: Value to serialize (dict or None). - - Returns: - Serialized JSON or None. - """ - if value is None: + async def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": + """Deserialize JSON payloads from Oracle JSON/BLOB/LOB values.""" + if data is None: return None + return await self._deserialize_state(data) + async def _serialize_event_json(self, event_json: Any) -> "str | bytes": + """Serialize event_json to the configured Oracle JSON storage format.""" storage_type = await self._detect_json_storage_type() - if storage_type == JSONStorageType.JSON_NATIVE: - return to_json(value) + return to_json(event_json) + return to_json(event_json, as_bytes=True) - return to_json(value, as_bytes=True) - - async def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": - """Deserialize optional JSON field from database. + async def _read_event_json(self, data: Any) -> str: + """Read event_json from database, handling LOB types. Args: - data: Data from database (may be LOB, str, bytes, dict, or None). + data: Data from database (may be LOB, str, or dict). Returns: - Deserialized dictionary or None. - - Notes: - Oracle JSON type may return dict directly. + JSON string. """ - if data is None: - return None - if is_async_readable(data): data = await data.read() elif is_readable(data): data = data.read() if isinstance(data, dict): - return cast("dict[str, Any]", _coerce_decimal_values(data)) + return to_json(data) if isinstance(data, bytes): - return from_json(data) # type: ignore[no-any-return] - - if isinstance(data, str): - return from_json(data) # type: ignore[no-any-return] + return data.decode("utf-8") - return from_json(str(data)) # type: ignore[no-any-return] + return str(data) def _get_create_sessions_table_sql_for_type(self, storage_type: JSONStorageType) -> str: """Get Oracle CREATE TABLE SQL for sessions with specified storage type. @@ -406,54 +359,27 @@ def _get_create_sessions_table_sql_for_type(self, storage_type: JSONStorageType) def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) -> str: """Get Oracle CREATE TABLE SQL for events with specified storage type. + The events table uses the new 5-column contract: session_id, invocation_id, + author, timestamp, and event_json. The event_json column stores the full + ADK Event as JSON (21c+) or BLOB (older versions). + Args: storage_type: JSON storage type to use. Returns: SQL statement to create adk_events table. """ - if storage_type == JSONStorageType.JSON_NATIVE: - json_columns = """ - content JSON, - grounding_metadata JSON, - custom_metadata JSON, - long_running_tool_ids_json JSON - """ - elif storage_type == JSONStorageType.BLOB_JSON: - json_columns = """ - content BLOB CHECK (content IS JSON), - grounding_metadata BLOB CHECK (grounding_metadata IS JSON), - custom_metadata BLOB CHECK (custom_metadata IS JSON), - long_running_tool_ids_json BLOB CHECK (long_running_tool_ids_json IS JSON) - """ - else: - json_columns = """ - content BLOB, - grounding_metadata BLOB, - custom_metadata BLOB, - long_running_tool_ids_json BLOB - """ - + event_json_col = _event_json_column_ddl(storage_type) inmemory_clause = " INMEMORY PRIORITY HIGH" if self._in_memory else "" return f""" BEGIN EXECUTE IMMEDIATE 'CREATE TABLE {self._events_table} ( - id VARCHAR2(128) PRIMARY KEY, session_id VARCHAR2(128) NOT NULL, - app_name VARCHAR2(128) NOT NULL, - user_id VARCHAR2(128) NOT NULL, invocation_id VARCHAR2(256), author VARCHAR2(256), - actions BLOB, - branch VARCHAR2(256), timestamp TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL, - {json_columns}, - partial NUMBER(1), - turn_complete NUMBER(1), - interrupted NUMBER(1), - error_code VARCHAR2(256), - error_message VARCHAR2(1024), + {event_json_col}, CONSTRAINT fk_{self._events_table}_session FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ){inmemory_clause}'; @@ -753,28 +679,14 @@ async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session. Args: - event_record: Event record to store. - - Notes: - Uses SYSTIMESTAMP for timestamp if not provided. - JSON fields are serialized using version-appropriate format. - Boolean fields are converted to NUMBER(1). + event_record: Event record with 5 keys: session_id, invocation_id, + author, timestamp, event_json. """ - content_data = await self._serialize_json_field(event_record.get("content")) - grounding_metadata_data = await self._serialize_json_field(event_record.get("grounding_metadata")) - custom_metadata_data = await self._serialize_json_field(event_record.get("custom_metadata")) - sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + session_id, invocation_id, author, timestamp, event_json ) VALUES ( - :id, :session_id, :app_name, :user_id, :invocation_id, :author, :actions, - :long_running_tool_ids_json, :branch, :timestamp, :content, - :grounding_metadata, :custom_metadata, :partial, :turn_complete, - :interrupted, :error_code, :error_message + :session_id, :invocation_id, :author, :timestamp, :event_json ) """ @@ -783,26 +695,58 @@ async def append_event(self, event_record: EventRecord) -> None: await cursor.execute( sql, { - "id": event_record["id"], "session_id": event_record["session_id"], - "app_name": event_record["app_name"], - "user_id": event_record["user_id"], "invocation_id": event_record["invocation_id"], "author": event_record["author"], - "actions": event_record["actions"], - "long_running_tool_ids_json": event_record.get("long_running_tool_ids_json"), - "branch": event_record.get("branch"), "timestamp": event_record["timestamp"], - "content": content_data, - "grounding_metadata": grounding_metadata_data, - "custom_metadata": custom_metadata_data, - "partial": _to_oracle_bool(event_record.get("partial")), - "turn_complete": _to_oracle_bool(event_record.get("turn_complete")), - "interrupted": _to_oracle_bool(event_record.get("interrupted")), - "error_code": event_record.get("error_code"), - "error_message": event_record.get("error_message"), + "event_json": await self._serialize_event_json(event_record["event_json"]), + }, + ) + await conn.commit() + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state. + + Both the event insert and session state update are executed within a + single transaction so they succeed or fail together. + + Args: + event_record: Event record with 5 keys: session_id, invocation_id, + author, timestamp, event_json. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot (``temp:`` keys already + stripped by the service layer). + """ + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES ( + :session_id, :invocation_id, :author, :timestamp, :event_json + ) + """ + + state_data = await self._serialize_state(state) + update_sql = f""" + UPDATE {self._session_table} + SET state = :state, update_time = SYSTIMESTAMP + WHERE id = :id + """ + + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute( + insert_sql, + { + "session_id": event_record["session_id"], + "invocation_id": event_record["invocation_id"], + "author": event_record["author"], + "timestamp": event_record["timestamp"], + "event_json": await self._serialize_event_json(event_record["event_json"]), }, ) + await cursor.execute(update_sql, {"state": state_data, "id": session_id}) await conn.commit() async def get_events( @@ -817,11 +761,6 @@ async def get_events( Returns: List of event records ordered by timestamp ASC. - - Notes: - Uses index on (session_id, timestamp ASC). - JSON fields deserialized using version-appropriate format. - Converts BLOB actions to bytes and NUMBER(1) booleans to Python bool. """ where_clauses = ["session_id = :session_id"] @@ -837,10 +776,7 @@ async def get_events( limit_clause = f" FETCH FIRST {limit} ROWS ONLY" sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -852,43 +788,16 @@ async def get_events( await cursor.execute(sql, params) rows = await cursor.fetchall() - results = [] - for row in rows: - actions_blob = row[6] - if is_async_readable(actions_blob): - actions_data = await actions_blob.read() - elif is_readable(actions_blob): - actions_data = actions_blob.read() - else: - actions_data = actions_blob - - content = await self._deserialize_json_field(row[10]) - grounding_metadata = await self._deserialize_json_field(row[11]) - custom_metadata = await self._deserialize_json_field(row[12]) - - results.append( - EventRecord( - id=row[0], - session_id=row[1], - app_name=row[2], - user_id=row[3], - invocation_id=row[4], - author=row[5], - actions=_coerce_bytes_payload(actions_data), - long_running_tool_ids_json=row[7], - branch=row[8], - timestamp=row[9], - content=content, - grounding_metadata=grounding_metadata, - custom_metadata=custom_metadata, - partial=_from_oracle_bool(row[13]), - turn_complete=_from_oracle_bool(row[14]), - interrupted=_from_oracle_bool(row[15]), - error_code=row[16], - error_message=row[17], - ) + return [ + EventRecord( + session_id=row[0], + invocation_id=_oracle_text_value(row[1]), + author=_oracle_text_value(row[2]), + timestamp=row[3], + event_json=await self._deserialize_json_field(row[4]) or {}, ) - return results + for row in rows + ] except oracledb.DatabaseError as e: error_obj = e.args[0] if e.args else None if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: @@ -896,13 +805,14 @@ async def get_events( raise -class OracleSyncADKStore(BaseSyncADKStore["OracleSyncConfig"]): +class OracleSyncADKStore(BaseAsyncADKStore["OracleSyncConfig"]): """Oracle synchronous ADK store using oracledb sync driver. Implements session and event storage for Google Agent Development Kit using Oracle Database via the python-oracledb synchronous driver. Provides: - Session state management with version-specific JSON storage - - Event history tracking with BLOB-serialized actions + - Full-fidelity event storage via ``event_json`` column + - Atomic ``create_event_and_update_state`` for durable session mutations - TIMESTAMP WITH TIME ZONE for timezone-aware timestamps - Foreign key constraints with cascade delete - Efficient upserts using MERGE statement @@ -910,28 +820,10 @@ class OracleSyncADKStore(BaseSyncADKStore["OracleSyncConfig"]): Args: config: OracleSyncConfig with extension_config["adk"] settings. - Example: - from sqlspec.adapters.oracledb import OracleSyncConfig - from sqlspec.adapters.oracledb.adk import OracleSyncADKStore - - config = OracleSyncConfig( - connection_config={"dsn": "oracle://..."}, - extension_config={ - "adk": { - "session_table": "my_sessions", - "events_table": "my_events", - "owner_id_column": "account_id NUMBER(19) REFERENCES accounts(id)" - } - } - ) - store = OracleSyncADKStore(config) - store.ensure_tables() - Notes: - JSON storage type detected based on Oracle version (21c+, 12c+, legacy) - - BLOB for pre-serialized actions from Google ADK + - event_json stored as JSON (21c+) or BLOB (older versions) - TIMESTAMP WITH TIME ZONE for timezone-aware timestamps - - NUMBER(1) for booleans (0/1/NULL) - Named parameters using :param_name - State merging handled at application level - owner_id_column supports NUMBER, VARCHAR2, RAW for Oracle FK types @@ -960,7 +852,7 @@ def __init__(self, config: "OracleSyncConfig") -> None: adk_config = config.extension_config.get("adk", {}) self._in_memory: bool = bool(adk_config.get("in_memory", False)) - def _get_create_sessions_table_sql(self) -> str: + async def _get_create_sessions_table_sql(self) -> str: """Get Oracle CREATE TABLE SQL for sessions table. Auto-detects optimal JSON storage type based on Oracle version. @@ -969,7 +861,7 @@ def _get_create_sessions_table_sql(self) -> str: storage_type = self._detect_json_storage_type() return self._get_create_sessions_table_sql_for_type(storage_type) - def _get_create_events_table_sql(self) -> str: + async def _get_create_events_table_sql(self) -> str: """Get Oracle CREATE TABLE SQL for events table. Auto-detects optimal JSON storage type based on Oracle version. @@ -987,10 +879,9 @@ def _detect_json_storage_type(self) -> JSONStorageType: Notes: Queries product_component_version to determine Oracle version. - Oracle 21c+ with compatible >= 20: Native JSON type - - Oracle 12c+: BLOB with IS JSON constraint (preferred) - - Oracle 11g and earlier: BLOB without constraint + - Oracle 12c+: BLOB with IS JSON constraint + - Oracle 11g and earlier: plain BLOB - BLOB is preferred over CLOB for 12c+ as per Oracle recommendations. Result is cached in self._json_storage_type. """ if self._json_storage_type is not None: @@ -1058,53 +949,38 @@ def _deserialize_state(self, data: Any) -> "dict[str, Any]": return from_json(str(data)) # type: ignore[no-any-return] - def _serialize_json_field(self, value: Any) -> "str | bytes | None": - """Serialize optional JSON field for event storage. - - Args: - value: Value to serialize (dict or None). - - Returns: - Serialized JSON or None. - """ - if value is None: + def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": + """Deserialize JSON payloads from Oracle JSON/BLOB/LOB values.""" + if data is None: return None + return self._deserialize_state(data) + def _serialize_event_json(self, event_json: Any) -> "str | bytes": + """Serialize event_json to the configured Oracle JSON storage format.""" storage_type = self._detect_json_storage_type() - if storage_type == JSONStorageType.JSON_NATIVE: - return to_json(value) + return to_json(event_json) + return to_json(event_json, as_bytes=True) - return to_json(value, as_bytes=True) - - def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": - """Deserialize optional JSON field from database. + def _read_event_json(self, data: Any) -> str: + """Read event_json from database, handling LOB types. Args: - data: Data from database (may be LOB, str, bytes, dict, or None). + data: Data from database (may be LOB, str, or dict). Returns: - Deserialized dictionary or None. - - Notes: - Oracle JSON type may return dict directly. + JSON string. """ - if data is None: - return None - if is_readable(data): data = data.read() if isinstance(data, dict): - return cast("dict[str, Any]", _coerce_decimal_values(data)) + return to_json(data) if isinstance(data, bytes): - return from_json(data) # type: ignore[no-any-return] + return data.decode("utf-8") - if isinstance(data, str): - return from_json(data) # type: ignore[no-any-return] - - return from_json(str(data)) # type: ignore[no-any-return] + return str(data) def _get_create_sessions_table_sql_for_type(self, storage_type: JSONStorageType) -> str: """Get Oracle CREATE TABLE SQL for sessions with specified storage type. @@ -1166,54 +1042,27 @@ def _get_create_sessions_table_sql_for_type(self, storage_type: JSONStorageType) def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) -> str: """Get Oracle CREATE TABLE SQL for events with specified storage type. + The events table uses the new 5-column contract: session_id, invocation_id, + author, timestamp, and event_json. The event_json column stores the full + ADK Event as JSON (21c+) or BLOB (older versions). + Args: storage_type: JSON storage type to use. Returns: SQL statement to create adk_events table. """ - if storage_type == JSONStorageType.JSON_NATIVE: - json_columns = """ - content JSON, - grounding_metadata JSON, - custom_metadata JSON, - long_running_tool_ids_json JSON - """ - elif storage_type == JSONStorageType.BLOB_JSON: - json_columns = """ - content BLOB CHECK (content IS JSON), - grounding_metadata BLOB CHECK (grounding_metadata IS JSON), - custom_metadata BLOB CHECK (custom_metadata IS JSON), - long_running_tool_ids_json BLOB CHECK (long_running_tool_ids_json IS JSON) - """ - else: - json_columns = """ - content BLOB, - grounding_metadata BLOB, - custom_metadata BLOB, - long_running_tool_ids_json BLOB - """ - + event_json_col = _event_json_column_ddl(storage_type) inmemory_clause = " INMEMORY PRIORITY HIGH" if self._in_memory else "" return f""" BEGIN EXECUTE IMMEDIATE 'CREATE TABLE {self._events_table} ( - id VARCHAR2(128) PRIMARY KEY, session_id VARCHAR2(128) NOT NULL, - app_name VARCHAR2(128) NOT NULL, - user_id VARCHAR2(128) NOT NULL, invocation_id VARCHAR2(256), author VARCHAR2(256), - actions BLOB, - branch VARCHAR2(256), timestamp TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL, - {json_columns}, - partial NUMBER(1), - turn_complete NUMBER(1), - interrupted NUMBER(1), - error_code VARCHAR2(256), - error_message VARCHAR2(1024), + {event_json_col}, CONSTRAINT fk_{self._events_table}_session FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ){inmemory_clause}'; @@ -1298,7 +1147,7 @@ def _get_drop_tables_sql(self) -> "list[str]": """, ] - def create_tables(self) -> None: + def _create_tables(self) -> None: """Create both sessions and events tables if they don't exist. Notes: @@ -1315,7 +1164,11 @@ def create_tables(self) -> None: events_sql = SQL(self._get_create_events_table_sql_for_type(storage_type)) driver.execute_script(events_sql) - def create_session( + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: """Create a new session. @@ -1361,9 +1214,19 @@ def create_session( cursor.execute(sql, params) conn.commit() - return self.get_session(session_id) # type: ignore[return-value] + result = self._get_session(session_id) + if result is None: + msg = "Failed to fetch created session" + raise RuntimeError(msg) + return result - def get_session(self, session_id: str) -> "SessionRecord | None": + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + def _get_session(self, session_id: str) -> "SessionRecord | None": """Get session by ID. Args: @@ -1410,7 +1273,11 @@ def get_session(self, session_id: str) -> "SessionRecord | None": return None raise - def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def get_session(self, session_id: str) -> "SessionRecord | None": + """Get session by ID.""" + return await async_(self._get_session)(session_id) + + def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: """Update session state. Args: @@ -1435,7 +1302,11 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None cursor.execute(sql, {"state": state_data, "id": session_id}) conn.commit() - def delete_session(self, session_id: str) -> None: + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + await async_(self._update_session_state)(session_id, state) + + def _delete_session(self, session_id: str) -> None: """Delete session and all associated events (cascade). Args: @@ -1451,7 +1322,11 @@ def delete_session(self, session_id: str) -> None: cursor.execute(sql, {"id": session_id}) conn.commit() - def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + async def delete_session(self, session_id: str) -> None: + """Delete session and associated events.""" + await async_(self._delete_session)(session_id) + + def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app, optionally filtered by user. Args: @@ -1510,159 +1385,147 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess return [] raise - def create_event( - self, - event_id: str, - session_id: str, - app_name: str, - user_id: str, - author: "str | None" = None, - actions: "bytes | None" = None, - content: "dict[str, Any] | None" = None, - **kwargs: Any, - ) -> "EventRecord": - """Create a new event. + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return await async_(self._list_sessions)(app_name, user_id) - Args: - event_id: Unique event identifier. - session_id: Session identifier. - app_name: Application name. - user_id: User identifier. - author: Event author (user/assistant/system). - actions: Pickled actions object. - content: Event content (JSONB/JSON). - **kwargs: Additional optional fields. + def _append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically create an event and update the session's durable state. - Returns: - Created event record. + Both the event insert and session state update are executed within a + single transaction so they succeed or fail together. - Notes: - Uses SYSTIMESTAMP for timestamp if not provided. - JSON fields are serialized using version-appropriate format. - Boolean fields are converted to NUMBER(1). + Args: + event_record: Event record with 5 keys: session_id, invocation_id, + author, timestamp, event_json. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot (``temp:`` keys already + stripped by the service layer). """ - content_data = self._serialize_json_field(content) - grounding_metadata_data = self._serialize_json_field(kwargs.get("grounding_metadata")) - custom_metadata_data = self._serialize_json_field(kwargs.get("custom_metadata")) - - sql = f""" + insert_sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + session_id, invocation_id, author, timestamp, event_json ) VALUES ( - :id, :session_id, :app_name, :user_id, :invocation_id, :author, :actions, - :long_running_tool_ids_json, :branch, :timestamp, :content, - :grounding_metadata, :custom_metadata, :partial, :turn_complete, - :interrupted, :error_code, :error_message + :session_id, :invocation_id, :author, :timestamp, :event_json ) """ + state_data = self._serialize_state(state) + update_sql = f""" + UPDATE {self._session_table} + SET state = :state, update_time = SYSTIMESTAMP + WHERE id = :id + """ + with self._config.provide_connection() as conn: cursor = conn.cursor() cursor.execute( - sql, + insert_sql, { - "id": event_id, - "session_id": session_id, - "app_name": app_name, - "user_id": user_id, - "invocation_id": kwargs.get("invocation_id"), - "author": author, - "actions": actions, - "long_running_tool_ids_json": kwargs.get("long_running_tool_ids_json"), - "branch": kwargs.get("branch"), - "timestamp": kwargs.get("timestamp"), - "content": content_data, - "grounding_metadata": grounding_metadata_data, - "custom_metadata": custom_metadata_data, - "partial": _to_oracle_bool(kwargs.get("partial")), - "turn_complete": _to_oracle_bool(kwargs.get("turn_complete")), - "interrupted": _to_oracle_bool(kwargs.get("interrupted")), - "error_code": kwargs.get("error_code"), - "error_message": kwargs.get("error_message"), + "session_id": event_record["session_id"], + "invocation_id": event_record["invocation_id"], + "author": event_record["author"], + "timestamp": event_record["timestamp"], + "event_json": self._serialize_event_json(event_record["event_json"]), }, ) + cursor.execute(update_sql, {"state": state_data, "id": session_id}) conn.commit() - events = self.list_events(session_id) - for event in events: - if event["id"] == event_id: - return event - - msg = f"Failed to retrieve created event {event_id}" - raise RuntimeError(msg) + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state.""" + await async_(self._append_event_and_update_state)(event_record, session_id, state) - def list_events(self, session_id: str) -> "list[EventRecord]": + def _get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": """List events for a session ordered by timestamp. Args: session_id: Session identifier. + after_timestamp: Only return events after this time. + limit: Maximum number of events to return. Returns: List of event records ordered by timestamp ASC. - - Notes: - Uses index on (session_id, timestamp ASC). - JSON fields deserialized using version-appropriate format. - Converts BLOB actions to bytes and NUMBER(1) booleans to Python bool. """ + where_clauses = ["session_id = :session_id"] + params: dict[str, Any] = {"session_id": session_id} + + if after_timestamp is not None: + where_clauses.append("timestamp > :after_timestamp") + params["after_timestamp"] = after_timestamp + + where_clause = " AND ".join(where_clauses) + limit_clause = f" FETCH FIRST {limit} ROWS ONLY" if limit else "" sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} - WHERE session_id = :session_id - ORDER BY timestamp ASC + WHERE {where_clause} + ORDER BY timestamp ASC{limit_clause} """ try: with self._config.provide_connection() as conn: cursor = conn.cursor() - cursor.execute(sql, {"session_id": session_id}) + cursor.execute(sql, params) rows = cursor.fetchall() - results = [] - for row in rows: - actions_blob = row[6] - actions_data = actions_blob.read() if is_readable(actions_blob) else actions_blob - - content = self._deserialize_json_field(row[10]) - grounding_metadata = self._deserialize_json_field(row[11]) - custom_metadata = self._deserialize_json_field(row[12]) - - results.append( - EventRecord( - id=row[0], - session_id=row[1], - app_name=row[2], - user_id=row[3], - invocation_id=row[4], - author=row[5], - actions=_coerce_bytes_payload(actions_data), - long_running_tool_ids_json=row[7], - branch=row[8], - timestamp=row[9], - content=content, - grounding_metadata=grounding_metadata, - custom_metadata=custom_metadata, - partial=_from_oracle_bool(row[13]), - turn_complete=_from_oracle_bool(row[14]), - interrupted=_from_oracle_bool(row[15]), - error_code=row[16], - error_message=row[17], - ) + return [ + EventRecord( + session_id=row[0], + invocation_id=_oracle_text_value(row[1]), + author=_oracle_text_value(row[2]), + timestamp=row[3], + event_json=self._deserialize_json_field(row[4]) or {}, ) - return results + for row in rows + ] except oracledb.DatabaseError as e: error_obj = e.args[0] if e.args else None if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: return [] raise + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session.""" + return await async_(self._get_events)(session_id, after_timestamp, limit) + + def _append_event(self, event_record: EventRecord) -> None: + """Synchronous implementation of append_event.""" + sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES ( + :session_id, :invocation_id, :author, :timestamp, :event_json + ) + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute( + sql, + { + "session_id": event_record["session_id"], + "invocation_id": event_record["invocation_id"], + "author": event_record["author"], + "timestamp": event_record["timestamp"], + "event_json": self._serialize_event_json(event_record["event_json"]), + }, + ) + conn.commit() + + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + await async_(self._append_event)(event_record) + ORACLE_DUPLICATE_KEY_ERROR: Final = 1 @@ -2030,7 +1893,7 @@ async def _rows_to_records(self, rows: "list[Any]") -> "list[MemoryRecord]": return records -class OracleSyncADKMemoryStore(BaseSyncADKMemoryStore["OracleSyncConfig"]): +class OracleSyncADKMemoryStore(BaseAsyncADKMemoryStore["OracleSyncConfig"]): """Oracle ADK memory store using sync oracledb driver.""" __slots__ = ("_in_memory", "_json_storage_type", "_oracle_version_info") @@ -2081,7 +1944,7 @@ def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": return _extract_json_value(data) - def _get_create_memory_table_sql(self) -> str: + async def _get_create_memory_table_sql(self) -> str: storage_type = self._detect_json_storage_type() return self._get_create_memory_table_sql_for_type(storage_type) @@ -2196,12 +2059,16 @@ def _get_drop_memory_table_sql(self) -> "list[str]": """, ] - def create_tables(self) -> None: + def _create_tables(self) -> None: if not self._enabled: return with self._config.provide_session() as driver: - driver.execute_script(self._get_create_memory_table_sql()) + driver.execute_script(run_(self._get_create_memory_table_sql)()) + + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() def _execute_insert_entry(self, cursor: Any, sql: str, params: "dict[str, Any]") -> bool: """Execute an insert and skip duplicate key errors.""" @@ -2214,7 +2081,7 @@ def _execute_insert_entry(self, cursor: Any, sql: str, params: "dict[str, Any]") raise return True - def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -2261,7 +2128,11 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object return inserted_count - def search_entries( + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + + def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": if not self._enabled: @@ -2280,6 +2151,12 @@ def search_entries( return [] raise + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": sql = f""" SELECT id, session_id, app_name, user_id, event_id, author, @@ -2326,7 +2203,7 @@ def _search_entries_simple(self, query: str, app_name: str, user_id: str, limit: rows = cursor.fetchall() return self._rows_to_records(rows) - def delete_entries_by_session(self, session_id: str) -> int: + def _delete_entries_by_session(self, session_id: str) -> int: sql = f"DELETE FROM {self._memory_table} WHERE session_id = :session_id" with self._config.provide_connection() as conn: cursor = conn.cursor() @@ -2334,7 +2211,11 @@ def delete_entries_by_session(self, session_id: str) -> int: conn.commit() return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 - def delete_entries_older_than(self, days: int) -> int: + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + def _delete_entries_older_than(self, days: int) -> int: sql = f""" DELETE FROM {self._memory_table} WHERE inserted_at < SYSTIMESTAMP - NUMTODSINTERVAL(:days, 'DAY') @@ -2345,6 +2226,10 @@ def delete_entries_older_than(self, days: int) -> int: conn.commit() return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) + def _rows_to_records(self, rows: "list[Any]") -> "list[MemoryRecord]": records: list[MemoryRecord] = [] for row in rows: diff --git a/sqlspec/adapters/psqlpy/adk/store.py b/sqlspec/adapters/psqlpy/adk/store.py index c1629e05d..9cfdcea27 100644 --- a/sqlspec/adapters/psqlpy/adk/store.py +++ b/sqlspec/adapters/psqlpy/adk/store.py @@ -29,79 +29,28 @@ class PsqlpyADKStore(BaseAsyncADKStore["PsqlpyConfig"]): Implements session and event storage for Google Agent Development Kit using PostgreSQL via the high-performance Rust-based psqlpy driver. + Events are stored as a single JSONB blob (``event_json``) alongside + indexed scalar columns for efficient querying. Provides: - Session state management with JSONB storage - - Event history tracking with BYTEA-serialized actions + - Full-fidelity event storage via ``event_json`` JSONB column + - Atomic ``append_event_and_update_state`` for durable session mutations - Microsecond-precision timestamps with TIMESTAMPTZ - Foreign key constraints with cascade delete - - Efficient upserts using ON CONFLICT - GIN indexes for JSONB queries - HOT updates with FILLFACTOR 80 Args: config: PsqlpyConfig with extension_config["adk"] settings. - - Example: - from sqlspec.adapters.psqlpy import PsqlpyConfig - from sqlspec.adapters.psqlpy.adk import PsqlpyADKStore - - config = PsqlpyConfig( - connection_config={"dsn": "postgresql://..."}, - extension_config={ - "adk": { - "session_table": "my_sessions", - "events_table": "my_events", - "owner_id_column": "tenant_id INTEGER NOT NULL REFERENCES tenants(id) ON DELETE CASCADE" - } - } - ) - store = PsqlpyADKStore(config) - await store.ensure_tables() - - Notes: - - PostgreSQL JSONB type used for state (more efficient than JSON) - - Psqlpy automatically converts Python dicts to/from JSONB - - TIMESTAMPTZ provides timezone-aware microsecond precision - - BYTEA for pre-serialized actions from Google ADK - - GIN index on state for JSONB queries (partial index) - - FILLFACTOR 80 leaves space for HOT updates - - Uses PostgreSQL numeric parameter style ($1, $2, $3) - - Configuration is read from config.extension_config["adk"] """ __slots__ = () def __init__(self, config: "PsqlpyConfig") -> None: - """Initialize Psqlpy ADK store. - - Args: - config: PsqlpyConfig instance. - - Notes: - Configuration is read from config.extension_config["adk"]: - - session_table: Sessions table name (default: "adk_sessions") - - events_table: Events table name (default: "adk_events") - - owner_id_column: Optional owner FK column DDL (default: None) - """ super().__init__(config) async def _get_create_sessions_table_sql(self) -> str: - """Get PostgreSQL CREATE TABLE SQL for sessions. - - Returns: - SQL statement to create adk_sessions table with indexes. - - Notes: - - VARCHAR(128) for IDs and names (sufficient for UUIDs and app names) - - JSONB type for state storage with default empty object - - TIMESTAMPTZ with microsecond precision - - FILLFACTOR 80 for HOT updates (reduces table bloat) - - Composite index on (app_name, user_id) for listing - - Index on update_time DESC for recent session queries - - Partial GIN index on state for JSONB queries (only non-empty) - - Optional owner ID column for multi-tenancy or user references - """ owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" @@ -128,66 +77,24 @@ async def _get_create_sessions_table_sql(self) -> str: """ async def _get_create_events_table_sql(self) -> str: - """Get PostgreSQL CREATE TABLE SQL for events. - - Returns: - SQL statement to create adk_events table with indexes. - - Notes: - - VARCHAR sizes: id(128), session_id(128), invocation_id(256), author(256), - branch(256), error_code(256), error_message(1024) - - BYTEA for pre-serialized actions (no size limit) - - JSONB for content, grounding_metadata, custom_metadata, long_running_tool_ids_json - - BOOLEAN for partial, turn_complete, interrupted - - Foreign key to sessions with CASCADE delete - - Index on (session_id, timestamp ASC) for ordered event retrieval - """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256), - author VARCHAR(256), - actions BYTEA, - long_running_tool_ids_json JSONB, - branch VARCHAR(256), + invocation_id VARCHAR(256) NOT NULL, + author VARCHAR(256) NOT NULL, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - content JSONB, - grounding_metadata JSONB, - custom_metadata JSONB, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE - ); + ) WITH (fillfactor = 80); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); """ def _get_drop_tables_sql(self) -> "list[str]": - """Get PostgreSQL DROP TABLE SQL statements. - - Returns: - List of SQL statements to drop tables and indexes. - - Notes: - Order matters: drop events table (child) before sessions (parent). - PostgreSQL automatically drops indexes when dropping tables. - """ return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] async def create_tables(self) -> None: - """Create both sessions and events tables if they don't exist. - - Notes: - Uses driver.execute_script() which handles multiple statements. - Creates sessions table first, then events table (FK dependency). - """ async with self._config.provide_session() as driver: await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) @@ -195,23 +102,6 @@ async def create_tables(self) -> None: async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: - """Create a new session. - - Args: - session_id: Unique session identifier. - app_name: Application name. - user_id: User identifier. - state: Initial session state. - owner_id: Optional owner ID value for owner_id_column (if configured). - - Returns: - Created session record. - - Notes: - Uses CURRENT_TIMESTAMP for create_time and update_time. - State is passed as dict and psqlpy converts to JSONB automatically. - If owner_id_column is configured, owner_id value must be provided. - """ async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] if self._owner_id_column_name: sql = f""" @@ -230,19 +120,6 @@ async def create_session( return await self.get_session(session_id) # type: ignore[return-value] async def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID. - - Args: - session_id: Session identifier. - - Returns: - Session record or None if not found. - - Notes: - PostgreSQL returns datetime objects for TIMESTAMPTZ columns. - JSONB is automatically parsed by psqlpy to Python dicts. - Returns None if table doesn't exist (catches database errors). - """ sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} @@ -273,17 +150,6 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": raise async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. - - Args: - session_id: Session identifier. - state: New state dictionary (replaces existing state). - - Notes: - This replaces the entire state dictionary. - Uses CURRENT_TIMESTAMP for update_time. - Psqlpy automatically converts dict to JSONB. - """ sql = f""" UPDATE {self._session_table} SET state = $1, update_time = CURRENT_TIMESTAMP @@ -294,33 +160,12 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") - await conn.execute(sql, [state, session_id]) async def delete_session(self, session_id: str) -> None: - """Delete session and all associated events (cascade). - - Args: - session_id: Session identifier. - - Notes: - Foreign key constraint ensures events are cascade-deleted. - """ sql = f"DELETE FROM {self._session_table} WHERE id = $1" async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] await conn.execute(sql, [session_id]) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app, optionally filtered by user. - - Args: - app_name: Application name. - user_id: User identifier. If None, lists all sessions for the app. - - Returns: - List of session records ordered by update_time DESC. - - Notes: - Uses composite index on (app_name, user_id) when user_id is provided. - Returns empty list if table doesn't exist. - """ if user_id is None: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -361,74 +206,54 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis raise async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session. - - Args: - event_record: Event record to store. - - Notes: - Uses CURRENT_TIMESTAMP for timestamp if not provided. - JSONB fields are passed as dicts and psqlpy converts automatically. - BYTEA actions field stores pre-serialized data from Google ADK. - """ - content_json = event_record.get("content") - grounding_metadata_json = event_record.get("grounding_metadata") - custom_metadata_json = event_record.get("custom_metadata") - sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18 - ) + session_id, invocation_id, author, timestamp, event_json + ) VALUES ($1, $2, $3, $4, $5) """ async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] await conn.execute( sql, [ - event_record["id"], event_record["session_id"], - event_record["app_name"], - event_record["user_id"], - event_record.get("invocation_id"), - event_record.get("author"), - event_record.get("actions"), - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_record["event_json"], + ], + ) + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES ($1, $2, $3, $4, $5) + """ + update_sql = f""" + UPDATE {self._session_table} + SET state = $1, update_time = CURRENT_TIMESTAMP + WHERE id = $2 + """ + + async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] + await conn.execute( + insert_sql, + [ + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], event_record["timestamp"], - content_json, - grounding_metadata_json, - custom_metadata_json, - event_record.get("partial"), - event_record.get("turn_complete"), - event_record.get("interrupted"), - event_record.get("error_code"), - event_record.get("error_message"), + event_record["event_json"], ], ) + await conn.execute(update_sql, [state, session_id]) async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": - """Get events for a session. - - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. - - Returns: - List of event records ordered by timestamp ASC. - - Notes: - Uses index on (session_id, timestamp ASC). - Parses JSONB fields and converts BYTEA actions to bytes. - Returns empty list if table doesn't exist. - """ where_clauses = ["session_id = $1"] params: list[Any] = [session_id] @@ -442,10 +267,7 @@ async def get_events( params.append(limit) sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -458,24 +280,11 @@ async def get_events( return [ EventRecord( - id=row["id"], session_id=row["session_id"], - app_name=row["app_name"], - user_id=row["user_id"], invocation_id=row["invocation_id"], author=row["author"], - actions=bytes(row["actions"]) if row["actions"] else b"", - long_running_tool_ids_json=row["long_running_tool_ids_json"], - branch=row["branch"], timestamp=row["timestamp"], - content=row["content"], - grounding_metadata=row["grounding_metadata"], - custom_metadata=row["custom_metadata"], - partial=row["partial"], - turn_complete=row["turn_complete"], - interrupted=row["interrupted"], - error_code=row["error_code"], - error_message=row["error_message"], + event_json=row["event_json"], ) for row in rows ] diff --git a/sqlspec/adapters/psycopg/adk/store.py b/sqlspec/adapters/psycopg/adk/store.py index 6011e6dc2..00d68aade 100644 --- a/sqlspec/adapters/psycopg/adk/store.py +++ b/sqlspec/adapters/psycopg/adk/store.py @@ -6,9 +6,10 @@ from psycopg import sql as pg_sql from psycopg.types.json import Jsonb -from sqlspec.extensions.adk import BaseAsyncADKStore, BaseSyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore +from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.utils.logging import get_logger +from sqlspec.utils.sync_tools import async_, run_ if TYPE_CHECKING: from datetime import datetime @@ -56,84 +57,32 @@ def _build_insert_params_with_owner(entry: "MemoryRecord", owner_id: "object | N class PsycopgAsyncADKStore(BaseAsyncADKStore["PsycopgAsyncConfig"]): - """PostgreSQL ADK store using Psycopg3 driver. + """PostgreSQL ADK store using Psycopg3 async driver. Implements session and event storage for Google Agent Development Kit using PostgreSQL via psycopg3 with native async/await support. + Events are stored as a single JSONB blob (``event_json``) alongside + indexed scalar columns for efficient querying. Provides: - - Session state management with JSONB storage and merge operations - - Event history tracking with BYTEA-serialized actions + - Session state management with JSONB storage + - Full-fidelity event storage via ``event_json`` JSONB column + - Atomic ``append_event_and_update_state`` for durable session mutations - Microsecond-precision timestamps with TIMESTAMPTZ - Foreign key constraints with cascade delete - - Efficient upserts using ON CONFLICT - GIN indexes for JSONB queries - HOT updates with FILLFACTOR 80 Args: config: PsycopgAsyncConfig with extension_config["adk"] settings. - - Example: - from sqlspec.adapters.psycopg import PsycopgAsyncConfig - from sqlspec.adapters.psycopg.adk import PsycopgAsyncADKStore - - config = PsycopgAsyncConfig( - connection_config={"conninfo": "postgresql://..."}, - extension_config={ - "adk": { - "session_table": "my_sessions", - "events_table": "my_events", - "owner_id_column": "tenant_id INTEGER NOT NULL REFERENCES tenants(id) ON DELETE CASCADE" - } - } - ) - store = PsycopgAsyncADKStore(config) - await store.ensure_tables() - - Notes: - - PostgreSQL JSONB type used for state (more efficient than JSON) - - Psycopg requires wrapping dicts with Jsonb() for type safety - - TIMESTAMPTZ provides timezone-aware microsecond precision - - State merging uses `state || $1::jsonb` operator for efficiency - - BYTEA for pre-serialized actions from Google ADK - - GIN index on state for JSONB queries (partial index) - - FILLFACTOR 80 leaves space for HOT updates - - Parameter style: $1, $2, $3 (PostgreSQL numeric placeholders) - - Configuration is read from config.extension_config["adk"] """ __slots__ = () def __init__(self, config: "PsycopgAsyncConfig") -> None: - """Initialize Psycopg ADK store. - - Args: - config: PsycopgAsyncConfig instance. - - Notes: - Configuration is read from config.extension_config["adk"]: - - session_table: Sessions table name (default: "adk_sessions") - - events_table: Events table name (default: "adk_events") - - owner_id_column: Optional owner FK column DDL (default: None) - """ super().__init__(config) async def _get_create_sessions_table_sql(self) -> str: - """Get PostgreSQL CREATE TABLE SQL for sessions. - - Returns: - SQL statement to create adk_sessions table with indexes. - - Notes: - - VARCHAR(128) for IDs and names (sufficient for UUIDs and app names) - - JSONB type for state storage with default empty object - - TIMESTAMPTZ with microsecond precision - - FILLFACTOR 80 for HOT updates (reduces table bloat) - - Composite index on (app_name, user_id) for listing - - Index on update_time DESC for recent session queries - - Partial GIN index on state for JSONB queries (only non-empty) - - Optional owner ID column for multi-tenancy or user references - """ owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" @@ -160,61 +109,24 @@ async def _get_create_sessions_table_sql(self) -> str: """ async def _get_create_events_table_sql(self) -> str: - """Get PostgreSQL CREATE TABLE SQL for events. - - Returns: - SQL statement to create adk_events table with indexes. - - Notes: - - VARCHAR sizes: id(128), session_id(128), invocation_id(256), author(256), - branch(256), error_code(256), error_message(1024) - - BYTEA for pickled actions (no size limit) - - JSONB for content, grounding_metadata, custom_metadata, long_running_tool_ids_json - - BOOLEAN for partial, turn_complete, interrupted - - Foreign key to sessions with CASCADE delete - - Index on (session_id, timestamp ASC) for ordered event retrieval - """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256), - author VARCHAR(256), - actions BYTEA, - long_running_tool_ids_json JSONB, - branch VARCHAR(256), + invocation_id VARCHAR(256) NOT NULL, + author VARCHAR(256) NOT NULL, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - content JSONB, - grounding_metadata JSONB, - custom_metadata JSONB, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE - ); + ) WITH (fillfactor = 80); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); """ def _get_drop_tables_sql(self) -> "list[str]": - """Get PostgreSQL DROP TABLE SQL statements. - - Returns: - List of SQL statements to drop tables and indexes. - - Notes: - Order matters: drop events table (child) before sessions (parent). - PostgreSQL automatically drops indexes when dropping tables. - """ return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] async def create_tables(self) -> None: - """Create both sessions and events tables if they don't exist.""" async with self._config.provide_session() as driver: await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) @@ -222,23 +134,6 @@ async def create_tables(self) -> None: async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: - """Create a new session. - - Args: - session_id: Unique session identifier. - app_name: Application name. - user_id: User identifier. - state: Initial session state. - owner_id: Optional owner ID value for owner_id_column (if configured). - - Returns: - Created session record. - - Notes: - Uses CURRENT_TIMESTAMP for create_time and update_time. - State is wrapped with Jsonb() for PostgreSQL type safety. - If owner_id_column is configured, owner_id value must be provided. - """ params: tuple[Any, ...] if self._owner_id_column_name: query = pg_sql.SQL(""" @@ -261,18 +156,6 @@ async def create_session( return await self.get_session(session_id) # type: ignore[return-value] async def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID. - - Args: - session_id: Session identifier. - - Returns: - Session record or None if not found. - - Notes: - PostgreSQL returns datetime objects for TIMESTAMPTZ columns. - JSONB is automatically deserialized by psycopg to Python dict. - """ query = pg_sql.SQL(""" SELECT id, app_name, user_id, state, create_time, update_time FROM {table} @@ -299,17 +182,6 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": return None async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. - - Args: - session_id: Session identifier. - state: New state dictionary (replaces existing state). - - Notes: - This replaces the entire state dictionary. - Uses CURRENT_TIMESTAMP for update_time. - State is wrapped with Jsonb() for PostgreSQL type safety. - """ query = pg_sql.SQL(""" UPDATE {table} SET state = %s, update_time = CURRENT_TIMESTAMP @@ -320,32 +192,12 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") - await cur.execute(query, (Jsonb(state), session_id)) async def delete_session(self, session_id: str) -> None: - """Delete session and all associated events (cascade). - - Args: - session_id: Session identifier. - - Notes: - Foreign key constraint ensures events are cascade-deleted. - """ query = pg_sql.SQL("DELETE FROM {table} WHERE id = %s").format(table=pg_sql.Identifier(self._session_table)) async with self._config.provide_connection() as conn, conn.cursor() as cur: await cur.execute(query, (session_id,)) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app, optionally filtered by user. - - Args: - app_name: Application name. - user_id: User identifier. If None, lists all sessions for the app. - - Returns: - List of session records ordered by update_time DESC. - - Notes: - Uses composite index on (app_name, user_id) when user_id is provided. - """ if user_id is None: query = pg_sql.SQL(""" SELECT id, app_name, user_id, state, create_time, update_time @@ -383,73 +235,62 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis return [] async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session. + query = pg_sql.SQL(""" + INSERT INTO {table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) + """).format(table=pg_sql.Identifier(self._events_table)) - Args: - event_record: Event record to store. + event_json_value = event_record["event_json"] + jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value - Notes: - Uses CURRENT_TIMESTAMP for timestamp if not provided. - JSONB fields are wrapped with Jsonb() for PostgreSQL type safety. - """ - content_json = event_record.get("content") - grounding_metadata_json = event_record.get("grounding_metadata") - custom_metadata_json = event_record.get("custom_metadata") + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute( + query, + ( + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + jsonb_value, + ), + ) - query = pg_sql.SQL(""" + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + insert_query = pg_sql.SQL(""" INSERT INTO {table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s - ) + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) """).format(table=pg_sql.Identifier(self._events_table)) + update_query = pg_sql.SQL(""" + UPDATE {table} + SET state = %s, update_time = CURRENT_TIMESTAMP + WHERE id = %s + """).format(table=pg_sql.Identifier(self._session_table)) + + event_json_value = event_record["event_json"] + jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value + async with self._config.provide_connection() as conn, conn.cursor() as cur: await cur.execute( - query, + insert_query, ( - event_record["id"], event_record["session_id"], - event_record["app_name"], - event_record["user_id"], - event_record.get("invocation_id"), - event_record.get("author"), - event_record.get("actions"), - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), + event_record["invocation_id"], + event_record["author"], event_record["timestamp"], - Jsonb(content_json) if content_json is not None else None, - Jsonb(grounding_metadata_json) if grounding_metadata_json is not None else None, - Jsonb(custom_metadata_json) if custom_metadata_json is not None else None, - event_record.get("partial"), - event_record.get("turn_complete"), - event_record.get("interrupted"), - event_record.get("error_code"), - event_record.get("error_message"), + jsonb_value, ), ) + await cur.execute(update_query, (Jsonb(state), session_id)) + await conn.commit() async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": - """Get events for a session. - - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. - - Returns: - List of event records ordered by timestamp ASC. - - Notes: - Uses index on (session_id, timestamp ASC). - JSONB fields are automatically deserialized by psycopg. - BYTEA actions are converted to bytes. - """ where_clauses = ["session_id = %s"] params: list[Any] = [session_id] @@ -463,10 +304,7 @@ async def get_events( query = pg_sql.SQL( """ - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -484,24 +322,11 @@ async def get_events( return [ EventRecord( - id=row["id"], session_id=row["session_id"], - app_name=row["app_name"], - user_id=row["user_id"], invocation_id=row["invocation_id"], author=row["author"], - actions=bytes(row["actions"]) if row["actions"] else b"", - long_running_tool_ids_json=row["long_running_tool_ids_json"], - branch=row["branch"], timestamp=row["timestamp"], - content=row["content"], - grounding_metadata=row["grounding_metadata"], - custom_metadata=row["custom_metadata"], - partial=row["partial"], - turn_complete=row["turn_complete"], - interrupted=row["interrupted"], - error_code=row["error_code"], - error_message=row["error_message"], + event_json=row["event_json"], ) for row in rows ] @@ -509,85 +334,33 @@ async def get_events( return [] -class PsycopgSyncADKStore(BaseSyncADKStore["PsycopgSyncConfig"]): +class PsycopgSyncADKStore(BaseAsyncADKStore["PsycopgSyncConfig"]): """PostgreSQL synchronous ADK store using Psycopg3 driver. Implements session and event storage for Google Agent Development Kit using PostgreSQL via psycopg3 with synchronous execution. + Events are stored as a single JSONB blob (``event_json``) alongside + indexed scalar columns for efficient querying. Provides: - - Session state management with JSONB storage and merge operations - - Event history tracking with BYTEA-serialized actions + - Session state management with JSONB storage + - Full-fidelity event storage via ``event_json`` JSONB column + - Atomic ``create_event_and_update_state`` for durable session mutations - Microsecond-precision timestamps with TIMESTAMPTZ - Foreign key constraints with cascade delete - - Efficient upserts using ON CONFLICT - GIN indexes for JSONB queries - HOT updates with FILLFACTOR 80 Args: config: PsycopgSyncConfig with extension_config["adk"] settings. - - Example: - from sqlspec.adapters.psycopg import PsycopgSyncConfig - from sqlspec.adapters.psycopg.adk import PsycopgSyncADKStore - - config = PsycopgSyncConfig( - connection_config={"conninfo": "postgresql://..."}, - extension_config={ - "adk": { - "session_table": "my_sessions", - "events_table": "my_events", - "owner_id_column": "tenant_id INTEGER NOT NULL REFERENCES tenants(id) ON DELETE CASCADE" - } - } - ) - store = PsycopgSyncADKStore(config) - store.ensure_tables() - - Notes: - - PostgreSQL JSONB type used for state (more efficient than JSON) - - Psycopg requires wrapping dicts with Jsonb() for type safety - - TIMESTAMPTZ provides timezone-aware microsecond precision - - State merging uses `state || $1::jsonb` operator for efficiency - - BYTEA for pre-serialized actions from Google ADK - - GIN index on state for JSONB queries (partial index) - - FILLFACTOR 80 leaves space for HOT updates - - Parameter style: $1, $2, $3 (PostgreSQL numeric placeholders) - - Configuration is read from config.extension_config["adk"] """ __slots__ = () def __init__(self, config: "PsycopgSyncConfig") -> None: - """Initialize Psycopg synchronous ADK store. - - Args: - config: PsycopgSyncConfig instance. - - Notes: - Configuration is read from config.extension_config["adk"]: - - session_table: Sessions table name (default: "adk_sessions") - - events_table: Events table name (default: "adk_events") - - owner_id_column: Optional owner FK column DDL (default: None) - """ super().__init__(config) - def _get_create_sessions_table_sql(self) -> str: - """Get PostgreSQL CREATE TABLE SQL for sessions. - - Returns: - SQL statement to create adk_sessions table with indexes. - - Notes: - - VARCHAR(128) for IDs and names (sufficient for UUIDs and app names) - - JSONB type for state storage with default empty object - - TIMESTAMPTZ with microsecond precision - - FILLFACTOR 80 for HOT updates (reduces table bloat) - - Composite index on (app_name, user_id) for listing - - Index on update_time DESC for recent session queries - - Partial GIN index on state for JSONB queries (only non-empty) - - Optional owner ID column for multi-tenancy or user references - """ + async def _get_create_sessions_table_sql(self) -> str: owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" @@ -613,86 +386,36 @@ def _get_create_sessions_table_sql(self) -> str: WHERE state != '{{}}'::jsonb; """ - def _get_create_events_table_sql(self) -> str: - """Get PostgreSQL CREATE TABLE SQL for events. - - Returns: - SQL statement to create adk_events table with indexes. - - Notes: - - VARCHAR sizes: id(128), session_id(128), invocation_id(256), author(256), - branch(256), error_code(256), error_message(1024) - - BYTEA for pickled actions (no size limit) - - JSONB for content, grounding_metadata, custom_metadata, long_running_tool_ids_json - - BOOLEAN for partial, turn_complete, interrupted - - Foreign key to sessions with CASCADE delete - - Index on (session_id, timestamp ASC) for ordered event retrieval - """ + async def _get_create_events_table_sql(self) -> str: return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256), - author VARCHAR(256), - actions BYTEA, - long_running_tool_ids_json JSONB, - branch VARCHAR(256), + invocation_id VARCHAR(256) NOT NULL, + author VARCHAR(256) NOT NULL, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - content JSONB, - grounding_metadata JSONB, - custom_metadata JSONB, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE - ); + ) WITH (fillfactor = 80); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); """ def _get_drop_tables_sql(self) -> "list[str]": - """Get PostgreSQL DROP TABLE SQL statements. - - Returns: - List of SQL statements to drop tables and indexes. - - Notes: - Order matters: drop events table (child) before sessions (parent). - PostgreSQL automatically drops indexes when dropping tables. - """ return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - def create_tables(self) -> None: - """Create both sessions and events tables if they don't exist.""" + def _create_tables(self) -> None: with self._config.provide_session() as driver: - driver.execute_script(self._get_create_sessions_table_sql()) - driver.execute_script(self._get_create_events_table_sql()) + driver.execute_script(run_(self._get_create_sessions_table_sql)()) + driver.execute_script(run_(self._get_create_events_table_sql)()) - def create_session( + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: - """Create a new session. - - Args: - session_id: Unique session identifier. - app_name: Application name. - user_id: User identifier. - state: Initial session state. - owner_id: Optional owner ID value for owner_id_column (if configured). - - Returns: - Created session record. - - Notes: - Uses CURRENT_TIMESTAMP for create_time and update_time. - State is wrapped with Jsonb() for PostgreSQL type safety. - If owner_id_column is configured, owner_id value must be provided. - """ params: tuple[Any, ...] if self._owner_id_column_name: query = pg_sql.SQL(""" @@ -712,21 +435,19 @@ def create_session( with self._config.provide_connection() as conn, conn.cursor() as cur: cur.execute(query, params) - return self.get_session(session_id) # type: ignore[return-value] - - def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID. - - Args: - session_id: Session identifier. + result = self._get_session(session_id) + if result is None: + msg = "Failed to fetch created session" + raise RuntimeError(msg) + return result - Returns: - Session record or None if not found. + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - Notes: - PostgreSQL returns datetime objects for TIMESTAMPTZ columns. - JSONB is automatically deserialized by psycopg to Python dict. - """ + def _get_session(self, session_id: str) -> "SessionRecord | None": query = pg_sql.SQL(""" SELECT id, app_name, user_id, state, create_time, update_time FROM {table} @@ -752,18 +473,11 @@ def get_session(self, session_id: str) -> "SessionRecord | None": except errors.UndefinedTable: return None - def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. - - Args: - session_id: Session identifier. - state: New state dictionary (replaces existing state). + async def get_session(self, session_id: str) -> "SessionRecord | None": + """Get session by ID.""" + return await async_(self._get_session)(session_id) - Notes: - This replaces the entire state dictionary. - Uses CURRENT_TIMESTAMP for update_time. - State is wrapped with Jsonb() for PostgreSQL type safety. - """ + def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: query = pg_sql.SQL(""" UPDATE {table} SET state = %s, update_time = CURRENT_TIMESTAMP @@ -773,33 +487,21 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None with self._config.provide_connection() as conn, conn.cursor() as cur: cur.execute(query, (Jsonb(state), session_id)) - def delete_session(self, session_id: str) -> None: - """Delete session and all associated events (cascade). - - Args: - session_id: Session identifier. + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + await async_(self._update_session_state)(session_id, state) - Notes: - Foreign key constraint ensures events are cascade-deleted. - """ + def _delete_session(self, session_id: str) -> None: query = pg_sql.SQL("DELETE FROM {table} WHERE id = %s").format(table=pg_sql.Identifier(self._session_table)) with self._config.provide_connection() as conn, conn.cursor() as cur: cur.execute(query, (session_id,)) - def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app, optionally filtered by user. - - Args: - app_name: Application name. - user_id: User identifier. If None, lists all sessions for the app. - - Returns: - List of session records ordered by update_time DESC. + async def delete_session(self, session_id: str) -> None: + """Delete session and associated events.""" + await async_(self._delete_session)(session_id) - Notes: - Uses composite index on (app_name, user_id) when user_id is provided. - """ + def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: query = pg_sql.SQL(""" SELECT id, app_name, user_id, state, create_time, update_time @@ -836,165 +538,130 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess except errors.UndefinedTable: return [] - def create_event( - self, - event_id: str, - session_id: str, - app_name: str, - user_id: str, - author: "str | None" = None, - actions: "bytes | None" = None, - content: "dict[str, Any] | None" = None, - **kwargs: Any, - ) -> EventRecord: - """Create a new event. - - Args: - event_id: Unique event identifier. - session_id: Session identifier. - app_name: Application name. - user_id: User identifier. - author: Event author (user/assistant/system). - actions: Pickled actions object. - content: Event content (JSONB). - **kwargs: Additional optional fields (invocation_id, branch, timestamp, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message, long_running_tool_ids_json). - - Returns: - Created event record. - - Notes: - Uses CURRENT_TIMESTAMP for timestamp if not provided in kwargs. - JSONB fields are wrapped with Jsonb() for PostgreSQL type safety. - """ - content_json = Jsonb(content) if content is not None else None - grounding_metadata = kwargs.get("grounding_metadata") - grounding_metadata_json = Jsonb(grounding_metadata) if grounding_metadata is not None else None - custom_metadata = kwargs.get("custom_metadata") - custom_metadata_json = Jsonb(custom_metadata) if custom_metadata is not None else None + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return await async_(self._list_sessions)(app_name, user_id) - query = pg_sql.SQL(""" + def _insert_event(self, event_record: EventRecord) -> None: + insert_query = pg_sql.SQL(""" INSERT INTO {table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - %s, %s, %s, %s, %s, %s, %s, %s, %s, COALESCE(%s, CURRENT_TIMESTAMP), %s, %s, %s, %s, %s, %s, %s, %s - ) - RETURNING id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) """).format(table=pg_sql.Identifier(self._events_table)) + event_json_value = event_record["event_json"] + jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value + with self._config.provide_connection() as conn, conn.cursor() as cur: cur.execute( - query, + insert_query, ( - event_id, - session_id, - app_name, - user_id, - kwargs.get("invocation_id"), - author, - actions, - kwargs.get("long_running_tool_ids_json"), - kwargs.get("branch"), - kwargs.get("timestamp"), - content_json, - grounding_metadata_json, - custom_metadata_json, - kwargs.get("partial"), - kwargs.get("turn_complete"), - kwargs.get("interrupted"), - kwargs.get("error_code"), - kwargs.get("error_message"), + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + jsonb_value, ), ) - row = cur.fetchone() - - if row is None: - msg = f"Failed to create event {event_id}" - raise RuntimeError(msg) - - return EventRecord( - id=row["id"], - session_id=row["session_id"], - app_name=row["app_name"], - user_id=row["user_id"], - invocation_id=row["invocation_id"], - author=row["author"], - actions=bytes(row["actions"]) if row["actions"] else b"", - long_running_tool_ids_json=row["long_running_tool_ids_json"], - branch=row["branch"], - timestamp=row["timestamp"], - content=row["content"], - grounding_metadata=row["grounding_metadata"], - custom_metadata=row["custom_metadata"], - partial=row["partial"], - turn_complete=row["turn_complete"], - interrupted=row["interrupted"], - error_code=row["error_code"], - error_message=row["error_message"], + conn.commit() + + def _append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + insert_query = pg_sql.SQL(""" + INSERT INTO {table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) + """).format(table=pg_sql.Identifier(self._events_table)) + + update_query = pg_sql.SQL(""" + UPDATE {table} + SET state = %s, update_time = CURRENT_TIMESTAMP + WHERE id = %s + """).format(table=pg_sql.Identifier(self._session_table)) + + event_json_value = event_record["event_json"] + jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value + + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute( + insert_query, + ( + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + jsonb_value, + ), ) + cur.execute(update_query, (Jsonb(state), session_id)) + conn.commit() - def list_events(self, session_id: str) -> "list[EventRecord]": - """List events for a session ordered by timestamp. + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state.""" + await async_(self._append_event_and_update_state)(event_record, session_id, state) - Args: - session_id: Session identifier. + def _get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + where_clauses = ["session_id = %s"] + params: list[Any] = [session_id] - Returns: - List of event records ordered by timestamp ASC. + if after_timestamp is not None: + where_clauses.append("timestamp > %s") + params.append(after_timestamp) - Notes: - Uses index on (session_id, timestamp ASC). - JSONB fields are automatically deserialized by psycopg. - BYTEA actions are converted to bytes. - """ - query = pg_sql.SQL(""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + where_clause = " AND ".join(where_clauses) + if limit: + params.append(limit) + + query = pg_sql.SQL( + """ + SELECT session_id, invocation_id, author, timestamp, event_json FROM {table} - WHERE session_id = %s - ORDER BY timestamp ASC - """).format(table=pg_sql.Identifier(self._events_table)) + WHERE {where_clause} + ORDER BY timestamp ASC{limit_clause} + """ + ).format( + table=pg_sql.Identifier(self._events_table), + where_clause=pg_sql.SQL(where_clause), # pyright: ignore[reportArgumentType] + limit_clause=pg_sql.SQL(" LIMIT %s" if limit else ""), # pyright: ignore[reportArgumentType] + ) try: with self._config.provide_connection() as conn, conn.cursor() as cur: - cur.execute(query, (session_id,)) + cur.execute(query, tuple(params)) rows = cur.fetchall() return [ EventRecord( - id=row["id"], session_id=row["session_id"], - app_name=row["app_name"], - user_id=row["user_id"], invocation_id=row["invocation_id"], author=row["author"], - actions=bytes(row["actions"]) if row["actions"] else b"", - long_running_tool_ids_json=row["long_running_tool_ids_json"], - branch=row["branch"], timestamp=row["timestamp"], - content=row["content"], - grounding_metadata=row["grounding_metadata"], - custom_metadata=row["custom_metadata"], - partial=row["partial"], - turn_complete=row["turn_complete"], - interrupted=row["interrupted"], - error_code=row["error_code"], - error_message=row["error_message"], + event_json=row["event_json"], ) for row in rows ] except errors.UndefinedTable: return [] + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session.""" + return await async_(self._get_events)(session_id, after_timestamp, limit) + + def _append_event(self, event_record: EventRecord) -> None: + """Synchronous implementation of append_event.""" + self._insert_event(event_record) + + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + await async_(self._append_event)(event_record) + class PsycopgAsyncADKMemoryStore(BaseAsyncADKMemoryStore["PsycopgAsyncConfig"]): """PostgreSQL ADK memory store using Psycopg3 async driver.""" @@ -1182,7 +849,7 @@ async def delete_entries_older_than(self, days: int) -> int: return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 -class PsycopgSyncADKMemoryStore(BaseSyncADKMemoryStore["PsycopgSyncConfig"]): +class PsycopgSyncADKMemoryStore(BaseAsyncADKMemoryStore["PsycopgSyncConfig"]): """PostgreSQL ADK memory store using Psycopg3 sync driver.""" __slots__ = () @@ -1191,7 +858,7 @@ def __init__(self, config: "PsycopgSyncConfig") -> None: """Initialize Psycopg sync memory store.""" super().__init__(config) - def _get_create_memory_table_sql(self) -> str: + async def _get_create_memory_table_sql(self) -> str: """Get PostgreSQL CREATE TABLE SQL for memory entries.""" owner_id_line = "" if self._owner_id_column_ddl: @@ -1231,15 +898,19 @@ def _get_drop_memory_table_sql(self) -> "list[str]": """Get PostgreSQL DROP TABLE SQL statements.""" return [f"DROP TABLE IF EXISTS {self._memory_table}"] - def create_tables(self) -> None: + def _create_tables(self) -> None: """Create the memory table and indexes if they don't exist.""" if not self._enabled: return with self._config.provide_session() as driver: - driver.execute_script(self._get_create_memory_table_sql()) + driver.execute_script(run_(self._get_create_memory_table_sql)()) - def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: """Bulk insert memory entries with deduplication.""" if not self._enabled: msg = "Memory store is disabled" @@ -1284,7 +955,11 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object return inserted_count - def search_entries( + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + + def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": """Search memory entries by text query.""" @@ -1304,6 +979,12 @@ def search_entries( except errors.UndefinedTable: return [] + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": sql = pg_sql.SQL( """ @@ -1344,7 +1025,7 @@ def _search_entries_simple(self, query: str, app_name: str, user_id: str, limit: rows = cur.fetchall() return _rows_to_records(rows) - def delete_entries_by_session(self, session_id: str) -> int: + def _delete_entries_by_session(self, session_id: str) -> int: """Delete all memory entries for a specific session.""" sql = pg_sql.SQL("DELETE FROM {table} WHERE session_id = %s").format( table=pg_sql.Identifier(self._memory_table) @@ -1354,7 +1035,11 @@ def delete_entries_by_session(self, session_id: str) -> int: cur.execute(sql, (session_id,)) return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 - def delete_entries_older_than(self, days: int) -> int: + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + def _delete_entries_older_than(self, days: int) -> int: """Delete memory entries older than specified days.""" sql = pg_sql.SQL( """ @@ -1367,6 +1052,10 @@ def delete_entries_older_than(self, days: int) -> int: cur.execute(sql) return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) + def _rows_to_records(rows: "list[Any]") -> "list[MemoryRecord]": return [ diff --git a/sqlspec/adapters/pymysql/adk/store.py b/sqlspec/adapters/pymysql/adk/store.py index e30765c7f..5e7ea8513 100644 --- a/sqlspec/adapters/pymysql/adk/store.py +++ b/sqlspec/adapters/pymysql/adk/store.py @@ -5,9 +5,10 @@ import pymysql -from sqlspec.extensions.adk import BaseSyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore +from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.utils.serializers import from_json, to_json +from sqlspec.utils.sync_tools import async_, run_ if TYPE_CHECKING: from datetime import datetime @@ -33,8 +34,23 @@ def _parse_owner_id_column_for_mysql(column_ddl: str) -> "tuple[str, str]": return (col_def, fk_constraint) -class PyMysqlADKStore(BaseSyncADKStore["PyMysqlConfig"]): - """MySQL/MariaDB ADK store using PyMySQL.""" +class PyMysqlADKStore(BaseAsyncADKStore["PyMysqlConfig"]): + """MySQL/MariaDB ADK store using PyMySQL. + + Implements session and event storage for Google Agent Development Kit + using MySQL/MariaDB via the PyMySQL sync driver. Provides: + - Session state management with JSON storage + - Full-event JSON storage (single ``event_json`` column) + - Atomic event-create + state-update in one transaction + - Microsecond-precision timestamps + - Foreign key constraints with cascade delete + + Notes: + - MySQL JSON type used - requires MySQL 5.7.8+ + - TIMESTAMP(6) provides microsecond precision + - InnoDB engine required for foreign key support + - Configuration is read from config.extension_config["adk"] + """ __slots__ = () @@ -44,7 +60,7 @@ def __init__(self, config: "PyMysqlConfig") -> None: def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]": return _parse_owner_id_column_for_mysql(column_ddl) - def _get_create_sessions_table_sql(self) -> str: + async def _get_create_sessions_table_sql(self) -> str: owner_id_col = "" fk_constraint = "" @@ -68,27 +84,18 @@ def _get_create_sessions_table_sql(self) -> str: ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci """ - def _get_create_events_table_sql(self) -> str: + async def _get_create_events_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for events. + + Post clean-break schema: 5 columns only. + """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( - id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, - actions BLOB NOT NULL, - long_running_tool_ids_json JSON, - branch VARCHAR(256), + author VARCHAR(128) NOT NULL, timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - content JSON, - grounding_metadata JSON, - custom_metadata JSON, - partial BOOLEAN, - turn_complete BOOLEAN, - interrupted BOOLEAN, - error_code VARCHAR(256), - error_message VARCHAR(1024), + event_json JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE, INDEX idx_{self._events_table}_session (session_id, timestamp ASC) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci @@ -97,12 +104,16 @@ def _get_create_events_table_sql(self) -> str: def _get_drop_tables_sql(self) -> "list[str]": return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - def create_tables(self) -> None: + def _create_tables(self) -> None: with self._config.provide_session() as driver: - driver.execute_script(self._get_create_sessions_table_sql()) - driver.execute_script(self._get_create_events_table_sql()) + driver.execute_script(run_(self._get_create_sessions_table_sql)()) + driver.execute_script(run_(self._get_create_events_table_sql)()) + + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() - def create_session( + def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: state_json = to_json(state) @@ -129,13 +140,19 @@ def create_session( cursor.close() conn.commit() - result = self.get_session(session_id) + result = self._get_session(session_id) if result is None: msg = "Failed to fetch created session" raise RuntimeError(msg) return result - def get_session(self, session_id: str) -> "SessionRecord | None": + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + def _get_session(self, session_id: str) -> "SessionRecord | None": sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} @@ -169,7 +186,11 @@ def get_session(self, session_id: str) -> "SessionRecord | None": return None raise - def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def get_session(self, session_id: str) -> "SessionRecord | None": + """Get session by ID.""" + return await async_(self._get_session)(session_id) + + def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: state_json = to_json(state) sql = f""" @@ -186,7 +207,11 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None cursor.close() conn.commit() - def delete_session(self, session_id: str) -> None: + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + await async_(self._update_session_state)(session_id, state) + + def _delete_session(self, session_id: str) -> None: sql = f"DELETE FROM {self._session_table} WHERE id = %s" with self._config.provide_connection() as conn: @@ -197,7 +222,11 @@ def delete_session(self, session_id: str) -> None: cursor.close() conn.commit() - def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + async def delete_session(self, session_id: str) -> None: + """Delete session and associated events.""" + await async_(self._delete_session)(session_id) + + def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -240,24 +269,71 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess return [] raise - def append_event(self, event_record: EventRecord) -> None: - content_json = to_json(event_record.get("content")) if event_record.get("content") else None - grounding_metadata_json = ( - to_json(event_record.get("grounding_metadata")) if event_record.get("grounding_metadata") else None - ) - custom_metadata_json = ( - to_json(event_record.get("custom_metadata")) if event_record.get("custom_metadata") else None - ) + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return await async_(self._list_sessions)(app_name, user_id) + + def _append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically create an event and update the session's durable state. + + The event insert and state update succeed together or fail together + within a single transaction. + + Args: + event_record: Event record to store. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot. + """ + event_json = event_record["event_json"] + event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + state_json = to_json(state) + + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) + """ + + update_sql = f""" + UPDATE {self._session_table} + SET state = %s + WHERE id = %s + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute( + insert_sql, + ( + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_json_str, + ), + ) + cursor.execute(update_sql, (state_json, session_id)) + finally: + cursor.close() + conn.commit() + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state.""" + await async_(self._append_event_and_update_state)(event_record, session_id, state) + + def _insert_event(self, event_record: EventRecord) -> None: + event_json = event_record["event_json"] + event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s - ) + session_id, invocation_id, author, timestamp, event_json + ) VALUES (%s, %s, %s, %s, %s) """ with self._config.provide_connection() as conn: @@ -266,33 +342,30 @@ def append_event(self, event_record: EventRecord) -> None: cursor.execute( sql, ( - event_record["id"], event_record["session_id"], - event_record["app_name"], - event_record["user_id"], event_record["invocation_id"], event_record["author"], - event_record["actions"], - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), event_record["timestamp"], - content_json, - grounding_metadata_json, - custom_metadata_json, - event_record.get("partial"), - event_record.get("turn_complete"), - event_record.get("interrupted"), - event_record.get("error_code"), - event_record.get("error_message"), + event_json_str, ), ) finally: cursor.close() conn.commit() - def get_events( + def _get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": + """List events for a session ordered by timestamp. + + Args: + session_id: Session identifier. + after_timestamp: Only return events after this time. + limit: Maximum number of events to return. + + Returns: + List of event records ordered by timestamp ASC. + """ where_clauses = ["session_id = %s"] params: list[Any] = [session_id] @@ -301,47 +374,32 @@ def get_events( params.append(after_timestamp) where_clause = " AND ".join(where_clauses) - limit_clause = f" LIMIT {limit}" if limit else "" - + limit_clause = " LIMIT %s" if limit else "" sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} """ + if limit: + params.append(limit) try: with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute(sql, params) + cursor.execute(sql, tuple(params)) rows = cursor.fetchall() finally: cursor.close() return [ EventRecord( - id=row[0], - session_id=row[1], - app_name=row[2], - user_id=row[3], - invocation_id=row[4], - author=row[5], - actions=bytes(row[6]), - long_running_tool_ids_json=row[7], - branch=row[8], - timestamp=row[9], - content=from_json(row[10]) if row[10] and isinstance(row[10], str) else row[10], - grounding_metadata=from_json(row[11]) if row[11] and isinstance(row[11], str) else row[11], - custom_metadata=from_json(row[12]) if row[12] and isinstance(row[12], str) else row[12], - partial=row[13], - turn_complete=row[14], - interrupted=row[15], - error_code=row[16], - error_message=row[17], + session_id=row[0], + invocation_id=row[1], + author=row[2], + timestamp=row[3], + event_json=from_json(row[4]) if isinstance(row[4], str) else row[4], ) for row in rows ] @@ -350,8 +408,22 @@ def get_events( return [] raise + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session.""" + return await async_(self._get_events)(session_id, after_timestamp, limit) + + def _append_event(self, event_record: EventRecord) -> None: + """Synchronous implementation of append_event.""" + self._insert_event(event_record) + + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + await async_(self._append_event)(event_record) + -class PyMysqlADKMemoryStore(BaseSyncADKMemoryStore["PyMysqlConfig"]): +class PyMysqlADKMemoryStore(BaseAsyncADKMemoryStore["PyMysqlConfig"]): """MySQL/MariaDB ADK memory store using PyMySQL.""" __slots__ = () @@ -359,7 +431,7 @@ class PyMysqlADKMemoryStore(BaseSyncADKMemoryStore["PyMysqlConfig"]): def __init__(self, config: "PyMysqlConfig") -> None: super().__init__(config) - def _get_create_memory_table_sql(self) -> str: + async def _get_create_memory_table_sql(self) -> str: owner_id_line = "" fk_constraint = "" if self._owner_id_column_ddl: @@ -393,14 +465,18 @@ def _get_create_memory_table_sql(self) -> str: def _get_drop_memory_table_sql(self) -> "list[str]": return [f"DROP TABLE IF EXISTS {self._memory_table}"] - def create_tables(self) -> None: + def _create_tables(self) -> None: if not self._enabled: return with self._config.provide_session() as driver: - driver.execute_script(self._get_create_memory_table_sql()) + driver.execute_script(run_(self._get_create_memory_table_sql)()) - def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -466,7 +542,11 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object conn.commit() return inserted_count - def search_entries( + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + + def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": if not self._enabled: @@ -506,7 +586,13 @@ def search_entries( return [cast("MemoryRecord", dict(zip(columns, row, strict=False))) for row in rows] - def delete_entries_by_session(self, session_id: str) -> int: + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + + def _delete_entries_by_session(self, session_id: str) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -521,7 +607,11 @@ def delete_entries_by_session(self, session_id: str) -> int: finally: cursor.close() - def delete_entries_older_than(self, days: int) -> int: + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + def _delete_entries_older_than(self, days: int) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -538,3 +628,7 @@ def delete_entries_older_than(self, days: int) -> int: return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 finally: cursor.close() + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) diff --git a/sqlspec/adapters/spanner/adk/store.py b/sqlspec/adapters/spanner/adk/store.py index 22ebebc7e..a180de7e5 100644 --- a/sqlspec/adapters/spanner/adk/store.py +++ b/sqlspec/adapters/spanner/adk/store.py @@ -7,11 +7,11 @@ from google.cloud.spanner_v1 import param_types from sqlspec.adapters.spanner.config import SpannerSyncConfig -from sqlspec.adapters.spanner.type_converter import bytes_to_spanner, spanner_to_bytes -from sqlspec.extensions.adk import BaseSyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore +from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.protocols import SpannerParamTypesProtocol from sqlspec.utils.serializers import from_json, to_json +from sqlspec.utils.sync_tools import async_, run_ if TYPE_CHECKING: from google.cloud.spanner_v1.database import Database @@ -42,7 +42,7 @@ def __call__(self, transaction: "Transaction") -> None: transaction.execute_update(sql, params=params, param_types=types) # type: ignore[no-untyped-call] -class SpannerSyncADKStore(BaseSyncADKStore[SpannerSyncConfig]): +class SpannerSyncADKStore(BaseAsyncADKStore[SpannerSyncConfig]): """Spanner ADK store backed by synchronous Spanner client.""" connector_name: ClassVar[str] = "spanner" @@ -80,30 +80,15 @@ def _session_param_types(self, include_owner: bool) -> "dict[str, Any]": types["owner_id"] = SPANNER_PARAM_TYPES.STRING return types - def _event_param_types(self, has_branch: bool) -> "dict[str, Any]": + def _event_param_types(self) -> "dict[str, Any]": json_type = _json_param_type() - types: dict[str, Any] = { - "id": SPANNER_PARAM_TYPES.STRING, + return { "session_id": SPANNER_PARAM_TYPES.STRING, - "app_name": SPANNER_PARAM_TYPES.STRING, - "user_id": SPANNER_PARAM_TYPES.STRING, - "author": SPANNER_PARAM_TYPES.STRING, - "actions": SPANNER_PARAM_TYPES.BYTES, - "long_running_tool_ids_json": json_type, "invocation_id": SPANNER_PARAM_TYPES.STRING, + "author": SPANNER_PARAM_TYPES.STRING, "timestamp": SPANNER_PARAM_TYPES.TIMESTAMP, - "content": json_type, - "grounding_metadata": json_type, - "custom_metadata": json_type, - "partial": SPANNER_PARAM_TYPES.BOOL, - "turn_complete": SPANNER_PARAM_TYPES.BOOL, - "interrupted": SPANNER_PARAM_TYPES.BOOL, - "error_code": SPANNER_PARAM_TYPES.STRING, - "error_message": SPANNER_PARAM_TYPES.STRING, + "event_json": json_type, } - if has_branch: - types["branch"] = SPANNER_PARAM_TYPES.STRING - return types def _decode_state(self, raw: Any) -> Any: if isinstance(raw, str): @@ -117,7 +102,7 @@ def _decode_json(self, raw: Any) -> Any: return from_json(raw) return raw - def create_session( + def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: state_json = to_json(state) @@ -146,7 +131,13 @@ def create_session( "update_time": datetime.now(timezone.utc), } - def get_session(self, session_id: str) -> "SessionRecord | None": + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + def _get_session(self, session_id: str) -> "SessionRecord | None": sql = f""" SELECT id, app_name, user_id, state, create_time, update_time{", " + self._owner_id_column_name if self._owner_id_column_name else ""} FROM {self._session_table} @@ -172,7 +163,11 @@ def get_session(self, session_id: str) -> "SessionRecord | None": } return record - def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def get_session(self, session_id: str) -> "SessionRecord | None": + """Get session by ID.""" + return await async_(self._get_session)(session_id) + + def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: params = {"id": session_id, "state": to_json(state)} json_type = _json_param_type() sql = f""" @@ -184,7 +179,11 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None sql = f"{sql} AND shard_id = MOD(FARM_FINGERPRINT(@id), {self._shard_count})" self._run_write([(sql, params, {"id": SPANNER_PARAM_TYPES.STRING, "state": json_type})]) - def list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]": + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + await async_(self._update_session_state)(session_id, state) + + def _list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]": sql = f""" SELECT id, app_name, user_id, state, create_time, update_time{", " + self._owner_id_column_name if self._owner_id_column_name else ""} FROM {self._session_table} @@ -198,6 +197,7 @@ def list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[Se types["user_id"] = SPANNER_PARAM_TYPES.STRING if self._shard_count > 1: sql = f"{sql} AND shard_id = MOD(FARM_FINGERPRINT(id), {self._shard_count})" + sql = f"{sql} ORDER BY update_time DESC" rows = self._run_read(sql, params, types) records: list[SessionRecord] = [] @@ -214,7 +214,11 @@ def list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[Se records.append(record) return records - def delete_session(self, session_id: str) -> None: + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return await async_(self._list_sessions)(app_name, user_id) + + def _delete_session(self, session_id: str) -> None: shard_clause = ( f" AND shard_id = MOD(FARM_FINGERPRINT(@session_id), {self._shard_count})" if self._shard_count > 1 else "" ) @@ -224,174 +228,135 @@ def delete_session(self, session_id: str) -> None: types = {"session_id": SPANNER_PARAM_TYPES.STRING} self._run_write([(delete_events_sql, params, types), (delete_session_sql, params, types)]) - def create_event( - self, - event_id: str, - session_id: str, - app_name: str, - user_id: str, - author: "str | None" = None, - actions: "bytes | None" = None, - content: "dict[str, Any] | None" = None, - **kwargs: Any, - ) -> EventRecord: - branch = kwargs.get("branch") - long_running_serialized = ( - to_json(kwargs.get("long_running_tool_ids_json")) - if kwargs.get("long_running_tool_ids_json") is not None - else None - ) - content_serialized = to_json(content) if content is not None else None - grounding_serialized = ( - to_json(kwargs.get("grounding_metadata")) if kwargs.get("grounding_metadata") is not None else None - ) - custom_serialized = ( - to_json(kwargs.get("custom_metadata")) if kwargs.get("custom_metadata") is not None else None - ) - params: dict[str, Any] = { - "id": event_id, - "session_id": session_id, - "app_name": app_name, - "user_id": user_id, - "author": author, - "actions": bytes_to_spanner(actions), - "long_running_tool_ids_json": long_running_serialized, - "timestamp": datetime.now(timezone.utc), - "content": content_serialized, - "grounding_metadata": grounding_serialized, - "custom_metadata": custom_serialized, - "invocation_id": kwargs.get("invocation_id"), - "partial": kwargs.get("partial"), - "turn_complete": kwargs.get("turn_complete"), - "interrupted": kwargs.get("interrupted"), - "error_code": kwargs.get("error_code"), - "error_message": kwargs.get("error_message"), - } - branch = kwargs.get("branch") - columns = [ - "id", - "session_id", - "app_name", - "user_id", - "author", - "actions", - "long_running_tool_ids_json", - "timestamp", - "content", - "grounding_metadata", - "custom_metadata", - "invocation_id", - "partial", - "turn_complete", - "interrupted", - "error_code", - "error_message", - ] - values = [ - "@id", - "@session_id", - "@app_name", - "@user_id", - "@author", - "@actions", - "@long_running_tool_ids_json", - "PENDING_COMMIT_TIMESTAMP()", - "@content", - "@grounding_metadata", - "@custom_metadata", - "@invocation_id", - "@partial", - "@turn_complete", - "@interrupted", - "@error_code", - "@error_message", - ] - has_branch = branch is not None - if has_branch: - params["branch"] = branch - columns.append("branch") - values.append("@branch") + async def delete_session(self, session_id: str) -> None: + """Delete session and associated events.""" + await async_(self._delete_session)(session_id) - sql = f""" - INSERT INTO {self._events_table} ({", ".join(columns)}) - VALUES ({", ".join(values)}) + def _append_event_and_update_state( + self, event_record: "EventRecord", session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically insert an event and update session state in one transaction. + + Both the event INSERT and the session state UPDATE execute within a single + Spanner transaction so they succeed or fail together. + + Args: + event_record: Event record to store. + session_id: Session whose state should be updated. + state: Post-append durable state snapshot. + """ + event_params: dict[str, Any] = { + "session_id": event_record["session_id"], + "invocation_id": event_record["invocation_id"], + "author": event_record["author"], + "timestamp": event_record["timestamp"], + "event_json": to_json(event_record["event_json"]), + } + insert_sql = f""" + INSERT INTO {self._events_table} (session_id, invocation_id, author, timestamp, event_json) + VALUES (@session_id, @invocation_id, @author, PENDING_COMMIT_TIMESTAMP(), @event_json) """ - self._run_write([(sql, params, self._event_param_types(has_branch))]) - record: EventRecord = { - "id": event_id, - "session_id": session_id, - "app_name": app_name, - "user_id": user_id, - "author": author or "", - "actions": actions or b"", - "long_running_tool_ids_json": long_running_serialized, - "branch": branch, - "timestamp": params["timestamp"], - "content": from_json(content_serialized) if content_serialized else None, - "grounding_metadata": from_json(grounding_serialized) if grounding_serialized else None, - "custom_metadata": from_json(custom_serialized) if custom_serialized else None, - "invocation_id": kwargs.get("invocation_id", ""), - "partial": kwargs.get("partial"), - "turn_complete": kwargs.get("turn_complete"), - "interrupted": kwargs.get("interrupted"), - "error_code": kwargs.get("error_code"), - "error_message": kwargs.get("error_message"), + json_type = _json_param_type() + state_params: dict[str, Any] = {"id": session_id, "state": to_json(state)} + update_sql = f""" + UPDATE {self._session_table} + SET state = @state, update_time = PENDING_COMMIT_TIMESTAMP() + WHERE id = @id + """ + if self._shard_count > 1: + update_sql = f"{update_sql} AND shard_id = MOD(FARM_FINGERPRINT(@id), {self._shard_count})" + + self._run_write([ + (insert_sql, event_params, self._event_param_types()), + (update_sql, state_params, {"id": SPANNER_PARAM_TYPES.STRING, "state": json_type}), + ]) + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state.""" + await async_(self._append_event_and_update_state)(event_record, session_id, state) + + def _insert_event(self, event_record: "EventRecord") -> None: + event_params: dict[str, Any] = { + "session_id": event_record["session_id"], + "invocation_id": event_record["invocation_id"], + "author": event_record["author"], + "timestamp": event_record["timestamp"], + "event_json": to_json(event_record["event_json"]), } - return record + insert_sql = f""" + INSERT INTO {self._events_table} (session_id, invocation_id, author, timestamp, event_json) + VALUES (@session_id, @invocation_id, @author, PENDING_COMMIT_TIMESTAMP(), @event_json) + """ + self._run_write([(insert_sql, event_params, self._event_param_types())]) - def list_events(self, session_id: str) -> "list[EventRecord]": + def _get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": sql = f""" - SELECT id, session_id, app_name, user_id, author, actions, long_running_tool_ids_json, branch, - timestamp, content, grounding_metadata, custom_metadata, invocation_id, partial, - turn_complete, interrupted, error_code, error_message + SELECT session_id, invocation_id, author, timestamp, event_json FROM {self._events_table} WHERE session_id = @session_id """ if self._shard_count > 1: sql = f"{sql} AND shard_id = MOD(FARM_FINGERPRINT(@session_id), {self._shard_count})" + params: dict[str, Any] = {"session_id": session_id} + types: dict[str, Any] = {"session_id": SPANNER_PARAM_TYPES.STRING} + if after_timestamp is not None: + sql = f"{sql} AND timestamp > @after_timestamp" + params["after_timestamp"] = after_timestamp + types["after_timestamp"] = SPANNER_PARAM_TYPES.TIMESTAMP sql = f"{sql} ORDER BY timestamp ASC" - params = {"session_id": session_id} - types = {"session_id": SPANNER_PARAM_TYPES.STRING} + if limit is not None: + sql = f"{sql} LIMIT @limit" + params["limit"] = limit + types["limit"] = SPANNER_PARAM_TYPES.INT64 rows = self._run_read(sql, params, types) return [ { - "id": row[0], - "session_id": row[1], - "app_name": row[2], - "user_id": row[3], - "invocation_id": row[12] or "", - "author": row[4] or "", - "actions": spanner_to_bytes(row[5]) or b"", - "long_running_tool_ids_json": row[6], - "branch": row[7], - "timestamp": row[8], - "content": self._decode_json(row[9]), - "grounding_metadata": self._decode_json(row[10]), - "custom_metadata": self._decode_json(row[11]), - "partial": row[13], - "turn_complete": row[14], - "interrupted": row[15], - "error_code": row[16], - "error_message": row[17], + "session_id": row[0], + "invocation_id": row[1] or "", + "author": row[2] or "", + "timestamp": row[3], + "event_json": row[4], } for row in rows ] - def create_tables(self) -> None: + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session.""" + return await async_(self._get_events)(session_id, after_timestamp, limit) + + def _append_event(self, event_record: EventRecord) -> None: + """Synchronous implementation of append_event.""" + self._insert_event(event_record) + + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + await async_(self._append_event)(event_record) + + def _create_tables(self) -> None: database = self._database() existing_tables = {t.table_id for t in database.list_tables()} # type: ignore[no-untyped-call] ddl_statements: list[str] = [] if self._session_table not in existing_tables: - ddl_statements.append(self._get_create_sessions_table_sql()) + ddl_statements.append(run_(self._get_create_sessions_table_sql)()) if self._events_table not in existing_tables: - ddl_statements.append(self._get_create_events_table_sql()) + ddl_statements.append(run_(self._get_create_events_table_sql)()) if ddl_statements: database.update_ddl(ddl_statements).result(300) # type: ignore[no-untyped-call] - def _get_create_sessions_table_sql(self) -> str: + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + async def _get_create_sessions_table_sql(self) -> str: owner_line = "" if self._owner_id_column_ddl: owner_line = f",\n {self._owner_id_column_ddl}" @@ -414,35 +379,22 @@ def _get_create_sessions_table_sql(self) -> str: ) {pk}{options} """ - def _get_create_events_table_sql(self) -> str: + async def _get_create_events_table_sql(self) -> str: shard_column = "" - pk = "PRIMARY KEY (session_id, timestamp, id)" + pk = "PRIMARY KEY (session_id, timestamp)" if self._shard_count > 1: shard_column = f",\n shard_id INT64 AS (MOD(FARM_FINGERPRINT(session_id), {self._shard_count})) STORED" - pk = "PRIMARY KEY (shard_id, session_id, timestamp, id)" + pk = "PRIMARY KEY (shard_id, session_id, timestamp)" options = "" if self._events_table_options: options = f"\nOPTIONS ({self._events_table_options})" return f""" CREATE TABLE {self._events_table} ( - id STRING(128) NOT NULL, session_id STRING(128) NOT NULL, - app_name STRING(128) NOT NULL, - user_id STRING(128) NOT NULL, - invocation_id STRING(128), - author STRING(64), - actions BYTES(MAX), - long_running_tool_ids_json JSON, - branch STRING(64), + invocation_id STRING(256) NOT NULL, + author STRING(128) NOT NULL, timestamp TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true), - content JSON, - grounding_metadata JSON, - custom_metadata JSON, - partial BOOL, - turn_complete BOOL, - interrupted BOOL, - error_code STRING(64), - error_message STRING(255){shard_column} + event_json JSON NOT NULL{shard_column} ) {pk}{options} """ @@ -479,7 +431,7 @@ def execute_sql( ) -> Iterable[Any]: ... -class SpannerSyncADKMemoryStore(BaseSyncADKMemoryStore[SpannerSyncConfig]): +class SpannerSyncADKMemoryStore(BaseAsyncADKMemoryStore[SpannerSyncConfig]): """Spanner ADK memory store backed by synchronous Spanner client.""" connector_name: ClassVar[str] = "spanner" @@ -532,7 +484,7 @@ def _decode_json(self, raw: Any) -> Any: return from_json(raw) return raw - def create_tables(self) -> None: + def _create_tables(self) -> None: if not self._enabled: return @@ -541,12 +493,16 @@ def create_tables(self) -> None: ddl_statements: list[str] = [] if self._memory_table not in existing_tables: - ddl_statements.extend(self._get_create_memory_table_sql()) + ddl_statements.extend(run_(self._get_create_memory_table_sql)()) if ddl_statements: database.update_ddl(ddl_statements).result(300) # type: ignore[no-untyped-call] - def _get_create_memory_table_sql(self) -> "list[str]": + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + async def _get_create_memory_table_sql(self) -> "list[str]": owner_line = "" if self._owner_id_column_ddl: owner_line = f",\n {self._owner_id_column_ddl}" @@ -590,7 +546,23 @@ def _get_create_memory_table_sql(self) -> "list[str]": statements.append(fts_index) return statements - def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + def _get_drop_memory_table_sql(self) -> "list[str]": + """Get SQL to drop the memory table and its indexes. + + Returns: + List of SQL statements to drop the memory table and associated indexes. + """ + statements: list[str] = [] + if self._use_fts: + statements.append(f"DROP SEARCH INDEX idx_{self._memory_table}_fts") + statements.extend([ + f"DROP INDEX idx_{self._memory_table}_session", + f"DROP INDEX idx_{self._memory_table}_app_user_time", + f"DROP TABLE {self._memory_table}", + ]) + return statements + + def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) @@ -639,12 +611,16 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object self._run_write(statements) return inserted_count + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + def _event_exists(self, event_id: str) -> bool: sql = f"SELECT event_id FROM {self._memory_table} WHERE event_id = @event_id LIMIT 1" rows = self._run_read(sql, {"event_id": event_id}, {"event_id": SPANNER_PARAM_TYPES.STRING}) return bool(rows) - def search_entries( + def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": if not self._enabled: @@ -657,6 +633,12 @@ def search_entries( return self._search_entries_fts(query, app_name, user_id, effective_limit) return self._search_entries_simple(query, app_name, user_id, effective_limit) + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": sql = f""" SELECT id, session_id, app_name, user_id, event_id, author, @@ -700,19 +682,27 @@ def _search_entries_simple(self, query: str, app_name: str, user_id: str, limit: rows = self._run_read(sql, params, types) return self._rows_to_records(rows) - def delete_entries_by_session(self, session_id: str) -> int: + def _delete_entries_by_session(self, session_id: str) -> int: sql = f"DELETE FROM {self._memory_table} WHERE session_id = @session_id" params = {"session_id": session_id} types = {"session_id": SPANNER_PARAM_TYPES.STRING} return self._execute_update(sql, params, types) - def delete_entries_older_than(self, days: int) -> int: + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + def _delete_entries_older_than(self, days: int) -> int: cutoff = datetime.now(timezone.utc) - timedelta(days=days) sql = f"DELETE FROM {self._memory_table} WHERE inserted_at < @cutoff" params = {"cutoff": cutoff} types = {"cutoff": SPANNER_PARAM_TYPES.TIMESTAMP} return self._execute_update(sql, params, types) + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) + def _rows_to_records(self, rows: "list[Any]") -> "list[MemoryRecord]": return [ { diff --git a/sqlspec/adapters/sqlite/adk/store.py b/sqlspec/adapters/sqlite/adk/store.py index ee3376d9e..bf4e51e52 100644 --- a/sqlspec/adapters/sqlite/adk/store.py +++ b/sqlspec/adapters/sqlite/adk/store.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json from sqlspec.utils.sync_tools import async_, run_ @@ -58,34 +58,6 @@ def _julian_to_datetime(julian: float) -> datetime: return datetime.fromtimestamp(timestamp, tz=timezone.utc) -def _to_sqlite_bool(value: "bool | None") -> "int | None": - """Convert Python bool to SQLite INTEGER. - - Args: - value: Boolean value or None. - - Returns: - 1 for True, 0 for False, None for None. - """ - if value is None: - return None - return 1 if value else 0 - - -def _from_sqlite_bool(value: "int | None") -> "bool | None": - """Convert SQLite INTEGER to Python bool. - - Args: - value: Integer value (0/1) or None. - - Returns: - True for 1, False for 0, None for None. - """ - if value is None: - return None - return bool(value) - - class SqliteADKStore(BaseAsyncADKStore["SqliteConfig"]): """SQLite ADK store using synchronous SQLite driver. @@ -95,10 +67,11 @@ class SqliteADKStore(BaseAsyncADKStore["SqliteConfig"]): Provides: - Session state management with JSON storage (as TEXT) - - Event history tracking with BLOB-serialized actions + - Event history tracking with full-event JSON storage - Julian Day timestamps (REAL) for efficient date operations - Foreign key constraints with cascade delete - - Efficient upserts using INSERT OR REPLACE + - Atomic event+state writes via append_event_and_update_state + - PRAGMA optimization profile for file-based databases Args: config: SqliteConfig instance with extension_config["adk"] settings. @@ -122,9 +95,8 @@ class SqliteADKStore(BaseAsyncADKStore["SqliteConfig"]): Notes: - JSON stored as TEXT with SQLSpec serializers (msgspec/orjson/stdlib) - - BOOLEAN as INTEGER (0/1, with None for NULL) - Timestamps as REAL (Julian day: julianday('now')) - - BLOB for pre-serialized actions from Google ADK + - Full event stored as JSON TEXT in event_data column - PRAGMA foreign_keys = ON (enable per connection) - Configuration is read from config.extension_config["adk"] """ @@ -145,6 +117,22 @@ def __init__(self, config: "SqliteConfig") -> None: """ super().__init__(config) + def _apply_pragmas(self, connection: Any) -> None: + """Apply PRAGMA optimization profile for this connection. + + Args: + connection: SQLite connection. + + Notes: + Enables foreign keys and applies performance PRAGMAs. + For file-based databases, adds cache_size, mmap_size, + and journal_size_limit optimizations. + """ + connection.execute("PRAGMA foreign_keys = ON") + connection.execute("PRAGMA cache_size = -64000") + connection.execute("PRAGMA mmap_size = 30000000") + connection.execute("PRAGMA journal_size_limit = 67108864") + async def _get_create_sessions_table_sql(self) -> str: """Get SQLite CREATE TABLE SQL for sessions. @@ -184,9 +172,8 @@ async def _get_create_events_table_sql(self) -> str: SQL statement to create adk_events table with indexes. Notes: - - TEXT for IDs, strings, and JSON content - - BLOB for pickled actions - - INTEGER for booleans (0/1/NULL) + - TEXT for IDs and indexed scalars + - TEXT for full event JSON (event_data) - REAL for Julian Day timestamps - Foreign key to sessions with CASCADE delete - Index on (session_id, timestamp ASC) @@ -195,22 +182,10 @@ async def _get_create_events_table_sql(self) -> str: CREATE TABLE IF NOT EXISTS {self._events_table} ( id TEXT PRIMARY KEY, session_id TEXT NOT NULL, - app_name TEXT NOT NULL, - user_id TEXT NOT NULL, - invocation_id TEXT NOT NULL, - author TEXT NOT NULL, - actions BLOB NOT NULL, - long_running_tool_ids_json TEXT, - branch TEXT, + invocation_id TEXT, + author TEXT, timestamp REAL NOT NULL, - content TEXT, - grounding_metadata TEXT, - custom_metadata TEXT, - partial INTEGER, - turn_complete INTEGER, - interrupted INTEGER, - error_code TEXT, - error_message TEXT, + event_data TEXT NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session @@ -229,21 +204,10 @@ def _get_drop_tables_sql(self) -> "list[str]": """ return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - def _enable_foreign_keys(self, connection: Any) -> None: - """Enable foreign key constraints for this connection. - - Args: - connection: SQLite connection. - - Notes: - SQLite requires PRAGMA foreign_keys = ON per connection. - """ - connection.execute("PRAGMA foreign_keys = ON") - def _create_tables(self) -> None: """Synchronous implementation of create_tables.""" with self._config.provide_session() as driver: - driver.connection.execute("PRAGMA foreign_keys = ON") + self._apply_pragmas(driver.connection) driver.execute_script(run_(self._get_create_sessions_table_sql)()) driver.execute_script(run_(self._get_create_events_table_sql)()) @@ -257,7 +221,7 @@ def _create_session( """Synchronous implementation of create_session.""" now = datetime.now(timezone.utc) now_julian = _datetime_to_julian(now) - state_json = to_json(state) if state else None + state_json = to_json(state) params: tuple[Any, ...] if self._owner_id_column_name: @@ -275,7 +239,7 @@ def _create_session( params = (session_id, app_name, user_id, state_json, now_julian, now_julian) with self._config.provide_connection() as conn: - self._enable_foreign_keys(conn) + self._apply_pragmas(conn) conn.execute(sql, params) conn.commit() @@ -300,7 +264,7 @@ async def create_session( Notes: Uses Julian Day for create_time and update_time. - State is JSON-serialized before insertion. + State is always JSON-serialized (empty dict becomes '{}', never NULL). If owner_id_column is configured, owner_id is inserted into that column. """ return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) @@ -314,7 +278,7 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": """ with self._config.provide_connection() as conn: - self._enable_foreign_keys(conn) + self._apply_pragmas(conn) cursor = conn.execute(sql, (session_id,)) row = cursor.fetchone() @@ -348,7 +312,7 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: """Synchronous implementation of update_session_state.""" now_julian = _datetime_to_julian(datetime.now(timezone.utc)) - state_json = to_json(state) if state else None + state_json = to_json(state) sql = f""" UPDATE {self._session_table} @@ -357,7 +321,7 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non """ with self._config.provide_connection() as conn: - self._enable_foreign_keys(conn) + self._apply_pragmas(conn) conn.execute(sql, (state_json, now_julian, session_id)) conn.commit() @@ -371,6 +335,7 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") - Notes: This replaces the entire state dictionary. Updates update_time to current Julian Day. + Empty dict is serialized as '{}', never NULL. """ await async_(self._update_session_state)(session_id, state) @@ -394,7 +359,7 @@ def _list_sessions(self, app_name: str, user_id: "str | None") -> "list[SessionR params = (app_name, user_id) with self._config.provide_connection() as conn: - self._enable_foreign_keys(conn) + self._apply_pragmas(conn) cursor = conn.execute(sql, params) rows = cursor.fetchall() @@ -430,7 +395,7 @@ def _delete_session(self, session_id: str) -> None: sql = f"DELETE FROM {self._session_table} WHERE id = ?" with self._config.provide_connection() as conn: - self._enable_foreign_keys(conn) + self._apply_pragmas(conn) conn.execute(sql, (session_id,)) conn.commit() @@ -448,53 +413,29 @@ async def delete_session(self, session_id: str) -> None: def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" timestamp_julian = _datetime_to_julian(event_record["timestamp"]) - - content_json = to_json(event_record.get("content")) if event_record.get("content") else None - grounding_metadata_json = ( - to_json(event_record.get("grounding_metadata")) if event_record.get("grounding_metadata") else None - ) - custom_metadata_json = ( - to_json(event_record.get("custom_metadata")) if event_record.get("custom_metadata") else None - ) - - partial_int = _to_sqlite_bool(event_record.get("partial")) - turn_complete_int = _to_sqlite_bool(event_record.get("turn_complete")) - interrupted_int = _to_sqlite_bool(event_record.get("interrupted")) + event_data_json = to_json(event_record["event_json"]) sql = f""" INSERT INTO {self._events_table} ( - id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message - ) VALUES ( - ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? - ) + id, session_id, invocation_id, author, timestamp, event_data + ) VALUES (?, ?, ?, ?, ?, ?) """ + import uuid + + event_id = str(uuid.uuid4()) + with self._config.provide_connection() as conn: - self._enable_foreign_keys(conn) + self._apply_pragmas(conn) conn.execute( sql, ( - event_record["id"], + event_id, event_record["session_id"], - event_record["app_name"], - event_record["user_id"], event_record["invocation_id"], event_record["author"], - event_record["actions"], - event_record.get("long_running_tool_ids_json"), - event_record.get("branch"), timestamp_julian, - content_json, - grounding_metadata_json, - custom_metadata_json, - partial_int, - turn_complete_int, - interrupted_int, - event_record.get("error_code"), - event_record.get("error_message"), + event_data_json, ), ) conn.commit() @@ -503,15 +444,71 @@ async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session. Args: - event_record: Event record to store. + event_record: Event record with 5 keys: session_id, invocation_id, + author, timestamp, event_json. Notes: Uses Julian Day for timestamp. - JSON fields are serialized to TEXT. - Boolean fields converted to INTEGER (0/1/NULL). + event_json dict is serialized to TEXT as event_data column. """ await async_(self._append_event)(event_record) + def _append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Synchronous implementation of append_event_and_update_state.""" + import uuid + + timestamp_julian = _datetime_to_julian(event_record["timestamp"]) + event_data_json = to_json(event_record["event_json"]) + now_julian = _datetime_to_julian(datetime.now(timezone.utc)) + state_json = to_json(state) + event_id = str(uuid.uuid4()) + + insert_sql = f""" + INSERT INTO {self._events_table} ( + id, session_id, invocation_id, author, timestamp, event_data + ) VALUES (?, ?, ?, ?, ?, ?) + """ + + update_sql = f""" + UPDATE {self._session_table} + SET state = ?, update_time = ? + WHERE id = ? + """ + + with self._config.provide_connection() as conn: + self._apply_pragmas(conn) + conn.execute( + insert_sql, + ( + event_id, + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + timestamp_julian, + event_data_json, + ), + ) + conn.execute(update_sql, (state_json, now_julian, session_id)) + conn.commit() + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state. + + Inserts the event and updates the session state + update_time in a + single transaction. Both operations succeed or fail together. + + Args: + event_record: Event record to store. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot (temp: keys already + stripped by the service layer). + """ + await async_(self._append_event_and_update_state)(event_record, session_id, state) + def _get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": @@ -527,40 +524,24 @@ def _get_events( limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" - SELECT id, session_id, app_name, user_id, invocation_id, author, actions, - long_running_tool_ids_json, branch, timestamp, content, - grounding_metadata, custom_metadata, partial, turn_complete, - interrupted, error_code, error_message + SELECT id, session_id, invocation_id, author, timestamp, event_data FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} """ with self._config.provide_connection() as conn: - self._enable_foreign_keys(conn) + self._apply_pragmas(conn) cursor = conn.execute(sql, params) rows = cursor.fetchall() return [ EventRecord( - id=row[0], session_id=row[1], - app_name=row[2], - user_id=row[3], - invocation_id=row[4], - author=row[5], - actions=bytes(row[6]), - long_running_tool_ids_json=row[7], - branch=row[8], - timestamp=_julian_to_datetime(row[9]), - content=from_json(row[10]) if row[10] else None, - grounding_metadata=from_json(row[11]) if row[11] else None, - custom_metadata=from_json(row[12]) if row[12] else None, - partial=_from_sqlite_bool(row[13]), - turn_complete=_from_sqlite_bool(row[14]), - interrupted=_from_sqlite_bool(row[15]), - error_code=row[16], - error_message=row[17], + invocation_id=row[2], + author=row[3], + timestamp=_julian_to_datetime(row[4]), + event_json=from_json(row[5]) if row[5] else {}, ) for row in rows ] @@ -580,13 +561,12 @@ async def get_events( Notes: Uses index on (session_id, timestamp ASC). - Parses JSON fields and converts BLOB actions to bytes. - Converts INTEGER booleans back to bool/None. + Parses event_data TEXT back to dict for event_json field. """ return await async_(self._get_events)(session_id, after_timestamp, limit) -class SqliteADKMemoryStore(BaseSyncADKMemoryStore["SqliteConfig"]): +class SqliteADKMemoryStore(BaseAsyncADKMemoryStore["SqliteConfig"]): """SQLite ADK memory store using synchronous SQLite driver. Implements memory entry storage for Google Agent Development Kit @@ -645,7 +625,7 @@ def __init__(self, config: "SqliteConfig") -> None: """ super().__init__(config) - def _get_create_memory_table_sql(self) -> str: + async def _get_create_memory_table_sql(self) -> str: """Get SQLite CREATE TABLE SQL for memory entries. Returns: @@ -737,7 +717,7 @@ def _enable_foreign_keys(self, connection: Any) -> None: """ connection.execute("PRAGMA foreign_keys = ON") - def create_tables(self) -> None: + def _create_tables(self) -> None: """Create the memory table and indexes if they don't exist. Skips table creation if memory store is disabled. @@ -747,9 +727,13 @@ def create_tables(self) -> None: with self._config.provide_session() as driver: self._enable_foreign_keys(driver.connection) - driver.execute_script(self._get_create_memory_table_sql()) + driver.execute_script(run_(self._get_create_memory_table_sql)()) + + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() - def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: """Bulk insert memory entries with deduplication. Uses INSERT OR IGNORE to skip duplicates based on event_id @@ -833,7 +817,11 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object return inserted_count - def search_entries( + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + + def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": """Search memory entries by text query. @@ -863,6 +851,12 @@ def search_entries( logger.warning("FTS search failed; falling back to simple search: %s", exc) return self._search_entries_simple(query, app_name, user_id, effective_limit) + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": sql = f""" SELECT m.id, m.session_id, m.app_name, m.user_id, m.event_id, m.author, @@ -915,7 +909,7 @@ def _fetch_records(self, sql: str, params: "tuple[Any, ...]") -> "list[MemoryRec for row in rows ] - def delete_entries_by_session(self, session_id: str) -> int: + def _delete_entries_by_session(self, session_id: str) -> int: """Delete all memory entries for a specific session. Args: @@ -934,7 +928,11 @@ def delete_entries_by_session(self, session_id: str) -> int: return deleted_count - def delete_entries_older_than(self, days: int) -> int: + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + def _delete_entries_older_than(self, days: int) -> int: """Delete memory entries older than specified days. Used for TTL cleanup operations. @@ -956,3 +954,7 @@ def delete_entries_older_than(self, days: int) -> int: conn.commit() return deleted_count + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) diff --git a/sqlspec/config.py b/sqlspec/config.py index e3a96c35e..63d1ae2dd 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -35,7 +35,11 @@ __all__ = ( + "ADKCompressionConfig", "ADKConfig", + "ADKPartitionConfig", + "ADKRetentionConfig", + "ADKSqliteOptimizationConfig", "AsyncConfigT", "AsyncDatabaseConfig", "ConfigT", @@ -374,6 +378,178 @@ class FastAPIConfig(StarletteConfig): """ +class ADKPartitionConfig(TypedDict): + """Configuration for table partitioning and sharding strategies. + + Controls how ADK tables are partitioned across backends that support it. + Backends without native partitioning support ignore these settings. + + Example: + extension_config={ + "adk": { + "partitioning": { + "strategy": "range", + "partition_key": "created_at", + "interval": "month", + } + } + } + """ + + strategy: NotRequired[Literal["range", "list", "hash"]] + """Partitioning strategy. Default: None (no partitioning). + + - range: Partition by range of values (e.g., time-based) + - list: Partition by discrete value lists + - hash: Partition by hash of the partition key + + Supported by: PostgreSQL, MySQL 8+, Oracle, BigQuery, Spanner. + Ignored by: SQLite, DuckDB. + """ + + partition_key: NotRequired[str] + """Column name used as the partition key. + + For range partitioning with time-based data, this is typically a timestamp column + like 'created_at'. For hash partitioning, this is typically the primary key. + """ + + interval: NotRequired[str] + """Partition interval for range partitioning. + + Examples: 'day', 'week', 'month', 'year'. + Only meaningful when strategy is 'range'. + """ + + +class ADKRetentionConfig(TypedDict): + """Configuration for data retention and TTL policies. + + Controls automatic cleanup of expired data. Backends with native TTL support + (CockroachDB Row-Level TTL, Spanner Row Deletion Policy) use database-level + enforcement. Others fall back to application-level sweep queries. + + Example: + extension_config={ + "adk": { + "retention": { + "session_ttl_seconds": 86400, + "event_ttl_seconds": 604800, + "memory_ttl_seconds": 0, + } + } + } + """ + + session_ttl_seconds: NotRequired[int] + """TTL for session records in seconds. Default: 0 (no expiry). + + When set, sessions older than this threshold are eligible for cleanup. + Backends with native TTL (CockroachDB, Spanner) enforce this at the database level. + Others require application-level cleanup via periodic sweep. + """ + + event_ttl_seconds: NotRequired[int] + """TTL for event records in seconds. Default: 0 (no expiry). + + When set, events older than this threshold are eligible for cleanup. + """ + + memory_ttl_seconds: NotRequired[int] + """TTL for memory entries in seconds. Default: 0 (no expiry). + + When set, memory entries older than this threshold are eligible for cleanup. + """ + + sweep_interval_seconds: NotRequired[int] + """Interval between application-level cleanup sweeps in seconds. Default: 3600 (1 hour). + + Only used when the backend does not support native TTL enforcement. + Set to 0 to disable automatic sweeps (manual cleanup only). + """ + + +class ADKCompressionConfig(TypedDict): + """Configuration for table-level compression. + + Controls compression of ADK table storage. Support and algorithms vary by backend. + + Example: + extension_config={ + "adk": { + "compression": { + "enabled": True, + "algorithm": "zstd", + } + } + } + """ + + enabled: NotRequired[bool] + """Enable table compression. Default: False. + + When True, adapters that support table-level compression will apply it + during table creation. + """ + + algorithm: NotRequired[str] + """Compression algorithm name. Backend-specific. + + Examples: + - PostgreSQL (with TOAST): 'pglz', 'lz4' (PG14+) + - MySQL/InnoDB: 'zlib' + - Oracle: 'basic', 'oltp', 'query_high', 'archive_high' + - DuckDB: 'zstd', 'snappy' + + When omitted, the backend default is used. + """ + + level: NotRequired[int] + """Compression level (where supported). Higher levels trade CPU for space savings. + + Valid ranges depend on the algorithm and backend. + """ + + +class ADKSqliteOptimizationConfig(TypedDict): + """SQLite-specific PRAGMA optimization settings. + + Controls SQLite performance tuning parameters applied at connection time. + These settings are ignored by non-SQLite adapters. + + Example: + extension_config={ + "adk": { + "sqlite_optimization": { + "cache_size": -64000, + "mmap_size": 31457280, + "journal_size_limit": 67108864, + } + } + } + """ + + cache_size: NotRequired[int] + """SQLite page cache size. Default: -64000 (64 MB, negative means KiB). + + Larger caches reduce disk I/O for read-heavy workloads. + Negative values specify size in KiB; positive values specify page count. + """ + + mmap_size: NotRequired[int] + """SQLite memory-mapped I/O size in bytes. Default: 31457280 (30 MB). + + Enables memory-mapped I/O for faster reads. Set to 0 to disable. + """ + + journal_size_limit: NotRequired[int] + """SQLite journal file size limit in bytes. Default: 67108864 (64 MB). + + Limits the size of the WAL or rollback journal file. + Prevents unbounded journal growth in write-heavy workloads. + """ + + class ADKConfig(TypedDict): """Configuration options for ADK session and memory store extension. @@ -460,6 +636,27 @@ class ADKConfig(TypedDict): "tenant_acme_memories" """ + artifact_table: NotRequired[str] + """Name of the artifact versions table. Default: 'adk_artifact_versions' + + Examples: + "agent_artifacts" + "my_app_artifact_versions" + """ + + artifact_storage_uri: NotRequired[str] + """Base URI for artifact content storage. + + Points to a ``sqlspec/storage/`` backend where artifact binary content + is stored. Can be a direct URI (``s3://bucket/path``, ``file:///path``) + or a registered alias in the storage registry. + + Examples: + "s3://my-bucket/adk-artifacts/" + "file:///var/data/artifacts/" + "gcs://my-gcs-bucket/artifacts/" + """ + memory_use_fts: NotRequired[bool] """Enable full-text search when supported. Default: False. @@ -585,6 +782,81 @@ class ADKConfig(TypedDict): expires_index_options: NotRequired[str] """Adapter-specific options for the expires/index used in ADK stores.""" + # --- Capability-based configuration (Chapter 2: schema-capability-config) --- + + fts_language: NotRequired[str] + """Language configuration for full-text search indexing. Default: 'english'. + + Controls the language dictionary/stemmer used by FTS implementations: + - PostgreSQL: to_tsvector/to_tsquery language parameter + - SQLite FTS5: tokenizer language for unicode61/porter + - MySQL: FULLTEXT parser language (with ngram for CJK on 5.7.6+) + - Oracle: CTXSYS.CONTEXT lexer language + - Spanner: TOKENIZE_FULLTEXT language parameter + - DuckDB: FTS stemmer language + + Only takes effect when ``memory_use_fts`` is True. + + Common values: 'english', 'simple', 'german', 'french', 'spanish', + 'portuguese', 'italian', 'dutch', 'russian', 'chinese', 'japanese', 'korean'. + + Notes: + Available languages vary by backend. Backends that do not support the + specified language will fall back to 'simple' or 'english'. + """ + + schema_version: NotRequired[int] + """Explicit schema version for ADK tables. Default: None (auto-detect). + + When set, locks the ADK schema to a specific version. This is useful for: + - Preventing automatic schema upgrades in production + - Pinning to a known-good schema during testing + - Coordinating schema changes across multiple application instances + + When None, the ADK extension auto-detects the current schema version + and applies any pending upgrades during initialization. + + Notes: + Schema versions are monotonically increasing integers managed by + the ADK extension migration system. Setting this to a version + lower than the current database schema will raise a configuration + error at startup. + """ + + partitioning: NotRequired[ADKPartitionConfig] + """Table partitioning configuration. Default: None (no partitioning). + + Controls how ADK tables are partitioned for improved query performance + and data management at scale. See ``ADKPartitionConfig`` for options. + + Supported by: PostgreSQL, MySQL 8+, Oracle, BigQuery, Spanner. + Ignored by: SQLite, DuckDB. + """ + + retention: NotRequired[ADKRetentionConfig] + """Data retention and TTL configuration. Default: None (no automatic cleanup). + + Controls automatic expiry and cleanup of old session, event, and memory data. + See ``ADKRetentionConfig`` for options. + + Backends with native TTL (CockroachDB, Spanner) use database-level enforcement. + Others fall back to application-level sweep queries. + """ + + compression: NotRequired[ADKCompressionConfig] + """Table compression configuration. Default: None (no compression). + + Controls table-level compression for ADK tables. + See ``ADKCompressionConfig`` for options. + """ + + sqlite_optimization: NotRequired[ADKSqliteOptimizationConfig] + """SQLite-specific PRAGMA optimization settings. Default: None (SQLite defaults). + + Controls SQLite performance tuning parameters. Ignored by non-SQLite adapters. + See ``ADKSqliteOptimizationConfig`` for options. + """ + class EventsConfig(TypedDict): """Configuration options for the events extension. diff --git a/sqlspec/core/splitter.py b/sqlspec/core/splitter.py index 45908cf27..0f5414f62 100644 --- a/sqlspec/core/splitter.py +++ b/sqlspec/core/splitter.py @@ -623,6 +623,8 @@ def statement_terminators(self) -> "set[str]": _pattern_cache: LRUCache | None = None _result_cache: LRUCache | None = None _cache_lock = threading.Lock() +_unknown_dialect_warning_lock = threading.Lock() +_warned_unknown_dialects: set[str] = set() def _get_pattern_cache() -> LRUCache: @@ -653,6 +655,16 @@ def _get_result_cache() -> LRUCache: return _result_cache +def _warn_unknown_dialect_once(dialect: "str | None") -> None: + """Emit the generic splitter fallback warning once per dialect.""" + key = "" if dialect is None else dialect.lower() + with _unknown_dialect_warning_lock: + if key in _warned_unknown_dialects: + return + _warned_unknown_dialects.add(key) + logger.warning("Unknown dialect '%s', using generic SQL splitter", dialect) + + @mypyc_attr(allow_interpreted_subclasses=False) class StatementSplitter: """SQL script splitter with caching and dialect support.""" @@ -933,7 +945,7 @@ def split_sql_script(script: str, dialect: str | None = None, strip_trailing_ter config = dialect_configs.get(dialect.lower()) if not config: - logger.warning("Unknown dialect '%s', using generic SQL splitter", dialect) + _warn_unknown_dialect_once(dialect) config = GenericDialectConfig() splitter = StatementSplitter(config, strip_trailing_semicolon=strip_trailing_terminator) @@ -949,6 +961,8 @@ def clear_splitter_caches() -> None: result_cache = _get_result_cache() pattern_cache.clear() result_cache.clear() + with _unknown_dialect_warning_lock: + _warned_unknown_dialects.clear() def get_splitter_cache_stats() -> "dict[str, Any]": diff --git a/sqlspec/extensions/adk/__init__.py b/sqlspec/extensions/adk/__init__.py index c7877b1a5..7f6cfc586 100644 --- a/sqlspec/extensions/adk/__init__.py +++ b/sqlspec/extensions/adk/__init__.py @@ -1,20 +1,24 @@ -"""Google ADK session backend extension for SQLSpec. +"""Google ADK session, memory, and artifact backend extension for SQLSpec. -Provides session, event, and memory storage for Google Agent Development Kit using -SQLSpec database adapters. +Provides session, event, memory, and artifact storage for Google Agent Development Kit +using SQLSpec database adapters. Public API exports: - ADKConfig: TypedDict for extension config (type-safe configuration) - SQLSpecSessionService: Main service class implementing BaseSessionService - SQLSpecMemoryService: Main async service class implementing BaseMemoryService - SQLSpecSyncMemoryService: Sync memory service for sync adapters + - SQLSpecArtifactService: Artifact service implementing BaseArtifactService - BaseAsyncADKStore: Base class for async database store implementations - BaseSyncADKStore: Base class for sync database store implementations - BaseAsyncADKMemoryStore: Base class for async memory store implementations - BaseSyncADKMemoryStore: Base class for sync memory store implementations + - BaseAsyncADKArtifactStore: Base class for async artifact metadata stores + - BaseSyncADKArtifactStore: Base class for sync artifact metadata stores - SessionRecord: TypedDict for session database records - EventRecord: TypedDict for event database records - MemoryRecord: TypedDict for memory database records + - ArtifactRecord: TypedDict for artifact metadata database records Example (with extension_config): from sqlspec.adapters.asyncpg import AsyncpgConfig @@ -45,6 +49,12 @@ from sqlspec.config import ADKConfig from sqlspec.extensions.adk._types import EventRecord, SessionRecord +from sqlspec.extensions.adk.artifact import ( + ArtifactRecord, + BaseAsyncADKArtifactStore, + BaseSyncADKArtifactStore, + SQLSpecArtifactService, +) from sqlspec.extensions.adk.memory import ( BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore, @@ -57,12 +67,16 @@ __all__ = ( "ADKConfig", + "ArtifactRecord", + "BaseAsyncADKArtifactStore", "BaseAsyncADKMemoryStore", "BaseAsyncADKStore", + "BaseSyncADKArtifactStore", "BaseSyncADKMemoryStore", "BaseSyncADKStore", "EventRecord", "MemoryRecord", + "SQLSpecArtifactService", "SQLSpecMemoryService", "SQLSpecSessionService", "SQLSpecSyncMemoryService", diff --git a/sqlspec/extensions/adk/_types.py b/sqlspec/extensions/adk/_types.py index 651431651..3f11b62f0 100644 --- a/sqlspec/extensions/adk/_types.py +++ b/sqlspec/extensions/adk/_types.py @@ -27,25 +27,15 @@ class SessionRecord(TypedDict): class EventRecord(TypedDict): """Database record for an event. - Represents the schema for events stored in the database. - Follows the ADK Event model plus session metadata. + Stores the full ADK Event as a single JSON blob (``event_json``) alongside + a small number of indexed scalar columns used for query filtering. + + This design eliminates column drift with upstream ADK: new Event fields are + automatically captured in ``event_json`` without schema changes. """ - id: str - app_name: str - user_id: str session_id: str invocation_id: str author: str - branch: "str | None" - actions: bytes - long_running_tool_ids_json: "str | None" timestamp: datetime - content: "dict[str, Any] | None" - grounding_metadata: "dict[str, Any] | None" - custom_metadata: "dict[str, Any] | None" - partial: "bool | None" - turn_complete: "bool | None" - interrupted: "bool | None" - error_code: "str | None" - error_message: "str | None" + event_json: "dict[str, Any]" diff --git a/sqlspec/extensions/adk/artifact/__init__.py b/sqlspec/extensions/adk/artifact/__init__.py new file mode 100644 index 000000000..36c3f8478 --- /dev/null +++ b/sqlspec/extensions/adk/artifact/__init__.py @@ -0,0 +1,57 @@ +"""Google ADK artifact service extension for SQLSpec. + +Provides artifact versioning and storage for Google Agent Development Kit +using SQLSpec database adapters for metadata and ``sqlspec/storage/`` backends +for content. + +Public API exports: + - SQLSpecArtifactService: Main service implementing BaseArtifactService + - BaseAsyncADKArtifactStore: Base class for async artifact metadata stores + - BaseSyncADKArtifactStore: Base class for sync artifact metadata stores + - ArtifactRecord: TypedDict for artifact metadata database records + +Example: + from sqlspec.adapters.asyncpg import AsyncpgConfig + from sqlspec.extensions.adk.artifact import SQLSpecArtifactService + + config = AsyncpgConfig( + connection_config={"dsn": "postgresql://..."}, + extension_config={ + "adk": { + "artifact_table": "adk_artifact_versions", + } + } + ) + + # Create an adapter-specific artifact store (e.g., AsyncpgADKArtifactStore) + # and ensure tables exist: + artifact_store = AsyncpgADKArtifactStore(config) + await artifact_store.ensure_table() + + # Create the service with a storage backend URI: + service = SQLSpecArtifactService( + store=artifact_store, + artifact_storage_uri="s3://my-bucket/adk-artifacts/", + ) + + # Save an artifact (returns version number starting from 0): + version = await service.save_artifact( + app_name="my_app", + user_id="user123", + filename="report.pdf", + artifact=part, + ) + + # Load artifact content: + loaded = await service.load_artifact( + app_name="my_app", + user_id="user123", + filename="report.pdf", + ) +""" + +from sqlspec.extensions.adk.artifact._types import ArtifactRecord +from sqlspec.extensions.adk.artifact.service import SQLSpecArtifactService +from sqlspec.extensions.adk.artifact.store import BaseAsyncADKArtifactStore, BaseSyncADKArtifactStore + +__all__ = ("ArtifactRecord", "BaseAsyncADKArtifactStore", "BaseSyncADKArtifactStore", "SQLSpecArtifactService") diff --git a/sqlspec/extensions/adk/artifact/_types.py b/sqlspec/extensions/adk/artifact/_types.py new file mode 100644 index 000000000..dcedffcf6 --- /dev/null +++ b/sqlspec/extensions/adk/artifact/_types.py @@ -0,0 +1,32 @@ +"""Type definitions for ADK artifact extension. + +These types define the database record structures for storing artifact metadata. +They are separate from the Pydantic models to keep mypyc compilation working. +""" + +from datetime import datetime +from typing import Any, TypedDict + +__all__ = ("ArtifactRecord",) + + +class ArtifactRecord(TypedDict): + """Database record for an artifact version. + + Represents the schema for artifact metadata stored in the database. + Content is stored separately in object storage; this record tracks + versioning, ownership, and the canonical URI pointing to the content. + + The composite key is (app_name, user_id, session_id, filename, version), + where session_id may be NULL for user-scoped artifacts. + """ + + app_name: str + user_id: str + session_id: "str | None" + filename: str + version: int + mime_type: "str | None" + canonical_uri: str + custom_metadata: "dict[str, Any] | None" + created_at: datetime diff --git a/sqlspec/extensions/adk/artifact/service.py b/sqlspec/extensions/adk/artifact/service.py new file mode 100644 index 000000000..cecb8f076 --- /dev/null +++ b/sqlspec/extensions/adk/artifact/service.py @@ -0,0 +1,509 @@ +"""SQLSpec-backed artifact service for Google ADK. + +Implements ``BaseArtifactService`` by composing SQL-backed metadata storage +(via :class:`BaseAsyncADKArtifactStore`) with ``sqlspec/storage/`` content +backends (via :class:`StorageRegistry`). + +Metadata (version, filename, MIME type, custom metadata, canonical URI) lives +in a SQL table. Content bytes live in object storage addressed by canonical +URI. Versioning is append-only with monotonically increasing version numbers +starting from 0. +""" + +import json +import logging +import re +from typing import TYPE_CHECKING, Any + +from google.adk.artifacts.base_artifact_service import BaseArtifactService + +from sqlspec.extensions.adk.artifact._types import ArtifactRecord +from sqlspec.storage.registry import StorageRegistry, storage_registry +from sqlspec.utils.logging import get_logger, log_with_context + +if TYPE_CHECKING: + from google.adk.artifacts.base_artifact_service import ArtifactVersion + from google.genai import types + + from sqlspec.extensions.adk.artifact.store import BaseAsyncADKArtifactStore + +logger = get_logger("sqlspec.extensions.adk.artifact.service") + +__all__ = ("SQLSpecArtifactService",) + +# Matches path traversal and absolute path components +_UNSAFE_PATH_CHARS = re.compile(r"(?:^|/)\.\.(?:/|$)|[\x00]") + + +def _sanitize_path_component(value: str) -> str: + """Sanitize a path component to prevent directory traversal. + + Removes leading/trailing slashes, rejects ``..`` traversals, and + replaces NUL bytes. + + Args: + value: Raw path component. + + Returns: + Sanitized path component. + + Raises: + ValueError: If the value contains path traversal sequences. + """ + value = value.strip("/") + if _UNSAFE_PATH_CHARS.search(value): + msg = f"Unsafe path component: {value!r}" + raise ValueError(msg) + return value + + +def _build_content_path( + app_name: str, user_id: str, filename: str, version: int, session_id: "str | None" = None +) -> str: + """Build the storage path for artifact content. + + Pattern: + ``apps/{app_name}/users/{user_id}/[sessions/{session_id}/]artifacts/{filename}/v{version}`` + + All path components are sanitized to prevent directory traversal. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + version: Version number. + session_id: Optional session identifier. + + Returns: + Sanitized storage path. + """ + parts = ["apps", _sanitize_path_component(app_name), "users", _sanitize_path_component(user_id)] + if session_id is not None: + parts.extend(["sessions", _sanitize_path_component(session_id)]) + parts.extend(["artifacts", _sanitize_path_component(filename), f"v{version}"]) + return "/".join(parts) + + +def _extract_mime_type(artifact: "types.Part | dict[str, Any]") -> "str | None": + """Extract MIME type from an artifact Part. + + Checks ``inline_data.mime_type`` and ``file_data.mime_type`` on the Part. + + Args: + artifact: ADK Part or dict representation. + + Returns: + MIME type string, or None if not determinable. + """ + if isinstance(artifact, dict): + # Handle camelCase and snake_case keys + inline = artifact.get("inline_data") or artifact.get("inlineData") + if isinstance(inline, dict): + return inline.get("mime_type") or inline.get("mimeType") + file_data = artifact.get("file_data") or artifact.get("fileData") + if isinstance(file_data, dict): + return file_data.get("mime_type") or file_data.get("mimeType") + return None + + # types.Part object + if hasattr(artifact, "inline_data") and artifact.inline_data is not None: + return getattr(artifact.inline_data, "mime_type", None) + if hasattr(artifact, "file_data") and artifact.file_data is not None: + return getattr(artifact.file_data, "mime_type", None) + return None + + +def _serialize_artifact(artifact: "types.Part | dict[str, Any]") -> bytes: + """Serialize an artifact Part to bytes for content storage. + + The artifact is serialized as JSON via ``model_dump(exclude_none=True)``. + This preserves the full Part structure including text, inline_data, + file_data, and any future Part fields. + + Args: + artifact: ADK Part or dict representation. + + Returns: + JSON-encoded bytes. + """ + if isinstance(artifact, dict): + return json.dumps(artifact, default=str).encode("utf-8") + + # Use Pydantic model serialization + if hasattr(artifact, "model_dump"): + data = artifact.model_dump(exclude_none=True) + return json.dumps(data, default=str).encode("utf-8") + + # Fallback for unexpected types + return json.dumps({"text": str(artifact)}).encode("utf-8") + + +def _deserialize_artifact(data: bytes) -> "types.Part": + """Deserialize bytes back into an ADK Part. + + Args: + data: JSON-encoded bytes from content storage. + + Returns: + Reconstructed Part object. + """ + from google.genai import types + + parsed = json.loads(data.decode("utf-8")) + return types.Part.model_validate(parsed) + + +def _record_to_artifact_version(record: "ArtifactRecord") -> "ArtifactVersion": + """Convert a database artifact record to an ADK ArtifactVersion. + + Args: + record: Database artifact record. + + Returns: + ArtifactVersion model instance. + """ + from google.adk.artifacts.base_artifact_service import ArtifactVersion + + return ArtifactVersion( + version=record["version"], + canonical_uri=record["canonical_uri"], + custom_metadata=record["custom_metadata"] or {}, + create_time=record["created_at"].timestamp(), + mime_type=record["mime_type"], + ) + + +class SQLSpecArtifactService(BaseArtifactService): + """SQLSpec-backed implementation of BaseArtifactService. + + Composes SQL metadata storage with ``sqlspec/storage/`` content backends + to provide versioned artifact management for Google ADK. + + Metadata (version number, filename, MIME type, custom metadata, canonical + URI) is stored in a SQL table managed by the artifact store. Content + bytes are stored in object storage (S3, GCS, Azure, local filesystem) + via the storage registry. + + Args: + store: Artifact metadata store implementation. + artifact_storage_uri: Base URI for content storage (e.g., + ``"s3://my-bucket/adk-artifacts/"``, ``"file:///var/data/artifacts/"``). + Can also be a registered alias in the storage registry. + registry: Storage registry to use. Defaults to the global singleton. + + Example: + from sqlspec.adapters.asyncpg.adk.artifact_store import AsyncpgADKArtifactStore + from sqlspec.extensions.adk.artifact import SQLSpecArtifactService + + artifact_store = AsyncpgADKArtifactStore(config) + await artifact_store.ensure_table() + + service = SQLSpecArtifactService( + store=artifact_store, + artifact_storage_uri="s3://my-bucket/adk-artifacts/", + ) + + version = await service.save_artifact( + app_name="my_app", + user_id="user123", + filename="output.png", + artifact=part, + ) + """ + + def __init__( + self, store: "BaseAsyncADKArtifactStore", artifact_storage_uri: str, registry: "StorageRegistry | None" = None + ) -> None: + self._store = store + self._artifact_storage_uri = artifact_storage_uri.rstrip("/") + self._registry = registry or storage_registry + + @property + def store(self) -> "BaseAsyncADKArtifactStore": + """Return the artifact metadata store.""" + return self._store + + @property + def artifact_storage_uri(self) -> str: + """Return the base URI for content storage.""" + return self._artifact_storage_uri + + async def save_artifact( + self, + *, + app_name: str, + user_id: str, + filename: str, + artifact: "types.Part | dict[str, Any]", + session_id: "str | None" = None, + custom_metadata: "dict[str, Any] | None" = None, + ) -> int: + """Save an artifact, returning the new version number. + + Writes content to object storage first, then inserts the metadata + row. If content write succeeds but metadata insert fails, the + orphaned content blob is logged but not automatically cleaned up + (eventual consistency is acceptable; orphan sweep can be added later). + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + artifact: ADK Part or dict to save. + session_id: Session identifier (None for user-scoped). + custom_metadata: Optional per-version metadata dict. + + Returns: + The version number (0-based, incrementing). + """ + from google.adk.artifacts.base_artifact_service import ensure_part + + # Normalize artifact to Part + artifact_part: types.Part = ensure_part(artifact) + + # Determine the next version + version = await self._store.get_next_version( + app_name=app_name, user_id=user_id, filename=filename, session_id=session_id + ) + + # Build the content path and canonical URI + content_path = _build_content_path( + app_name=app_name, user_id=user_id, filename=filename, version=version, session_id=session_id + ) + canonical_uri = f"{self._artifact_storage_uri}/{content_path}" + + # Serialize content + content_bytes = _serialize_artifact(artifact_part) + + # Extract MIME type + mime_type = _extract_mime_type(artifact_part) + + # Write content first (fail-fast before metadata) + backend = self._registry.get(self._artifact_storage_uri) + if hasattr(backend, "write_bytes_async"): + await backend.write_bytes_async(content_path, content_bytes) + else: + backend.write_bytes_sync(content_path, content_bytes) + + # Insert metadata row + from datetime import datetime, timezone + + record = ArtifactRecord( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + version=version, + mime_type=mime_type, + canonical_uri=canonical_uri, + custom_metadata=custom_metadata, + created_at=datetime.now(tz=timezone.utc), + ) + await self._store.insert_artifact(record) + + log_with_context( + logger, + logging.DEBUG, + "adk.artifact.save", + app_name=app_name, + user_id=user_id, + filename=filename, + version=version, + session_id=session_id, + mime_type=mime_type, + ) + return version + + async def load_artifact( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: "str | None" = None, + version: "int | None" = None, + ) -> "types.Part | None": + """Load an artifact by reading metadata then fetching content. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + version: Specific version, or None for latest. + + Returns: + Deserialized Part, or None if not found. + """ + record = await self._store.get_artifact( + app_name=app_name, user_id=user_id, filename=filename, session_id=session_id, version=version + ) + if record is None: + log_with_context( + logger, + logging.DEBUG, + "adk.artifact.load", + app_name=app_name, + filename=filename, + version=version, + found=False, + ) + return None + + # Derive content path from canonical URI + content_path = record["canonical_uri"].removeprefix(self._artifact_storage_uri + "/") + + backend = self._registry.get(self._artifact_storage_uri) + if hasattr(backend, "read_bytes_async"): + content_bytes = await backend.read_bytes_async(content_path) + else: + content_bytes = backend.read_bytes_sync(content_path) + + log_with_context( + logger, + logging.DEBUG, + "adk.artifact.load", + app_name=app_name, + filename=filename, + version=record["version"], + found=True, + ) + return _deserialize_artifact(content_bytes) + + async def list_artifact_keys(self, *, app_name: str, user_id: str, session_id: "str | None" = None) -> "list[str]": + """List distinct artifact filenames. + + When ``session_id`` is provided, returns both session-scoped and + user-scoped filenames. When None, returns only user-scoped filenames. + + Args: + app_name: Application name. + user_id: User identifier. + session_id: Session identifier. + + Returns: + List of artifact filenames. + """ + keys = await self._store.list_artifact_keys(app_name=app_name, user_id=user_id, session_id=session_id) + log_with_context( + logger, + logging.DEBUG, + "adk.artifact.list_keys", + app_name=app_name, + user_id=user_id, + session_id=session_id, + count=len(keys), + ) + return keys + + async def delete_artifact( + self, *, app_name: str, user_id: str, filename: str, session_id: "str | None" = None + ) -> None: + """Delete an artifact and all its versions. + + Deletes metadata rows first (fail-fast), then removes content + objects from storage (best-effort). + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + """ + deleted_records = await self._store.delete_artifact( + app_name=app_name, user_id=user_id, filename=filename, session_id=session_id + ) + + # Best-effort content cleanup + backend = self._registry.get(self._artifact_storage_uri) + for record in deleted_records: + content_path = record["canonical_uri"].removeprefix(self._artifact_storage_uri + "/") + try: + if hasattr(backend, "delete_async"): + await backend.delete_async(content_path) + else: + backend.delete_sync(content_path) + except Exception: + log_with_context( + logger, + logging.WARNING, + "adk.artifact.delete.content_cleanup_failed", + canonical_uri=record["canonical_uri"], + version=record["version"], + ) + + log_with_context( + logger, + logging.DEBUG, + "adk.artifact.delete", + app_name=app_name, + filename=filename, + session_id=session_id, + versions_deleted=len(deleted_records), + ) + + async def list_versions( + self, *, app_name: str, user_id: str, filename: str, session_id: "str | None" = None + ) -> "list[int]": + """List all version numbers for an artifact. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + + Returns: + Sorted list of version numbers. + """ + records = await self._store.list_artifact_versions( + app_name=app_name, user_id=user_id, filename=filename, session_id=session_id + ) + return [r["version"] for r in records] + + async def list_artifact_versions( + self, *, app_name: str, user_id: str, filename: str, session_id: "str | None" = None + ) -> "list[ArtifactVersion]": + """List all versions with full metadata for an artifact. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + + Returns: + List of ArtifactVersion objects ordered by version ascending. + """ + records = await self._store.list_artifact_versions( + app_name=app_name, user_id=user_id, filename=filename, session_id=session_id + ) + return [_record_to_artifact_version(r) for r in records] + + async def get_artifact_version( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: "str | None" = None, + version: "int | None" = None, + ) -> "ArtifactVersion | None": + """Get metadata for a specific artifact version. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + version: Version number, or None for latest. + + Returns: + ArtifactVersion if found, None otherwise. + """ + record = await self._store.get_artifact( + app_name=app_name, user_id=user_id, filename=filename, session_id=session_id, version=version + ) + if record is None: + return None + return _record_to_artifact_version(record) diff --git a/sqlspec/extensions/adk/artifact/store.py b/sqlspec/extensions/adk/artifact/store.py new file mode 100644 index 000000000..ec9c08a33 --- /dev/null +++ b/sqlspec/extensions/adk/artifact/store.py @@ -0,0 +1,363 @@ +"""Base store classes for ADK artifact metadata backend (sync and async). + +These abstract base classes define the database operations needed to manage +artifact version metadata. Content storage is handled separately by +``sqlspec/storage/`` backends; these stores only manage the relational +metadata rows. + +Adapter-specific subclasses (e.g., ``AsyncpgADKArtifactStore``) implement +the abstract methods with dialect-specific SQL. +""" + +import logging +import re +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Final, Generic, TypeVar, cast + +from sqlspec.observability import resolve_db_system +from sqlspec.utils.logging import get_logger, log_with_context + +if TYPE_CHECKING: + from sqlspec.config import ADKConfig, DatabaseConfigProtocol + from sqlspec.extensions.adk.artifact._types import ArtifactRecord + +ConfigT = TypeVar("ConfigT", bound="DatabaseConfigProtocol[Any, Any, Any]") + +logger = get_logger("sqlspec.extensions.adk.artifact.store") + +__all__ = ("BaseAsyncADKArtifactStore", "BaseSyncADKArtifactStore") + +VALID_TABLE_NAME_PATTERN: Final = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") +MAX_TABLE_NAME_LENGTH: Final = 63 + + +def _validate_table_name(table_name: str) -> None: + """Validate table name for SQL safety. + + Args: + table_name: Table name to validate. + + Raises: + ValueError: If table name is invalid. + """ + if not table_name: + msg = "Table name cannot be empty" + raise ValueError(msg) + + if len(table_name) > MAX_TABLE_NAME_LENGTH: + msg = f"Table name too long: {len(table_name)} chars (max {MAX_TABLE_NAME_LENGTH})" + raise ValueError(msg) + + if not VALID_TABLE_NAME_PATTERN.match(table_name): + msg = ( + f"Invalid table name: {table_name!r}. " + "Must start with letter/underscore and contain only alphanumeric characters and underscores" + ) + raise ValueError(msg) + + +class BaseAsyncADKArtifactStore(ABC, Generic[ConfigT]): + """Base class for async SQLSpec-backed ADK artifact metadata stores. + + Manages artifact version metadata in a SQL table. Content bytes are + stored externally via ``sqlspec/storage/`` backends and referenced + by canonical URI in each metadata row. + + Subclasses must implement dialect-specific SQL queries. + + Args: + config: SQLSpec database configuration with extension_config["adk"] settings. + + Notes: + Configuration is read from config.extension_config["adk"]: + - artifact_table: Artifact versions table name (default: "adk_artifact_versions") + """ + + __slots__ = ("_artifact_table", "_config") + + def __init__(self, config: ConfigT) -> None: + """Initialize the async ADK artifact store. + + Args: + config: SQLSpec database configuration. + """ + self._config = config + adk_config = self._get_adk_config() + self._artifact_table: str = str(adk_config.get("artifact_table", "adk_artifact_versions")) + _validate_table_name(self._artifact_table) + + def _get_adk_config(self) -> "dict[str, Any]": + """Extract ADK configuration from extension_config. + + Returns: + Dict with ADK configuration values. + """ + extension_config = self._config.extension_config + return dict(cast("ADKConfig", extension_config.get("adk", {}))) + + @property + def config(self) -> ConfigT: + """Return the database configuration.""" + return self._config + + @property + def artifact_table(self) -> str: + """Return the artifact versions table name.""" + return self._artifact_table + + @abstractmethod + async def insert_artifact(self, record: "ArtifactRecord") -> None: + """Insert an artifact version metadata row. + + Args: + record: Artifact metadata record to insert. + """ + + @abstractmethod + async def get_artifact( + self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None, version: "int | None" = None + ) -> "ArtifactRecord | None": + """Get a specific artifact version's metadata. + + When ``version`` is None, returns the latest version. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + version: Specific version number, or None for latest. + + Returns: + Artifact record if found, None otherwise. + """ + + @abstractmethod + async def list_artifact_keys(self, app_name: str, user_id: str, session_id: "str | None" = None) -> "list[str]": + """List distinct artifact filenames. + + When ``session_id`` is provided, returns filenames from both + session-scoped and user-scoped artifacts. When None, returns + only user-scoped artifact filenames. + + Args: + app_name: Application name. + user_id: User identifier. + session_id: Session identifier (None for user-scoped only). + + Returns: + List of distinct artifact filenames. + """ + + @abstractmethod + async def list_artifact_versions( + self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None + ) -> "list[ArtifactRecord]": + """List all version records for an artifact, ordered by version ascending. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + + Returns: + List of artifact records ordered by version ascending. + """ + + @abstractmethod + async def delete_artifact( + self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None + ) -> "list[ArtifactRecord]": + """Delete all version records for an artifact and return them. + + The caller uses the returned records to clean up content from + object storage. Metadata is deleted first (fail-fast); content + cleanup is best-effort. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + + Returns: + List of deleted artifact records (needed for content cleanup). + """ + + @abstractmethod + async def get_next_version( + self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None + ) -> int: + """Get the next version number for an artifact. + + Returns 0 if no versions exist (first version), otherwise + ``max(version) + 1``. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + + Returns: + Next version number (0-based). + """ + + @abstractmethod + async def create_table(self) -> None: + """Create the artifact versions table if it does not exist.""" + + async def ensure_table(self) -> None: + """Create the artifact table and emit a standardized log entry.""" + await self.create_table() + log_with_context( + logger, + logging.DEBUG, + "adk.artifact.table.ready", + db_system=resolve_db_system(type(self).__name__), + artifact_table=self._artifact_table, + ) + + +class BaseSyncADKArtifactStore(ABC, Generic[ConfigT]): + """Base class for sync SQLSpec-backed ADK artifact metadata stores. + + Synchronous counterpart of :class:`BaseAsyncADKArtifactStore`. + + Args: + config: SQLSpec database configuration with extension_config["adk"] settings. + """ + + __slots__ = ("_artifact_table", "_config") + + def __init__(self, config: ConfigT) -> None: + """Initialize the sync ADK artifact store. + + Args: + config: SQLSpec database configuration. + """ + self._config = config + adk_config = self._get_adk_config() + self._artifact_table: str = str(adk_config.get("artifact_table", "adk_artifact_versions")) + _validate_table_name(self._artifact_table) + + def _get_adk_config(self) -> "dict[str, Any]": + """Extract ADK configuration from extension_config. + + Returns: + Dict with ADK configuration values. + """ + extension_config = self._config.extension_config + return dict(cast("ADKConfig", extension_config.get("adk", {}))) + + @property + def config(self) -> ConfigT: + """Return the database configuration.""" + return self._config + + @property + def artifact_table(self) -> str: + """Return the artifact versions table name.""" + return self._artifact_table + + @abstractmethod + def insert_artifact(self, record: "ArtifactRecord") -> None: + """Insert an artifact version metadata row. + + Args: + record: Artifact metadata record to insert. + """ + + @abstractmethod + def get_artifact( + self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None, version: "int | None" = None + ) -> "ArtifactRecord | None": + """Get a specific artifact version's metadata. + + When ``version`` is None, returns the latest version. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + version: Specific version number, or None for latest. + + Returns: + Artifact record if found, None otherwise. + """ + + @abstractmethod + def list_artifact_keys(self, app_name: str, user_id: str, session_id: "str | None" = None) -> "list[str]": + """List distinct artifact filenames. + + Args: + app_name: Application name. + user_id: User identifier. + session_id: Session identifier (None for user-scoped only). + + Returns: + List of distinct artifact filenames. + """ + + @abstractmethod + def list_artifact_versions( + self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None + ) -> "list[ArtifactRecord]": + """List all version records for an artifact, ordered by version ascending. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + + Returns: + List of artifact records ordered by version ascending. + """ + + @abstractmethod + def delete_artifact( + self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None + ) -> "list[ArtifactRecord]": + """Delete all version records for an artifact and return them. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + + Returns: + List of deleted artifact records (needed for content cleanup). + """ + + @abstractmethod + def get_next_version(self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None) -> int: + """Get the next version number for an artifact. + + Args: + app_name: Application name. + user_id: User identifier. + filename: Artifact filename. + session_id: Session identifier (None for user-scoped). + + Returns: + Next version number (0-based). + """ + + @abstractmethod + def create_table(self) -> None: + """Create the artifact versions table if it does not exist.""" + + def ensure_table(self) -> None: + """Create the artifact table and emit a standardized log entry.""" + self.create_table() + log_with_context( + logger, + logging.DEBUG, + "adk.artifact.table.ready", + db_system=resolve_db_system(type(self).__name__), + artifact_table=self._artifact_table, + ) diff --git a/sqlspec/extensions/adk/converters.py b/sqlspec/extensions/adk/converters.py index d68cbc141..f1bff6ec1 100644 --- a/sqlspec/extensions/adk/converters.py +++ b/sqlspec/extensions/adk/converters.py @@ -1,20 +1,37 @@ -"""Conversion functions between ADK models and database records.""" +"""Conversion functions between ADK models and database records. + +Implements full-event JSON storage: the entire Event is serialized via +``Event.model_dump_json(exclude_none=True)`` into a single ``event_json`` +column, with a small set of indexed scalar columns extracted alongside for +query performance. Reconstruction uses ``Event.model_validate_json()``. + +Also provides scoped-state helpers that normalise ADK state prefixes +(``app:``, ``user:``, ``temp:``) so the shared service layer can split, +filter, and merge state before handing it to backend stores. +""" -import json -import pickle from datetime import datetime, timezone from typing import Any from google.adk.events.event import Event from google.adk.sessions import Session -from google.genai import types from sqlspec.extensions.adk._types import EventRecord, SessionRecord -from sqlspec.utils.logging import get_logger -logger = get_logger("sqlspec.extensions.adk.converters") +__all__ = ( + "event_to_record", + "filter_temp_state", + "merge_scoped_state", + "record_to_event", + "record_to_session", + "session_to_record", + "split_scoped_state", +) + -__all__ = ("event_to_record", "record_to_event", "record_to_session", "session_to_record") +# --------------------------------------------------------------------------- +# Session converters +# --------------------------------------------------------------------------- def session_to_record(session: "Session") -> SessionRecord: @@ -30,7 +47,7 @@ def session_to_record(session: "Session") -> SessionRecord: id=session.id, app_name=session.app_name, user_id=session.user_id, - state=session.state, + state=filter_temp_state(session.state), create_time=datetime.now(timezone.utc), update_time=datetime.fromtimestamp(session.last_update_time, tz=timezone.utc), ) @@ -58,115 +75,115 @@ def record_to_session(record: SessionRecord, events: "list[EventRecord]") -> "Se ) -def event_to_record(event: "Event", session_id: str, app_name: str, user_id: str) -> EventRecord: - """Convert ADK Event to database record. +# --------------------------------------------------------------------------- +# Event converters (full-event JSON storage) +# --------------------------------------------------------------------------- + + +def event_to_record(event: "Event", session_id: str) -> EventRecord: + """Convert ADK Event to database record using full-event JSON storage. + + The entire Event is serialized into ``event_json`` via Pydantic's + ``model_dump_json(exclude_none=True)``. A small number of indexed scalar + columns are extracted alongside for query performance. Args: event: ADK Event object. session_id: ID of the parent session. - app_name: Name of the application. - user_id: ID of the user. Returns: EventRecord for database storage. """ - actions_bytes = pickle.dumps(event.actions) - - long_running_tool_ids_json = None - if event.long_running_tool_ids: - long_running_tool_ids_json = json.dumps(list(event.long_running_tool_ids)) - - content_dict = None - if event.content: - content_dict = event.content.model_dump(exclude_none=True, mode="json") - - grounding_metadata_dict = None - if event.grounding_metadata: - grounding_metadata_dict = event.grounding_metadata.model_dump(exclude_none=True, mode="json") - - custom_metadata_dict = event.custom_metadata - return EventRecord( - id=event.id, - app_name=app_name, - user_id=user_id, session_id=session_id, invocation_id=event.invocation_id, author=event.author, - branch=event.branch, - actions=actions_bytes, - long_running_tool_ids_json=long_running_tool_ids_json, timestamp=datetime.fromtimestamp(event.timestamp, tz=timezone.utc), - content=content_dict, - grounding_metadata=grounding_metadata_dict, - custom_metadata=custom_metadata_dict, - partial=event.partial, - turn_complete=event.turn_complete, - interrupted=event.interrupted, - error_code=event.error_code, - error_message=event.error_message, + event_json=event.model_dump(exclude_none=True, mode="json"), ) def record_to_event(record: "EventRecord") -> "Event": """Convert database record to ADK Event. + Reconstruction is lossless: the full Event is restored from + ``event_json`` via ``Event.model_validate_json()``. + Args: record: Event database record. Returns: ADK Event object. """ - actions = pickle.loads(record["actions"]) # noqa: S301 + return Event.model_validate(record["event_json"]) - long_running_tool_ids = None - if record["long_running_tool_ids_json"]: - long_running_tool_ids = set(json.loads(record["long_running_tool_ids_json"])) - return Event( - id=record["id"], - invocation_id=record["invocation_id"], - author=record["author"], - branch=record["branch"], - actions=actions, - timestamp=record["timestamp"].timestamp(), - content=_decode_content(record["content"]), - long_running_tool_ids=long_running_tool_ids, - partial=record["partial"], - turn_complete=record["turn_complete"], - error_code=record["error_code"], - error_message=record["error_message"], - interrupted=record["interrupted"], - grounding_metadata=_decode_grounding_metadata(record["grounding_metadata"]), - custom_metadata=record["custom_metadata"], - ) +# --------------------------------------------------------------------------- +# Scoped-state helpers +# --------------------------------------------------------------------------- + +def filter_temp_state(state: "dict[str, Any]") -> "dict[str, Any]": + """Return a copy of *state* with all ``temp:`` keys removed. -def _decode_content(content_dict: "dict[str, Any] | None") -> Any: - """Decode content dictionary from database to ADK Content object. + ``temp:`` keys are process-local/session-runtime state and must never be + written to persistent storage. Args: - content_dict: Content dictionary from database. + state: ADK state dictionary (may contain ``temp:`` prefixed keys). Returns: - ADK Content object or None. + A new dict without any ``temp:``-prefixed keys. """ - if not content_dict: - return None + return {k: v for k, v in state.items() if not k.startswith("temp:")} - return types.Content.model_validate(content_dict) - -def _decode_grounding_metadata(grounding_dict: "dict[str, Any] | None") -> Any: - """Decode grounding metadata dictionary from database to ADK object. +def split_scoped_state(state: "dict[str, Any]") -> "tuple[dict[str, Any], dict[str, Any], dict[str, Any]]": + """Split state into app-scoped, user-scoped, and session-scoped buckets. Args: - grounding_dict: Grounding metadata dictionary from database. + state: Full session state dict (temp: already stripped). Returns: - ADK GroundingMetadata object or None. + Tuple of (app_state, user_state, session_state). + app_state: keys starting with "app:" + user_state: keys starting with "user:" + session_state: all other keys """ - if not grounding_dict: - return None + app_state: dict[str, Any] = {} + user_state: dict[str, Any] = {} + session_state: dict[str, Any] = {} + for k, v in state.items(): + if k.startswith("app:"): + app_state[k] = v + elif k.startswith("user:"): + user_state[k] = v + else: + session_state[k] = v + return app_state, user_state, session_state + + +def merge_scoped_state( + session_state: "dict[str, Any]", + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, +) -> "dict[str, Any]": + """Merge scoped state buckets into a single state dict. + + Priority: session_state is base, app_state and user_state overlay. + This matches ADK's documented merge semantics on session load. + + Args: + session_state: Per-session state. + app_state: App-scoped state (shared across sessions for same app). + user_state: User-scoped state (shared across sessions for same app+user). - return types.GroundingMetadata.model_validate(grounding_dict) + Returns: + Merged state dict. + """ + merged = dict(session_state) + if app_state: + merged.update(app_state) + if user_state: + merged.update(user_state) + return merged diff --git a/sqlspec/extensions/adk/memory/converters.py b/sqlspec/extensions/adk/memory/converters.py index 1dc0dbf83..fafea6d12 100644 --- a/sqlspec/extensions/adk/memory/converters.py +++ b/sqlspec/extensions/adk/memory/converters.py @@ -19,15 +19,22 @@ logger = get_logger("sqlspec.extensions.adk.memory.converters") -__all__ = ("event_to_memory_record", "extract_content_text", "record_to_memory_entry", "session_to_memory_records") +__all__ = ( + "event_to_memory_record", + "extract_content_text", + "memory_entry_to_record", + "record_to_memory_entry", + "records_to_memory_entries", + "session_to_memory_records", +) def extract_content_text(content: "types.Content") -> str: """Extract plain text from ADK Content for search indexing. Handles multi-modal Content.parts including text, function calls, - and other part types. Non-text parts are indexed by their type - for discoverability. + function responses, and other part types. Non-text parts are indexed + by their type for discoverability. Args: content: ADK Content object with parts list. @@ -91,6 +98,66 @@ def event_to_memory_record(event: "Event", session_id: str, app_name: str, user_ ) +def memory_entry_to_record( + entry: "MemoryEntry", app_name: str, user_id: str, extra_metadata: "dict[str, Any] | None" = None +) -> "MemoryRecord | None": + """Convert an ADK MemoryEntry to a database record. + + Serializes the entry's ``content`` to ``content_json``, extracts text + from ``content.parts`` for ``content_text``, and merges entry-level + ``custom_metadata`` with the optional ``extra_metadata`` parameter. + + Args: + entry: ADK MemoryEntry object. + app_name: Name of the application. + user_id: ID of the user. + extra_metadata: Optional call-level metadata to merge with the + entry's own ``custom_metadata``. + + Returns: + MemoryRecord for database storage, or None if entry has no + indexable content. + """ + content_text = extract_content_text(entry.content) + if not content_text.strip(): + return None + + content_dict = entry.content.model_dump(exclude_none=True, mode="json") + + # Merge entry-level and call-level metadata + merged_metadata: dict[str, Any] | None = None + if entry.custom_metadata or extra_metadata: + merged_metadata = {} + if extra_metadata: + merged_metadata.update(extra_metadata) + if entry.custom_metadata: + merged_metadata.update(entry.custom_metadata) + + now = datetime.now(timezone.utc) + + # Parse timestamp from entry if available + timestamp = now + if entry.timestamp: + try: + timestamp = datetime.fromisoformat(entry.timestamp) + except (ValueError, TypeError): + timestamp = now + + return MemoryRecord( + id=entry.id or str(uuid.uuid4()), + session_id="", + app_name=app_name, + user_id=user_id, + event_id="", + author=entry.author or "", + timestamp=timestamp, + content_json=content_dict, + content_text=content_text, + metadata_json=merged_metadata, + inserted_at=now, + ) + + def session_to_memory_records(session: "Session") -> list["MemoryRecord"]: """Convert a completed ADK Session to a list of memory records. @@ -121,11 +188,14 @@ def session_to_memory_records(session: "Session") -> list["MemoryRecord"]: def record_to_memory_entry(record: "MemoryRecord") -> "MemoryEntry": """Convert a database record to an ADK MemoryEntry. + Preserves ``id`` and ``custom_metadata`` fields that were previously + dropped on readback. + Args: record: Memory database record. Returns: - ADK MemoryEntry object. + ADK MemoryEntry object with all available fields populated. """ from google.adk.memory.memory_entry import MemoryEntry from google.genai import types @@ -134,7 +204,13 @@ def record_to_memory_entry(record: "MemoryRecord") -> "MemoryEntry": timestamp_str = record["timestamp"].isoformat() if record["timestamp"] else None - return MemoryEntry(content=content, author=record["author"], timestamp=timestamp_str) + return MemoryEntry( + id=record["id"], + content=content, + author=record["author"], + timestamp=timestamp_str, + custom_metadata=record["metadata_json"] or {}, + ) def records_to_memory_entries(records: list["MemoryRecord"]) -> list["Any"]: diff --git a/sqlspec/extensions/adk/memory/service.py b/sqlspec/extensions/adk/memory/service.py index abc6360d5..a94c7e9cd 100644 --- a/sqlspec/extensions/adk/memory/service.py +++ b/sqlspec/extensions/adk/memory/service.py @@ -4,10 +4,17 @@ from google.adk.memory.base_memory_service import BaseMemoryService, SearchMemoryResponse -from sqlspec.extensions.adk.memory.converters import records_to_memory_entries, session_to_memory_records +from sqlspec.extensions.adk.memory.converters import ( + memory_entry_to_record, + records_to_memory_entries, + session_to_memory_records, +) from sqlspec.utils.logging import get_logger if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + + from google.adk.events.event import Event from google.adk.memory.memory_entry import MemoryEntry from google.adk.sessions import Session @@ -102,6 +109,98 @@ async def add_session_to_memory(self, session: "Session") -> None: "Stored %d memory entries for session %s (total events: %d)", inserted_count, session.id, len(records) ) + async def add_events_to_memory( + self, + *, + app_name: str, + user_id: str, + events: "Sequence[Event]", + session_id: "str | None" = None, + custom_metadata: "Mapping[str, object] | None" = None, + ) -> None: + """Add an explicit list of events to the memory service. + + Same Event-to-MemoryRecord extraction logic as + ``add_session_to_memory``, but operates on a sequence of Events + directly (no Session wrapper needed). + + Args: + app_name: The application name for memory scope. + user_id: The user ID for memory scope. + events: The events to add to memory. + session_id: Optional session ID for memory scope/partitioning. + If None, memory entries are user-scoped only. + custom_metadata: Optional portable metadata stored in + ``MemoryRecord.metadata_json``. + """ + from sqlspec.extensions.adk.memory.converters import event_to_memory_record + + metadata_dict = dict(custom_metadata) if custom_metadata else None + records = [] + for event in events: + record = event_to_memory_record( + event=event, session_id=session_id or "", app_name=app_name, user_id=user_id + ) + if record is not None: + if metadata_dict: + record["metadata_json"] = metadata_dict + records.append(record) + + if not records: + logger.debug( + "No content to store for events (app=%s, user=%s, count=%d)", app_name, user_id, len(list(events)) + ) + return + + inserted_count = await self._store.insert_memory_entries(records) + logger.debug( + "Stored %d memory entries from %d events (app=%s, user=%s)", inserted_count, len(records), app_name, user_id + ) + + async def add_memory( + self, + *, + app_name: str, + user_id: str, + memories: "Sequence[MemoryEntry]", + custom_metadata: "Mapping[str, object] | None" = None, + ) -> None: + """Add explicit memory items directly to the memory service. + + Each entry's ``content`` is serialized to ``content_json``, text is + extracted from ``content.parts`` for ``content_text``, and + ``custom_metadata`` merges the entry-level ``entry.custom_metadata`` + with the call-level ``custom_metadata`` parameter. + + Args: + app_name: The application name for memory scope. + user_id: The user ID for memory scope. + memories: Explicit memory items to add. + custom_metadata: Optional portable metadata for memory writes. + Merged with each entry's ``custom_metadata``. + """ + call_metadata = dict(custom_metadata) if custom_metadata else {} + records = [] + for entry in memories: + record = memory_entry_to_record( + entry=entry, app_name=app_name, user_id=user_id, extra_metadata=call_metadata + ) + if record is not None: + records.append(record) + + if not records: + logger.debug("No content to store for memories (app=%s, user=%s)", app_name, user_id) + return + + inserted_count = await self._store.insert_memory_entries(records) + logger.debug( + "Stored %d memory entries from %d memories (app=%s, user=%s)", + inserted_count, + len(records), + app_name, + user_id, + ) + async def search_memory(self, *, app_name: str, user_id: str, query: str) -> "SearchMemoryResponse": """Search memory entries by text query. diff --git a/sqlspec/extensions/adk/service.py b/sqlspec/extensions/adk/service.py index 8656f4beb..132d9ad69 100644 --- a/sqlspec/extensions/adk/service.py +++ b/sqlspec/extensions/adk/service.py @@ -7,7 +7,7 @@ from google.adk.sessions.base_session_service import BaseSessionService, GetSessionConfig, ListSessionsResponse -from sqlspec.extensions.adk.converters import event_to_record, record_to_session +from sqlspec.extensions.adk.converters import event_to_record, filter_temp_state, record_to_session from sqlspec.utils.logging import get_logger, log_with_context if TYPE_CHECKING: @@ -80,8 +80,10 @@ async def create_session( if state is None: state = {} + persisted_state = filter_temp_state(state) + record = await self._store.create_session( - session_id=session_id, app_name=app_name, user_id=user_id, state=state + session_id=session_id, app_name=app_name, user_id=user_id, state=persisted_state ) log_with_context( logger, logging.DEBUG, "adk.session.create", app_name=app_name, session_id=session_id, has_state=bool(state) @@ -192,6 +194,11 @@ async def delete_session(self, *, app_name: str, user_id: str, session_id: str) async def append_event(self, session: "Session", event: "Event") -> "Event": """Append an event to a session. + Persists the event record and the post-append durable state + atomically via ``store.append_event_and_update_state()``. ``temp:`` + keys are stripped from the persisted state snapshot so they never + survive a reload. + Args: session: Session to append to. event: Event to append. @@ -204,11 +211,14 @@ async def append_event(self, session: "Session", event: "Event") -> "Event": if event.partial: return event - event_record = event_to_record( - event=event, session_id=session.id, app_name=session.app_name, user_id=session.user_id - ) + event_record = event_to_record(event=event, session_id=session.id) - await self._store.append_event(event_record) + # Strip temp: keys before persisting state + durable_state = filter_temp_state(session.state) + + await self._store.append_event_and_update_state( + event_record=event_record, session_id=session.id, state=durable_state + ) log_with_context( logger, logging.DEBUG, diff --git a/sqlspec/extensions/adk/store.py b/sqlspec/extensions/adk/store.py index 9903ee7b8..9f2ea8d1f 100644 --- a/sqlspec/extensions/adk/store.py +++ b/sqlspec/extensions/adk/store.py @@ -250,6 +250,24 @@ async def append_event(self, event_record: "EventRecord") -> None: """ raise NotImplementedError + @abstractmethod + async def append_event_and_update_state( + self, event_record: "EventRecord", session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically append an event and update the session's durable state. + + This is the authoritative durable write boundary for post-creation + session mutations. The event insert and state update must succeed + together or fail together. + + Args: + event_record: Event record to store. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot (``temp:`` keys already + stripped by the service layer). + """ + raise NotImplementedError + @abstractmethod async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None @@ -505,6 +523,24 @@ def create_event( """ raise NotImplementedError + @abstractmethod + def create_event_and_update_state( + self, event_record: "EventRecord", session_id: str, state: "dict[str, Any]" + ) -> None: + """Atomically create an event and update the session's durable state. + + This is the authoritative durable write boundary for post-creation + session mutations. The event insert and state update must succeed + together or fail together. + + Args: + event_record: Event record to store. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot (``temp:`` keys already + stripped by the service layer). + """ + raise NotImplementedError + @abstractmethod def list_events(self, session_id: str) -> "list[EventRecord]": """List events for a session ordered by timestamp. diff --git a/tests/conftest.py b/tests/conftest.py index 0a98005ed..b3fabd251 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import logging import os import warnings @@ -63,6 +64,26 @@ def disable_spanner_builtin_metrics() -> "Generator[None, None, None]": yield +@pytest.fixture(scope="session", autouse=True) +def suppress_noisy_test_loggers() -> "Generator[None, None, None]": + """Lower especially noisy library loggers during test runs.""" + overrides = { + "httpx": logging.WARNING, + "httpcore": logging.WARNING, + "mysql.connector": logging.WARNING, + "asyncmy": logging.ERROR, + "sqlspec.migrations.tracker": logging.WARNING, + } + original_levels = {name: logging.getLogger(name).level for name in overrides} + for name, level in overrides.items(): + logging.getLogger(name).setLevel(level) + try: + yield + finally: + for name, level in original_levels.items(): + logging.getLogger(name).setLevel(level) + + @pytest.fixture(scope="session") def minio_client(minio_service: "MinioService", minio_default_bucket_name: str) -> Generator[Minio, None, None]: """Override pytest-databases minio_client to use new minio API with keyword arguments.""" diff --git a/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py b/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py index 7c0451ba8..e20536d83 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py @@ -10,6 +10,7 @@ if the driver is not installed. """ +import json from pathlib import Path from typing import Any @@ -22,16 +23,16 @@ @pytest.fixture() -def sqlite_store(tmp_path: Path) -> Any: +async def sqlite_store(tmp_path: Path) -> Any: """SQLite ADBC store fixture.""" db_path = tmp_path / "sqlite_test.db" config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) store = AdbcADKStore(config) - store.create_tables() + await store.create_tables() return store -def test_sqlite_dialect_creates_text_columns(sqlite_store: Any) -> None: +async def test_sqlite_dialect_creates_text_columns(sqlite_store: Any) -> None: """Test SQLite dialect creates TEXT columns for JSON.""" with sqlite_store.config.provide_connection() as conn: cursor = conn.cursor() @@ -45,54 +46,61 @@ def test_sqlite_dialect_creates_text_columns(sqlite_store: Any) -> None: cursor.close() # type: ignore[no-untyped-call] -def test_sqlite_dialect_session_operations(sqlite_store: Any) -> None: +async def test_sqlite_dialect_session_operations(sqlite_store: Any) -> None: """Test SQLite dialect with full session CRUD.""" session_id = "sqlite-session-1" app_name = "test-app" user_id = "user-123" state = {"nested": {"key": "value"}, "count": 42} - created = sqlite_store.create_session(session_id, app_name, user_id, state) + created = await sqlite_store.create_session(session_id, app_name, user_id, state) assert created["id"] == session_id assert created["state"] == state - retrieved = sqlite_store.get_session(session_id) + retrieved = await sqlite_store.get_session(session_id) assert retrieved["state"] == state new_state = {"updated": True} - sqlite_store.update_session_state(session_id, new_state) + await sqlite_store.update_session_state(session_id, new_state) - updated = sqlite_store.get_session(session_id) + updated = await sqlite_store.get_session(session_id) assert updated["state"] == new_state -def test_sqlite_dialect_event_operations(sqlite_store: Any) -> None: +async def test_sqlite_dialect_event_operations(sqlite_store: Any) -> None: """Test SQLite dialect with event operations.""" session_id = "sqlite-session-events" app_name = "test-app" user_id = "user-123" - sqlite_store.create_session(session_id, app_name, user_id, {}) + await sqlite_store.create_session(session_id, app_name, user_id, {}) - event_id = "event-1" - actions = b"pickled_actions_data" content = {"message": "Hello"} - event = sqlite_store.create_event( - event_id=event_id, session_id=session_id, app_name=app_name, user_id=user_id, actions=actions, content=content - ) + from datetime import datetime, timezone + + from sqlspec.extensions.adk import EventRecord - assert event["id"] == event_id - assert event["content"] == content + event_record: EventRecord = { + "session_id": session_id, + "invocation_id": "", + "author": "", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "event-1", "content": content, "app_name": app_name, "user_id": user_id}, + } + await sqlite_store.append_event(event_record) - events = sqlite_store.list_events(session_id) + events = await sqlite_store.get_events(session_id) assert len(events) == 1 - assert events[0]["content"] == content + retrieved_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + ) + assert retrieved_data["content"] == content @pytest.mark.postgres @pytest.mark.skipif(True, reason="Requires adbc-driver-postgresql and PostgreSQL server") -def test_postgresql_dialect_creates_jsonb_columns() -> None: +async def test_postgresql_dialect_creates_jsonb_columns() -> None: """Test PostgreSQL dialect creates JSONB columns. This test is skipped by default. To run: @@ -105,7 +113,7 @@ def test_postgresql_dialect_creates_jsonb_columns() -> None: connection_config={"driver_name": "postgresql", "uri": "postgresql://user:pass@localhost/testdb"} ) store = AdbcADKStore(config) - store.create_tables() + await store.create_tables() with store.config.provide_connection() as conn: cursor = conn.cursor() @@ -127,7 +135,7 @@ def test_postgresql_dialect_creates_jsonb_columns() -> None: @pytest.mark.duckdb @pytest.mark.skipif(True, reason="Requires adbc-driver-duckdb") -def test_duckdb_dialect_creates_json_columns(tmp_path: Path) -> None: +async def test_duckdb_dialect_creates_json_columns(tmp_path: Path) -> None: """Test DuckDB dialect creates JSON columns. This test is skipped by default. To run: @@ -137,18 +145,18 @@ def test_duckdb_dialect_creates_json_columns(tmp_path: Path) -> None: db_path = tmp_path / "duckdb_test.db" config = AdbcConfig(connection_config={"driver_name": "duckdb", "uri": f"file:{db_path}"}) store = AdbcADKStore(config) - store.create_tables() + await store.create_tables() session_id = "duckdb-session-1" state = {"analytics": {"count": 1000, "revenue": 50000.00}} - created = store.create_session(session_id, "app", "user", state) + created = await store.create_session(session_id, "app", "user", state) assert created["state"] == state @pytest.mark.snowflake @pytest.mark.skipif(True, reason="Requires adbc-driver-snowflake and Snowflake account") -def test_snowflake_dialect_creates_variant_columns() -> None: +async def test_snowflake_dialect_creates_variant_columns() -> None: """Test Snowflake dialect creates VARIANT columns. This test is skipped by default. To run: @@ -165,7 +173,7 @@ def test_snowflake_dialect_creates_variant_columns() -> None: } ) store = AdbcADKStore(config) - store.create_tables() + await store.create_tables() with store.config.provide_connection() as conn: cursor = conn.cursor() @@ -185,7 +193,7 @@ def test_snowflake_dialect_creates_variant_columns() -> None: cursor.close() # type: ignore[no-untyped-call] -def test_sqlite_with_owner_id_column(tmp_path: Path) -> None: +async def test_sqlite_with_owner_id_column(tmp_path: Path) -> None: """Test SQLite with owner ID column creates proper constraints.""" db_path = tmp_path / "sqlite_fk_test.db" base_config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) @@ -205,16 +213,16 @@ def test_sqlite_with_owner_id_column(tmp_path: Path) -> None: extension_config={"adk": {"owner_id_column": "tenant_id INTEGER NOT NULL REFERENCES tenants(id)"}}, ) store = AdbcADKStore(config) - store.create_tables() + await store.create_tables() - session = store.create_session("s1", "app", "user", {"data": "test"}, owner_id=1) + session = await store.create_session("s1", "app", "user", {"data": "test"}, owner_id=1) assert session["id"] == "s1" - retrieved = store.get_session("s1") + retrieved = await store.get_session("s1") assert retrieved is not None -def test_generic_dialect_fallback(tmp_path: Path) -> None: +async def test_generic_dialect_fallback(tmp_path: Path) -> None: """Test generic dialect is used for unknown drivers.""" db_path = tmp_path / "generic_test.db" @@ -223,7 +231,7 @@ def test_generic_dialect_fallback(tmp_path: Path) -> None: store = AdbcADKStore(config) assert store.dialect in ["sqlite", "generic"] - store.create_tables() + await store.create_tables() - session = store.create_session("generic-1", "app", "user", {"test": True}) + session = await store.create_session("generic-1", "app", "user", {"test": True}) assert session["state"]["test"] is True diff --git a/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py b/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py index c87302f23..703d40437 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py @@ -89,54 +89,59 @@ def test_generic_sessions_ddl_contains_text() -> None: assert "TIMESTAMP" in ddl -def test_postgresql_events_ddl_contains_jsonb() -> None: - """Test PostgreSQL events DDL uses JSONB for content fields.""" +def test_postgresql_events_ddl_uses_jsonb() -> None: + """Test PostgreSQL events DDL uses JSONB for event_json.""" config = AdbcConfig(connection_config={"driver_name": "postgresql", "uri": ":memory:"}) store = AdbcADKStore(config) ddl = store._get_events_ddl_postgresql() # pyright: ignore[reportPrivateUsage] assert "JSONB" in ddl - assert "BYTEA" in ddl - assert "BOOLEAN" in ddl + assert "event_json" in ddl + assert "session_id" in ddl + assert "invocation_id" in ddl + assert "author" in ddl + assert "timestamp" in ddl.lower() -def test_sqlite_events_ddl_contains_text_and_integer() -> None: - """Test SQLite events DDL uses TEXT for JSON and INTEGER for booleans.""" +def test_sqlite_events_ddl_uses_text() -> None: + """Test SQLite events DDL uses TEXT for event_json.""" config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": ":memory:"}) store = AdbcADKStore(config) ddl = store._get_events_ddl_sqlite() # pyright: ignore[reportPrivateUsage] assert "TEXT" in ddl - assert "BLOB" in ddl - assert "INTEGER" in ddl + assert "event_json" in ddl + assert "session_id" in ddl + assert "REAL" in ddl # SQLite uses REAL for timestamps -def test_duckdb_events_ddl_contains_json_and_boolean() -> None: - """Test DuckDB events DDL uses JSON and BOOLEAN types.""" +def test_duckdb_events_ddl_uses_json() -> None: + """Test DuckDB events DDL uses JSON type for event_json.""" config = AdbcConfig(connection_config={"driver_name": "duckdb", "uri": ":memory:"}) store = AdbcADKStore(config) ddl = store._get_events_ddl_duckdb() # pyright: ignore[reportPrivateUsage] assert "JSON" in ddl - assert "BOOLEAN" in ddl + assert "event_json" in ddl -def test_snowflake_events_ddl_contains_variant() -> None: - """Test Snowflake events DDL uses VARIANT for content.""" +def test_snowflake_events_ddl_uses_variant() -> None: + """Test Snowflake events DDL uses VARIANT for event_json.""" config = AdbcConfig(connection_config={"driver_name": "snowflake", "uri": "snowflake://test"}) store = AdbcADKStore(config) ddl = store._get_events_ddl_snowflake() # pyright: ignore[reportPrivateUsage] assert "VARIANT" in ddl - assert "BINARY" in ddl + assert "event_json" in ddl -def test_ddl_dispatch_uses_correct_dialect() -> None: +async def test_ddl_dispatch_uses_correct_dialect() -> None: """Test that DDL dispatch selects correct dialect method.""" config = AdbcConfig(connection_config={"driver_name": "postgresql", "uri": ":memory:"}) store = AdbcADKStore(config) - sessions_ddl = store._get_create_sessions_table_sql() # pyright: ignore[reportPrivateUsage] + sessions_ddl = await store._get_create_sessions_table_sql() # pyright: ignore[reportPrivateUsage] assert "JSONB" in sessions_ddl - events_ddl = store._get_create_events_table_sql() # pyright: ignore[reportPrivateUsage] + events_ddl = await store._get_create_events_table_sql() # pyright: ignore[reportPrivateUsage] assert "JSONB" in events_ddl + assert "event_json" in events_ddl def test_owner_id_column_included_in_sessions_ddl() -> None: diff --git a/tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py b/tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py index fc39cebb2..0e11dd0bb 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py @@ -1,5 +1,6 @@ """Tests for ADBC ADK store edge cases and error handling.""" +import json from pathlib import Path from typing import Any @@ -12,19 +13,19 @@ @pytest.fixture() -def adbc_store(tmp_path: Path) -> AdbcADKStore: +async def adbc_store(tmp_path: Path) -> AdbcADKStore: """Create ADBC ADK store with SQLite backend.""" db_path = tmp_path / "test_adk.db" config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) store = AdbcADKStore(config) - store.create_tables() + await store.create_tables() return store -def test_create_tables_idempotent(adbc_store: Any) -> None: +async def test_create_tables_idempotent(adbc_store: Any) -> None: """Test that create_tables can be called multiple times safely.""" - adbc_store.create_tables() - adbc_store.create_tables() + await adbc_store.create_tables() + await adbc_store.create_tables() def test_table_names_validation(tmp_path: Path) -> None: @@ -61,23 +62,23 @@ def test_table_names_validation(tmp_path: Path) -> None: AdbcADKStore(config) -def test_operations_before_create_tables(tmp_path: Path) -> None: +async def test_operations_before_create_tables(tmp_path: Path) -> None: """Test operations gracefully handle missing tables.""" db_path = tmp_path / "test_no_tables.db" config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) store = AdbcADKStore(config) - session = store.get_session("nonexistent") + session = await store.get_session("nonexistent") assert session is None - sessions = store.list_sessions("app", "user") + sessions = await store.list_sessions("app", "user") assert sessions == [] - events = store.list_events("session") + events = await store.get_events("session") assert events == [] -def test_custom_table_names(tmp_path: Path) -> None: +async def test_custom_table_names(tmp_path: Path) -> None: """Test using custom table names.""" db_path = tmp_path / "test_custom.db" config = AdbcConfig( @@ -85,43 +86,56 @@ def test_custom_table_names(tmp_path: Path) -> None: extension_config={"adk": {"session_table": "custom_sessions", "events_table": "custom_events"}}, ) store = AdbcADKStore(config) - store.create_tables() + await store.create_tables() session_id = "test" - session = store.create_session(session_id, "app", "user", {"data": "test"}) + session = await store.create_session(session_id, "app", "user", {"data": "test"}) assert session["id"] == session_id - retrieved = store.get_session(session_id) + retrieved = await store.get_session(session_id) assert retrieved is not None -def test_unicode_in_fields(adbc_store: Any) -> None: +async def test_unicode_in_fields(adbc_store: Any) -> None: """Test Unicode characters in various fields.""" session_id = "unicode-session" - app_name = "测试应用" - user_id = "ユーザー123" - state = {"message": "Hello 世界", "emoji": "🎉"} + app_name = "\u6d4b\u8bd5\u5e94\u7528" + user_id = "\u30e6\u30fc\u30b6\u30fc123" + state = {"message": "Hello \u4e16\u754c"} - created_session = adbc_store.create_session(session_id, app_name, user_id, state) + created_session = await adbc_store.create_session(session_id, app_name, user_id, state) assert created_session["app_name"] == app_name assert created_session["user_id"] == user_id - assert created_session["state"]["message"] == "Hello 世界" - assert created_session["state"]["emoji"] == "🎉" - - event = adbc_store.create_event( - event_id="unicode-event", - session_id=session_id, - app_name=app_name, - user_id=user_id, - author="アシスタント", - content={"text": "こんにちは 🌍"}, - ) + assert created_session["state"]["message"] == "Hello \u4e16\u754c" + + from datetime import datetime, timezone + + from sqlspec.extensions.adk import EventRecord + + event_record: EventRecord = { + "session_id": session_id, + "invocation_id": "", + "author": "\u30a2\u30b7\u30b9\u30bf\u30f3\u30c8", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "unicode-event", + "content": {"text": "\u3053\u3093\u306b\u3061\u306f"}, + "app_name": app_name, + "user_id": user_id, + }, + } + await adbc_store.append_event(event_record) - assert event["author"] == "アシスタント" - assert event["content"]["text"] == "こんにちは 🌍" + events = await adbc_store.get_events(session_id) + assert len(events) == 1 + assert events[0]["author"] == "\u30a2\u30b7\u30b9\u30bf\u30f3\u30c8" + event_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + ) + assert event_data["content"]["text"] == "\u3053\u3093\u306b\u3061\u306f" -def test_special_characters_in_json(adbc_store: Any) -> None: +async def test_special_characters_in_json(adbc_store: Any) -> None: """Test special characters in JSON fields.""" session_id = "special-chars" state = { @@ -131,115 +145,105 @@ def test_special_characters_in_json(adbc_store: Any) -> None: "tab": "Col1\tCol2", } - adbc_store.create_session(session_id, "app", "user", state) - retrieved = adbc_store.get_session(session_id) + await adbc_store.create_session(session_id, "app", "user", state) + retrieved = await adbc_store.get_session(session_id) assert retrieved is not None assert retrieved["state"] == state -def test_very_long_strings(adbc_store: Any) -> None: +async def test_very_long_strings(adbc_store: Any) -> None: """Test handling very long strings in VARCHAR fields.""" long_id = "x" * 127 long_app = "a" * 127 long_user = "u" * 127 - session = adbc_store.create_session(long_id, long_app, long_user, {}) + session = await adbc_store.create_session(long_id, long_app, long_user, {}) assert session["id"] == long_id assert session["app_name"] == long_app assert session["user_id"] == long_user -def test_session_state_with_deeply_nested_data(adbc_store: Any) -> None: +async def test_session_state_with_deeply_nested_data(adbc_store: Any) -> None: """Test deeply nested JSON structures.""" session_id = "deep-nest" deeply_nested = {"level1": {"level2": {"level3": {"level4": {"level5": {"value": "deep"}}}}}} - adbc_store.create_session(session_id, "app", "user", deeply_nested) - retrieved = adbc_store.get_session(session_id) + await adbc_store.create_session(session_id, "app", "user", deeply_nested) + retrieved = await adbc_store.get_session(session_id) assert retrieved is not None assert retrieved["state"]["level1"]["level2"]["level3"]["level4"]["level5"]["value"] == "deep" -def test_concurrent_session_updates(adbc_store: Any) -> None: +async def test_concurrent_session_updates(adbc_store: Any) -> None: """Test multiple updates to the same session.""" session_id = "concurrent-test" - adbc_store.create_session(session_id, "app", "user", {"version": 1}) + await adbc_store.create_session(session_id, "app", "user", {"version": 1}) for i in range(10): - adbc_store.update_session_state(session_id, {"version": i + 2}) + await adbc_store.update_session_state(session_id, {"version": i + 2}) - final_session = adbc_store.get_session(session_id) + final_session = await adbc_store.get_session(session_id) assert final_session is not None assert final_session["state"]["version"] == 11 -def test_event_with_none_values(adbc_store: Any) -> None: - """Test creating event with explicit None values.""" +async def test_event_with_none_values(adbc_store: Any) -> None: + """Test creating event with explicit None values for optional fields.""" session_id = "none-test" - adbc_store.create_session(session_id, "app", "user", {}) - - event = adbc_store.create_event( - event_id="none-event", - session_id=session_id, - app_name="app", - user_id="user", - invocation_id=None, - author=None, - actions=None, - content=None, - grounding_metadata=None, - custom_metadata=None, - partial=None, - turn_complete=None, - interrupted=None, - error_code=None, - error_message=None, - ) + await adbc_store.create_session(session_id, "app", "user", {}) + + from datetime import datetime, timezone + + from sqlspec.extensions.adk import EventRecord + + event_record: EventRecord = { + "session_id": session_id, + "invocation_id": "", + "author": "", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "none-event", "app_name": "app", "user_id": "user"}, + } + await adbc_store.append_event(event_record) - assert event["invocation_id"] is None - assert event["author"] is None - assert event["actions"] == b"" - assert event["content"] is None - assert event["grounding_metadata"] is None - assert event["custom_metadata"] is None - assert event["partial"] is None - assert event["turn_complete"] is None - assert event["interrupted"] is None + events = await adbc_store.get_events(session_id) + assert len(events) == 1 + assert events[0]["session_id"] == session_id + assert "event_json" in events[0] -def test_list_sessions_with_same_user_different_apps(adbc_store: Any) -> None: +async def test_list_sessions_with_same_user_different_apps(adbc_store: Any) -> None: """Test listing sessions doesn't mix data across apps.""" user_id = "user-123" app1 = "app1" app2 = "app2" - adbc_store.create_session("s1", app1, user_id, {}) - adbc_store.create_session("s2", app1, user_id, {}) - adbc_store.create_session("s3", app2, user_id, {}) + await adbc_store.create_session("s1", app1, user_id, {}) + await adbc_store.create_session("s2", app1, user_id, {}) + await adbc_store.create_session("s3", app2, user_id, {}) - app1_sessions = adbc_store.list_sessions(app1, user_id) - app2_sessions = adbc_store.list_sessions(app2, user_id) + app1_sessions = await adbc_store.list_sessions(app1, user_id) + app2_sessions = await adbc_store.list_sessions(app2, user_id) assert len(app1_sessions) == 2 assert len(app2_sessions) == 1 -def test_delete_nonexistent_session(adbc_store: Any) -> None: +async def test_delete_nonexistent_session(adbc_store: Any) -> None: """Test deleting a session that doesn't exist.""" - adbc_store.delete_session("nonexistent-session") + await adbc_store.delete_session("nonexistent-session") -def test_update_nonexistent_session(adbc_store: Any) -> None: +async def test_update_nonexistent_session(adbc_store: Any) -> None: """Test updating a session that doesn't exist.""" - adbc_store.update_session_state("nonexistent-session", {"data": "test"}) + await adbc_store.update_session_state("nonexistent-session", {"data": "test"}) -def test_drop_and_recreate_tables(adbc_store: Any) -> None: +async def test_drop_and_recreate_tables(adbc_store: Any) -> None: """Test dropping and recreating tables.""" session_id = "test-session" - adbc_store.create_session(session_id, "app", "user", {"data": "test"}) + await adbc_store.create_session(session_id, "app", "user", {"data": "test"}) drop_sqls = adbc_store._get_drop_tables_sql() with adbc_store._config.provide_connection() as conn: @@ -251,19 +255,19 @@ def test_drop_and_recreate_tables(adbc_store: Any) -> None: finally: cursor.close() - adbc_store.create_tables() + await adbc_store.create_tables() - session = adbc_store.get_session(session_id) + session = await adbc_store.get_session(session_id) assert session is None -def test_json_with_escaped_characters(adbc_store: Any) -> None: +async def test_json_with_escaped_characters(adbc_store: Any) -> None: """Test JSON serialization of escaped characters.""" session_id = "escaped-json" state = {"escaped": r"test\nvalue\t", "quotes": r'"quoted"'} - adbc_store.create_session(session_id, "app", "user", state) - retrieved = adbc_store.get_session(session_id) + await adbc_store.create_session(session_id, "app", "user", state) + retrieved = await adbc_store.get_session(session_id) assert retrieved is not None assert retrieved["state"] == state diff --git a/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py b/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py index e18cd1496..8e54bf766 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py @@ -1,6 +1,8 @@ """Tests for ADBC ADK store event operations.""" -from datetime import datetime, timezone +import asyncio +import json +from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Any @@ -8,313 +10,387 @@ from sqlspec.adapters.adbc import AdbcConfig from sqlspec.adapters.adbc.adk import AdbcADKStore +from sqlspec.extensions.adk import EventRecord pytestmark = [pytest.mark.xdist_group("sqlite"), pytest.mark.adbc, pytest.mark.integration] @pytest.fixture() -def adbc_store(tmp_path: Path) -> AdbcADKStore: +async def adbc_store(tmp_path: Path) -> AdbcADKStore: """Create ADBC ADK store with SQLite backend.""" db_path = tmp_path / "test_adk.db" config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) store = AdbcADKStore(config) - store.create_tables() + await store.create_tables() return store @pytest.fixture() -def session_fixture(adbc_store: Any) -> dict[str, str]: +async def session_fixture(adbc_store: Any) -> dict[str, str]: """Create a test session.""" session_id = "test-session" app_name = "test-app" user_id = "user-123" state = {"test": True} - adbc_store.create_session(session_id, app_name, user_id, state) + await adbc_store.create_session(session_id, app_name, user_id, state) return {"session_id": session_id, "app_name": app_name, "user_id": user_id} -def test_create_event(adbc_store: Any, session_fixture: Any) -> None: - """Test creating a new event.""" - event_id = "event-1" - event = adbc_store.create_event( - event_id=event_id, - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - author="user", - actions=b"serialized_actions", - content={"message": "Hello"}, +async def test_create_event(adbc_store: Any, session_fixture: Any) -> None: + """Test creating a new event returns 5-key EventRecord.""" + event_record: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "user", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "event-1", + "content": {"message": "Hello"}, + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + }, + } + await adbc_store.append_event(event_record) + + events = await adbc_store.get_events(session_fixture["session_id"]) + assert len(events) == 1 + assert events[0]["session_id"] == session_fixture["session_id"] + assert events[0]["author"] == "user" + assert events[0]["timestamp"] is not None + assert "event_json" in events[0] + + # Content is stored inside event_json + event_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] ) + assert event_data["content"] == {"message": "Hello"} - assert event["id"] == event_id - assert event["session_id"] == session_fixture["session_id"] - assert event["author"] == "user" - assert event["actions"] == b"serialized_actions" - assert event["content"] == {"message": "Hello"} - assert event["timestamp"] is not None - -def test_list_events(adbc_store: Any, session_fixture: Any) -> None: +async def test_list_events(adbc_store: Any, session_fixture: Any) -> None: """Test listing events for a session.""" - adbc_store.create_event( - event_id="event-1", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - author="user", - content={"seq": 1}, - ) - adbc_store.create_event( - event_id="event-2", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - author="assistant", - content={"seq": 2}, - ) - - events = adbc_store.list_events(session_fixture["session_id"]) + event1: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "user", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "event-1", + "content": {"seq": 1}, + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + }, + } + event2: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "assistant", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "event-2", + "content": {"seq": 2}, + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + }, + } + await adbc_store.append_event(event1) + await adbc_store.append_event(event2) + + events = await adbc_store.get_events(session_fixture["session_id"]) assert len(events) == 2 - assert events[0]["id"] == "event-1" - assert events[1]["id"] == "event-2" + assert events[0]["author"] == "user" + assert events[1]["author"] == "assistant" -def test_list_events_empty(adbc_store: Any, session_fixture: Any) -> None: +async def test_list_events_empty(adbc_store: Any, session_fixture: Any) -> None: """Test listing events when none exist.""" - events = adbc_store.list_events(session_fixture["session_id"]) + events = await adbc_store.get_events(session_fixture["session_id"]) assert events == [] -def test_event_with_all_fields(adbc_store: Any, session_fixture: Any) -> None: - """Test creating event with all optional fields.""" - timestamp = datetime.now(timezone.utc) - event = adbc_store.create_event( - event_id="full-event", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - invocation_id="invocation-123", - author="assistant", - actions=b"complex_action_data", - long_running_tool_ids_json='["tool1", "tool2"]', - branch="main", - timestamp=timestamp, - content={"text": "Response"}, - grounding_metadata={"sources": ["doc1", "doc2"]}, - custom_metadata={"custom": "data"}, - partial=True, - turn_complete=False, - interrupted=False, - error_code="NONE", - error_message="No errors", +async def test_event_with_all_fields(adbc_store: Any, session_fixture: Any) -> None: + """Test creating event with all optional fields stored in event_json.""" + event_record: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "invocation-123", + "author": "assistant", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "full-event", + "content": {"text": "Response"}, + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + "branch": "main", + "grounding_metadata": {"sources": ["doc1", "doc2"]}, + "custom_metadata": {"custom": "data"}, + "partial": True, + "turn_complete": False, + "interrupted": False, + "error_code": "NONE", + "error_message": "No errors", + }, + } + await adbc_store.append_event(event_record) + + events = await adbc_store.get_events(session_fixture["session_id"]) + assert len(events) == 1 + + # Top-level indexed columns + assert events[0]["invocation_id"] == "invocation-123" + assert events[0]["author"] == "assistant" + + # Everything else is in event_json + event_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] ) - - assert event["invocation_id"] == "invocation-123" - assert event["author"] == "assistant" - assert event["actions"] == b"complex_action_data" - assert event["long_running_tool_ids_json"] == '["tool1", "tool2"]' - assert event["branch"] == "main" - assert event["content"] == {"text": "Response"} - assert event["grounding_metadata"] == {"sources": ["doc1", "doc2"]} - assert event["custom_metadata"] == {"custom": "data"} - assert event["partial"] is True - assert event["turn_complete"] is False - assert event["interrupted"] is False - assert event["error_code"] == "NONE" - assert event["error_message"] == "No errors" - - -def test_event_with_minimal_fields(adbc_store: Any, session_fixture: Any) -> None: + assert event_data["content"] == {"text": "Response"} + assert event_data["branch"] == "main" + assert event_data["grounding_metadata"] == {"sources": ["doc1", "doc2"]} + assert event_data["custom_metadata"] == {"custom": "data"} + assert event_data["partial"] is True + assert event_data["turn_complete"] is False + assert event_data["interrupted"] is False + assert event_data["error_code"] == "NONE" + assert event_data["error_message"] == "No errors" + + +async def test_event_with_minimal_fields(adbc_store: Any, session_fixture: Any) -> None: """Test creating event with only required fields.""" - event = adbc_store.create_event( - event_id="minimal-event", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - ) - - assert event["id"] == "minimal-event" - assert event["session_id"] == session_fixture["session_id"] - assert event["app_name"] == session_fixture["app_name"] - assert event["user_id"] == session_fixture["user_id"] - assert event["author"] is None - assert event["actions"] == b"" - assert event["content"] is None - - -def test_event_boolean_fields(adbc_store: Any, session_fixture: Any) -> None: - """Test event boolean field conversion.""" - event_true = adbc_store.create_event( - event_id="event-true", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - partial=True, - turn_complete=True, - interrupted=True, - ) - - assert event_true["partial"] is True - assert event_true["turn_complete"] is True - assert event_true["interrupted"] is True - - event_false = adbc_store.create_event( - event_id="event-false", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - partial=False, - turn_complete=False, - interrupted=False, - ) - - assert event_false["partial"] is False - assert event_false["turn_complete"] is False - assert event_false["interrupted"] is False - - event_none = adbc_store.create_event( - event_id="event-none", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - ) - - assert event_none["partial"] is None - assert event_none["turn_complete"] is None - assert event_none["interrupted"] is None - - -def test_event_json_fields(adbc_store: Any, session_fixture: Any) -> None: - """Test event JSON field serialization and deserialization.""" + event_record: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "minimal-event", + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + }, + } + await adbc_store.append_event(event_record) + + events = await adbc_store.get_events(session_fixture["session_id"]) + assert len(events) == 1 + assert events[0]["session_id"] == session_fixture["session_id"] + assert "event_json" in events[0] + + +async def test_event_json_fields(adbc_store: Any, session_fixture: Any) -> None: + """Test event JSON field serialization and deserialization via event_json.""" complex_content = {"nested": {"data": "value"}, "list": [1, 2, 3], "null": None} complex_grounding = {"sources": [{"title": "Doc", "url": "http://example.com"}]} complex_custom = {"metadata": {"version": 1, "tags": ["tag1", "tag2"]}} - event = adbc_store.create_event( - event_id="json-event", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - content=complex_content, - grounding_metadata=complex_grounding, - custom_metadata=complex_custom, + event_record: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "json-event", + "content": complex_content, + "grounding_metadata": complex_grounding, + "custom_metadata": complex_custom, + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + }, + } + await adbc_store.append_event(event_record) + + events = await adbc_store.get_events(session_fixture["session_id"]) + assert len(events) == 1 + event_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] ) + assert event_data["content"] == complex_content + assert event_data["grounding_metadata"] == complex_grounding + assert event_data["custom_metadata"] == complex_custom - assert event["content"] == complex_content - assert event["grounding_metadata"] == complex_grounding - assert event["custom_metadata"] == complex_custom - - events = adbc_store.list_events(session_fixture["session_id"]) - retrieved = events[0] - - assert retrieved["content"] == complex_content - assert retrieved["grounding_metadata"] == complex_grounding - assert retrieved["custom_metadata"] == complex_custom - -def test_event_ordering(adbc_store: Any, session_fixture: Any) -> None: +async def test_event_ordering(adbc_store: Any, session_fixture: Any) -> None: """Test that events are ordered by timestamp ASC.""" - import time - - adbc_store.create_event( - event_id="event-1", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - ) - - time.sleep(0.01) - - adbc_store.create_event( - event_id="event-2", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - ) - - time.sleep(0.01) - - adbc_store.create_event( - event_id="event-3", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - ) - - events = adbc_store.list_events(session_fixture["session_id"]) + ev1: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "event-1", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, + } + await adbc_store.append_event(ev1) + + await asyncio.sleep(0.01) + + ev2: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "event-2", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, + } + await adbc_store.append_event(ev2) + + await asyncio.sleep(0.01) + + ev3: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "event-3", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, + } + await adbc_store.append_event(ev3) + + events = await adbc_store.get_events(session_fixture["session_id"]) assert len(events) == 3 - assert events[0]["id"] == "event-1" - assert events[1]["id"] == "event-2" - assert events[2]["id"] == "event-3" assert events[0]["timestamp"] < events[1]["timestamp"] assert events[1]["timestamp"] < events[2]["timestamp"] -def test_delete_session_cascades_events(adbc_store: Any, session_fixture: Any, tmp_path: Path) -> None: +async def test_delete_session_cascades_events(adbc_store: Any, session_fixture: Any, tmp_path: Path) -> None: """Test that deleting a session cascades to delete events. Note: SQLite with ADBC requires foreign key enforcement to be explicitly enabled for cascade deletes to work. This test manually enables it. """ - adbc_store.create_event( - event_id="event-1", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - ) - adbc_store.create_event( - event_id="event-2", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - ) - - events_before = adbc_store.list_events(session_fixture["session_id"]) + ev1: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "event-1", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, + } + ev2: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "event-2", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, + } + await adbc_store.append_event(ev1) + await adbc_store.append_event(ev2) + + events_before = await adbc_store.get_events(session_fixture["session_id"]) assert len(events_before) == 2 - # For SQLite with separate connections per operation, we need to manually delete events - # or note that cascade deletes require persistent connections - # For this test, just verify the session deletion works - adbc_store.delete_session(session_fixture["session_id"]) + await adbc_store.delete_session(session_fixture["session_id"]) - # Session should be gone - session_after = adbc_store.get_session(session_fixture["session_id"]) + session_after = await adbc_store.get_session(session_fixture["session_id"]) assert session_after is None - # Events may still exist with ADBC SQLite due to FK enforcement across connections - # This is a known limitation when using ADBC with SQLite in-memory or file-based - # with separate connections per operation - -def test_event_with_empty_actions(adbc_store: Any, session_fixture: Any) -> None: +async def test_event_with_empty_actions(adbc_store: Any, session_fixture: Any) -> None: """Test creating event with empty actions bytes.""" - event = adbc_store.create_event( - event_id="empty-actions", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - actions=b"", + event_record: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "empty-actions", + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + }, + } + await adbc_store.append_event(event_record) + + events = await adbc_store.get_events(session_fixture["session_id"]) + assert len(events) == 1 + assert "event_json" in events[0] + + +async def test_event_with_large_content(adbc_store: Any, session_fixture: Any) -> None: + """Test creating event with large content in event_json.""" + large_content = {"data": "x" * 10000} + + event_record: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "large-content", + "content": large_content, + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + }, + } + await adbc_store.append_event(event_record) + + events = await adbc_store.get_events(session_fixture["session_id"]) + assert len(events) == 1 + event_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] ) - - assert event["actions"] == b"" - - events = adbc_store.list_events(session_fixture["session_id"]) - assert events[0]["actions"] == b"" - - -def test_event_with_large_actions(adbc_store: Any, session_fixture: Any) -> None: - """Test creating event with large actions BLOB.""" - large_actions = b"x" * 10000 - - event = adbc_store.create_event( - event_id="large-actions", - session_id=session_fixture["session_id"], - app_name=session_fixture["app_name"], - user_id=session_fixture["user_id"], - actions=large_actions, - ) - - assert event["actions"] == large_actions - assert len(event["actions"]) == 10000 + assert event_data["content"] == large_content + + +async def test_append_event_preserves_existing_session_state(adbc_store: Any, session_fixture: Any) -> None: + """append_event must not overwrite the durable session state.""" + event_record: EventRecord = { + "session_id": session_fixture["session_id"], + "invocation_id": "append-only", + "author": "user", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "append-only-event", + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + }, + } + + await adbc_store.append_event(event_record) + + session = await adbc_store.get_session(session_fixture["session_id"]) + assert session is not None + assert session["state"] == {"test": True} + + +async def test_get_events_applies_after_timestamp_and_limit(adbc_store: Any, session_fixture: Any) -> None: + """get_events must respect both after_timestamp and limit.""" + base_time = datetime(2026, 1, 1, tzinfo=timezone.utc) + event_records = [ + { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "user", + "timestamp": base_time, + "event_json": { + "id": "event-1", + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + }, + }, + { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "assistant", + "timestamp": base_time + timedelta(seconds=1), + "event_json": { + "id": "event-2", + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + }, + }, + { + "session_id": session_fixture["session_id"], + "invocation_id": "", + "author": "assistant", + "timestamp": base_time + timedelta(seconds=2), + "event_json": { + "id": "event-3", + "app_name": session_fixture["app_name"], + "user_id": session_fixture["user_id"], + }, + }, + ] + + for event_record in event_records: + await adbc_store.append_event(event_record) + + filtered_events = await adbc_store.get_events(session_fixture["session_id"], after_timestamp=base_time, limit=1) + + assert len(filtered_events) == 1 + filtered_event = filtered_events[0]["event_json"] + filtered_data = json.loads(filtered_event) if isinstance(filtered_event, str) else filtered_event + assert filtered_data["id"] == "event-2" diff --git a/tests/integration/adapters/adbc/extensions/adk/test_memory_store.py b/tests/integration/adapters/adbc/extensions/adk/test_memory_store.py index a417cd84f..0c3afa838 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_memory_store.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_memory_store.py @@ -30,63 +30,63 @@ def _build_record(*, session_id: str, event_id: str, content_text: str, inserted ) -def _build_store(tmp_path: Path) -> AdbcADKMemoryStore: +async def _build_store(tmp_path: Path) -> AdbcADKMemoryStore: db_path = tmp_path / "test_adk_memory.db" config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) store = AdbcADKMemoryStore(config) - store.create_tables() + await store.create_tables() return store -def test_adbc_memory_store_insert_search_dedup(tmp_path: Path) -> None: +async def test_adbc_memory_store_insert_search_dedup(tmp_path: Path) -> None: """Insert memory entries, search by text, and skip duplicates.""" - store = _build_store(tmp_path) + store = await _build_store(tmp_path) now = datetime.now(timezone.utc) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now) record2 = _build_record(session_id="s1", event_id="evt-2", content_text="latte", inserted_at=now) - inserted = store.insert_memory_entries([record1, record2]) + inserted = await store.insert_memory_entries([record1, record2]) assert inserted == 2 - results = store.search_entries(query="espresso", app_name="app", user_id="user") + results = await store.search_entries(query="espresso", app_name="app", user_id="user") assert len(results) == 1 assert results[0]["event_id"] == "evt-1" - deduped = store.insert_memory_entries([record1]) + deduped = await store.insert_memory_entries([record1]) assert deduped == 0 -def test_adbc_memory_store_delete_by_session(tmp_path: Path) -> None: +async def test_adbc_memory_store_delete_by_session(tmp_path: Path) -> None: """Delete memory entries by session id.""" - store = _build_store(tmp_path) + store = await _build_store(tmp_path) now = datetime.now(timezone.utc) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now) record2 = _build_record(session_id="s2", event_id="evt-2", content_text="latte", inserted_at=now) - store.insert_memory_entries([record1, record2]) + await store.insert_memory_entries([record1, record2]) - deleted = store.delete_entries_by_session("s1") + deleted = await store.delete_entries_by_session("s1") assert deleted == 1 - remaining = store.search_entries(query="latte", app_name="app", user_id="user") + remaining = await store.search_entries(query="latte", app_name="app", user_id="user") assert len(remaining) == 1 assert remaining[0]["session_id"] == "s2" -def test_adbc_memory_store_delete_older_than(tmp_path: Path) -> None: +async def test_adbc_memory_store_delete_older_than(tmp_path: Path) -> None: """Delete memory entries older than a cutoff.""" - store = _build_store(tmp_path) + store = await _build_store(tmp_path) now = datetime.now(timezone.utc) old = now - timedelta(days=40) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="old", inserted_at=old) record2 = _build_record(session_id="s1", event_id="evt-2", content_text="new", inserted_at=now) - store.insert_memory_entries([record1, record2]) + await store.insert_memory_entries([record1, record2]) - deleted = store.delete_entries_older_than(30) + deleted = await store.delete_entries_older_than(30) assert deleted == 1 - remaining = store.search_entries(query="new", app_name="app", user_id="user") + remaining = await store.search_entries(query="new", app_name="app", user_id="user") assert len(remaining) == 1 assert remaining[0]["event_id"] == "evt-2" diff --git a/tests/integration/adapters/adbc/extensions/adk/test_owner_id_column.py b/tests/integration/adapters/adbc/extensions/adk/test_owner_id_column.py index ce2a1bbfa..1b1752ae4 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_owner_id_column.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_owner_id_column.py @@ -9,7 +9,7 @@ @pytest.fixture() -def adbc_store_with_fk(tmp_path): # type: ignore[no-untyped-def] +async def adbc_store_with_fk(tmp_path): # type: ignore[no-untyped-def] """Create ADBC ADK store with owner ID column (SQLite).""" db_path = tmp_path / "test_fk.db" config = AdbcConfig( @@ -29,21 +29,21 @@ def adbc_store_with_fk(tmp_path): # type: ignore[no-untyped-def] finally: cursor.close() # type: ignore[no-untyped-call] - store.create_tables() + await store.create_tables() return store @pytest.fixture() -def adbc_store_no_fk(tmp_path): # type: ignore[no-untyped-def] +async def adbc_store_no_fk(tmp_path): # type: ignore[no-untyped-def] """Create ADBC ADK store without owner ID column (SQLite).""" db_path = tmp_path / "test_no_fk.db" config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) store = AdbcADKStore(config) - store.create_tables() + await store.create_tables() return store -def test_create_session_with_owner_id(adbc_store_with_fk): # type: ignore[no-untyped-def] +async def test_create_session_with_owner_id(adbc_store_with_fk): # type: ignore[no-untyped-def] """Test creating session with owner ID value.""" session_id = "test-session-1" app_name = "test-app" @@ -51,32 +51,32 @@ def test_create_session_with_owner_id(adbc_store_with_fk): # type: ignore[no-un state = {"key": "value"} tenant_id = 1 - session = adbc_store_with_fk.create_session(session_id, app_name, user_id, state, owner_id=tenant_id) + session = await adbc_store_with_fk.create_session(session_id, app_name, user_id, state, owner_id=tenant_id) assert session["id"] == session_id assert session["state"] == state -def test_create_session_without_owner_id_value(adbc_store_with_fk): # type: ignore[no-untyped-def] +async def test_create_session_without_owner_id_value(adbc_store_with_fk): # type: ignore[no-untyped-def] """Test creating session without providing owner ID value still works.""" session_id = "test-session-2" app_name = "test-app" user_id = "user-123" state = {"key": "value"} - session = adbc_store_with_fk.create_session(session_id, app_name, user_id, state) + session = await adbc_store_with_fk.create_session(session_id, app_name, user_id, state) assert session["id"] == session_id -def test_create_session_no_fk_column_configured(adbc_store_no_fk): # type: ignore[no-untyped-def] +async def test_create_session_no_fk_column_configured(adbc_store_no_fk): # type: ignore[no-untyped-def] """Test creating session when no FK column configured.""" session_id = "test-session-3" app_name = "test-app" user_id = "user-123" state = {"key": "value"} - session = adbc_store_no_fk.create_session(session_id, app_name, user_id, state) + session = await adbc_store_no_fk.create_session(session_id, app_name, user_id, state) assert session["id"] == session_id assert session["state"] == state @@ -109,16 +109,16 @@ def test_owner_id_column_complex_ddl() -> None: assert store._owner_id_column_ddl == complex_ddl # pyright: ignore[reportPrivateUsage] -def test_multiple_tenants_isolation(adbc_store_with_fk): # type: ignore[no-untyped-def] +async def test_multiple_tenants_isolation(adbc_store_with_fk): # type: ignore[no-untyped-def] """Test sessions are properly isolated by tenant.""" app_name = "test-app" user_id = "user-123" - adbc_store_with_fk.create_session("session-tenant1", app_name, user_id, {"data": "tenant1"}, owner_id=1) - adbc_store_with_fk.create_session("session-tenant2", app_name, user_id, {"data": "tenant2"}, owner_id=2) + await adbc_store_with_fk.create_session("session-tenant1", app_name, user_id, {"data": "tenant1"}, owner_id=1) + await adbc_store_with_fk.create_session("session-tenant2", app_name, user_id, {"data": "tenant2"}, owner_id=2) - retrieved1 = adbc_store_with_fk.get_session("session-tenant1") - retrieved2 = adbc_store_with_fk.get_session("session-tenant2") + retrieved1 = await adbc_store_with_fk.get_session("session-tenant1") + retrieved2 = await adbc_store_with_fk.get_session("session-tenant2") assert retrieved1["state"]["data"] == "tenant1" assert retrieved2["state"]["data"] == "tenant2" diff --git a/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py b/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py index 819002edc..b749461d7 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py @@ -12,23 +12,23 @@ @pytest.fixture() -def adbc_store(tmp_path: Path) -> AdbcADKStore: +async def adbc_store(tmp_path: Path) -> AdbcADKStore: """Create ADBC ADK store with SQLite backend.""" db_path = tmp_path / "test_adk.db" config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) store = AdbcADKStore(config) - store.create_tables() + await store.create_tables() return store -def test_create_session(adbc_store: Any) -> None: +async def test_create_session(adbc_store: Any) -> None: """Test creating a new session.""" session_id = "test-session-1" app_name = "test-app" user_id = "user-123" state = {"key": "value", "count": 42} - session = adbc_store.create_session(session_id, app_name, user_id, state) + session = await adbc_store.create_session(session_id, app_name, user_id, state) assert session["id"] == session_id assert session["app_name"] == app_name @@ -38,82 +38,82 @@ def test_create_session(adbc_store: Any) -> None: assert session["update_time"] is not None -def test_get_session(adbc_store: Any) -> None: +async def test_get_session(adbc_store: Any) -> None: """Test retrieving a session by ID.""" session_id = "test-session-2" app_name = "test-app" user_id = "user-123" state = {"data": "test"} - adbc_store.create_session(session_id, app_name, user_id, state) - retrieved = adbc_store.get_session(session_id) + await adbc_store.create_session(session_id, app_name, user_id, state) + retrieved = await adbc_store.get_session(session_id) assert retrieved is not None assert retrieved["id"] == session_id assert retrieved["state"] == state -def test_get_nonexistent_session(adbc_store: Any) -> None: +async def test_get_nonexistent_session(adbc_store: Any) -> None: """Test retrieving a session that doesn't exist.""" - result = adbc_store.get_session("nonexistent-id") + result = await adbc_store.get_session("nonexistent-id") assert result is None -def test_update_session_state(adbc_store: Any) -> None: +async def test_update_session_state(adbc_store: Any) -> None: """Test updating session state.""" session_id = "test-session-3" app_name = "test-app" user_id = "user-123" initial_state = {"version": 1} - adbc_store.create_session(session_id, app_name, user_id, initial_state) + await adbc_store.create_session(session_id, app_name, user_id, initial_state) new_state = {"version": 2, "updated": True} - adbc_store.update_session_state(session_id, new_state) + await adbc_store.update_session_state(session_id, new_state) - updated = adbc_store.get_session(session_id) + updated = await adbc_store.get_session(session_id) assert updated is not None assert updated["state"] == new_state assert updated["state"] != initial_state -def test_delete_session(adbc_store: Any) -> None: +async def test_delete_session(adbc_store: Any) -> None: """Test deleting a session.""" session_id = "test-session-4" app_name = "test-app" user_id = "user-123" state = {"data": "test"} - adbc_store.create_session(session_id, app_name, user_id, state) - assert adbc_store.get_session(session_id) is not None + await adbc_store.create_session(session_id, app_name, user_id, state) + assert await adbc_store.get_session(session_id) is not None - adbc_store.delete_session(session_id) - assert adbc_store.get_session(session_id) is None + await adbc_store.delete_session(session_id) + assert await adbc_store.get_session(session_id) is None -def test_list_sessions(adbc_store: Any) -> None: +async def test_list_sessions(adbc_store: Any) -> None: """Test listing sessions for an app and user.""" app_name = "test-app" user_id = "user-123" - adbc_store.create_session("session-1", app_name, user_id, {"num": 1}) - adbc_store.create_session("session-2", app_name, user_id, {"num": 2}) - adbc_store.create_session("session-3", "other-app", user_id, {"num": 3}) + await adbc_store.create_session("session-1", app_name, user_id, {"num": 1}) + await adbc_store.create_session("session-2", app_name, user_id, {"num": 2}) + await adbc_store.create_session("session-3", "other-app", user_id, {"num": 3}) - sessions = adbc_store.list_sessions(app_name, user_id) + sessions = await adbc_store.list_sessions(app_name, user_id) assert len(sessions) == 2 session_ids = {s["id"] for s in sessions} assert session_ids == {"session-1", "session-2"} -def test_list_sessions_empty(adbc_store: Any) -> None: +async def test_list_sessions_empty(adbc_store: Any) -> None: """Test listing sessions when none exist.""" - sessions = adbc_store.list_sessions("nonexistent-app", "nonexistent-user") + sessions = await adbc_store.list_sessions("nonexistent-app", "nonexistent-user") assert sessions == [] -def test_session_state_with_complex_data(adbc_store: Any) -> None: +async def test_session_state_with_complex_data(adbc_store: Any) -> None: """Test session state with nested complex data structures.""" session_id = "complex-session" app_name = "test-app" @@ -125,41 +125,41 @@ def test_session_state_with_complex_data(adbc_store: Any) -> None: "null_value": None, } - session = adbc_store.create_session(session_id, app_name, user_id, complex_state) + session = await adbc_store.create_session(session_id, app_name, user_id, complex_state) assert session["state"] == complex_state - retrieved = adbc_store.get_session(session_id) + retrieved = await adbc_store.get_session(session_id) assert retrieved is not None assert retrieved["state"] == complex_state -def test_session_state_empty_dict(adbc_store: Any) -> None: +async def test_session_state_empty_dict(adbc_store: Any) -> None: """Test creating session with empty state dictionary.""" session_id = "empty-state-session" app_name = "test-app" user_id = "user-123" empty_state: dict[str, Any] = {} - session = adbc_store.create_session(session_id, app_name, user_id, empty_state) + session = await adbc_store.create_session(session_id, app_name, user_id, empty_state) assert session["state"] == empty_state - retrieved = adbc_store.get_session(session_id) + retrieved = await adbc_store.get_session(session_id) assert retrieved is not None assert retrieved["state"] == empty_state -def test_multiple_users_same_app(adbc_store: Any) -> None: +async def test_multiple_users_same_app(adbc_store: Any) -> None: """Test sessions for multiple users in the same app.""" app_name = "test-app" user1 = "user-1" user2 = "user-2" - adbc_store.create_session("session-user1-1", app_name, user1, {"user": 1}) - adbc_store.create_session("session-user1-2", app_name, user1, {"user": 1}) - adbc_store.create_session("session-user2-1", app_name, user2, {"user": 2}) + await adbc_store.create_session("session-user1-1", app_name, user1, {"user": 1}) + await adbc_store.create_session("session-user1-2", app_name, user1, {"user": 1}) + await adbc_store.create_session("session-user2-1", app_name, user2, {"user": 2}) - user1_sessions = adbc_store.list_sessions(app_name, user1) - user2_sessions = adbc_store.list_sessions(app_name, user2) + user1_sessions = await adbc_store.list_sessions(app_name, user1) + user2_sessions = await adbc_store.list_sessions(app_name, user2) assert len(user1_sessions) == 2 assert len(user2_sessions) == 1 @@ -167,18 +167,18 @@ def test_multiple_users_same_app(adbc_store: Any) -> None: assert all(s["user_id"] == user2 for s in user2_sessions) -def test_session_ordering(adbc_store: Any) -> None: +async def test_session_ordering(adbc_store: Any) -> None: """Test that sessions are ordered by update_time DESC.""" app_name = "test-app" user_id = "user-123" - adbc_store.create_session("session-1", app_name, user_id, {"order": 1}) - adbc_store.create_session("session-2", app_name, user_id, {"order": 2}) - adbc_store.create_session("session-3", app_name, user_id, {"order": 3}) + await adbc_store.create_session("session-1", app_name, user_id, {"order": 1}) + await adbc_store.create_session("session-2", app_name, user_id, {"order": 2}) + await adbc_store.create_session("session-3", app_name, user_id, {"order": 3}) - adbc_store.update_session_state("session-1", {"order": 1, "updated": True}) + await adbc_store.update_session_state("session-1", {"order": 1, "updated": True}) - sessions = adbc_store.list_sessions(app_name, user_id) + sessions = await adbc_store.list_sessions(app_name, user_id) assert len(sessions) == 3 assert sessions[0]["id"] == "session-1" diff --git a/tests/integration/adapters/asyncmy/extensions/adk/test_store.py b/tests/integration/adapters/asyncmy/extensions/adk/test_store.py index d53323565..bd52cade2 100644 --- a/tests/integration/adapters/asyncmy/extensions/adk/test_store.py +++ b/tests/integration/adapters/asyncmy/extensions/adk/test_store.py @@ -1,11 +1,12 @@ """Integration tests for AsyncMY ADK session store.""" -import pickle +import json from datetime import datetime, timezone import pytest from sqlspec.adapters.asyncmy.adk.store import AsyncmyADKStore +from sqlspec.extensions.adk import EventRecord pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.asyncmy, pytest.mark.integration] @@ -51,13 +52,14 @@ async def test_storage_types_verification(asyncmy_adk_store: AsyncmyADKStore) -> ORDER BY ORDINAL_POSITION """) event_columns = await cursor.fetchall() + event_col_names = [col[0] for col in event_columns] - actions_col = next(col for col in event_columns if col[0] == "actions") - assert actions_col[1] == "blob", "actions column must use BLOB type for pickled data" - - content_col = next((col for col in event_columns if col[0] == "content"), None) - if content_col: - assert content_col[1] == "json", "content column must use native JSON type" + # New 5-column schema: session_id, invocation_id, author, timestamp, event_json + assert "session_id" in event_col_names + assert "invocation_id" in event_col_names + assert "author" in event_col_names + assert "timestamp" in event_col_names + assert "event_json" in event_col_names timestamp_col = next(col for col in event_columns if col[0] == "timestamp") assert "timestamp(6)" in timestamp_col[2].lower(), "timestamp must be TIMESTAMP(6) for microseconds" @@ -141,18 +143,14 @@ async def test_delete_session_cascade(asyncmy_adk_store: AsyncmyADKStore) -> Non await asyncmy_adk_store.create_session(session_id, app_name, user_id, {"status": "active"}) - event_record = { - "id": "event-001", + event_record: EventRecord = { "session_id": session_id, - "app_name": app_name, - "user_id": user_id, "invocation_id": "inv-001", "author": "user", - "actions": pickle.dumps([{"type": "test_action"}]), "timestamp": datetime.now(timezone.utc), - "content": {"text": "Hello"}, + "event_json": {"content": {"text": "Hello"}, "app_name": app_name, "user_id": user_id}, } - await asyncmy_adk_store.append_event(event_record) # type: ignore[arg-type] + await asyncmy_adk_store.append_event(event_record) events_before = await asyncmy_adk_store.get_events(session_id) assert len(events_before) == 1 @@ -174,48 +172,39 @@ async def test_append_and_get_events(asyncmy_adk_store: AsyncmyADKStore) -> None await asyncmy_adk_store.create_session(session_id, app_name, user_id, {"status": "active"}) - event1 = { - "id": "event-001", + event1: EventRecord = { "session_id": session_id, - "app_name": app_name, - "user_id": user_id, "invocation_id": "inv-001", "author": "user", - "actions": pickle.dumps([{"type": "message", "content": "Hello"}]), "timestamp": datetime.now(timezone.utc), - "content": {"text": "Hello", "role": "user"}, - "partial": False, - "turn_complete": True, + "event_json": {"content": {"text": "Hello", "role": "user"}, "app_name": app_name}, } - event2 = { - "id": "event-002", + event2: EventRecord = { "session_id": session_id, - "app_name": app_name, - "user_id": user_id, "invocation_id": "inv-002", "author": "assistant", - "actions": pickle.dumps([{"type": "response", "content": "Hi there"}]), "timestamp": datetime.now(timezone.utc), - "content": {"text": "Hi there", "role": "assistant"}, - "partial": False, - "turn_complete": True, + "event_json": {"content": {"text": "Hi there", "role": "assistant"}, "app_name": app_name}, } - await asyncmy_adk_store.append_event(event1) # type: ignore[arg-type] - await asyncmy_adk_store.append_event(event2) # type: ignore[arg-type] + await asyncmy_adk_store.append_event(event1) + await asyncmy_adk_store.append_event(event2) events = await asyncmy_adk_store.get_events(session_id) assert len(events) == 2 - assert events[0]["id"] == "event-001" - assert events[1]["id"] == "event-002" - assert events[0]["content"] is not None - assert events[1]["content"] is not None - assert events[0]["content"]["text"] == "Hello" - assert events[1]["content"]["text"] == "Hi there" - assert isinstance(events[0]["actions"], bytes) - assert pickle.loads(events[0]["actions"])[0]["type"] == "message" + assert events[0]["author"] == "user" + assert events[1]["author"] == "assistant" + # Content is inside event_json + event0_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + ) + event1_data = ( + json.loads(events[1]["event_json"]) if isinstance(events[1]["event_json"], str) else events[1]["event_json"] + ) + assert event0_data["content"]["text"] == "Hello" + assert event1_data["content"]["text"] == "Hi there" async def test_timestamp_precision(asyncmy_adk_store: AsyncmyADKStore) -> None: @@ -230,17 +219,14 @@ async def test_timestamp_precision(asyncmy_adk_store: AsyncmyADKStore) -> None: assert hasattr(created["create_time"], "microsecond") event_time = datetime.now(timezone.utc) - event = { - "id": "event-micro", + event: EventRecord = { "session_id": session_id, - "app_name": app_name, - "user_id": user_id, "invocation_id": "inv-micro", "author": "system", - "actions": b"", "timestamp": event_time, + "event_json": {"app_name": app_name}, } - await asyncmy_adk_store.append_event(event) # type: ignore[arg-type] + await asyncmy_adk_store.append_event(event) events = await asyncmy_adk_store.get_events(session_id) assert len(events) == 1 diff --git a/tests/integration/adapters/duckdb/extensions/adk/test_memory_store.py b/tests/integration/adapters/duckdb/extensions/adk/test_memory_store.py index e3092176b..b86fef3ae 100644 --- a/tests/integration/adapters/duckdb/extensions/adk/test_memory_store.py +++ b/tests/integration/adapters/duckdb/extensions/adk/test_memory_store.py @@ -30,63 +30,63 @@ def _build_record(*, session_id: str, event_id: str, content_text: str, inserted ) -def _build_store(tmp_path: Path, worker_id: str) -> DuckdbADKMemoryStore: +async def _build_store(tmp_path: Path, worker_id: str) -> DuckdbADKMemoryStore: db_path = tmp_path / f"test_adk_memory_{worker_id}.duckdb" config = DuckDBConfig(connection_config={"database": str(db_path)}) store = DuckdbADKMemoryStore(config) - store.create_tables() + await store.create_tables() return store -def test_duckdb_memory_store_insert_search_dedup(tmp_path: Path, worker_id: str) -> None: +async def test_duckdb_memory_store_insert_search_dedup(tmp_path: Path, worker_id: str) -> None: """Insert memory entries, search by text, and skip duplicates.""" - store = _build_store(tmp_path, worker_id) + store = await _build_store(tmp_path, worker_id) now = datetime.now(timezone.utc) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now) record2 = _build_record(session_id="s1", event_id="evt-2", content_text="latte", inserted_at=now) - inserted = store.insert_memory_entries([record1, record2]) + inserted = await store.insert_memory_entries([record1, record2]) assert inserted == 2 - results = store.search_entries(query="espresso", app_name="app", user_id="user") + results = await store.search_entries(query="espresso", app_name="app", user_id="user") assert len(results) == 1 assert results[0]["event_id"] == "evt-1" - deduped = store.insert_memory_entries([record1]) + deduped = await store.insert_memory_entries([record1]) assert deduped == 0 -def test_duckdb_memory_store_delete_by_session(tmp_path: Path, worker_id: str) -> None: +async def test_duckdb_memory_store_delete_by_session(tmp_path: Path, worker_id: str) -> None: """Delete memory entries by session id.""" - store = _build_store(tmp_path, worker_id) + store = await _build_store(tmp_path, worker_id) now = datetime.now(timezone.utc) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now) record2 = _build_record(session_id="s2", event_id="evt-2", content_text="latte", inserted_at=now) - store.insert_memory_entries([record1, record2]) + await store.insert_memory_entries([record1, record2]) - deleted = store.delete_entries_by_session("s1") + deleted = await store.delete_entries_by_session("s1") assert deleted == 1 - remaining = store.search_entries(query="latte", app_name="app", user_id="user") + remaining = await store.search_entries(query="latte", app_name="app", user_id="user") assert len(remaining) == 1 assert remaining[0]["session_id"] == "s2" -def test_duckdb_memory_store_delete_older_than(tmp_path: Path, worker_id: str) -> None: +async def test_duckdb_memory_store_delete_older_than(tmp_path: Path, worker_id: str) -> None: """Delete memory entries older than a cutoff.""" - store = _build_store(tmp_path, worker_id) + store = await _build_store(tmp_path, worker_id) now = datetime.now(timezone.utc) old = now - timedelta(days=40) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="old", inserted_at=old) record2 = _build_record(session_id="s1", event_id="evt-2", content_text="new", inserted_at=now) - store.insert_memory_entries([record1, record2]) + await store.insert_memory_entries([record1, record2]) - deleted = store.delete_entries_older_than(30) + deleted = await store.delete_entries_older_than(30) assert deleted == 1 - remaining = store.search_entries(query="new", app_name="app", user_id="user") + remaining = await store.search_entries(query="new", app_name="app", user_id="user") assert len(remaining) == 1 assert remaining[0]["event_id"] == "evt-2" diff --git a/tests/integration/adapters/duckdb/extensions/adk/test_store.py b/tests/integration/adapters/duckdb/extensions/adk/test_store.py index a685e08b8..bc67ad007 100644 --- a/tests/integration/adapters/duckdb/extensions/adk/test_store.py +++ b/tests/integration/adapters/duckdb/extensions/adk/test_store.py @@ -1,6 +1,7 @@ """Integration tests for DuckDB ADK session store.""" -from collections.abc import Generator +import json +from collections.abc import AsyncGenerator from datetime import datetime, timezone from pathlib import Path @@ -8,12 +9,13 @@ from sqlspec.adapters.duckdb.adk.store import DuckdbADKStore from sqlspec.adapters.duckdb.config import DuckDBConfig +from sqlspec.extensions.adk import EventRecord pytestmark = [pytest.mark.duckdb, pytest.mark.integration] @pytest.fixture -def duckdb_adk_store(tmp_path: Path, worker_id: str) -> "Generator[DuckdbADKStore, None, None]": +async def duckdb_adk_store(tmp_path: Path, worker_id: str) -> "AsyncGenerator[DuckdbADKStore, None]": """Create DuckDB ADK store with temporary file-based database. Args: @@ -34,27 +36,27 @@ def duckdb_adk_store(tmp_path: Path, worker_id: str) -> "Generator[DuckdbADKStor extension_config={"adk": {"session_table": "test_sessions", "events_table": "test_events"}}, ) store = DuckdbADKStore(config) - store.create_tables() + await store.create_tables() yield store finally: if db_path.exists(): db_path.unlink() -def test_create_tables(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_create_tables(duckdb_adk_store: DuckdbADKStore) -> None: """Test table creation succeeds without errors.""" assert duckdb_adk_store.session_table == "test_sessions" assert duckdb_adk_store.events_table == "test_events" -def test_create_and_get_session(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_create_and_get_session(duckdb_adk_store: DuckdbADKStore) -> None: """Test creating and retrieving a session.""" session_id = "session-001" app_name = "test-app" user_id = "user-001" state = {"key": "value", "count": 42} - created_session = duckdb_adk_store.create_session( + created_session = await duckdb_adk_store.create_session( session_id=session_id, app_name=app_name, user_id=user_id, state=state ) @@ -65,49 +67,51 @@ def test_create_and_get_session(duckdb_adk_store: DuckdbADKStore) -> None: assert isinstance(created_session["create_time"], datetime) assert isinstance(created_session["update_time"], datetime) - retrieved_session = duckdb_adk_store.get_session(session_id) + retrieved_session = await duckdb_adk_store.get_session(session_id) assert retrieved_session is not None assert retrieved_session["id"] == session_id assert retrieved_session["state"] == state -def test_get_nonexistent_session(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_get_nonexistent_session(duckdb_adk_store: DuckdbADKStore) -> None: """Test getting a non-existent session returns None.""" - result = duckdb_adk_store.get_session("nonexistent-session") + result = await duckdb_adk_store.get_session("nonexistent-session") assert result is None -def test_update_session_state(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_update_session_state(duckdb_adk_store: DuckdbADKStore) -> None: """Test updating session state.""" session_id = "session-002" initial_state = {"status": "active"} updated_state = {"status": "completed", "result": "success"} - duckdb_adk_store.create_session(session_id=session_id, app_name="test-app", user_id="user-002", state=initial_state) + await duckdb_adk_store.create_session( + session_id=session_id, app_name="test-app", user_id="user-002", state=initial_state + ) - session_before = duckdb_adk_store.get_session(session_id) + session_before = await duckdb_adk_store.get_session(session_id) assert session_before is not None assert session_before["state"] == initial_state - duckdb_adk_store.update_session_state(session_id, updated_state) + await duckdb_adk_store.update_session_state(session_id, updated_state) - session_after = duckdb_adk_store.get_session(session_id) + session_after = await duckdb_adk_store.get_session(session_id) assert session_after is not None assert session_after["state"] == updated_state assert session_after["update_time"] >= session_before["update_time"] -def test_list_sessions(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_list_sessions(duckdb_adk_store: DuckdbADKStore) -> None: """Test listing sessions for an app and user.""" app_name = "test-app" user_id = "user-003" - duckdb_adk_store.create_session("session-1", app_name, user_id, {"num": 1}) - duckdb_adk_store.create_session("session-2", app_name, user_id, {"num": 2}) - duckdb_adk_store.create_session("session-3", app_name, user_id, {"num": 3}) - duckdb_adk_store.create_session("session-other", "other-app", user_id, {"num": 999}) + await duckdb_adk_store.create_session("session-1", app_name, user_id, {"num": 1}) + await duckdb_adk_store.create_session("session-2", app_name, user_id, {"num": 2}) + await duckdb_adk_store.create_session("session-3", app_name, user_id, {"num": 3}) + await duckdb_adk_store.create_session("session-other", "other-app", user_id, {"num": 999}) - sessions = duckdb_adk_store.list_sessions(app_name, user_id) + sessions = await duckdb_adk_store.list_sessions(app_name, user_id) assert len(sessions) == 3 session_ids = {s["id"] for s in sessions} @@ -116,192 +120,214 @@ def test_list_sessions(duckdb_adk_store: DuckdbADKStore) -> None: assert all(s["user_id"] == user_id for s in sessions) -def test_list_sessions_empty(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_list_sessions_empty(duckdb_adk_store: DuckdbADKStore) -> None: """Test listing sessions when none exist.""" - sessions = duckdb_adk_store.list_sessions("nonexistent-app", "nonexistent-user") + sessions = await duckdb_adk_store.list_sessions("nonexistent-app", "nonexistent-user") assert sessions == [] -def test_delete_session(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_delete_session(duckdb_adk_store: DuckdbADKStore) -> None: """Test deleting a session.""" session_id = "session-to-delete" - duckdb_adk_store.create_session(session_id, "test-app", "user-004", {"data": "test"}) + await duckdb_adk_store.create_session(session_id, "test-app", "user-004", {"data": "test"}) - assert duckdb_adk_store.get_session(session_id) is not None + assert await duckdb_adk_store.get_session(session_id) is not None - duckdb_adk_store.delete_session(session_id) + await duckdb_adk_store.delete_session(session_id) - assert duckdb_adk_store.get_session(session_id) is None + assert await duckdb_adk_store.get_session(session_id) is None -def test_delete_session_cascade_events(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_delete_session_cascade_events(duckdb_adk_store: DuckdbADKStore) -> None: """Test deleting a session also deletes associated events.""" session_id = "session-with-events" - duckdb_adk_store.create_session(session_id, "test-app", "user-005", {"data": "test"}) - - event = duckdb_adk_store.create_event( - event_id="event-001", - session_id=session_id, - app_name="test-app", - user_id="user-005", - author="user", - actions=b"test-actions", - content={"message": "Hello"}, - ) + await duckdb_adk_store.create_session(session_id, "test-app", "user-005", {"data": "test"}) + + event_record: EventRecord = { + "session_id": session_id, + "invocation_id": "", + "author": "user", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "event-001", + "content": {"message": "Hello"}, + "app_name": "test-app", + "user_id": "user-005", + }, + } + await duckdb_adk_store.append_event(event_record) - assert event["id"] == "event-001" - events = duckdb_adk_store.list_events(session_id) + events = await duckdb_adk_store.get_events(session_id) assert len(events) == 1 - duckdb_adk_store.delete_session(session_id) + await duckdb_adk_store.delete_session(session_id) - assert duckdb_adk_store.get_session(session_id) is None - events_after = duckdb_adk_store.list_events(session_id) + assert await duckdb_adk_store.get_session(session_id) is None + events_after = await duckdb_adk_store.get_events(session_id) assert len(events_after) == 0 -def test_create_and_get_event(duckdb_adk_store: DuckdbADKStore) -> None: - """Test creating and retrieving an event.""" +async def test_create_event(duckdb_adk_store: DuckdbADKStore) -> None: + """Test creating an event and verifying the returned 5-key EventRecord.""" session_id = "session-006" - duckdb_adk_store.create_session(session_id, "test-app", "user-006", {}) + await duckdb_adk_store.create_session(session_id, "test-app", "user-006", {}) - event_id = "event-002" timestamp = datetime.now(timezone.utc) content = {"text": "Test message", "role": "user"} - custom_metadata = {"source": "test"} - - created_event = duckdb_adk_store.create_event( - event_id=event_id, - session_id=session_id, - app_name="test-app", - user_id="user-006", - author="user", - actions=b"pickled-actions", - content=content, - timestamp=timestamp, - custom_metadata=custom_metadata, - ) - - assert created_event["id"] == event_id - assert created_event["session_id"] == session_id - assert created_event["author"] == "user" - assert created_event["content"] == content - assert created_event["custom_metadata"] == custom_metadata - retrieved_event = duckdb_adk_store.get_event(event_id) - assert retrieved_event is not None - assert retrieved_event["id"] == event_id - assert retrieved_event["content"] == content + event_record: EventRecord = { + "session_id": session_id, + "invocation_id": "", + "author": "user", + "timestamp": timestamp, + "event_json": {"id": "event-002", "content": content, "app_name": "test-app", "user_id": "user-006"}, + } + await duckdb_adk_store.append_event(event_record) + events = await duckdb_adk_store.get_events(session_id) + assert len(events) == 1 + assert events[0]["session_id"] == session_id + assert events[0]["author"] == "user" -def test_get_nonexistent_event(duckdb_adk_store: DuckdbADKStore) -> None: - """Test getting a non-existent event returns None.""" - result = duckdb_adk_store.get_event("nonexistent-event") - assert result is None + # Content is stored inside event_json + event_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + ) + assert event_data["content"] == content -def test_list_events(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_list_events(duckdb_adk_store: DuckdbADKStore) -> None: """Test listing events for a session.""" session_id = "session-007" - duckdb_adk_store.create_session(session_id, "test-app", "user-007", {}) - - duckdb_adk_store.create_event( - event_id="event-1", - session_id=session_id, - app_name="test-app", - user_id="user-007", - author="user", - content={"message": "First"}, - ) - duckdb_adk_store.create_event( - event_id="event-2", - session_id=session_id, - app_name="test-app", - user_id="user-007", - author="assistant", - content={"message": "Second"}, - ) + await duckdb_adk_store.create_session(session_id, "test-app", "user-007", {}) + + event1: EventRecord = { + "session_id": session_id, + "invocation_id": "", + "author": "user", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "event-1", "content": {"message": "First"}, "app_name": "test-app", "user_id": "user-007"}, + } + event2: EventRecord = { + "session_id": session_id, + "invocation_id": "", + "author": "assistant", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "event-2", + "content": {"message": "Second"}, + "app_name": "test-app", + "user_id": "user-007", + }, + } + await duckdb_adk_store.append_event(event1) + await duckdb_adk_store.append_event(event2) - events = duckdb_adk_store.list_events(session_id) + events = await duckdb_adk_store.get_events(session_id) assert len(events) == 2 - assert events[0]["id"] == "event-1" - assert events[1]["id"] == "event-2" + assert events[0]["author"] == "user" + assert events[1]["author"] == "assistant" assert events[0]["timestamp"] <= events[1]["timestamp"] -def test_list_events_empty(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_list_events_empty(duckdb_adk_store: DuckdbADKStore) -> None: """Test listing events when none exist.""" session_id = "session-no-events" - duckdb_adk_store.create_session(session_id, "test-app", "user-008", {}) + await duckdb_adk_store.create_session(session_id, "test-app", "user-008", {}) - events = duckdb_adk_store.list_events(session_id) + events = await duckdb_adk_store.get_events(session_id) assert events == [] -def test_event_with_optional_fields(duckdb_adk_store: DuckdbADKStore) -> None: - """Test creating events with all optional fields.""" +async def test_event_with_optional_fields(duckdb_adk_store: DuckdbADKStore) -> None: + """Test creating events with optional fields stored in event_json.""" session_id = "session-008" - duckdb_adk_store.create_session(session_id, "test-app", "user-008", {}) - - event = duckdb_adk_store.create_event( - event_id="event-full", - session_id=session_id, - app_name="test-app", - user_id="user-008", - author="assistant", - actions=b"actions-data", - content={"text": "Response"}, - invocation_id="inv-123", - branch="main", - grounding_metadata={"sources": ["doc1", "doc2"]}, - custom_metadata={"priority": "high"}, - partial=True, - turn_complete=False, - interrupted=False, - error_code=None, - error_message=None, - ) + await duckdb_adk_store.create_session(session_id, "test-app", "user-008", {}) + + event_record: EventRecord = { + "session_id": session_id, + "invocation_id": "inv-123", + "author": "assistant", + "timestamp": datetime.now(timezone.utc), + "event_json": { + "id": "event-full", + "content": {"text": "Response"}, + "app_name": "test-app", + "user_id": "user-008", + "branch": "main", + "grounding_metadata": {"sources": ["doc1", "doc2"]}, + "custom_metadata": {"priority": "high"}, + "partial": True, + "turn_complete": False, + "interrupted": False, + }, + } + await duckdb_adk_store.append_event(event_record) + + events = await duckdb_adk_store.get_events(session_id) + assert len(events) == 1 - assert event["invocation_id"] == "inv-123" - assert event["branch"] == "main" - assert event["grounding_metadata"] == {"sources": ["doc1", "doc2"]} - assert event["partial"] is True - assert event["turn_complete"] is False + # The 5-key record has invocation_id as a top-level indexed column + assert events[0]["invocation_id"] == "inv-123" - retrieved = duckdb_adk_store.get_event("event-full") - assert retrieved is not None - assert retrieved["grounding_metadata"] == {"sources": ["doc1", "doc2"]} + # Other fields are inside event_json + event_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + ) + assert event_data["branch"] == "main" + assert event_data["grounding_metadata"] == {"sources": ["doc1", "doc2"]} + assert event_data["partial"] is True + assert event_data["turn_complete"] is False -def test_event_ordering_by_timestamp(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_event_ordering_by_timestamp(duckdb_adk_store: DuckdbADKStore) -> None: """Test events are ordered by timestamp ascending.""" session_id = "session-009" - duckdb_adk_store.create_session(session_id, "test-app", "user-009", {}) + await duckdb_adk_store.create_session(session_id, "test-app", "user-009", {}) t1 = datetime.now(timezone.utc) t2 = datetime.now(timezone.utc) t3 = datetime.now(timezone.utc) - duckdb_adk_store.create_event( - event_id="event-middle", session_id=session_id, app_name="test-app", user_id="user-009", timestamp=t2 - ) - duckdb_adk_store.create_event( - event_id="event-last", session_id=session_id, app_name="test-app", user_id="user-009", timestamp=t3 - ) - duckdb_adk_store.create_event( - event_id="event-first", session_id=session_id, app_name="test-app", user_id="user-009", timestamp=t1 - ) + ev_middle: EventRecord = { + "session_id": session_id, + "invocation_id": "", + "author": "", + "timestamp": t2, + "event_json": {"id": "event-middle", "app_name": "test-app", "user_id": "user-009"}, + } + ev_last: EventRecord = { + "session_id": session_id, + "invocation_id": "", + "author": "", + "timestamp": t3, + "event_json": {"id": "event-last", "app_name": "test-app", "user_id": "user-009"}, + } + ev_first: EventRecord = { + "session_id": session_id, + "invocation_id": "", + "author": "", + "timestamp": t1, + "event_json": {"id": "event-first", "app_name": "test-app", "user_id": "user-009"}, + } + + await duckdb_adk_store.append_event(ev_middle) + await duckdb_adk_store.append_event(ev_last) + await duckdb_adk_store.append_event(ev_first) - events = duckdb_adk_store.list_events(session_id) + events = await duckdb_adk_store.get_events(session_id) assert len(events) == 3 - assert events[0]["id"] == "event-first" - assert events[1]["id"] == "event-middle" - assert events[2]["id"] == "event-last" + # Events should be ordered by timestamp ASC + event_ids = [] + for e in events: + data = json.loads(e["event_json"]) if isinstance(e["event_json"], str) else e["event_json"] + event_ids.append(data["id"]) + assert event_ids == ["event-first", "event-middle", "event-last"] -def test_session_state_with_complex_data(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_session_state_with_complex_data(duckdb_adk_store: DuckdbADKStore) -> None: """Test session state with nested JSON structures.""" session_id = "session-complex" complex_state = { @@ -314,86 +340,84 @@ def test_session_state_with_complex_data(duckdb_adk_store: DuckdbADKStore) -> No "flags": [True, False, True], } - duckdb_adk_store.create_session(session_id, "test-app", "user-010", complex_state) + await duckdb_adk_store.create_session(session_id, "test-app", "user-010", complex_state) - session = duckdb_adk_store.get_session(session_id) + session = await duckdb_adk_store.get_session(session_id) assert session is not None assert session["state"] == complex_state assert session["state"]["user"]["preferences"]["theme"] == "dark" assert session["state"]["conversation"]["turn_count"] == 5 -def test_empty_state(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_empty_state(duckdb_adk_store: DuckdbADKStore) -> None: """Test creating session with empty state.""" session_id = "session-empty-state" - duckdb_adk_store.create_session(session_id, "test-app", "user-011", {}) + await duckdb_adk_store.create_session(session_id, "test-app", "user-011", {}) - session = duckdb_adk_store.get_session(session_id) + session = await duckdb_adk_store.get_session(session_id) assert session is not None assert session["state"] == {} -def test_table_not_found_handling(tmp_path: Path, worker_id: str) -> None: +async def test_table_not_found_handling(tmp_path: Path, worker_id: str) -> None: """Test graceful handling when tables don't exist.""" db_path = tmp_path / f"test_no_tables_{worker_id}.duckdb" try: config = DuckDBConfig(connection_config={"database": str(db_path)}) store = DuckdbADKStore(config) - result = store.get_session("nonexistent") + result = await store.get_session("nonexistent") assert result is None - sessions = store.list_sessions("app", "user") + sessions = await store.list_sessions("app", "user") assert sessions == [] - events = store.list_events("session") + events = await store.get_events("session") assert events == [] finally: if db_path.exists(): db_path.unlink() -def test_binary_actions_data(duckdb_adk_store: DuckdbADKStore) -> None: - """Test storing and retrieving binary actions data.""" - session_id = "session-binary" - duckdb_adk_store.create_session(session_id, "test-app", "user-012", {}) +async def test_event_json_round_trip(duckdb_adk_store: DuckdbADKStore) -> None: + """Test storing and retrieving event data via event_json.""" + session_id = "session-json-rt" + await duckdb_adk_store.create_session(session_id, "test-app", "user-012", {}) - binary_data = bytes(range(256)) + event_record: EventRecord = { + "session_id": session_id, + "invocation_id": "", + "author": "system", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "event-json", "content": {"data": "value"}, "app_name": "test-app", "user_id": "user-012"}, + } + await duckdb_adk_store.append_event(event_record) - event = duckdb_adk_store.create_event( - event_id="event-binary", - session_id=session_id, - app_name="test-app", - user_id="user-012", - author="system", - actions=binary_data, + events = await duckdb_adk_store.get_events(session_id) + assert len(events) == 1 + event_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] ) - - assert event["actions"] == binary_data - - retrieved = duckdb_adk_store.get_event("event-binary") - assert retrieved is not None - assert retrieved["actions"] == binary_data - assert len(retrieved["actions"]) == 256 + assert event_data["content"] == {"data": "value"} -def test_concurrent_session_updates(duckdb_adk_store: DuckdbADKStore) -> None: +async def test_concurrent_session_updates(duckdb_adk_store: DuckdbADKStore) -> None: """Test multiple updates to same session.""" session_id = "session-concurrent" - duckdb_adk_store.create_session(session_id, "test-app", "user-013", {"counter": 0}) + await duckdb_adk_store.create_session(session_id, "test-app", "user-013", {"counter": 0}) for i in range(10): - session = duckdb_adk_store.get_session(session_id) + session = await duckdb_adk_store.get_session(session_id) assert session is not None current_counter = session["state"]["counter"] - duckdb_adk_store.update_session_state(session_id, {"counter": current_counter + 1}) + await duckdb_adk_store.update_session_state(session_id, {"counter": current_counter + 1}) - final_session = duckdb_adk_store.get_session(session_id) + final_session = await duckdb_adk_store.get_session(session_id) assert final_session is not None assert final_session["state"]["counter"] == 10 -def test_owner_id_column_with_integer(tmp_path: Path, worker_id: str) -> None: +async def test_owner_id_column_with_integer(tmp_path: Path, worker_id: str) -> None: """Test owner ID column with INTEGER type.""" db_path = tmp_path / f"test_owner_id_int_{worker_id}.duckdb" try: @@ -415,12 +439,12 @@ def test_owner_id_column_with_integer(tmp_path: Path, worker_id: str) -> None: }, ) store = DuckdbADKStore(config_with_extension) - store.create_tables() + await store.create_tables() assert store.owner_id_column_name == "tenant_id" assert store.owner_id_column_ddl == "tenant_id INTEGER NOT NULL REFERENCES tenants(id)" - session = store.create_session( + session = await store.create_session( session_id="session-tenant-1", app_name="test-app", user_id="user-001", state={"data": "test"}, owner_id=1 ) @@ -436,7 +460,7 @@ def test_owner_id_column_with_integer(tmp_path: Path, worker_id: str) -> None: db_path.unlink() -def test_owner_id_column_with_ubigint(tmp_path: Path, worker_id: str) -> None: +async def test_owner_id_column_with_ubigint(tmp_path: Path, worker_id: str) -> None: """Test owner ID column with DuckDB UBIGINT type.""" db_path = tmp_path / f"test_owner_id_ubigint_{worker_id}.duckdb" try: @@ -458,11 +482,11 @@ def test_owner_id_column_with_ubigint(tmp_path: Path, worker_id: str) -> None: }, ) store = DuckdbADKStore(config_with_extension) - store.create_tables() + await store.create_tables() assert store.owner_id_column_name == "owner_id" - session = store.create_session( + session = await store.create_session( session_id="session-user-1", app_name="test-app", user_id="user-001", @@ -482,7 +506,7 @@ def test_owner_id_column_with_ubigint(tmp_path: Path, worker_id: str) -> None: db_path.unlink() -def test_owner_id_column_foreign_key_constraint(tmp_path: Path, worker_id: str) -> None: +async def test_owner_id_column_foreign_key_constraint(tmp_path: Path, worker_id: str) -> None: """Test that FK constraint is enforced.""" db_path = tmp_path / f"test_owner_id_constraint_{worker_id}.duckdb" try: @@ -504,14 +528,14 @@ def test_owner_id_column_foreign_key_constraint(tmp_path: Path, worker_id: str) }, ) store = DuckdbADKStore(config_with_extension) - store.create_tables() + await store.create_tables() - store.create_session( + await store.create_session( session_id="session-org-1", app_name="test-app", user_id="user-001", state={"data": "test"}, owner_id=100 ) with pytest.raises(Exception) as exc_info: - store.create_session( + await store.create_session( session_id="session-org-invalid", app_name="test-app", user_id="user-002", @@ -525,7 +549,7 @@ def test_owner_id_column_foreign_key_constraint(tmp_path: Path, worker_id: str) db_path.unlink() -def test_owner_id_column_without_value(tmp_path: Path, worker_id: str) -> None: +async def test_owner_id_column_without_value(tmp_path: Path, worker_id: str) -> None: """Test creating session without owner_id when column is configured but nullable.""" db_path = tmp_path / f"test_owner_id_nullable_{worker_id}.duckdb" try: @@ -546,22 +570,22 @@ def test_owner_id_column_without_value(tmp_path: Path, worker_id: str) -> None: }, ) store = DuckdbADKStore(config_with_extension) - store.create_tables() + await store.create_tables() - session = store.create_session( + session = await store.create_session( session_id="session-no-fk", app_name="test-app", user_id="user-001", state={"data": "test"}, owner_id=None ) assert session["id"] == "session-no-fk" - retrieved = store.get_session("session-no-fk") + retrieved = await store.get_session("session-no-fk") assert retrieved is not None finally: if db_path.exists(): db_path.unlink() -def test_owner_id_column_with_varchar(tmp_path: Path, worker_id: str) -> None: +async def test_owner_id_column_with_varchar(tmp_path: Path, worker_id: str) -> None: """Test owner ID column with VARCHAR type.""" db_path = tmp_path / f"test_owner_id_varchar_{worker_id}.duckdb" try: @@ -583,9 +607,9 @@ def test_owner_id_column_with_varchar(tmp_path: Path, worker_id: str) -> None: }, ) store = DuckdbADKStore(config_with_extension) - store.create_tables() + await store.create_tables() - session = store.create_session( + session = await store.create_session( session_id="session-company-1", app_name="test-app", user_id="user-001", @@ -605,7 +629,7 @@ def test_owner_id_column_with_varchar(tmp_path: Path, worker_id: str) -> None: db_path.unlink() -def test_owner_id_column_multiple_sessions(tmp_path: Path, worker_id: str) -> None: +async def test_owner_id_column_multiple_sessions(tmp_path: Path, worker_id: str) -> None: """Test multiple sessions with same FK value.""" db_path = tmp_path / f"test_owner_id_multiple_{worker_id}.duckdb" try: @@ -627,10 +651,10 @@ def test_owner_id_column_multiple_sessions(tmp_path: Path, worker_id: str) -> No }, ) store = DuckdbADKStore(config_with_extension) - store.create_tables() + await store.create_tables() for i in range(5): - store.create_session( + await store.create_session( session_id=f"session-dept-{i}", app_name="test-app", user_id=f"user-{i}", @@ -648,7 +672,7 @@ def test_owner_id_column_multiple_sessions(tmp_path: Path, worker_id: str) -> No db_path.unlink() -def test_owner_id_column_query_by_fk(tmp_path: Path, worker_id: str) -> None: +async def test_owner_id_column_query_by_fk(tmp_path: Path, worker_id: str) -> None: """Test querying sessions by FK column value.""" db_path = tmp_path / f"test_owner_id_query_{worker_id}.duckdb" try: @@ -670,11 +694,11 @@ def test_owner_id_column_query_by_fk(tmp_path: Path, worker_id: str) -> None: }, ) store = DuckdbADKStore(config_with_extension) - store.create_tables() + await store.create_tables() - store.create_session("s1", "app", "u1", {"val": 1}, owner_id=1) - store.create_session("s2", "app", "u2", {"val": 2}, owner_id=1) - store.create_session("s3", "app", "u3", {"val": 3}, owner_id=2) + await store.create_session("s1", "app", "u1", {"val": 1}, owner_id=1) + await store.create_session("s2", "app", "u2", {"val": 2}, owner_id=1) + await store.create_session("s3", "app", "u3", {"val": 3}, owner_id=2) with config.provide_connection() as conn: cursor = conn.execute("SELECT id FROM sessions_with_project WHERE project_id = ? ORDER BY id", (1,)) diff --git a/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py b/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py index 5e6253a47..eb12c69f8 100644 --- a/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py +++ b/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py @@ -1,12 +1,13 @@ """Integration tests for MysqlConnector ADK session store.""" -import pickle +import json from datetime import datetime, timezone from typing import cast import pytest from sqlspec.adapters.mysqlconnector.adk.store import MysqlConnectorAsyncADKStore +from sqlspec.extensions.adk import EventRecord pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.mysql_connector, pytest.mark.integration] @@ -50,13 +51,14 @@ async def test_storage_types_verification(mysqlconnector_adk_store: MysqlConnect ORDER BY ORDINAL_POSITION """) event_columns = await cursor.fetchall() + event_col_names = [col[0] for col in event_columns] - actions_col = next(col for col in event_columns if col[0] == "actions") - assert actions_col[1] == "blob" - - content_col = next((col for col in event_columns if col[0] == "content"), None) - if content_col: - assert content_col[1] == "json" + # New 5-column schema: session_id, invocation_id, author, timestamp, event_json + assert "session_id" in event_col_names + assert "invocation_id" in event_col_names + assert "author" in event_col_names + assert "timestamp" in event_col_names + assert "event_json" in event_col_names timestamp_col = next(col for col in event_columns if col[0] == "timestamp") assert "timestamp(6)" in cast("str", timestamp_col[2]).lower() @@ -142,18 +144,14 @@ async def test_delete_session_cascade(mysqlconnector_adk_store: MysqlConnectorAs await mysqlconnector_adk_store.create_session(session_id, app_name, user_id, {"status": "active"}) - event_record = { - "id": "event-001", + event_record: EventRecord = { "session_id": session_id, - "app_name": app_name, - "user_id": user_id, "invocation_id": "inv-001", "author": "user", - "actions": pickle.dumps([{"type": "test_action"}]), "timestamp": datetime.now(timezone.utc), - "content": {"text": "Hello"}, + "event_json": {"content": {"text": "Hello"}, "app_name": app_name, "user_id": user_id}, } - await mysqlconnector_adk_store.append_event(event_record) # type: ignore[arg-type] + await mysqlconnector_adk_store.append_event(event_record) events_before = await mysqlconnector_adk_store.get_events(session_id) assert len(events_before) == 1 @@ -175,48 +173,39 @@ async def test_append_and_get_events(mysqlconnector_adk_store: MysqlConnectorAsy await mysqlconnector_adk_store.create_session(session_id, app_name, user_id, {"status": "active"}) - event1 = { - "id": "event-001", + event1: EventRecord = { "session_id": session_id, - "app_name": app_name, - "user_id": user_id, "invocation_id": "inv-001", "author": "user", - "actions": pickle.dumps([{"type": "message", "content": "Hello"}]), "timestamp": datetime.now(timezone.utc), - "content": {"text": "Hello", "role": "user"}, - "partial": False, - "turn_complete": True, + "event_json": {"content": {"text": "Hello", "role": "user"}, "app_name": app_name}, } - event2 = { - "id": "event-002", + event2: EventRecord = { "session_id": session_id, - "app_name": app_name, - "user_id": user_id, "invocation_id": "inv-002", "author": "assistant", - "actions": pickle.dumps([{"type": "response", "content": "Hi there"}]), "timestamp": datetime.now(timezone.utc), - "content": {"text": "Hi there", "role": "assistant"}, - "partial": False, - "turn_complete": True, + "event_json": {"content": {"text": "Hi there", "role": "assistant"}, "app_name": app_name}, } - await mysqlconnector_adk_store.append_event(event1) # type: ignore[arg-type] - await mysqlconnector_adk_store.append_event(event2) # type: ignore[arg-type] + await mysqlconnector_adk_store.append_event(event1) + await mysqlconnector_adk_store.append_event(event2) events = await mysqlconnector_adk_store.get_events(session_id) assert len(events) == 2 - assert events[0]["id"] == "event-001" - assert events[1]["id"] == "event-002" - content0 = events[0]["content"] - content1 = events[1]["content"] - assert content0 is not None and content0["text"] == "Hello" - assert content1 is not None and content1["text"] == "Hi there" - assert isinstance(events[0]["actions"], bytes) - assert pickle.loads(events[0]["actions"])[0]["type"] == "message" + assert events[0]["author"] == "user" + assert events[1]["author"] == "assistant" + # Content is inside event_json + event0_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + ) + event1_data = ( + json.loads(events[1]["event_json"]) if isinstance(events[1]["event_json"], str) else events[1]["event_json"] + ) + assert event0_data["content"]["text"] == "Hello" + assert event1_data["content"]["text"] == "Hi there" async def test_timestamp_precision(mysqlconnector_adk_store: MysqlConnectorAsyncADKStore) -> None: @@ -230,17 +219,14 @@ async def test_timestamp_precision(mysqlconnector_adk_store: MysqlConnectorAsync assert hasattr(created["create_time"], "microsecond") event_time = datetime.now(timezone.utc) - event = { - "id": "event-micro", + event: EventRecord = { "session_id": session_id, - "app_name": app_name, - "user_id": user_id, "invocation_id": "inv-micro", "author": "system", - "actions": b"", "timestamp": event_time, + "event_json": {"app_name": app_name}, } - await mysqlconnector_adk_store.append_event(event) # type: ignore[arg-type] + await mysqlconnector_adk_store.append_event(event) events = await mysqlconnector_adk_store.get_events(session_id) assert len(events) == 1 diff --git a/tests/integration/adapters/oracledb/extensions/adk/test_inmemory.py b/tests/integration/adapters/oracledb/extensions/adk/test_inmemory.py index 26d86165f..5a32a911b 100644 --- a/tests/integration/adapters/oracledb/extensions/adk/test_inmemory.py +++ b/tests/integration/adapters/oracledb/extensions/adk/test_inmemory.py @@ -302,14 +302,14 @@ async def test_inmemory_tables_functional_async(oracle_async_config: OracleAsync @pytest.mark.oracledb -def test_inmemory_enabled_sync(oracle_sync_config: OracleSyncConfig) -> None: +async def test_inmemory_enabled_sync(oracle_sync_config: OracleSyncConfig) -> None: """Test that in_memory=True works with sync store.""" config = OracleSyncConfig( connection_config=oracle_sync_config.connection_config, extension_config={"adk": {"in_memory": True}} ) store = OracleSyncADKStore(config) - store.create_tables() + await store.create_tables() try: with config.provide_connection() as conn: @@ -343,14 +343,14 @@ def test_inmemory_enabled_sync(oracle_sync_config: OracleSyncConfig) -> None: @pytest.mark.oracledb -def test_inmemory_disabled_sync(oracle_sync_config: OracleSyncConfig) -> None: +async def test_inmemory_disabled_sync(oracle_sync_config: OracleSyncConfig) -> None: """Test that in_memory=False works with sync store.""" config = OracleSyncConfig( connection_config=oracle_sync_config.connection_config, extension_config={"adk": {"in_memory": False}} ) store = OracleSyncADKStore(config) - store.create_tables() + await store.create_tables() try: with config.provide_connection() as conn: @@ -382,14 +382,14 @@ def test_inmemory_disabled_sync(oracle_sync_config: OracleSyncConfig) -> None: @pytest.mark.oracledb -def test_inmemory_tables_functional_sync(oracle_sync_config: OracleSyncConfig) -> None: +async def test_inmemory_tables_functional_sync(oracle_sync_config: OracleSyncConfig) -> None: """Test that INMEMORY tables work correctly in sync mode.""" config = OracleSyncConfig( connection_config=oracle_sync_config.connection_config, extension_config={"adk": {"in_memory": True}} ) store = OracleSyncADKStore(config) - store.create_tables() + await store.create_tables() try: session_id = "inmemory-sync-session" @@ -397,11 +397,11 @@ def test_inmemory_tables_functional_sync(oracle_sync_config: OracleSyncConfig) - user_id = "user-456" state = {"sync": True, "value": 99} - session = store.create_session(session_id, app_name, user_id, state) + session = await store.create_session(session_id, app_name, user_id, state) assert session["id"] == session_id assert session["state"] == state - retrieved = store.get_session(session_id) + retrieved = await store.get_session(session_id) assert retrieved is not None assert retrieved["state"] == state diff --git a/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py b/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py index 52cd230f4..1e6cb1f94 100644 --- a/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py +++ b/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py @@ -1,7 +1,7 @@ """Oracle-specific ADK store tests for LOB handling, JSON types, and FK columns.""" -import pickle -from collections.abc import AsyncGenerator, Generator +import json +from collections.abc import AsyncGenerator from datetime import datetime, timezone from typing import Any, cast from uuid import uuid4 @@ -62,11 +62,11 @@ async def oracle_async_store(oracle_async_config: "OracleAsyncConfig") -> "Async await _cleanup_async_store(store, oracle_async_config) -@pytest.fixture(scope="module") -def oracle_sync_store(oracle_sync_config: "OracleSyncConfig") -> "Generator[OracleSyncADKStore, None, None]": - """Create a sync Oracle ADK store with tables created once per module.""" +@pytest.fixture +async def oracle_sync_store(oracle_sync_config: "OracleSyncConfig") -> "AsyncGenerator[OracleSyncADKStore, None]": + """Create a sync Oracle ADK store with tables created per test.""" store = OracleSyncADKStore(oracle_sync_config) - store.create_tables() + await store.create_tables() try: yield store finally: @@ -140,7 +140,9 @@ async def oracle_store_with_fk( @pytest.fixture -def oracle_config_with_users_table(oracle_sync_config: "OracleSyncConfig") -> "Generator[OracleSyncConfig, None, None]": +async def oracle_config_with_users_table( + oracle_sync_config: "OracleSyncConfig", +) -> "AsyncGenerator[OracleSyncConfig, None]": """Create a users table for FK testing.""" with oracle_sync_config.provide_connection() as conn: cursor = conn.cursor() @@ -187,9 +189,9 @@ def oracle_config_with_users_table(oracle_sync_config: "OracleSyncConfig") -> "G @pytest.fixture -def oracle_store_sync_with_fk( +async def oracle_store_sync_with_fk( oracle_config_with_users_table: "OracleSyncConfig", -) -> "Generator[OracleSyncADKStore, None, None]": +) -> "AsyncGenerator[OracleSyncADKStore, None]": """Create a sync Oracle ADK store with owner_id FK column.""" config_with_extension = OracleSyncConfig( connection_config=oracle_config_with_users_table.connection_config, @@ -197,7 +199,7 @@ def oracle_store_sync_with_fk( ) store = OracleSyncADKStore(config_with_extension) _cleanup_sync_store(store, config_with_extension) - store.create_tables() + await store.create_tables() try: yield store finally: @@ -220,8 +222,8 @@ async def test_state_lob_deserialization(oracle_async_store: "OracleAsyncADKStor assert retrieved["state"]["large_field"] == "x" * 10000 -async def test_event_content_lob_deserialization(oracle_async_store: "OracleAsyncADKStore") -> None: - """Test event content CLOB is correctly deserialized.""" +async def test_event_json_lob_deserialization(oracle_async_store: "OracleAsyncADKStore") -> None: + """Test event_json LOB data is correctly deserialized.""" session_id = _unique_session_id("event-lob") app_name = "test-app" user_id = "user-123" @@ -229,169 +231,130 @@ async def test_event_content_lob_deserialization(oracle_async_store: "OracleAsyn await oracle_async_store.create_session(session_id, app_name, user_id, {}) content = {"message": "x" * 5000, "data": {"nested": True}} - grounding_metadata = {"sources": ["a" * 1000, "b" * 1000]} - custom_metadata = {"tags": ["tag1", "tag2"], "priority": "high"} + event_data = { + "content": content, + "app_name": app_name, + "user_id": user_id, + "grounding_metadata": {"sources": ["a" * 1000, "b" * 1000]}, + "custom_metadata": {"tags": ["tag1", "tag2"], "priority": "high"}, + } event_record: EventRecord = { - "id": "event-1", "session_id": session_id, - "app_name": app_name, - "user_id": user_id, + "invocation_id": "", "author": "assistant", - "actions": pickle.dumps([{"name": "test", "args": {}}]), - "content": content, - "grounding_metadata": grounding_metadata, - "custom_metadata": custom_metadata, "timestamp": datetime.now(timezone.utc), - "partial": False, - "turn_complete": True, - "interrupted": False, - "error_code": None, - "error_message": None, - "invocation_id": "", - "branch": None, - "long_running_tool_ids_json": None, + "event_json": event_data, } await oracle_async_store.append_event(event_record) events = await oracle_async_store.get_events(session_id) assert len(events) == 1 - assert events[0]["content"] == content - assert events[0]["grounding_metadata"] == grounding_metadata - assert events[0]["custom_metadata"] == custom_metadata + # event_json contains all the data + retrieved_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + ) + assert retrieved_data["content"] == content + assert retrieved_data["grounding_metadata"] == {"sources": ["a" * 1000, "b" * 1000]} + assert retrieved_data["custom_metadata"] == {"tags": ["tag1", "tag2"], "priority": "high"} -async def test_actions_blob_handling(oracle_async_store: "OracleAsyncADKStore") -> None: - """Test actions BLOB is correctly read and unpickled.""" - session_id = _unique_session_id("actions-blob") +async def test_event_json_storage(oracle_async_store: "OracleAsyncADKStore") -> None: + """Test event_json blob is correctly stored and retrieved.""" + session_id = _unique_session_id("event-json") app_name = "test-app" user_id = "user-123" await oracle_async_store.create_session(session_id, app_name, user_id, {}) - test_actions = [{"function": "test_func", "args": {"param": "value"}, "result": 42}] - actions_bytes = pickle.dumps(test_actions) + event_data = {"function": "test_func", "args": {"param": "value"}, "result": 42} event_record: EventRecord = { - "id": "event-actions", "session_id": session_id, - "app_name": app_name, - "user_id": user_id, + "invocation_id": "", "author": "user", - "actions": actions_bytes, - "content": None, - "grounding_metadata": None, - "custom_metadata": None, "timestamp": datetime.now(timezone.utc), - "partial": None, - "turn_complete": None, - "interrupted": None, - "error_code": None, - "error_message": None, - "invocation_id": "", - "branch": None, - "long_running_tool_ids_json": None, + "event_json": event_data, } await oracle_async_store.append_event(event_record) events = await oracle_async_store.get_events(session_id) assert len(events) == 1 - assert events[0]["actions"] == actions_bytes - unpickled = pickle.loads(events[0]["actions"]) - assert unpickled == test_actions + retrieved_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + ) + assert retrieved_data == event_data -def test_state_lob_deserialization_sync(oracle_sync_store: "OracleSyncADKStore") -> None: +async def test_state_lob_deserialization_sync(oracle_sync_store: "OracleSyncADKStore") -> None: """Test state CLOB/BLOB is correctly deserialized in sync mode.""" session_id = _unique_session_id("lob-session-sync") app_name = "test-app" user_id = "user-123" state = {"large_field": "y" * 10000, "nested": {"data": [4, 5, 6]}} - session = oracle_sync_store.create_session(session_id, app_name, user_id, state) + session = await oracle_sync_store.create_session(session_id, app_name, user_id, state) assert session["state"] == state - retrieved = oracle_sync_store.get_session(session_id) + retrieved = await oracle_sync_store.get_session(session_id) assert retrieved is not None assert retrieved["state"] == state -async def test_boolean_fields_conversion(oracle_async_store: "OracleAsyncADKStore") -> None: - """Test partial, turn_complete, interrupted converted to NUMBER(1).""" - session_id = _unique_session_id("bool-session") +async def test_event_record_5_column_contract(oracle_async_store: "OracleAsyncADKStore") -> None: + """Test the new 5-column EventRecord contract with append_event.""" + session_id = _unique_session_id("5col-session") app_name = "test-app" user_id = "user-123" await oracle_async_store.create_session(session_id, app_name, user_id, {}) event_record: EventRecord = { - "id": "bool-event-1", "session_id": session_id, - "app_name": app_name, - "user_id": user_id, + "invocation_id": "inv-001", "author": "assistant", - "actions": b"", - "content": None, - "grounding_metadata": None, - "custom_metadata": None, "timestamp": datetime.now(timezone.utc), - "partial": True, - "turn_complete": False, - "interrupted": True, - "error_code": None, - "error_message": None, - "invocation_id": "", - "branch": None, - "long_running_tool_ids_json": None, + "event_json": {"content": {"text": "Hello"}, "partial": True, "turn_complete": False, "interrupted": True}, } await oracle_async_store.append_event(event_record) events = await oracle_async_store.get_events(session_id) assert len(events) == 1 - assert events[0]["partial"] is True - assert events[0]["turn_complete"] is False - assert events[0]["interrupted"] is True + assert events[0]["session_id"] == session_id + assert events[0]["invocation_id"] == "inv-001" + assert events[0]["author"] == "assistant" + + retrieved_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + ) + assert retrieved_data["partial"] is True + assert retrieved_data["turn_complete"] is False + assert retrieved_data["interrupted"] is True -async def test_boolean_fields_none_values(oracle_async_store: "OracleAsyncADKStore") -> None: - """Test None values for boolean fields.""" - session_id = _unique_session_id("bool-none-session") +async def test_event_with_none_values(oracle_async_store: "OracleAsyncADKStore") -> None: + """Test event with minimal event_json content.""" + session_id = _unique_session_id("none-session") app_name = "test-app" user_id = "user-123" await oracle_async_store.create_session(session_id, app_name, user_id, {}) event_record: EventRecord = { - "id": "bool-event-none", "session_id": session_id, - "app_name": app_name, - "user_id": user_id, + "invocation_id": "", "author": "user", - "actions": b"", - "content": None, - "grounding_metadata": None, - "custom_metadata": None, "timestamp": datetime.now(timezone.utc), - "partial": None, - "turn_complete": None, - "interrupted": None, - "error_code": None, - "error_message": None, - "invocation_id": "", - "branch": None, - "long_running_tool_ids_json": None, + "event_json": {"app_name": app_name}, } await oracle_async_store.append_event(event_record) events = await oracle_async_store.get_events(session_id) assert len(events) == 1 - assert events[0]["partial"] is None - assert events[0]["turn_complete"] is None - assert events[0]["interrupted"] is None async def test_create_session_with_owner_id(oracle_store_with_fk: "OracleAsyncADKStore") -> None: @@ -453,7 +416,7 @@ async def test_json_storage_type_detection(oracle_async_store: "OracleAsyncADKSt detector = cast("Any", oracle_async_store) storage_type = await detector._detect_json_storage_type() - assert storage_type in ["json", "blob_json", "clob_json", "blob_plain"] + assert storage_type in ["json", "blob_json", "blob_plain"] async def test_json_fields_stored_and_retrieved(oracle_async_store: "OracleAsyncADKStore") -> None: @@ -465,7 +428,7 @@ async def test_json_fields_stored_and_retrieved(oracle_async_store: "OracleAsync "complex": { "nested": {"deep": {"structure": "value"}}, "array": [1, 2, 3, {"key": "value"}], - "unicode": "こんにちは世界", + "unicode": "\u65e5\u672c\u8a9e\u30c6\u30b9\u30c8", "special_chars": "test@example.com | value > 100", } } @@ -476,10 +439,10 @@ async def test_json_fields_stored_and_retrieved(oracle_async_store: "OracleAsync retrieved = await oracle_async_store.get_session(session_id) assert retrieved is not None assert retrieved["state"] == state - assert retrieved["state"]["complex"]["unicode"] == "こんにちは世界" + assert retrieved["state"]["complex"]["unicode"] == "\u65e5\u672c\u8a9e\u30c6\u30b9\u30c8" -def test_create_session_with_owner_id_sync(oracle_store_sync_with_fk: "OracleSyncADKStore") -> None: +async def test_create_session_with_owner_id_sync(oracle_store_sync_with_fk: "OracleSyncADKStore") -> None: """Test creating session with owner_id in sync mode.""" session_id = _unique_session_id("sync-fk") app_name = "test-app" @@ -487,10 +450,10 @@ def test_create_session_with_owner_id_sync(oracle_store_sync_with_fk: "OracleSyn state = {"data": "sync test"} owner_id = 100 - session = oracle_store_sync_with_fk.create_session(session_id, app_name, user_id, state, owner_id=owner_id) + session = await oracle_store_sync_with_fk.create_session(session_id, app_name, user_id, state, owner_id=owner_id) assert session["id"] == session_id assert session["state"] == state - retrieved = oracle_store_sync_with_fk.get_session(session_id) + retrieved = await oracle_store_sync_with_fk.get_session(session_id) assert retrieved is not None assert retrieved["id"] == session_id diff --git a/tests/integration/adapters/psycopg/extensions/adk/test_owner_id_column.py b/tests/integration/adapters/psycopg/extensions/adk/test_owner_id_column.py index 789eff133..5e7fb0123 100644 --- a/tests/integration/adapters/psycopg/extensions/adk/test_owner_id_column.py +++ b/tests/integration/adapters/psycopg/extensions/adk/test_owner_id_column.py @@ -1,6 +1,6 @@ """Integration tests for Psycopg ADK store owner_id_column feature.""" -from collections.abc import AsyncGenerator, Generator +from collections.abc import AsyncGenerator from typing import TYPE_CHECKING, Any import pytest @@ -42,7 +42,7 @@ async def psycopg_async_store_with_fk(postgres_service: "PostgresService") -> "A @pytest.fixture -def psycopg_sync_store_with_fk(postgres_service: "PostgresService") -> "Generator[Any, None, None]": +async def psycopg_sync_store_with_fk(postgres_service: "PostgresService") -> "AsyncGenerator[Any, None]": """Create Psycopg sync ADK store with owner_id_column configured.""" config = PsycopgSyncConfig( connection_config={ @@ -57,7 +57,7 @@ def psycopg_sync_store_with_fk(postgres_service: "PostgresService") -> "Generato }, ) store = PsycopgSyncADKStore(config) - store.create_tables() + await store.create_tables() yield store with config.provide_connection() as conn, conn.cursor() as cur: @@ -74,7 +74,7 @@ async def test_async_store_owner_id_column_initialization(psycopg_async_store_wi assert psycopg_async_store_with_fk.owner_id_column_name == "tenant_id" -def test_sync_store_owner_id_column_initialization(psycopg_sync_store_with_fk: PsycopgSyncADKStore) -> None: +async def test_sync_store_owner_id_column_initialization(psycopg_sync_store_with_fk: PsycopgSyncADKStore) -> None: """Test that owner_id_column is properly initialized in sync store.""" assert psycopg_sync_store_with_fk.owner_id_column_ddl == "account_id VARCHAR(64) NOT NULL" assert psycopg_sync_store_with_fk.owner_id_column_name == "account_id" @@ -105,7 +105,7 @@ async def test_async_store_inherits_owner_id_column(postgres_service: "PostgresS await config.close_pool() -def test_sync_store_inherits_owner_id_column(postgres_service: "PostgresService") -> None: +async def test_sync_store_inherits_owner_id_column(postgres_service: "PostgresService") -> None: """Test that sync store correctly inherits owner_id_column from base class.""" config = PsycopgSyncConfig( connection_config={ @@ -147,7 +147,7 @@ async def test_async_store_without_owner_id_column(postgres_service: "PostgresSe await config.close_pool() -def test_sync_store_without_owner_id_column(postgres_service: "PostgresService") -> None: +async def test_sync_store_without_owner_id_column(postgres_service: "PostgresService") -> None: """Test that sync store works without owner_id_column (default behavior).""" config = PsycopgSyncConfig( connection_config={ @@ -172,9 +172,9 @@ async def test_async_ddl_includes_owner_id_column(psycopg_async_store_with_fk: P assert "test_sessions_fk" in ddl -def test_sync_ddl_includes_owner_id_column(psycopg_sync_store_with_fk: PsycopgSyncADKStore) -> None: +async def test_sync_ddl_includes_owner_id_column(psycopg_sync_store_with_fk: PsycopgSyncADKStore) -> None: """Test that the DDL generation includes the owner_id_column.""" - ddl = psycopg_sync_store_with_fk._get_create_sessions_table_sql() # pyright: ignore[reportPrivateUsage] + ddl = await psycopg_sync_store_with_fk._get_create_sessions_table_sql() # pyright: ignore[reportPrivateUsage] assert "account_id VARCHAR(64) NOT NULL" in ddl assert "test_sessions_sync_fk" in ddl diff --git a/tests/integration/adapters/spanner/extensions/adk/conftest.py b/tests/integration/adapters/spanner/extensions/adk/conftest.py index 4ae8782d2..57ad9bace 100644 --- a/tests/integration/adapters/spanner/extensions/adk/conftest.py +++ b/tests/integration/adapters/spanner/extensions/adk/conftest.py @@ -1,4 +1,4 @@ -from collections.abc import Generator +from collections.abc import AsyncGenerator from typing import TYPE_CHECKING import pytest @@ -30,7 +30,7 @@ def spanner_adk_config(spanner_service: SpannerService, spanner_database: "Datab @pytest.fixture -def spanner_adk_store(spanner_adk_config: SpannerSyncConfig) -> Generator[SpannerSyncADKStore, None, None]: +async def spanner_adk_store(spanner_adk_config: SpannerSyncConfig) -> AsyncGenerator[SpannerSyncADKStore, None]: store = SpannerSyncADKStore(spanner_adk_config) - store.create_tables() + await store.create_tables() yield store diff --git a/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py b/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py index 98a313d15..b7cca39f2 100644 --- a/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py +++ b/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py @@ -1,75 +1,95 @@ -"""Integration tests for Spanner ADK store (sync).""" +"""Integration tests for Spanner ADK store.""" +import json +from datetime import datetime, timezone from typing import Any import pytest +from sqlspec.extensions.adk import EventRecord + pytestmark = [pytest.mark.spanner, pytest.mark.integration] -def test_create_and_get_session(spanner_adk_store: Any) -> None: +async def test_create_and_get_session(spanner_adk_store: Any) -> None: session_id = "session-create" - spanner_adk_store.delete_session(session_id) - created = spanner_adk_store.create_session(session_id, "app", "user", {"a": 1}) + await spanner_adk_store.delete_session(session_id) + created = await spanner_adk_store.create_session(session_id, "app", "user", {"a": 1}) assert created["id"] == session_id - fetched = spanner_adk_store.get_session(session_id) + fetched = await spanner_adk_store.get_session(session_id) assert fetched is not None assert fetched["state"] == {"a": 1} -def test_update_session_state(spanner_adk_store: Any) -> None: +async def test_update_session_state(spanner_adk_store: Any) -> None: session_id = "session-update" - spanner_adk_store.delete_session(session_id) - spanner_adk_store.create_session(session_id, "app", "user", {"a": 1}) + await spanner_adk_store.delete_session(session_id) + await spanner_adk_store.create_session(session_id, "app", "user", {"a": 1}) - spanner_adk_store.update_session_state(session_id, {"a": 2, "b": True}) + await spanner_adk_store.update_session_state(session_id, {"a": 2, "b": True}) - fetched = spanner_adk_store.get_session(session_id) + fetched = await spanner_adk_store.get_session(session_id) assert fetched is not None assert fetched["state"] == {"a": 2, "b": True} -def test_list_sessions(spanner_adk_store: Any) -> None: - spanner_adk_store.delete_session("session-list-1") - spanner_adk_store.delete_session("session-list-2") - spanner_adk_store.delete_session("session-list-3") - spanner_adk_store.create_session("session-list-1", "app-list", "user1", {"v": 1}) - spanner_adk_store.create_session("session-list-2", "app-list", "user1", {"v": 2}) - spanner_adk_store.create_session("session-list-3", "app-list", "user2", {"v": 3}) +async def test_list_sessions(spanner_adk_store: Any) -> None: + await spanner_adk_store.delete_session("session-list-1") + await spanner_adk_store.delete_session("session-list-2") + await spanner_adk_store.delete_session("session-list-3") + await spanner_adk_store.create_session("session-list-1", "app-list", "user1", {"v": 1}) + await spanner_adk_store.create_session("session-list-2", "app-list", "user1", {"v": 2}) + await spanner_adk_store.create_session("session-list-3", "app-list", "user2", {"v": 3}) - sessions = spanner_adk_store.list_sessions("app-list", "user1") + sessions = await spanner_adk_store.list_sessions("app-list", "user1") session_ids = {s["id"] for s in sessions} assert session_ids == {"session-list-1", "session-list-2"} -def test_delete_session(spanner_adk_store: Any) -> None: +async def test_delete_session(spanner_adk_store: Any) -> None: session_id = "session-delete" - spanner_adk_store.delete_session(session_id) - spanner_adk_store.create_session(session_id, "app", "user", {"k": "v"}) - spanner_adk_store.delete_session(session_id) + await spanner_adk_store.delete_session(session_id) + await spanner_adk_store.create_session(session_id, "app", "user", {"k": "v"}) + await spanner_adk_store.delete_session(session_id) - assert spanner_adk_store.get_session(session_id) is None + assert await spanner_adk_store.get_session(session_id) is None -def test_create_and_list_events(spanner_adk_store: Any) -> None: +async def test_create_and_list_events(spanner_adk_store: Any) -> None: session_id = "session-events" - spanner_adk_store.delete_session(session_id) - spanner_adk_store.create_session(session_id, "app", "user", {"x": 1}) - - spanner_adk_store.create_event("event-1", session_id, "app", "user", author="user", content={"msg": "hi"}) - spanner_adk_store.create_event( - "event-2", - session_id, - "app", - "user", - author="assistant", - content={"msg": "ok"}, - partial=False, - turn_complete=True, + await spanner_adk_store.delete_session(session_id) + await spanner_adk_store.create_session(session_id, "app", "user", {"x": 1}) + + event_one: EventRecord = { + "session_id": session_id, + "invocation_id": "event-1", + "author": "user", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "event-1", "content": {"msg": "hi"}, "app_name": "app", "user_id": "user"}, + } + event_two: EventRecord = { + "session_id": session_id, + "invocation_id": "event-2", + "author": "assistant", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "event-2", "content": {"msg": "ok"}, "app_name": "app", "user_id": "user"}, + } + + await spanner_adk_store.append_event(event_one) + await spanner_adk_store.append_event(event_two) + + events = await spanner_adk_store.get_events(session_id) + assert len(events) == 2 + assert events[0]["author"] == "user" + assert events[1]["author"] == "assistant" + + # Content is inside event_json in the new 5-column schema + event0_data = ( + json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] ) - - events = spanner_adk_store.list_events(session_id) - ids = [e["id"] for e in events] - assert ids == ["event-1", "event-2"] - assert events[0]["content"] == {"msg": "hi"} + event1_data = ( + json.loads(events[1]["event_json"]) if isinstance(events[1]["event_json"], str) else events[1]["event_json"] + ) + assert event0_data["content"] == {"msg": "hi"} + assert event1_data["content"] == {"msg": "ok"} diff --git a/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py b/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py index e3397dd8e..71645b388 100644 --- a/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py +++ b/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py @@ -30,64 +30,64 @@ def _build_record(*, session_id: str, event_id: str, content_text: str, inserted ) -def test_sqlite_memory_store_insert_search_dedup() -> None: +async def test_sqlite_memory_store_insert_search_dedup() -> None: """Insert memory entries, search by text, and skip duplicates.""" with tempfile.NamedTemporaryFile(suffix=".db") as tmp: config = SqliteConfig(connection_config={"database": tmp.name}) store = SqliteADKMemoryStore(config) - store.create_tables() + await store.create_tables() now = datetime.now(timezone.utc) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now) record2 = _build_record(session_id="s1", event_id="evt-2", content_text="latte", inserted_at=now) - inserted = store.insert_memory_entries([record1, record2]) + inserted = await store.insert_memory_entries([record1, record2]) assert inserted == 2 - results = store.search_entries(query="espresso", app_name="app", user_id="user") + results = await store.search_entries(query="espresso", app_name="app", user_id="user") assert len(results) == 1 assert results[0]["event_id"] == "evt-1" - deduped = store.insert_memory_entries([record1]) + deduped = await store.insert_memory_entries([record1]) assert deduped == 0 -def test_sqlite_memory_store_delete_by_session() -> None: +async def test_sqlite_memory_store_delete_by_session() -> None: """Delete memory entries by session id.""" with tempfile.NamedTemporaryFile(suffix=".db") as tmp: config = SqliteConfig(connection_config={"database": tmp.name}) store = SqliteADKMemoryStore(config) - store.create_tables() + await store.create_tables() now = datetime.now(timezone.utc) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now) record2 = _build_record(session_id="s2", event_id="evt-2", content_text="latte", inserted_at=now) - store.insert_memory_entries([record1, record2]) + await store.insert_memory_entries([record1, record2]) - deleted = store.delete_entries_by_session("s1") + deleted = await store.delete_entries_by_session("s1") assert deleted == 1 - remaining = store.search_entries(query="latte", app_name="app", user_id="user") + remaining = await store.search_entries(query="latte", app_name="app", user_id="user") assert len(remaining) == 1 assert remaining[0]["session_id"] == "s2" -def test_sqlite_memory_store_delete_older_than() -> None: +async def test_sqlite_memory_store_delete_older_than() -> None: """Delete memory entries older than a cutoff.""" with tempfile.NamedTemporaryFile(suffix=".db") as tmp: config = SqliteConfig(connection_config={"database": tmp.name}) store = SqliteADKMemoryStore(config) - store.create_tables() + await store.create_tables() now = datetime.now(timezone.utc) old = now - timedelta(days=40) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="old", inserted_at=old) record2 = _build_record(session_id="s1", event_id="evt-2", content_text="new", inserted_at=now) - store.insert_memory_entries([record1, record2]) + await store.insert_memory_entries([record1, record2]) - deleted = store.delete_entries_older_than(30) + deleted = await store.delete_entries_older_than(30) assert deleted == 1 - remaining = store.search_entries(query="new", app_name="app", user_id="user") + remaining = await store.search_entries(query="new", app_name="app", user_id="user") assert len(remaining) == 1 assert remaining[0]["event_id"] == "evt-2" diff --git a/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py b/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py index ec618e94e..b275006c5 100644 --- a/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py +++ b/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py @@ -2,7 +2,12 @@ from decimal import Decimal -from sqlspec.adapters.oracledb.adk.store import OracleAsyncADKStore, OracleSyncADKStore +from sqlspec.adapters.oracledb.adk.store import ( + JSONStorageType, + OracleAsyncADKStore, + OracleSyncADKStore, + _event_json_column_ddl, +) async def test_oracle_async_adk_store_deserialize_dict_coerces_decimal() -> None: @@ -43,3 +48,29 @@ def test_oracle_sync_adk_store_deserialize_state_dict_coerces_decimal() -> None: result = store._deserialize_state(payload) # type: ignore[attr-defined] assert result == {"state": 5.0} + + +def test_oracle_event_json_column_ddl_prefers_blob_over_clob() -> None: + assert _event_json_column_ddl(JSONStorageType.JSON_NATIVE) == "event_json JSON NOT NULL" + assert _event_json_column_ddl(JSONStorageType.BLOB_JSON) == "event_json BLOB CHECK (event_json IS JSON) NOT NULL" + assert _event_json_column_ddl(JSONStorageType.BLOB_PLAIN) == "event_json BLOB NOT NULL" + + +async def test_oracle_async_adk_store_serialize_event_json_uses_blob_for_non_native() -> None: + store = OracleAsyncADKStore.__new__(OracleAsyncADKStore) # type: ignore[call-arg] + store._json_storage_type = JSONStorageType.BLOB_JSON # type: ignore[attr-defined] + + result = await store._serialize_event_json({"value": 1}) # type: ignore[attr-defined] + + assert isinstance(result, bytes) + assert b'"value":1' in result + + +def test_oracle_sync_adk_store_serialize_event_json_uses_blob_for_non_native() -> None: + store = OracleSyncADKStore.__new__(OracleSyncADKStore) # type: ignore[call-arg] + store._json_storage_type = JSONStorageType.BLOB_JSON # type: ignore[attr-defined] + + result = store._serialize_event_json({"value": 1}) # type: ignore[attr-defined] + + assert isinstance(result, bytes) + assert b'"value":1' in result diff --git a/tests/unit/adapters/test_psycopg/test_adk_store.py b/tests/unit/adapters/test_psycopg/test_adk_store.py new file mode 100644 index 000000000..a0d3eb6c1 --- /dev/null +++ b/tests/unit/adapters/test_psycopg/test_adk_store.py @@ -0,0 +1,109 @@ +"""Unit tests for psycopg ADK store sync wrappers.""" + +from datetime import datetime, timezone +from typing import Any + +from psycopg.types.json import Jsonb +from typing_extensions import Self + +from sqlspec.adapters.psycopg.adk.store import PsycopgSyncADKStore + + +class _DummyCursor: + def __init__(self, rows: "list[dict[str, Any]] | None" = None) -> None: + self.execute_calls: list[tuple[Any, Any]] = [] + self._rows = rows or [] + + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + return None + + def execute(self, query: Any, params: Any) -> None: + self.execute_calls.append((query, params)) + + def fetchall(self) -> "list[dict[str, Any]]": + return self._rows + + +class _DummyConnection: + def __init__(self, cursor: _DummyCursor) -> None: + self._cursor = cursor + self.commit_called = False + + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + return None + + def cursor(self) -> _DummyCursor: + return self._cursor + + def commit(self) -> None: + self.commit_called = True + + +class _DummyConfig: + def __init__(self, connection: _DummyConnection) -> None: + self._connection = connection + + def provide_connection(self) -> _DummyConnection: + return self._connection + + +def _build_store( + rows: "list[dict[str, Any]] | None" = None, +) -> "tuple[PsycopgSyncADKStore, _DummyCursor, _DummyConnection]": + cursor = _DummyCursor(rows) + connection = _DummyConnection(cursor) + store = PsycopgSyncADKStore.__new__(PsycopgSyncADKStore) # type: ignore[call-arg] + store._config = _DummyConfig(connection) # type: ignore[attr-defined] + store._events_table = "test_events" # type: ignore[attr-defined] + store._session_table = "test_sessions" # type: ignore[attr-defined] + store._owner_id_column_ddl = None # type: ignore[attr-defined] + store._owner_id_column_name = None # type: ignore[attr-defined] + return store, cursor, connection + + +def test_sync_append_event_inserts_without_session_update() -> None: + """append_event must insert a single event without writing session state.""" + store, cursor, connection = _build_store() + event_record = { + "session_id": "session-1", + "invocation_id": "", + "author": "assistant", + "timestamp": datetime.now(timezone.utc), + "event_json": {"id": "event-1"}, + } + + store._append_event(event_record) # type: ignore[arg-type] + + assert len(cursor.execute_calls) == 1 + _, params = cursor.execute_calls[0] + assert params[0] == "session-1" + assert isinstance(params[4], Jsonb) + assert connection.commit_called + + +def test_sync_get_events_passes_after_timestamp_and_limit() -> None: + """get_events must forward after_timestamp and limit to the sync query.""" + base_time = datetime(2026, 1, 1, tzinfo=timezone.utc) + rows = [ + { + "session_id": "session-1", + "invocation_id": "", + "author": "assistant", + "timestamp": base_time, + "event_json": {"id": "event-2"}, + } + ] + store, cursor, _ = _build_store(rows) + + result = store._get_events("session-1", after_timestamp=base_time, limit=1) + + assert len(cursor.execute_calls) == 1 + _, params = cursor.execute_calls[0] + assert params == ("session-1", base_time, 1) + assert result[0]["event_json"]["id"] == "event-2" diff --git a/tests/unit/extensions/test_adk/test_converters.py b/tests/unit/extensions/test_adk/test_converters.py new file mode 100644 index 000000000..903358ae9 --- /dev/null +++ b/tests/unit/extensions/test_adk/test_converters.py @@ -0,0 +1,478 @@ +"""Unit tests for ADK session/event converters and scoped state helpers. + +Tests the NEW contract specified in Chapter 1 of the ADK Clean-Break Overhaul: +- EventRecord has exactly 5 keys (session_id, invocation_id, author, timestamp, event_json) +- event_to_record takes only (event, session_id), not (event, session_id, app_name, user_id) +- record_to_event uses Event.model_validate for full round-trip fidelity +- filter_temp_state, split_scoped_state, merge_scoped_state for scoped state handling +- session_to_record strips temp: keys from state +""" + +import importlib.util +from datetime import datetime, timezone + +import pytest + +if importlib.util.find_spec("google.genai") is None or importlib.util.find_spec("google.adk") is None: + pytest.skip("google-adk not installed", allow_module_level=True) + +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.sessions.session import Session +from google.genai import types + +from sqlspec.extensions.adk.converters import ( + event_to_record, + filter_temp_state, + merge_scoped_state, + record_to_event, + record_to_session, + session_to_record, + split_scoped_state, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_event( + *, + event_id: str = "evt-1", + invocation_id: str = "inv-1", + author: str = "user", + text: "str | None" = None, + state_delta: "dict | None" = None, + branch: "str | None" = None, + partial: "bool | None" = None, + turn_complete: "bool | None" = None, + custom_metadata: "dict | None" = None, +) -> Event: + content = types.Content(parts=[types.Part(text=text)]) if text is not None else None + actions = EventActions(state_delta=state_delta or {}) + return Event( + id=event_id, + invocation_id=invocation_id, + author=author, + content=content, + actions=actions, + timestamp=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc).timestamp(), + branch=branch, + partial=partial, + turn_complete=turn_complete, + custom_metadata=custom_metadata, + ) + + +def _make_session( + *, session_id: str = "session-1", app_name: str = "test-app", user_id: str = "user-1", state: "dict | None" = None +) -> Session: + return Session( + id=session_id, + app_name=app_name, + user_id=user_id, + state=state or {}, + last_update_time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc).timestamp(), + ) + + +# --------------------------------------------------------------------------- +# filter_temp_state +# --------------------------------------------------------------------------- + + +def test_filter_temp_state_removes_temp_keys() -> None: + """temp:-prefixed keys are removed; all other keys are kept.""" + state = {"x": 1, "temp:y": 2, "app:z": 3, "user:w": 4} + result = filter_temp_state(state) + assert result == {"x": 1, "app:z": 3, "user:w": 4} + + +def test_filter_temp_state_empty_dict() -> None: + """Empty dict returns empty dict.""" + assert filter_temp_state({}) == {} + + +def test_filter_temp_state_all_temp_keys() -> None: + """Dict with only temp: keys returns empty dict.""" + state = {"temp:a": 1, "temp:b": 2, "temp:": 3} + assert filter_temp_state(state) == {} + + +def test_filter_temp_state_no_temp_keys() -> None: + """Dict with no temp: keys is returned unchanged.""" + state = {"x": 1, "app:y": 2, "user:z": 3} + result = filter_temp_state(state) + assert result == state + + +def test_filter_temp_state_does_not_mutate_input() -> None: + """Input dict is not mutated.""" + state = {"key": "v", "temp:remove": "gone"} + original = dict(state) + filter_temp_state(state) + assert state == original + + +# --------------------------------------------------------------------------- +# split_scoped_state +# --------------------------------------------------------------------------- + + +def test_split_scoped_state_separates_buckets() -> None: + """app:, user:, and plain keys go into the correct buckets.""" + state = {"app:shared": "a", "user:profile": "u", "session_key": "s", "another": "v"} + app, user, session = split_scoped_state(state) + assert app == {"app:shared": "a"} + assert user == {"user:profile": "u"} + assert session == {"session_key": "s", "another": "v"} + + +def test_split_scoped_state_empty() -> None: + """Empty state produces three empty dicts.""" + app, user, session = split_scoped_state({}) + assert app == {} + assert user == {} + assert session == {} + + +def test_split_scoped_state_only_app_keys() -> None: + """State with only app: keys puts everything in app bucket.""" + state = {"app:x": 1, "app:y": 2} + app, user, session = split_scoped_state(state) + assert app == {"app:x": 1, "app:y": 2} + assert user == {} + assert session == {} + + +def test_split_scoped_state_only_user_keys() -> None: + """State with only user: keys puts everything in user bucket.""" + state = {"user:a": "one", "user:b": "two"} + app, user, session = split_scoped_state(state) + assert app == {} + assert user == {"user:a": "one", "user:b": "two"} + assert session == {} + + +def test_split_scoped_state_only_session_keys() -> None: + """State with no prefix puts everything in session bucket.""" + state = {"key1": 1, "key2": 2} + app, user, session = split_scoped_state(state) + assert app == {} + assert user == {} + assert session == {"key1": 1, "key2": 2} + + +def test_split_scoped_state_preserves_full_key_names() -> None: + """Keys are not stripped of their prefix in the returned buckets.""" + state = {"app:my_key": "val", "user:my_key": "val2"} + app, user, _ = split_scoped_state(state) + assert "app:my_key" in app + assert "user:my_key" in user + + +# --------------------------------------------------------------------------- +# merge_scoped_state +# --------------------------------------------------------------------------- + + +def test_merge_scoped_state_combines_all_buckets() -> None: + """All three buckets appear in the merged result.""" + merged = merge_scoped_state(session_state={"key": "s"}, app_state={"app:x": "a"}, user_state={"user:y": "u"}) + assert merged == {"key": "s", "app:x": "a", "user:y": "u"} + + +def test_merge_scoped_state_overlay_priority_app_over_session() -> None: + """app_state overlays session_state for the same key.""" + merged = merge_scoped_state(session_state={"app:x": "old"}, app_state={"app:x": "new"}) + assert merged["app:x"] == "new" + + +def test_merge_scoped_state_overlay_priority_user_over_session() -> None: + """user_state overlays session_state for the same key.""" + merged = merge_scoped_state(session_state={"user:y": "session_val"}, user_state={"user:y": "user_val"}) + assert merged["user:y"] == "user_val" + + +def test_merge_scoped_state_no_app_no_user() -> None: + """Merging without app_state or user_state returns session_state copy.""" + session = {"key": "v", "other": 42} + merged = merge_scoped_state(session_state=session) + assert merged == session + + +def test_merge_scoped_state_empty_session_state() -> None: + """Empty session_state with app/user state returns combined app+user keys.""" + merged = merge_scoped_state(session_state={}, app_state={"app:a": 1}, user_state={"user:b": 2}) + assert merged == {"app:a": 1, "user:b": 2} + + +def test_merge_scoped_state_does_not_mutate_session_state() -> None: + """Input session_state dict is not mutated.""" + session = {"key": "v"} + original = dict(session) + merge_scoped_state(session_state=session, app_state={"app:x": 1}) + assert session == original + + +# --------------------------------------------------------------------------- +# event_to_record — signature and structure +# --------------------------------------------------------------------------- + + +def test_event_to_record_only_5_keys() -> None: + """EventRecord has exactly session_id, invocation_id, author, timestamp, event_json.""" + event = _make_event() + record = event_to_record(event, "session-1") + assert set(record.keys()) == {"session_id", "invocation_id", "author", "timestamp", "event_json"} + + +def test_event_to_record_signature_two_args_only() -> None: + """event_to_record raises TypeError if called with extra positional args (old 4-arg signature).""" + event = _make_event() + with pytest.raises(TypeError): + event_to_record(event, "session-1", "app-name", "user-id") # type: ignore[call-arg] + + +def test_event_to_record_session_id_stored_correctly() -> None: + """session_id in the record matches the argument passed.""" + event = _make_event(invocation_id="inv-abc", author="model") + record = event_to_record(event, "my-session-id") + assert record["session_id"] == "my-session-id" + + +def test_event_to_record_indexed_fields_match_event() -> None: + """Indexed scalar columns (invocation_id, author, timestamp) match the source event.""" + event = _make_event(invocation_id="inv-xyz", author="tool") + record = event_to_record(event, "s1") + assert record["invocation_id"] == "inv-xyz" + assert record["author"] == "tool" + assert isinstance(record["timestamp"], datetime) + + +def test_event_to_record_event_json_matches_model_dump() -> None: + """event_json in the record equals event.model_dump(exclude_none=True, mode='json').""" + event = _make_event(text="hello", state_delta={"key": "val"}, custom_metadata={"foo": "bar"}) + record = event_to_record(event, "s1") + expected_json = event.model_dump(exclude_none=True, mode="json") + assert record["event_json"] == expected_json + + +def test_event_to_record_event_json_is_dict() -> None: + """event_json field is a plain dict (not bytes, not string).""" + event = _make_event() + record = event_to_record(event, "s1") + assert isinstance(record["event_json"], dict) + + +def test_event_to_record_actions_in_event_json_is_structured() -> None: + """Actions are stored as structured JSON dict in event_json, not as raw bytes.""" + event = _make_event(state_delta={"x": "y"}) + record = event_to_record(event, "s1") + event_json = record["event_json"] + # actions should be a dict in the JSON blob + if "actions" in event_json: + assert isinstance(event_json["actions"], dict) + + +def test_event_to_record_timestamp_is_datetime() -> None: + """timestamp column is a datetime object with timezone.""" + event = _make_event() + record = event_to_record(event, "s1") + assert isinstance(record["timestamp"], datetime) + assert record["timestamp"].tzinfo is not None + + +# --------------------------------------------------------------------------- +# record_to_event — full round-trip fidelity +# --------------------------------------------------------------------------- + + +def test_record_to_event_full_roundtrip_basic() -> None: + """Event -> record -> Event produces an identical object for basic fields.""" + original = _make_event(event_id="evt-rt", invocation_id="inv-rt", author="model") + record = event_to_record(original, "s1") + restored = record_to_event(record) + + assert restored.id == original.id + assert restored.invocation_id == original.invocation_id + assert restored.author == original.author + + +def test_record_to_event_roundtrip_preserves_content() -> None: + """Content (parts) survives the round-trip.""" + original = _make_event(text="hello world", author="model") + record = event_to_record(original, "s1") + restored = record_to_event(record) + + assert restored.content is not None + assert restored.content.parts is not None + assert restored.content.parts[0].text == "hello world" + + +def test_record_to_event_roundtrip_preserves_actions() -> None: + """EventActions (state_delta) survives the round-trip.""" + original = _make_event(state_delta={"key": "v1", "other": 42}) + record = event_to_record(original, "s1") + restored = record_to_event(record) + + assert restored.actions is not None + assert restored.actions.state_delta == {"key": "v1", "other": 42} + + +def test_record_to_event_roundtrip_preserves_custom_metadata() -> None: + """custom_metadata survives the round-trip.""" + original = _make_event(custom_metadata={"tag": "v2", "score": 0.9}) + record = event_to_record(original, "s1") + restored = record_to_event(record) + + assert restored.custom_metadata == {"tag": "v2", "score": 0.9} + + +def test_record_to_event_roundtrip_preserves_branch() -> None: + """branch field survives the round-trip.""" + original = _make_event(branch="feature-branch") + record = event_to_record(original, "s1") + restored = record_to_event(record) + + assert restored.branch == "feature-branch" + + +def test_record_to_event_roundtrip_preserves_partial_flag() -> None: + """partial flag survives the round-trip.""" + original = _make_event(partial=True) + record = event_to_record(original, "s1") + restored = record_to_event(record) + + assert restored.partial is True + + +def test_record_to_event_roundtrip_preserves_turn_complete() -> None: + """turn_complete flag survives the round-trip.""" + original = _make_event(turn_complete=True) + record = event_to_record(original, "s1") + restored = record_to_event(record) + + assert restored.turn_complete is True + + +def test_record_to_event_roundtrip_preserves_timestamp() -> None: + """timestamp survives the round-trip within float precision.""" + fixed_ts = datetime(2024, 6, 1, 10, 30, 0, tzinfo=timezone.utc).timestamp() + event = Event(id="ts-evt", invocation_id="inv-1", author="user", actions=EventActions(), timestamp=fixed_ts) + record = event_to_record(event, "s1") + restored = record_to_event(record) + + assert abs(restored.timestamp - fixed_ts) < 1.0 # within 1 second + + +@pytest.mark.xfail( + reason="ADK Event model uses extra='forbid' — unknown fields raise ValidationError. " + "Future ADK versions that add fields will also update the model, so this is safe.", + strict=True, +) +def test_record_to_event_with_extra_fields_in_event_json() -> None: + """Events with extra/unknown fields in event_json are rejected by Event model.""" + event = _make_event(event_id="extra-fields-evt", author="tool") + record = event_to_record(event, "s1") + + # Inject hypothetical future ADK field into event_json + record["event_json"]["hypothetical_v3_field"] = "some_value" # type: ignore[index] + + # This WILL raise because Event has extra='forbid' + restored = record_to_event(record) + assert restored.id == "extra-fields-evt" + + +# --------------------------------------------------------------------------- +# session_to_record — strips temp: keys +# --------------------------------------------------------------------------- + + +def test_session_to_record_strips_temp_keys_from_state() -> None: + """session_to_record removes temp:-prefixed keys before persisting.""" + session = _make_session(state={"key": "v", "temp:x": "t", "app:y": "a"}) + record = session_to_record(session) + assert "temp:x" not in record["state"] + assert record["state"]["key"] == "v" + assert record["state"]["app:y"] == "a" + + +def test_session_to_record_empty_state_stays_empty() -> None: + """Empty state produces empty state in record.""" + session = _make_session(state={}) + record = session_to_record(session) + assert record["state"] == {} + + +def test_session_to_record_all_temp_state_produces_empty() -> None: + """Session state with only temp: keys produces empty state in record.""" + session = _make_session(state={"temp:a": 1, "temp:b": 2}) + record = session_to_record(session) + assert record["state"] == {} + + +def test_session_to_record_no_temp_state_unchanged() -> None: + """Session state with no temp: keys is stored without modification.""" + state = {"x": 1, "app:y": 2, "user:z": 3} + session = _make_session(state=state) + record = session_to_record(session) + assert record["state"] == state + + +def test_session_to_record_includes_required_fields() -> None: + """Session record includes id, app_name, user_id, state, create_time, update_time.""" + session = _make_session() + record = session_to_record(session) + assert "id" in record + assert "app_name" in record + assert "user_id" in record + assert "state" in record + assert "create_time" in record + assert "update_time" in record + + +# --------------------------------------------------------------------------- +# record_to_session — integrates with record_to_event +# --------------------------------------------------------------------------- + + +def test_record_to_session_with_events_round_trip() -> None: + """Sessions with events reconstruct correctly using record_to_session.""" + from sqlspec.extensions.adk._types import SessionRecord + + session_record = SessionRecord( + id="s1", + app_name="app", + user_id="u1", + state={"key": "val"}, + create_time=datetime.now(timezone.utc), + update_time=datetime.now(timezone.utc), + ) + event = _make_event(text="hello", author="user") + event_record = event_to_record(event, "s1") + + session = record_to_session(session_record, [event_record]) + + assert session.id == "s1" + assert session.app_name == "app" + assert session.user_id == "u1" + assert session.state == {"key": "val"} + assert len(session.events) == 1 + assert session.events[0].id == event.id + + +def test_record_to_session_empty_events() -> None: + """Sessions without events reconstruct with empty events list.""" + from sqlspec.extensions.adk._types import SessionRecord + + session_record = SessionRecord( + id="s2", + app_name="app", + user_id="u2", + state={}, + create_time=datetime.now(timezone.utc), + update_time=datetime.now(timezone.utc), + ) + session = record_to_session(session_record, []) + assert session.events == [] diff --git a/tests/unit/extensions/test_adk/test_service.py b/tests/unit/extensions/test_adk/test_service.py new file mode 100644 index 000000000..a34d78862 --- /dev/null +++ b/tests/unit/extensions/test_adk/test_service.py @@ -0,0 +1,321 @@ +"""Unit tests for SQLSpecSessionService — state persistence fix. + +Tests the NEW contract specified in Chapter 1 of the ADK Clean-Break Overhaul: +- append_event calls append_event_and_update_state (not the old append_event) +- temp: keys are stripped before persisting session state +- partial events are not persisted +- create_session strips temp: keys from initial state + +The store is mocked — no database required. +""" + +import importlib.util +from datetime import datetime, timezone +from typing import Any + +import pytest + +if importlib.util.find_spec("google.genai") is None or importlib.util.find_spec("google.adk") is None: + pytest.skip("google-adk not installed", allow_module_level=True) + +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.sessions.session import Session + +from sqlspec.extensions.adk.service import SQLSpecSessionService + +# --------------------------------------------------------------------------- +# Mock store +# --------------------------------------------------------------------------- + + +class MockStore: + """Simple mock that records calls to store methods. + + Attributes are set to AsyncMock so that await works out of the box, + and call arguments are captured for assertion. + """ + + def __init__(self) -> None: + # Track calls to the new combined method + self.append_event_and_update_state_calls: list[dict[str, Any]] = [] + self.append_event_and_update_state_called = False + + # Track calls to create_session + self.create_session_calls: list[dict[str, Any]] = [] + + # Provide a get_session that returns a minimal session record + self._session_record = { + "id": "s1", + "app_name": "app", + "user_id": "u1", + "state": {}, + "create_time": datetime.now(timezone.utc), + "update_time": datetime.now(timezone.utc), + } + + async def append_event_and_update_state(self, event_record: Any, session_id: str, state: "dict[str, Any]") -> None: + self.append_event_and_update_state_called = True + self.append_event_and_update_state_calls.append({ + "event_record": event_record, + "session_id": session_id, + "state": state, + }) + + async def get_session(self, session_id: str) -> "dict[str, Any] | None": + return self._session_record + + async def create_session( + self, *, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]" + ) -> "dict[str, Any]": + self.create_session_calls.append({ + "session_id": session_id, + "app_name": app_name, + "user_id": user_id, + "state": state, + }) + return { + "id": session_id, + "app_name": app_name, + "user_id": user_id, + "state": state, + "create_time": datetime.now(timezone.utc), + "update_time": datetime.now(timezone.utc), + } + + # Old method — should NOT be called by the new service + async def append_event(self, event_record: Any) -> None: + raise AssertionError("append_event (old method) must not be called — use append_event_and_update_state") + + async def get_events(self, *, session_id: str, after_timestamp: Any = None, limit: Any = None) -> list: + return [] + + async def list_sessions(self, *, app_name: str, user_id: "str | None" = None) -> list: + return [] + + async def delete_session(self, session_id: str) -> None: + pass + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_session( + *, session_id: str = "s1", app_name: str = "app", user_id: str = "u1", state: "dict | None" = None +) -> Session: + return Session( + id=session_id, + app_name=app_name, + user_id=user_id, + state=state or {}, + last_update_time=datetime.now(timezone.utc).timestamp(), + ) + + +def _make_event( + *, invocation_id: str = "inv-1", author: str = "model", state_delta: "dict | None" = None, partial: bool = False +) -> Event: + actions = EventActions(state_delta=state_delta or {}) + return Event( + invocation_id=invocation_id, + author=author, + actions=actions, + timestamp=datetime.now(timezone.utc).timestamp(), + partial=partial, + ) + + +# --------------------------------------------------------------------------- +# append_event — calls append_event_and_update_state +# --------------------------------------------------------------------------- + + +@pytest.mark.anyio +async def test_append_event_calls_append_event_and_update_state() -> None: + """append_event must call append_event_and_update_state, not the old append_event.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + session = _make_session(state={"key": "v0"}) + event = _make_event(state_delta={"key": "v1"}) + + await service.append_event(session, event) + + assert store.append_event_and_update_state_called, ( + "append_event_and_update_state was never called — state will not be persisted" + ) + + +@pytest.mark.anyio +async def test_append_event_persists_updated_state() -> None: + """append_event persists the state AFTER applying event.actions.state_delta.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + session = _make_session(state={"key": "v0"}) + event = _make_event(state_delta={"key": "v1"}) + + await service.append_event(session, event) + + assert store.append_event_and_update_state_called + last_call = store.append_event_and_update_state_calls[-1] + # The persisted state must reflect the mutation from state_delta + assert last_call["state"]["key"] == "v1" + + +@pytest.mark.anyio +async def test_append_event_strips_temp_from_persisted_state() -> None: + """temp: keys are removed before state persistence.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + session = _make_session(state={"key": "v", "temp:transient": "should_not_persist"}) + event = _make_event() + + await service.append_event(session, event) + + assert store.append_event_and_update_state_called + last_call = store.append_event_and_update_state_calls[-1] + persisted_state = last_call["state"] + assert "temp:transient" not in persisted_state + assert persisted_state["key"] == "v" + + +@pytest.mark.anyio +async def test_append_event_strips_temp_state_delta_from_persisted_state() -> None: + """temp: keys added via state_delta are also stripped before persisting.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + # Session state has temp: key added by an agent via state_delta + session = _make_session(state={"regular": "v"}) + event = _make_event(state_delta={"temp:output": "transient", "regular": "updated"}) + + await service.append_event(session, event) + + last_call = store.append_event_and_update_state_calls[-1] + persisted_state = last_call["state"] + assert "temp:output" not in persisted_state + assert persisted_state["regular"] == "updated" + + +@pytest.mark.anyio +async def test_append_event_skips_partial_events() -> None: + """Partial events are not persisted to the store.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + session = _make_session() + partial_event = _make_event(partial=True) + + result = await service.append_event(session, partial_event) + + assert not store.append_event_and_update_state_called, ( + "append_event_and_update_state must NOT be called for partial events" + ) + assert result.partial is True + + +@pytest.mark.anyio +async def test_append_event_passes_correct_session_id_to_store() -> None: + """append_event_and_update_state receives the correct session_id.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + session = _make_session(session_id="my-unique-session-id") + event = _make_event() + + await service.append_event(session, event) + + last_call = store.append_event_and_update_state_calls[-1] + assert last_call["session_id"] == "my-unique-session-id" + + +@pytest.mark.anyio +async def test_append_event_event_record_has_5_keys() -> None: + """The event_record passed to the store has exactly 5 keys (new schema).""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + session = _make_session() + event = _make_event() + + await service.append_event(session, event) + + last_call = store.append_event_and_update_state_calls[-1] + event_record = last_call["event_record"] + assert set(event_record.keys()) == {"session_id", "invocation_id", "author", "timestamp", "event_json"} + + +@pytest.mark.anyio +async def test_append_event_returns_the_event() -> None: + """append_event returns the event after persisting.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + session = _make_session() + event = _make_event(author="model") + + result = await service.append_event(session, event) + + assert result is not None + assert result.author == "model" + + +# --------------------------------------------------------------------------- +# create_session — strips temp: keys from initial state +# --------------------------------------------------------------------------- + + +@pytest.mark.anyio +async def test_create_session_strips_temp_keys_from_initial_state() -> None: + """create_session filters temp: keys before passing state to the store.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + + await service.create_session(app_name="app", user_id="u1", state={"x": 1, "temp:y": 2, "app:z": 3}) + + assert len(store.create_session_calls) == 1 + persisted_state = store.create_session_calls[0]["state"] + assert "temp:y" not in persisted_state + assert persisted_state["x"] == 1 + assert persisted_state["app:z"] == 3 + + +@pytest.mark.anyio +async def test_create_session_with_only_temp_state_persists_empty() -> None: + """create_session with only temp: state persists empty state dict.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + + await service.create_session(app_name="app", user_id="u1", state={"temp:only": "gone"}) + + assert store.create_session_calls[0]["state"] == {} + + +@pytest.mark.anyio +async def test_create_session_none_state_persists_empty() -> None: + """create_session with state=None persists empty state dict.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + + await service.create_session(app_name="app", user_id="u1") + + assert store.create_session_calls[0]["state"] == {} + + +@pytest.mark.anyio +async def test_create_session_generates_uuid_if_no_session_id() -> None: + """create_session generates a UUID if no session_id is provided.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + + session = await service.create_session(app_name="app", user_id="u1") + + assert session.id is not None + assert len(session.id) > 0 + + +@pytest.mark.anyio +async def test_create_session_uses_provided_session_id() -> None: + """create_session uses the caller-provided session_id.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + + session = await service.create_session(app_name="app", user_id="u1", session_id="my-id") + + assert session.id == "my-id" diff --git a/tests/unit/extensions/test_adk/test_store_instantiation.py b/tests/unit/extensions/test_adk/test_store_instantiation.py new file mode 100644 index 000000000..b98662fee --- /dev/null +++ b/tests/unit/extensions/test_adk/test_store_instantiation.py @@ -0,0 +1,87 @@ +"""Smoke tests verifying all shipped ADK store classes are instantiable (not abstract). + +Every shipped store class must be concrete — no unsatisfied abstract methods. +This catches bugs where stores have method signature mismatches with the base +class, such as cockroach, mysqlconnector sync, pymysql, and spanner stores +that are missing abstract method implementations added to the base contract. + +NOTE: Some stores WILL fail this test currently — that is expected and +documents one of the bugs the ADK Clean-Break Overhaul (Ch1) is fixing. +""" + +import importlib + +import pytest + +# Session stores (async) +ASYNC_SESSION_STORES = [ + "sqlspec.adapters.asyncpg.adk.store.AsyncpgADKStore", + "sqlspec.adapters.aiosqlite.adk.store.AiosqliteADKStore", + "sqlspec.adapters.asyncmy.adk.store.AsyncmyADKStore", + "sqlspec.adapters.cockroach_asyncpg.adk.store.CockroachAsyncpgADKStore", + "sqlspec.adapters.cockroach_psycopg.adk.store.CockroachPsycopgAsyncADKStore", + "sqlspec.adapters.mysqlconnector.adk.store.MysqlConnectorAsyncADKStore", + "sqlspec.adapters.oracledb.adk.store.OracleAsyncADKStore", + "sqlspec.adapters.psqlpy.adk.store.PsqlpyADKStore", + "sqlspec.adapters.psycopg.adk.store.PsycopgAsyncADKStore", + # sqlite uses BaseAsyncADKStore despite being backed by a sync driver + "sqlspec.adapters.sqlite.adk.store.SqliteADKStore", +] + +# Session stores (sync) +SYNC_SESSION_STORES = [ + "sqlspec.adapters.adbc.adk.store.AdbcADKStore", + "sqlspec.adapters.cockroach_psycopg.adk.store.CockroachPsycopgSyncADKStore", + "sqlspec.adapters.duckdb.adk.store.DuckdbADKStore", + "sqlspec.adapters.mysqlconnector.adk.store.MysqlConnectorSyncADKStore", + "sqlspec.adapters.oracledb.adk.store.OracleSyncADKStore", + "sqlspec.adapters.psycopg.adk.store.PsycopgSyncADKStore", + "sqlspec.adapters.pymysql.adk.store.PyMysqlADKStore", + "sqlspec.adapters.spanner.adk.store.SpannerSyncADKStore", +] + +# Memory stores (async) +ASYNC_MEMORY_STORES = [ + "sqlspec.adapters.asyncpg.adk.store.AsyncpgADKMemoryStore", + "sqlspec.adapters.aiosqlite.adk.store.AiosqliteADKMemoryStore", + "sqlspec.adapters.asyncmy.adk.store.AsyncmyADKMemoryStore", + "sqlspec.adapters.cockroach_asyncpg.adk.store.CockroachAsyncpgADKMemoryStore", + "sqlspec.adapters.cockroach_psycopg.adk.store.CockroachPsycopgAsyncADKMemoryStore", + "sqlspec.adapters.mysqlconnector.adk.store.MysqlConnectorAsyncADKMemoryStore", + "sqlspec.adapters.oracledb.adk.store.OracleAsyncADKMemoryStore", + "sqlspec.adapters.psqlpy.adk.store.PsqlpyADKMemoryStore", + "sqlspec.adapters.psycopg.adk.store.PsycopgAsyncADKMemoryStore", +] + +# Memory stores (sync) +SYNC_MEMORY_STORES = [ + "sqlspec.adapters.adbc.adk.store.AdbcADKMemoryStore", + "sqlspec.adapters.cockroach_psycopg.adk.store.CockroachPsycopgSyncADKMemoryStore", + "sqlspec.adapters.duckdb.adk.store.DuckdbADKMemoryStore", + "sqlspec.adapters.mysqlconnector.adk.store.MysqlConnectorSyncADKMemoryStore", + "sqlspec.adapters.oracledb.adk.store.OracleSyncADKMemoryStore", + "sqlspec.adapters.psycopg.adk.store.PsycopgSyncADKMemoryStore", + "sqlspec.adapters.pymysql.adk.store.PyMysqlADKMemoryStore", + "sqlspec.adapters.spanner.adk.store.SpannerSyncADKMemoryStore", + "sqlspec.adapters.sqlite.adk.store.SqliteADKMemoryStore", +] + +ALL_STORE_CLASSES = ASYNC_SESSION_STORES + SYNC_SESSION_STORES + ASYNC_MEMORY_STORES + SYNC_MEMORY_STORES + + +@pytest.mark.parametrize("class_path", ALL_STORE_CLASSES) +def test_store_has_no_abstract_methods(class_path: str) -> None: + """Every shipped store class must be concrete (no unsatisfied abstract methods). + + A class with entries in ``__abstractmethods__`` cannot be instantiated and + signals that the concrete store is missing one or more method implementations + required by its base class contract. + """ + module_path, class_name = class_path.rsplit(".", 1) + try: + module = importlib.import_module(module_path) + except ImportError: + pytest.skip(f"Module {module_path} not importable (missing optional dependency)") + cls = getattr(module, class_name) + abstract: set[str] = getattr(cls, "__abstractmethods__", set()) + assert not abstract, f"{class_path} has unsatisfied abstract methods: {abstract}" diff --git a/uv.lock b/uv.lock index 4300db512..b499bb1d7 100644 --- a/uv.lock +++ b/uv.lock @@ -1509,7 +1509,7 @@ name = "exceptiongroup" version = "1.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" } wheels = [ @@ -1963,7 +1963,7 @@ wheels = [ [[package]] name = "google-cloud-aiplatform" -version = "1.141.0" +version = "1.142.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "docstring-parser" }, @@ -1979,13 +1979,14 @@ dependencies = [ { name = "pydantic" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ac/dc/1209c7aab43bd7233cf631165a3b1b4284d22fc7fe7387c66228d07868ab/google_cloud_aiplatform-1.141.0.tar.gz", hash = "sha256:e3b1cdb28865dd862aac9c685dfc5ac076488705aba0a5354016efadcddd59c6", size = 10152688, upload-time = "2026-03-10T22:20:08.692Z" } +sdist = { url = "https://files.pythonhosted.org/packages/41/0d/3063a0512d60cf18854a279e00ccb796429545464345ef821cf77cb93d05/google_cloud_aiplatform-1.142.0.tar.gz", hash = "sha256:87b49e002703dc14885093e9b264587db84222bef5f70f5a442d03f41beecdd1", size = 10207993, upload-time = "2026-03-20T22:49:13.797Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/6a/fc/428af69a69ff2e477e7f5e12d227b31fe5790f1a8234aacd54297f49c836/google_cloud_aiplatform-1.141.0-py2.py3-none-any.whl", hash = "sha256:6bd25b4d514c40b8181ca703e1b313ad6d0454ab8006fc9907fb3e9f672f31d1", size = 8358409, upload-time = "2026-03-10T22:20:04.871Z" }, + { url = "https://files.pythonhosted.org/packages/59/8b/f29646d3fa940f0e38cfcc12137f4851856b50d7486a3c05103ebc78d82d/google_cloud_aiplatform-1.142.0-py2.py3-none-any.whl", hash = "sha256:17c91db9b613cbbafb2c36335b123686aeb2b4b8448be5134b565ae07165a39a", size = 8388991, upload-time = "2026-03-20T22:49:10.334Z" }, ] [package.optional-dependencies] agent-engines = [ + { name = "aiohttp" }, { name = "cloudpickle" }, { name = "google-cloud-iam" }, { name = "google-cloud-logging" }, @@ -7071,7 +7072,7 @@ wheels = [ [[package]] name = "sqlspec" -version = "0.41.1" +version = "0.42.0" source = { editable = "." } dependencies = [ { name = "mypy-extensions" }, @@ -7121,6 +7122,7 @@ cockroachdb = [ ] duckdb = [ { name = "duckdb" }, + { name = "pytz" }, ] fastapi = [ { name = "fastapi" }, @@ -7397,6 +7399,7 @@ requires-dist = [ { name = "pydantic-extra-types", marker = "extra == 'pydantic'" }, { name = "pymssql", marker = "extra == 'pymssql'" }, { name = "pymysql", marker = "extra == 'pymysql'" }, + { name = "pytz", marker = "extra == 'duckdb'" }, { name = "rich-click", specifier = ">=1.9.0" }, { name = "sqlglot", specifier = ">=30.0.0" }, { name = "sqlglot", extras = ["c"], marker = "extra == 'mypyc'", specifier = ">=30.0.0" },