diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index f04a450c4..d4381dd65 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -108,7 +108,7 @@ jobs:
- name: Test
env:
PYTHONFAULTHANDLER: "1"
- PYTEST_ADDOPTS: "--max-worker-restart=0 -s"
+ PYTEST_ADDOPTS: "--max-worker-restart=0"
run: timeout 900s uv run pytest -n 2 --dist=loadgroup
# test-linux-freethreaded:
diff --git a/.gitignore b/.gitignore
index 20a641773..442abbaa4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -68,3 +68,4 @@ uv.toml
.geminiignore
.beads/
tools/scripts/profiles/*.prof
+.agents/
diff --git a/docs/_tapes/migration_workflow.tape b/docs/_tapes/migration_workflow.tape
index f7703a6f6..d08a78408 100644
--- a/docs/_tapes/migration_workflow.tape
+++ b/docs/_tapes/migration_workflow.tape
@@ -28,7 +28,7 @@ Type "# Initialize the migration environment"
Enter
Sleep 500ms
-Type "sqlspec db init"
+Type "sqlspec init"
Enter
Sleep 3s
@@ -36,7 +36,7 @@ Type "# Create a new migration"
Enter
Sleep 500ms
-Type 'sqlspec db create-migration -m "add users table"'
+Type 'sqlspec create-migration -m "add users table"'
Enter
Sleep 3s
@@ -44,7 +44,7 @@ Type "# Apply the migration"
Enter
Sleep 500ms
-Type "sqlspec db upgrade"
+Type "sqlspec upgrade"
Enter
Sleep 3s
@@ -52,7 +52,7 @@ Type "# Check current revision"
Enter
Sleep 500ms
-Type "sqlspec db show-current-revision"
+Type "sqlspec show-current-revision"
Enter
Sleep 3s
diff --git a/docs/extensions/adk/adapters.rst b/docs/extensions/adk/adapters.rst
index 5ff658122..72b0d06b7 100644
--- a/docs/extensions/adk/adapters.rst
+++ b/docs/extensions/adk/adapters.rst
@@ -10,11 +10,36 @@ Choosing an Adapter
Use async adapters for best performance with ADK runners:
-- **PostgreSQL**: ``asyncpg`` (recommended), ``psycopg`` (async mode)
-- **SQLite**: ``aiosqlite``
-- **MySQL**: ``asyncmy``
+- **PostgreSQL** (recommended): ``asyncpg``, ``psycopg`` (async mode), ``psqlpy``
+- **CockroachDB**: ``cockroach_asyncpg``, ``cockroach_psycopg`` (full FTS support)
+- **MySQL/MariaDB**: ``asyncmy``
+- **SQLite**: ``aiosqlite`` (development and single-process)
+- **Oracle**: ``oracledb``
+- **DuckDB**: ``duckdb`` (analytics; reduced-scope for ADK)
+- **ADBC**: ``adbc`` (Arrow-native, driver-agnostic)
+- **Spanner**: ``spanner`` (Google Cloud, globally distributed)
-Sync adapters work but require wrapping with ``anyio`` for async ADK runners.
+Sync adapters (``psycopg`` sync mode, ``sqlite``, ``mysqlconnector``, ``pymysql``)
+work but require wrapping with ``anyio`` for async ADK runners.
+
+Each Adapter Provides
+=====================
+
+Every adapter with ADK support ships three store classes:
+
+- **Session store** (e.g., ``AsyncpgADKStore``) -- sessions and events.
+- **Memory store** (e.g., ``AsyncpgADKMemoryStore``) -- long-term memory with FTS.
+- **Artifact store** (e.g., ``AsyncpgADKArtifactStore``) -- artifact metadata.
+
+Import from the adapter's ``adk`` subpackage:
+
+.. code-block:: python
+
+ from sqlspec.adapters.asyncpg.adk import (
+ AsyncpgADKStore,
+ AsyncpgADKMemoryStore,
+ AsyncpgADKArtifactStore,
+ )
Example
=======
@@ -30,6 +55,6 @@ Example
See Also
========
-- :doc:`backends` for the full adapter support matrix.
+- :doc:`backends` for the full support matrix and backend-specific notes.
- :doc:`/usage/drivers_and_querying` for adapter configuration patterns.
- :doc:`/reference/adapters` for the complete adapter API.
diff --git a/docs/extensions/adk/api.rst b/docs/extensions/adk/api.rst
index 95d2431d4..e95a28a31 100644
--- a/docs/extensions/adk/api.rst
+++ b/docs/extensions/adk/api.rst
@@ -19,8 +19,20 @@ Services
:show-inheritance:
:no-index:
-Base Stores
-===========
+.. autoclass:: sqlspec.extensions.adk.memory.SQLSpecSyncMemoryService
+ :members:
+ :undoc-members:
+ :show-inheritance:
+ :no-index:
+
+.. autoclass:: SQLSpecArtifactService
+ :members:
+ :undoc-members:
+ :show-inheritance:
+ :no-index:
+
+Session Stores
+==============
.. autoclass:: BaseAsyncADKStore
:members:
@@ -34,6 +46,9 @@ Base Stores
:show-inheritance:
:no-index:
+Memory Stores
+=============
+
.. autoclass:: BaseAsyncADKMemoryStore
:members:
:undoc-members:
@@ -45,3 +60,57 @@ Base Stores
:undoc-members:
:show-inheritance:
:no-index:
+
+Artifact Stores
+===============
+
+.. autoclass:: BaseAsyncADKArtifactStore
+ :members:
+ :undoc-members:
+ :show-inheritance:
+ :no-index:
+
+.. autoclass:: BaseSyncADKArtifactStore
+ :members:
+ :undoc-members:
+ :show-inheritance:
+ :no-index:
+
+Record Types
+============
+
+.. autoclass:: SessionRecord
+ :members:
+ :show-inheritance:
+ :no-index:
+
+.. autoclass:: EventRecord
+ :members:
+ :show-inheritance:
+ :no-index:
+
+.. autoclass:: MemoryRecord
+ :members:
+ :show-inheritance:
+ :no-index:
+
+.. autoclass:: ArtifactRecord
+ :members:
+ :show-inheritance:
+ :no-index:
+
+Configuration
+=============
+
+.. autoclass:: ADKConfig
+ :members:
+ :show-inheritance:
+ :no-index:
+
+Converters
+==========
+
+.. automodule:: sqlspec.extensions.adk.converters
+ :members:
+ :undoc-members:
+ :no-index:
diff --git a/docs/extensions/adk/backends.rst b/docs/extensions/adk/backends.rst
index e550d0024..1f770f68d 100644
--- a/docs/extensions/adk/backends.rst
+++ b/docs/extensions/adk/backends.rst
@@ -2,51 +2,249 @@
Backends
========
-ADK stores are implemented per adapter. Use the backend config helpers when
-connecting to multiple databases or configuring advanced options.
+ADK stores are implemented per adapter. Each backend has different capabilities
+for session, event, memory, and artifact storage. Use the support matrix below
+to select the right backend for your deployment.
-Example
-=======
+.. _adk-support-matrix:
-.. literalinclude:: /examples/extensions/adk/backend_config.py
- :language: python
- :caption: ``adk backend config``
- :start-after: # start-example
- :end-before: # end-example
- :dedent: 4
- :no-upgrade:
+Support Matrix
+==============
-Supported Backends
-==================
+The table below classifies every backend by its ADK support level.
.. list-table::
:header-rows: 1
+ :widths: 20 15 15 15 15 20
* - Adapter
- Status
+ - Session/Event
+ - Memory (FTS)
+ - Artifacts
+ - Notes
* - asyncpg
- - Production
- * - psycopg
- - Production
+ - Recommended
+ - Full
+ - Full
+ - Full
+ - Best async PostgreSQL driver.
+ * - psycopg (async)
+ - Recommended
+ - Full
+ - Full
+ - Full
+ - Supports both sync and async modes.
* - psqlpy
- - Production
+ - Supported
+ - Full
+ - Full
+ - Full
+ - Rust-backed PostgreSQL driver.
+ * - cockroach_asyncpg
+ - Supported
+ - Full
+ - Full
+ - Full
+ - CockroachDB with full FTS support.
+ * - cockroach_psycopg
+ - Supported
+ - Full
+ - Full
+ - Full
+ - CockroachDB with full FTS support.
* - asyncmy
- - Production
- * - sqlite
- - Production
+ - Supported
+ - Full
+ - Full
+ - Full
+ - MySQL/MariaDB async driver.
+ * - mysqlconnector
+ - Supported
+ - Full
+ - Full
+ - Full
+ - MySQL/MariaDB sync driver.
+ * - pymysql
+ - Supported
+ - Full
+ - Full
+ - Full
+ - MySQL/MariaDB sync driver.
* - aiosqlite
- - Production
+ - Supported
+ - Full
+ - Full
+ - Full
+ - SQLite async, ideal for development.
+ * - sqlite
+ - Supported
+ - Full
+ - Full
+ - Full
+ - SQLite sync with thread-local pools.
* - oracledb
- - Production
+ - Supported
+ - Full
+ - Full
+ - Full
+ - Oracle Database driver.
* - duckdb
- - Production (analytics)
- * - bigquery
- - Production
+ - Reduced-scope
+ - Full
+ - Limited
+ - Full
+ - Analytics-oriented; no concurrent writes.
* - adbc
- - Production
+ - Supported
+ - Full
+ - Full
+ - Full
+ - Arrow-native database connectivity.
+ * - spanner
+ - Supported
+ - Full
+ - Full
+ - Full
+ - Google Cloud Spanner (cloud-managed).
+
+Status Definitions
+------------------
+
+**Recommended**
+ Production-grade, fully tested, actively optimized. Start here unless you
+ have a specific reason not to.
+
+**Supported**
+ Fully implemented and tested. Works correctly for all ADK operations.
+
+**Reduced-scope**
+ Implemented with known limitations. Specific features may be absent or
+ behave differently. See backend-specific notes.
+
+**Removed**
+ Previously available but no longer supported. See the removal notice for
+ migration guidance.
+
+Removed Backends
+----------------
+
+**BigQuery** was removed from the ADK backend surface. BigQuery's batch-oriented
+architecture is incompatible with the low-latency, transactional write patterns
+that ADK session and event storage require. If you were using BigQuery for ADK
+storage, migrate to PostgreSQL (asyncpg or psycopg) or any other supported
+backend.
+
+Backend Details
+===============
+
+PostgreSQL Family
+-----------------
+
+PostgreSQL backends (asyncpg, psycopg, psqlpy) provide the fullest feature set:
+
+- Native ``JSONB`` storage for session state and event JSON.
+- Full-text search via ``tsvector`` for memory entries.
+- ``UPSERT`` and ``RETURNING`` clauses for atomic operations.
+- ``append_event_and_update_state()`` executes as a single transaction.
+
+**Recommended for production deployments.**
+
+CockroachDB
+------------
+
+CockroachDB backends (cockroach_asyncpg, cockroach_psycopg) provide full ADK
+support including full-text search. CockroachDB is a distributed SQL database
+compatible with the PostgreSQL wire protocol.
+
+- Full FTS support for memory search.
+- Distributed transactions for session and event atomicity.
+- Horizontal scalability for high-throughput agent deployments.
+
+MySQL Family
+------------
+
+MySQL backends (asyncmy, mysqlconnector, pymysql) provide full ADK support:
+
+- JSON column storage for session state and event records.
+- Full-text search on ``InnoDB`` tables for memory entries.
+- Transactional writes for ``append_event_and_update_state()``.
+
+SQLite
+------
+
+SQLite backends (aiosqlite, sqlite) are ideal for local development, testing,
+and single-process deployments:
+
+- JSON1 extension for state and event storage.
+- FTS5 virtual tables for memory full-text search.
+- File-based or in-memory operation.
+
+.. note::
+
+ SQLite does not support concurrent writers. Use a server-backed database
+ for production multi-process deployments.
+
+Oracle
+------
+
+Oracle Database (oracledb) provides full ADK support:
+
+- Native JSON column support (Oracle 21c+).
+- Oracle Text for full-text search on memory entries.
+- Full transactional support for atomic operations.
+
+DuckDB
+------
+
+DuckDB provides session and event storage but has limitations:
+
+- Optimized for analytics, not OLTP workloads.
+- Single-writer constraint limits concurrent access.
+- Memory search capabilities are limited compared to server databases.
+
+**Best suited for analytics pipelines and offline agent evaluation.**
+
+ADBC
+----
+
+ADBC (Arrow Database Connectivity) provides a driver-agnostic interface:
+
+- Works with any ADBC-compatible driver (PostgreSQL, SQLite, DuckDB, etc.).
+- Arrow-native data transfer for high-throughput event ingestion.
+- Backend capabilities depend on the underlying database driver.
+
+Spanner
+-------
+
+Google Cloud Spanner provides globally distributed ADK storage:
+
+- Cloud-managed, horizontally scalable.
+- Full-text search support for memory entries.
+- Strong consistency across regions.
+- Suitable for multi-region agent deployments.
+
+Configuration
+=============
+
+All backends are configured through ``extension_config["adk"]``:
+
+.. code-block:: python
+
+ from sqlspec.adapters.asyncpg import AsyncpgConfig
-Notes
-=====
+ config = AsyncpgConfig(
+ connection_config={"dsn": "postgresql://localhost/mydb"},
+ extension_config={
+ "adk": {
+ "session_table": "adk_sessions",
+ "events_table": "adk_events",
+ "memory_table": "adk_memory_entries",
+ "memory_use_fts": True,
+ "artifact_table": "adk_artifact_versions",
+ "owner_id_column": "tenant_id INTEGER NOT NULL",
+ }
+ },
+ )
-- Use async backends for ADK runners; sync backends can be wrapped with anyio.
-- Backend stores expose ``create_tables`` to bootstrap schema.
+See :doc:`adapters` for adapter-specific configuration patterns.
diff --git a/docs/extensions/adk/index.rst b/docs/extensions/adk/index.rst
index d83c786c1..de771ccc8 100644
--- a/docs/extensions/adk/index.rst
+++ b/docs/extensions/adk/index.rst
@@ -2,8 +2,23 @@
Google ADK Extension
====================
-SQLSpec provides an ADK extension for session, event, and memory storage with
-SQL-backed persistence.
+SQLSpec provides a full-featured backend for
+`Google Agent Development Kit `_,
+covering session, event, memory, and artifact storage with SQL-backed
+persistence across 14 database adapters.
+
+Key capabilities:
+
+- **Session and event storage** with atomic ``append_event_and_update_state()``
+ ensuring events and state are always consistent.
+- **Full-event JSON storage** (EventRecord) that captures the entire ADK Event
+ in a single column, eliminating schema drift with upstream ADK releases.
+- **Scoped state semantics** (``app:``, ``user:``, ``temp:``) for controlling
+ state visibility and persistence across sessions.
+- **Memory service** with database-native full-text search (tsvector, FTS5,
+ InnoDB FT) for long-term agent context.
+- **Artifact service** with append-only versioning, SQL metadata, and pluggable
+ object storage backends.
Choose a guide
==============
@@ -22,45 +37,45 @@ Choose a guide
:link: quickstart
:link-type: doc
- Persist memory and sessions with minimal setup.
+ Persist sessions, memory, and artifacts with minimal setup.
- .. grid-item-card:: API Reference
- :link: api
+ .. grid-item-card:: Support Matrix
+ :link: backends
:link-type: doc
- Interfaces, stores, and configuration helpers.
+ See which backends are recommended, supported, or reduced-scope.
.. grid-item-card:: Adapters
:link: adapters
:link-type: doc
- Configure supported SQLSpec adapters.
+ Configure supported SQLSpec adapters for ADK.
- .. grid-item-card:: Backends
- :link: backends
+ .. grid-item-card:: Schema
+ :link: schema
:link-type: doc
- Storage backends and connection profiles.
+ Table layouts, EventRecord, scoped state, and artifact metadata.
- .. grid-item-card:: Migrations
- :link: migrations
+ .. grid-item-card:: API Reference
+ :link: api
:link-type: doc
- Apply schema changes safely over time.
+ Services, stores, and record types.
- .. grid-item-card:: Schema
- :link: schema
+ .. grid-item-card:: Migrations
+ :link: migrations
:link-type: doc
- Table layouts for sessions and memory records.
+ Apply schema changes safely over time.
.. toctree::
:hidden:
installation
quickstart
- api
- adapters
backends
- migrations
+ adapters
schema
+ api
+ migrations
diff --git a/docs/extensions/adk/installation.rst b/docs/extensions/adk/installation.rst
index 8304d72f7..347f8aa13 100644
--- a/docs/extensions/adk/installation.rst
+++ b/docs/extensions/adk/installation.rst
@@ -6,7 +6,7 @@ Install SQLSpec with a database adapter and the Google ADK SDK.
.. tab-set::
- .. tab-item:: PostgreSQL
+ .. tab-item:: PostgreSQL (recommended)
.. tab-set::
@@ -90,6 +90,34 @@ Install SQLSpec with a database adapter and the Google ADK SDK.
pdm add "sqlspec[asyncmy,adk]"
+ .. tab-item:: CockroachDB
+
+ .. tab-set::
+
+ .. tab-item:: uv
+
+ .. code-block:: bash
+
+ uv add "sqlspec[cockroach-asyncpg,adk]"
+
+ .. tab-item:: pip
+
+ .. code-block:: bash
+
+ pip install "sqlspec[cockroach-asyncpg,adk]"
+
+ .. tab-item:: Poetry
+
+ .. code-block:: bash
+
+ poetry add "sqlspec[cockroach-asyncpg,adk]"
+
+ .. tab-item:: PDM
+
+ .. code-block:: bash
+
+ pdm add "sqlspec[cockroach-asyncpg,adk]"
+
.. tab-item:: DuckDB
.. tab-set::
@@ -123,11 +151,17 @@ What This Provides
The ``adk`` extra includes the Google ADK SDK (``google-genai``). SQLSpec provides:
-- **Session Store** - Persist ADK agent sessions to your database.
-- **Memory Store** - Store agent memory for context across conversations.
-- **Event Store** - Log agent events for observability.
+- **Session Service** -- Persist ADK agent sessions and events to your database
+ with atomic ``append_event_and_update_state()`` writes.
+- **Memory Service** -- Store agent memory with database-native full-text search
+ for context retrieval across conversations.
+- **Artifact Service** -- Version and store binary artifacts with SQL metadata
+ and pluggable object storage backends.
+- **Event Storage** -- Full-event JSON storage (EventRecord) that captures the
+ entire ADK Event without schema drift.
Next Steps
----------
-Proceed to :doc:`quickstart` to set up stores for your ADK agent.
+Proceed to :doc:`quickstart` to set up stores for your ADK agent, or see
+:doc:`backends` for the full support matrix.
diff --git a/docs/extensions/adk/migrations.rst b/docs/extensions/adk/migrations.rst
index 948969ae8..9359fbcd3 100644
--- a/docs/extensions/adk/migrations.rst
+++ b/docs/extensions/adk/migrations.rst
@@ -5,4 +5,63 @@ Migrations
ADK stores use standard SQLSpec migrations. Generate migrations for the database
used by your ADK backend, then run them with the SQLSpec migration CLI.
+Schema Bootstrapping
+====================
+
+You can programmatically create ADK tables with ``create_tables()`` /
+``ensure_tables()``:
+
+.. code-block:: python
+
+ await session_store.ensure_tables()
+ await memory_store.ensure_tables()
+ await artifact_store.ensure_table()
+
+Alternatively, configure SQLSpec migrations on the database config and run the
+migration CLI ahead of deployment:
+
+.. code-block:: python
+
+ from sqlspec.adapters.asyncpg import AsyncpgConfig
+
+ config = AsyncpgConfig(
+ connection_config={"dsn": "postgresql://localhost/app"},
+ migration_config={"script_location": "migrations/postgres"},
+ )
+
+.. code-block:: console
+
+ sqlspec upgrade
+
+Use the programmatic table-creation path when you want the store to bootstrap
+its own schema. Use migrations when you want schema changes tracked and applied
+through your deployment workflow.
+
+.. note::
+
+ The migration CLI resolves configuration from ``--config``,
+ ``SQLSPEC_CONFIG``, or ``[tool.sqlspec]`` in ``pyproject.toml``.
+
+ When ``extension_config["adk"]`` is present, ADK extension migrations are
+ auto-included. Use ``migration_config={"exclude_extensions": ["adk"]}``
+ to skip only ADK extension migrations, or
+ ``migration_config={"include_extensions": ["adk"]}`` to opt in explicitly
+ by extension name. Use ``migration_config={"enabled": False}`` to disable
+ migrations entirely for a given database config.
+
+Clean-Break Migration Notes
+============================
+
+If you are upgrading from a pre-clean-break version of the ADK extension,
+note the following schema changes:
+
+- **Events table**: The column layout changed to full-event JSON storage.
+ The ``event_json`` column now stores the entire ADK Event as a JSON blob.
+ Individual event columns (``content``, ``actions``, ``branch``, etc.) have
+ been replaced by indexed scalar columns (``invocation_id``, ``author``,
+ ``timestamp``) plus ``event_json``.
+- **Artifact table**: New table (``adk_artifact_versions``) for artifact
+ metadata. Create this table when enabling the artifact service.
+- **BigQuery**: Removed. Migrate to PostgreSQL or any other supported backend.
+
See :doc:`/usage/migrations` for the full workflow and commands.
diff --git a/docs/extensions/adk/quickstart.rst b/docs/extensions/adk/quickstart.rst
index 93c829100..69bc0e6b6 100644
--- a/docs/extensions/adk/quickstart.rst
+++ b/docs/extensions/adk/quickstart.rst
@@ -2,56 +2,147 @@
Quickstart
==========
-Wire SQLSpec stores into your ADK agent to persist sessions and memory across restarts.
+Wire SQLSpec stores into your ADK agent to persist sessions, events, memory,
+and artifacts across restarts.
How It Works
============
-1. Create a SQLSpec database config.
-2. Initialize ADK stores (session, memory, event) backed by that config.
-3. Pass the stores to your ADK agent.
+1. Create a SQLSpec database config with ADK extension settings.
+2. Initialize the appropriate stores (session, memory, artifact).
+3. Pass the service wrappers to your ADK agent.
-Session Store
-=============
+Session Service
+===============
-The session store persists agent state between conversations. When a user returns,
-the agent can resume from where it left off.
+The session service persists agent state and events between conversations.
+When a user returns, the agent can resume from where it left off.
-.. literalinclude:: /examples/extensions/adk/memory_store.py
- :language: python
- :caption: ``adk session store``
- :start-after: # start-example
- :end-before: # end-example
- :dedent: 4
- :no-upgrade:
+.. code-block:: python
+
+ from sqlspec.adapters.asyncpg import AsyncpgConfig
+ from sqlspec.adapters.asyncpg.adk import AsyncpgADKStore
+ from sqlspec.extensions.adk import SQLSpecSessionService
+
+ config = AsyncpgConfig(
+ connection_config={"dsn": "postgresql://localhost/mydb"},
+ extension_config={
+ "adk": {
+ "session_table": "adk_sessions",
+ "events_table": "adk_events",
+ }
+ },
+ )
+
+ store = AsyncpgADKStore(config)
+ await store.ensure_tables()
+
+ session_service = SQLSpecSessionService(store)
+
+ # Create a session with scoped state
+ session = await session_service.create_session(
+ app_name="my_agent",
+ user_id="user_123",
+ state={
+ "app:model": "gemini-2.0", # shared across all sessions
+ "user:name": "Alice", # shared across user's sessions
+ "conversation_turn": 0, # session-local
+ "temp:scratch": "...", # runtime-only, never persisted
+ },
+ )
+
+Events are persisted automatically when you use the session service with an
+ADK runner. Each call to ``append_event()`` atomically stores the event and
+updates the session's durable state via ``append_event_and_update_state()``.
+
+Scoped State
+------------
+
+State keys use prefixes to control their scope and persistence:
+
+- ``app:`` -- shared across all sessions for the same application.
+- ``user:`` -- shared across all sessions for the same user.
+- ``temp:`` -- runtime-only, stripped before every write to storage.
+- *(no prefix)* -- private to the current session.
+
+See :ref:`scoped-state` for full details.
+
+Memory Service
+==============
+
+The memory service retains context that the agent can reference later. This
+enables long-term memory across sessions with full-text search.
+
+.. code-block:: python
+
+ from sqlspec.adapters.asyncpg.adk import AsyncpgADKMemoryStore
+ from sqlspec.extensions.adk import SQLSpecMemoryService
+
+ memory_store = AsyncpgADKMemoryStore(config)
+ await memory_store.ensure_tables()
+
+ memory_service = SQLSpecMemoryService(memory_store)
+
+Enable full-text search by setting ``memory_use_fts: True`` in the ADK config.
+This creates database-native FTS indexes (tsvector, FTS5, InnoDB FT) for
+efficient memory retrieval.
+
+Artifact Service
+================
+
+The artifact service stores binary artifacts (files, images, reports) with
+automatic versioning. Metadata lives in SQL; content lives in object storage.
-Memory Store Integration
-========================
+.. code-block:: python
+
+ from sqlspec.adapters.asyncpg.adk import AsyncpgADKArtifactStore
+ from sqlspec.extensions.adk import SQLSpecArtifactService
+
+ artifact_store = AsyncpgADKArtifactStore(config)
+ await artifact_store.ensure_table()
+
+ artifact_service = SQLSpecArtifactService(
+ store=artifact_store,
+ artifact_storage_uri="s3://my-bucket/adk-artifacts/",
+ )
-The memory store retains context that the agent can reference later. This enables
-long-term memory across sessions.
+ # Save an artifact (returns version number starting from 0)
+ version = await artifact_service.save_artifact(
+ app_name="my_agent",
+ user_id="user_123",
+ filename="report.pdf",
+ artifact=part,
+ )
-.. literalinclude:: /examples/extensions/adk/tool_integration.py
- :language: python
- :caption: ``adk memory integration``
- :start-after: # start-example
- :end-before: # end-example
- :dedent: 4
- :no-upgrade:
+ # Load the latest version
+ loaded = await artifact_service.load_artifact(
+ app_name="my_agent",
+ user_id="user_123",
+ filename="report.pdf",
+ )
Schema Setup
============
-Stores create their tables automatically on first use. For production, run migrations
-ahead of time:
+You can programmatically create ADK tables ahead of first use with
+``ensure_tables()`` / ``ensure_table()``:
.. code-block:: python
- await session_store.create_tables()
- await memory_store.create_tables()
+ await session_store.ensure_tables()
+ await memory_store.ensure_tables()
+ await artifact_store.ensure_table()
+
+Alternatively, configure SQLSpec migrations for your database and run the
+migration CLI as part of deployment:
+
+.. code-block:: console
+
+ sqlspec upgrade
Next Steps
==========
-- :doc:`backends` for adapter-specific configuration.
-- :doc:`schema` for table layouts and indexes.
+- :doc:`backends` for the full support matrix and backend-specific details.
+- :doc:`schema` for table layouts, EventRecord format, and scoped state semantics.
+- :doc:`api` for the complete API reference.
diff --git a/docs/extensions/adk/schema.rst b/docs/extensions/adk/schema.rst
index 18a1482bd..82bdd5ef2 100644
--- a/docs/extensions/adk/schema.rst
+++ b/docs/extensions/adk/schema.rst
@@ -2,7 +2,308 @@
Schema
======
-ADK stores create tables for sessions, events, and memory entries. Table names
-and schemas are configurable via the store config.
+ADK stores create tables for sessions, events, memory entries, and artifact
+metadata. Table names are configurable via ``extension_config["adk"]``.
-Use ``create_tables()`` on a store to apply the schema.
+You can programmatically create the schema with ``create_tables()`` or
+``ensure_tables()`` on a store. For managed deployments, configure SQLSpec
+migrations for the target database and run ``sqlspec upgrade`` instead.
+
+.. contents:: On this page
+ :local:
+ :depth: 2
+
+Sessions Table
+==============
+
+The sessions table stores agent session metadata and durable state.
+
+Default name: ``adk_sessions``
+
+.. list-table::
+ :header-rows: 1
+
+ * - Column
+ - Type
+ - Description
+ * - ``id``
+ - ``VARCHAR`` / ``TEXT``
+ - Primary key. UUID assigned by the service layer.
+ * - ``app_name``
+ - ``VARCHAR`` / ``TEXT``
+ - Application identifier.
+ * - ``user_id``
+ - ``VARCHAR`` / ``TEXT``
+ - User identifier.
+ * - ``state``
+ - ``JSONB`` / ``JSON`` / ``TEXT``
+ - Durable session state (see :ref:`scoped-state`).
+ * - ``create_time``
+ - ``TIMESTAMP``
+ - When the session was created (UTC).
+ * - ``update_time``
+ - ``TIMESTAMP``
+ - Last state update time (UTC).
+
+An optional ``owner_id`` column can be added via ``owner_id_column`` in the ADK
+config for multi-tenant deployments.
+
+.. _event-record:
+
+Events Table (EventRecord)
+==========================
+
+The events table uses **full-event JSON storage**: the entire ADK ``Event`` is
+serialized into a single ``event_json`` column alongside a small set of indexed
+scalar columns used for query filtering.
+
+This design eliminates column drift with upstream ADK releases. New ``Event``
+fields are automatically captured in ``event_json`` without schema changes.
+
+Default name: ``adk_events``
+
+.. list-table::
+ :header-rows: 1
+
+ * - Column
+ - Type
+ - Description
+ * - ``session_id``
+ - ``VARCHAR`` / ``TEXT``
+ - Foreign key to the sessions table.
+ * - ``invocation_id``
+ - ``VARCHAR`` / ``TEXT``
+ - ADK invocation identifier (indexed for filtering).
+ * - ``author``
+ - ``VARCHAR`` / ``TEXT``
+ - Event author: ``"user"``, ``"agent"``, or ``"system"``.
+ * - ``timestamp``
+ - ``TIMESTAMP``
+ - Event timestamp (UTC, indexed for range queries).
+ * - ``event_json``
+ - ``JSONB`` / ``JSON`` / ``TEXT``
+ - Full ADK Event serialized via ``Event.model_dump()``.
+
+**Serialization and reconstruction:**
+
+Events are converted to records via ``event_to_record()``, which calls
+``event.model_dump(exclude_none=True, mode="json")`` to produce the JSON blob.
+Reconstruction is lossless: ``record_to_event()`` restores the full ``Event``
+via ``Event.model_validate()``.
+
+.. code-block:: python
+
+ from sqlspec.extensions.adk.converters import event_to_record, record_to_event
+
+ # Serialize: Event -> EventRecord
+ record = event_to_record(event=adk_event, session_id="sess_123")
+
+ # Reconstruct: EventRecord -> Event
+ restored_event = record_to_event(record)
+
+.. _scoped-state:
+
+Scoped State Semantics
+======================
+
+ADK uses key prefixes to scope state visibility across sessions. SQLSpec
+respects these prefixes when persisting and loading state.
+
+.. list-table::
+ :header-rows: 1
+
+ * - Prefix
+ - Scope
+ - Persisted
+ - Description
+ * - ``app:``
+ - Application
+ - Yes
+ - Shared across all sessions for the same ``app_name``.
+ * - ``user:``
+ - User
+ - Yes
+ - Shared across all sessions for the same ``app_name`` + ``user_id``.
+ * - ``temp:``
+ - Runtime
+ - **No**
+ - Process-local state. Stripped before every write to storage.
+ * - *(no prefix)*
+ - Session
+ - Yes
+ - Private to a single session.
+
+**How scoped state is handled:**
+
+1. On ``create_session()``, the service strips ``temp:`` keys before the
+ initial ``INSERT``.
+
+2. On ``append_event()``, the service calls ``filter_temp_state()`` to produce
+ a durable state snapshot, then calls ``append_event_and_update_state()`` to
+ atomically persist the event and the state update.
+
+3. On ``get_session()``, state is loaded from the database. Since ``temp:``
+ keys were never written, they are absent from the loaded state.
+
+.. code-block:: python
+
+ from sqlspec.extensions.adk.converters import filter_temp_state, split_scoped_state
+
+ state = {
+ "app:model_version": "v2",
+ "user:preferences": {"theme": "dark"},
+ "temp:scratch_pad": "...",
+ "conversation_turn": 5,
+ }
+
+ # Strip temp keys before persisting
+ durable = filter_temp_state(state)
+ # {"app:model_version": "v2", "user:preferences": {...}, "conversation_turn": 5}
+
+ # Split into scoped buckets
+ app_state, user_state, session_state = split_scoped_state(durable)
+ # app_state: {"app:model_version": "v2"}
+ # user_state: {"user:preferences": {"theme": "dark"}}
+ # session_state: {"conversation_turn": 5}
+
+.. _append-event-contract:
+
+The ``append_event_and_update_state()`` Contract
+=================================================
+
+This method is the **authoritative durable write boundary** for post-creation
+session mutations. It atomically:
+
+1. Inserts the event record into the events table.
+2. Updates the session's durable state in the sessions table.
+
+Both operations succeed together or fail together within a single database
+transaction.
+
+.. code-block:: python
+
+ # Called by SQLSpecSessionService.append_event() internally:
+ await store.append_event_and_update_state(
+ event_record=event_record,
+ session_id=session.id,
+ state=durable_state, # temp: keys already stripped
+ )
+
+**Why this matters:**
+
+- Prevents state from advancing without the corresponding event being recorded.
+- Prevents orphaned events that reference a stale session state.
+- Ensures that on session reload, the state always reflects all persisted events.
+
+Every backend store implements this as a single transaction (or equivalent
+atomic operation for the backend's concurrency model).
+
+Memory Table
+============
+
+The memory table stores long-term context entries that agents can search and
+reference across sessions.
+
+Default name: ``adk_memory_entries``
+
+.. list-table::
+ :header-rows: 1
+
+ * - Column
+ - Type
+ - Description
+ * - ``id``
+ - ``VARCHAR`` / ``TEXT``
+ - Primary key.
+ * - ``session_id``
+ - ``VARCHAR`` / ``TEXT``
+ - Session that produced this memory.
+ * - ``app_name``
+ - ``VARCHAR`` / ``TEXT``
+ - Application identifier.
+ * - ``user_id``
+ - ``VARCHAR`` / ``TEXT``
+ - User identifier.
+ * - ``content_text``
+ - ``TEXT``
+ - Searchable text content (used by FTS).
+ * - ``content_json``
+ - ``JSONB`` / ``JSON`` / ``TEXT``
+ - Structured content.
+ * - ``inserted_at``
+ - ``TIMESTAMP``
+ - When the entry was created.
+
+When ``memory_use_fts`` is enabled in the ADK config, backends create
+full-text search indexes on ``content_text`` using the database's native
+FTS engine (tsvector, FTS5, InnoDB FT, etc.).
+
+.. _artifact-schema:
+
+Artifact Metadata Table
+=======================
+
+The artifact table stores versioning metadata for binary artifacts. Content
+bytes are stored separately in object storage; this table tracks ownership,
+versioning, and canonical URIs.
+
+Default name: ``adk_artifact_versions``
+
+.. list-table::
+ :header-rows: 1
+
+ * - Column
+ - Type
+ - Description
+ * - ``app_name``
+ - ``VARCHAR`` / ``TEXT``
+ - Application identifier.
+ * - ``user_id``
+ - ``VARCHAR`` / ``TEXT``
+ - User identifier.
+ * - ``session_id``
+ - ``VARCHAR`` / ``TEXT`` (nullable)
+ - Session identifier. NULL for user-scoped artifacts.
+ * - ``filename``
+ - ``VARCHAR`` / ``TEXT``
+ - Artifact filename.
+ * - ``version``
+ - ``INTEGER``
+ - Monotonically increasing version (starts at 0).
+ * - ``mime_type``
+ - ``VARCHAR`` / ``TEXT`` (nullable)
+ - MIME type of the artifact content.
+ * - ``canonical_uri``
+ - ``VARCHAR`` / ``TEXT``
+ - URI pointing to content in object storage.
+ * - ``custom_metadata``
+ - ``JSONB`` / ``JSON`` / ``TEXT`` (nullable)
+ - User-defined metadata.
+ * - ``created_at``
+ - ``TIMESTAMP``
+ - When this version was created.
+
+The composite key is ``(app_name, user_id, session_id, filename, version)``.
+
+Table Name Configuration
+========================
+
+All table names are configurable:
+
+.. code-block:: python
+
+ config = AsyncpgConfig(
+ connection_config={"dsn": "postgresql://..."},
+ extension_config={
+ "adk": {
+ "session_table": "my_sessions", # default: "adk_sessions"
+ "events_table": "my_events", # default: "adk_events"
+ "memory_table": "my_memory", # default: "adk_memory_entries"
+ "artifact_table": "my_artifacts", # default: "adk_artifact_versions"
+ }
+ },
+ )
+
+Table names are validated on store initialization: they must start with a
+letter or underscore, contain only alphanumeric characters and underscores,
+and be at most 63 characters long.
diff --git a/docs/reference/extensions/adk.rst b/docs/reference/extensions/adk.rst
index aa0ca33a9..3c08b6ed7 100644
--- a/docs/reference/extensions/adk.rst
+++ b/docs/reference/extensions/adk.rst
@@ -2,7 +2,7 @@
Google ADK
==========
-Session, event, and memory storage backends for
+Session, event, memory, and artifact storage backends for
`Google Agent Development Kit `_.
Session Service
@@ -23,6 +23,13 @@ Memory Services
:members:
:show-inheritance:
+Artifact Service
+================
+
+.. autoclass:: sqlspec.extensions.adk.SQLSpecArtifactService
+ :members:
+ :show-inheritance:
+
Store Base Classes
==================
@@ -45,6 +52,17 @@ Memory Store Base Classes
:members:
:show-inheritance:
+Artifact Store Base Classes
+===========================
+
+.. autoclass:: sqlspec.extensions.adk.BaseAsyncADKArtifactStore
+ :members:
+ :show-inheritance:
+
+.. autoclass:: sqlspec.extensions.adk.BaseSyncADKArtifactStore
+ :members:
+ :show-inheritance:
+
Record Types
============
@@ -59,3 +77,14 @@ Record Types
.. autoclass:: sqlspec.extensions.adk.MemoryRecord
:members:
:show-inheritance:
+
+.. autoclass:: sqlspec.extensions.adk.ArtifactRecord
+ :members:
+ :show-inheritance:
+
+Configuration
+=============
+
+.. autoclass:: sqlspec.extensions.adk.ADKConfig
+ :members:
+ :show-inheritance:
diff --git a/docs/usage/cli.rst b/docs/usage/cli.rst
index f1ae3cef9..1bbdfe740 100644
--- a/docs/usage/cli.rst
+++ b/docs/usage/cli.rst
@@ -4,15 +4,18 @@ Command Line Interface
SQLSpec includes a CLI for managing migrations and inspecting configuration. Use it
when you want a fast, explicit workflow without additional tooling.
+Configuration can come from ``--config``, ``SQLSPEC_CONFIG``, or
+``[tool.sqlspec]`` in ``pyproject.toml``.
+
Core Commands
-------------
.. code-block:: console
- sqlspec db init
- sqlspec db create-migration -m "add users"
- sqlspec db upgrade
- sqlspec db downgrade
+ sqlspec init
+ sqlspec create-migration -m "add users"
+ sqlspec upgrade
+ sqlspec downgrade
Common Options
--------------
@@ -28,7 +31,7 @@ Tips
----
- Run ``sqlspec --help`` to see global options.
-- Run ``sqlspec db --help`` to see migration command details.
+- Run ``sqlspec upgrade --help`` to see command-specific migration options.
Related Guides
--------------
diff --git a/docs/usage/migrations.rst b/docs/usage/migrations.rst
index 7d4aa5826..2e5cd1b62 100644
--- a/docs/usage/migrations.rst
+++ b/docs/usage/migrations.rst
@@ -21,9 +21,9 @@ Common Commands
.. code-block:: console
- sqlspec db init
- sqlspec db create-migration -m "add users"
- sqlspec db upgrade
+ sqlspec init
+ sqlspec create-migration -m "add users"
+ sqlspec upgrade
Configuration
-------------
@@ -31,6 +31,9 @@ Configuration
Set ``migration_config`` on your database configuration to customize script
locations, version table names, and extension migration behavior.
+The migration CLI resolves config from ``--config``, ``SQLSPEC_CONFIG``, or
+``[tool.sqlspec]`` in ``pyproject.toml``.
+
.. code-block:: python
from sqlspec.adapters.duckdb import DuckDBConfig
@@ -38,7 +41,7 @@ locations, version table names, and extension migration behavior.
config = DuckDBConfig(
connection_config={"database": "/tmp/analytics.db"},
migration_config={
- "migration_dir": "migrations/duckdb",
+ "script_location": "migrations/duckdb",
"version_table": "_schema_versions",
},
)
@@ -60,11 +63,17 @@ For async configs, ``migrate_up()`` returns an awaitable:
config = AsyncpgConfig(
connection_config={"dsn": "postgresql://localhost/app"},
- migration_config={"migration_dir": "migrations/postgres"},
+ migration_config={"script_location": "migrations/postgres"},
)
await config.migrate_up()
+Extension migrations are auto-included when the corresponding entry exists in
+``extension_config``. Use ``migration_config["exclude_extensions"]`` to skip a
+specific extension, ``migration_config["include_extensions"]`` to opt in
+explicitly by extension name, or ``migration_config["enabled"] = False`` to
+disable migrations entirely for a database config.
+
Logging and Echo Controls
-------------------------
diff --git a/pyproject.toml b/pyproject.toml
index cd9e7484a..53240438b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -24,7 +24,7 @@ maintainers = [{ name = "Litestar Developers", email = "hello@litestar.dev" }]
name = "sqlspec"
readme = "README.md"
requires-python = ">=3.10, <4.0"
-version = "0.41.1"
+version = "0.42.0"
[project.urls]
Discord = "https://discord.gg/litestar"
@@ -43,7 +43,7 @@ attrs = ["attrs", "cattrs"]
bigquery = ["google-cloud-bigquery", "google-cloud-storage"]
cloud-sql = ["cloud-sql-python-connector"]
cockroachdb = ["psycopg[binary,pool]", "asyncpg"]
-duckdb = ["duckdb"]
+duckdb = ["duckdb", "pytz"]
fastapi = ["fastapi"]
flask = ["flask"]
fsspec = ["fsspec"]
@@ -254,7 +254,7 @@ opt_level = "3" # Maximum optimization (0-3)
allow_dirty = true
commit = false
commit_args = "--no-verify"
-current_version = "0.41.1"
+current_version = "0.42.0"
ignore_missing_files = false
ignore_missing_version = false
message = "chore(release): bump to v{new_version}"
@@ -296,7 +296,7 @@ version = "{current_version}"
"""
[tool.codespell]
-ignore-words-list = "te,ECT,SELCT,froms,ccompiler"
+ignore-words-list = "te,ECT,SELCT,froms,ccompiler,BRIN"
skip = 'uv.lock'
[tool.coverage.run]
diff --git a/sqlspec/adapters/adbc/adk/store.py b/sqlspec/adapters/adbc/adk/store.py
index 65e5c4975..50d6c72f4 100644
--- a/sqlspec/adapters/adbc/adk/store.py
+++ b/sqlspec/adapters/adbc/adk/store.py
@@ -1,12 +1,14 @@
"""ADBC ADK store for Google Agent Development Kit session/event storage."""
+import contextlib
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Final
-from sqlspec.extensions.adk import BaseSyncADKStore, EventRecord, SessionRecord
-from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore
+from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord
+from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore
from sqlspec.utils.logging import get_logger
from sqlspec.utils.serializers import from_json, to_json
+from sqlspec.utils.sync_tools import async_, run_
if TYPE_CHECKING:
from sqlspec.adapters.adbc.config import AdbcConfig
@@ -25,16 +27,23 @@
ADBC_TABLE_NOT_FOUND_PATTERNS: Final = ("no such table", "table or view does not exist", "relation does not exist")
-class AdbcADKStore(BaseSyncADKStore["AdbcConfig"]):
+class AdbcADKStore(BaseAsyncADKStore["AdbcConfig"]):
"""ADBC synchronous ADK store for Arrow Database Connectivity.
Implements session and event storage for Google Agent Development Kit
using ADBC. ADBC provides a vendor-neutral API with Arrow-native data
transfer across multiple databases (PostgreSQL, SQLite, DuckDB, etc.).
+ Events use the new 5-column contract: session_id, invocation_id, author,
+ timestamp, and event_json. The full ADK Event payload is stored as a
+ single JSON blob in event_json using a dialect-appropriate column type
+ (JSONB for PostgreSQL, JSON for DuckDB, VARIANT for Snowflake, TEXT for
+ SQLite and generic fallback).
+
Provides:
- - Session state management with JSON serialization (TEXT storage)
- - Event history tracking with BLOB-serialized actions
+ - Session state management with JSON serialization
+ - Event history tracking via single event_json blob
+ - Atomic event insert + session state update
- Timezone-aware timestamps
- Foreign key constraints with cascade delete
- Database-agnostic SQL (supports multiple backends)
@@ -60,12 +69,9 @@ class AdbcADKStore(BaseSyncADKStore["AdbcConfig"]):
store.ensure_tables()
Notes:
- - TEXT for JSON storage (compatible across all ADBC backends)
- - BLOB for pre-serialized actions from Google ADK
+ - Dialect-appropriate JSON type for event_json storage
- TIMESTAMP for timezone-aware timestamps (driver-dependent precision)
- - INTEGER for booleans (0/1/NULL)
- - Parameter style varies by backend (?, $1, :name, etc.)
- - Uses dialect-agnostic SQL for maximum compatibility
+ - Parameter style: ``?`` universally across ADBC backends
- State and JSON fields use to_json/from_json for serialization
- ADBC drivers handle parameter binding automatically
- Configuration is read from config.extension_config["adk"]
@@ -171,7 +177,7 @@ def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None":
return None
return from_json(str(data)) # type: ignore[no-any-return]
- def _get_create_sessions_table_sql(self) -> str:
+ async def _get_create_sessions_table_sql(self) -> str:
"""Get CREATE TABLE SQL for sessions with dialect dispatch.
Returns:
@@ -277,7 +283,7 @@ def _get_sessions_ddl_generic(self) -> str:
)
"""
- def _get_create_events_table_sql(self) -> str:
+ async def _get_create_events_table_sql(self) -> str:
"""Get CREATE TABLE SQL for events with dialect dispatch.
Returns:
@@ -298,27 +304,17 @@ def _get_events_ddl_postgresql(self) -> str:
Returns:
SQL to create events table optimized for PostgreSQL.
+
+ Notes:
+ Uses JSONB for event_json to enable indexing and query support.
"""
return f"""
CREATE TABLE IF NOT EXISTS {self._events_table} (
- id VARCHAR(128) PRIMARY KEY,
session_id VARCHAR(128) NOT NULL,
- app_name VARCHAR(128) NOT NULL,
- user_id VARCHAR(128) NOT NULL,
- invocation_id VARCHAR(256),
- author VARCHAR(256),
- actions BYTEA,
- long_running_tool_ids_json TEXT,
- branch VARCHAR(256),
+ invocation_id VARCHAR(256) NOT NULL,
+ author VARCHAR(256) NOT NULL,
timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
- content JSONB,
- grounding_metadata JSONB,
- custom_metadata JSONB,
- partial BOOLEAN,
- turn_complete BOOLEAN,
- interrupted BOOLEAN,
- error_code VARCHAR(256),
- error_message VARCHAR(1024),
+ event_json JSONB NOT NULL,
FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE
)
"""
@@ -328,27 +324,17 @@ def _get_events_ddl_sqlite(self) -> str:
Returns:
SQL to create events table optimized for SQLite.
+
+ Notes:
+ Uses TEXT for event_json (SQLite has no native JSON column type).
"""
return f"""
CREATE TABLE IF NOT EXISTS {self._events_table} (
- id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
- app_name TEXT NOT NULL,
- user_id TEXT NOT NULL,
- invocation_id TEXT,
- author TEXT,
- actions BLOB,
- long_running_tool_ids_json TEXT,
- branch TEXT,
+ invocation_id TEXT NOT NULL,
+ author TEXT NOT NULL,
timestamp REAL NOT NULL,
- content TEXT,
- grounding_metadata TEXT,
- custom_metadata TEXT,
- partial INTEGER,
- turn_complete INTEGER,
- interrupted INTEGER,
- error_code TEXT,
- error_message TEXT,
+ event_json TEXT NOT NULL,
FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE
)
"""
@@ -358,27 +344,17 @@ def _get_events_ddl_duckdb(self) -> str:
Returns:
SQL to create events table optimized for DuckDB.
+
+ Notes:
+ Uses JSON for event_json (DuckDB native JSON type).
"""
return f"""
CREATE TABLE IF NOT EXISTS {self._events_table} (
- id VARCHAR(128) PRIMARY KEY,
session_id VARCHAR(128) NOT NULL,
- app_name VARCHAR(128) NOT NULL,
- user_id VARCHAR(128) NOT NULL,
- invocation_id VARCHAR(256),
- author VARCHAR(256),
- actions BLOB,
- long_running_tool_ids_json VARCHAR,
- branch VARCHAR(256),
+ invocation_id VARCHAR(256) NOT NULL,
+ author VARCHAR(256) NOT NULL,
timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
- content JSON,
- grounding_metadata JSON,
- custom_metadata JSON,
- partial BOOLEAN,
- turn_complete BOOLEAN,
- interrupted BOOLEAN,
- error_code VARCHAR(256),
- error_message VARCHAR(1024),
+ event_json JSON NOT NULL,
FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE
)
"""
@@ -388,27 +364,17 @@ def _get_events_ddl_snowflake(self) -> str:
Returns:
SQL to create events table optimized for Snowflake.
+
+ Notes:
+ Uses VARIANT for event_json (Snowflake semi-structured type).
"""
return f"""
CREATE TABLE IF NOT EXISTS {self._events_table} (
- id VARCHAR PRIMARY KEY,
session_id VARCHAR NOT NULL,
- app_name VARCHAR NOT NULL,
- user_id VARCHAR NOT NULL,
- invocation_id VARCHAR,
- author VARCHAR,
- actions BINARY,
- long_running_tool_ids_json VARCHAR,
- branch VARCHAR,
+ invocation_id VARCHAR NOT NULL,
+ author VARCHAR NOT NULL,
timestamp TIMESTAMP_TZ NOT NULL DEFAULT CURRENT_TIMESTAMP(),
- content VARIANT,
- grounding_metadata VARIANT,
- custom_metadata VARIANT,
- partial BOOLEAN,
- turn_complete BOOLEAN,
- interrupted BOOLEAN,
- error_code VARCHAR,
- error_message VARCHAR,
+ event_json VARIANT NOT NULL,
FOREIGN KEY (session_id) REFERENCES {self._session_table}(id)
)
"""
@@ -418,27 +384,17 @@ def _get_events_ddl_generic(self) -> str:
Returns:
SQL to create events table using generic types.
+
+ Notes:
+ Uses TEXT for event_json (maximum portability).
"""
return f"""
CREATE TABLE IF NOT EXISTS {self._events_table} (
- id VARCHAR(128) PRIMARY KEY,
session_id VARCHAR(128) NOT NULL,
- app_name VARCHAR(128) NOT NULL,
- user_id VARCHAR(128) NOT NULL,
- invocation_id VARCHAR(256),
- author VARCHAR(256),
- actions BLOB,
- long_running_tool_ids_json TEXT,
- branch VARCHAR(256),
+ invocation_id VARCHAR(256) NOT NULL,
+ author VARCHAR(256) NOT NULL,
timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
- content TEXT,
- grounding_metadata TEXT,
- custom_metadata TEXT,
- partial INTEGER,
- turn_complete INTEGER,
- interrupted INTEGER,
- error_code VARCHAR(256),
- error_message VARCHAR(1024),
+ event_json TEXT NOT NULL,
FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE
)
"""
@@ -455,14 +411,14 @@ def _get_drop_tables_sql(self) -> "list[str]":
"""
return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"]
- def create_tables(self) -> None:
+ def _create_tables(self) -> None:
"""Create both sessions and events tables if they don't exist."""
with self._config.provide_connection() as conn:
cursor = conn.cursor()
try:
self._enable_foreign_keys(cursor, conn)
- cursor.execute(self._get_create_sessions_table_sql())
+ cursor.execute(run_(self._get_create_sessions_table_sql)())
conn.commit()
sessions_idx_app_user = (
@@ -479,7 +435,7 @@ def create_tables(self) -> None:
cursor.execute(sessions_idx_update)
conn.commit()
- cursor.execute(self._get_create_events_table_sql())
+ cursor.execute(run_(self._get_create_events_table_sql)())
conn.commit()
events_idx = (
@@ -491,6 +447,10 @@ def create_tables(self) -> None:
finally:
cursor.close() # type: ignore[no-untyped-call]
+ async def create_tables(self) -> None:
+ """Create tables if they don't exist."""
+ await async_(self._create_tables)()
+
def _enable_foreign_keys(self, cursor: Any, conn: Any) -> None:
"""Enable foreign key constraints for SQLite.
@@ -508,7 +468,7 @@ def _enable_foreign_keys(self, cursor: Any, conn: Any) -> None:
except Exception:
logger.debug("Foreign key enforcement not supported or already enabled")
- def create_session(
+ def _create_session(
self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
) -> SessionRecord:
"""Create a new session.
@@ -548,9 +508,19 @@ def create_session(
finally:
cursor.close() # type: ignore[no-untyped-call]
- return self.get_session(session_id) # type: ignore[return-value]
+ result = self._get_session(session_id)
+ if result is None:
+ msg = "Failed to fetch created session"
+ raise RuntimeError(msg)
+ return result
+
+ async def create_session(
+ self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
+ ) -> SessionRecord:
+ """Create a new session."""
+ return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id)
- def get_session(self, session_id: str) -> "SessionRecord | None":
+ def _get_session(self, session_id: str) -> "SessionRecord | None":
"""Get session by ID.
Args:
@@ -594,7 +564,11 @@ def get_session(self, session_id: str) -> "SessionRecord | None":
return None
raise
- def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
+ async def get_session(self, session_id: str) -> "SessionRecord | None":
+ """Get session by ID."""
+ return await async_(self._get_session)(session_id)
+
+ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
"""Update session state.
Args:
@@ -620,7 +594,11 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None
finally:
cursor.close() # type: ignore[no-untyped-call]
- def delete_session(self, session_id: str) -> None:
+ async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
+ """Update session state."""
+ await async_(self._update_session_state)(session_id, state)
+
+ def _delete_session(self, session_id: str) -> None:
"""Delete session and all associated events (cascade).
Args:
@@ -640,7 +618,11 @@ def delete_session(self, session_id: str) -> None:
finally:
cursor.close() # type: ignore[no-untyped-call]
- def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
+ async def delete_session(self, session_id: str) -> None:
+ """Delete session and associated events."""
+ await async_(self._delete_session)(session_id)
+
+ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
"""List sessions for an app, optionally filtered by user.
Args:
@@ -696,149 +678,144 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess
return []
raise
- def create_event(
- self,
- event_id: str,
- session_id: str,
- app_name: str,
- user_id: str,
- author: "str | None" = None,
- actions: "bytes | None" = None,
- content: "dict[str, Any] | None" = None,
- **kwargs: Any,
- ) -> "EventRecord":
- """Create a new event.
+ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
+ """List sessions for an app."""
+ return await async_(self._list_sessions)(app_name, user_id)
- Args:
- event_id: Unique event identifier.
- session_id: Session identifier.
- app_name: Application name.
- user_id: User identifier.
- author: Event author (user/assistant/system).
- actions: Pickled actions object.
- content: Event content (JSON).
- **kwargs: Additional optional fields.
-
- Returns:
- Created event record.
+ def _insert_event(self, event_record: "EventRecord") -> None:
+ """Insert an event record into the events table.
- Notes:
- Uses CURRENT_TIMESTAMP for timestamp if not provided.
- JSON fields are serialized to JSON strings.
- Boolean fields are converted to INTEGER (0/1).
+ Args:
+ event_record: Event record to store.
"""
- content_json = self._serialize_json_field(content)
- grounding_metadata_json = self._serialize_json_field(kwargs.get("grounding_metadata"))
- custom_metadata_json = self._serialize_json_field(kwargs.get("custom_metadata"))
-
- partial_int = self._to_int_bool(kwargs.get("partial"))
- turn_complete_int = self._to_int_bool(kwargs.get("turn_complete"))
- interrupted_int = self._to_int_bool(kwargs.get("interrupted"))
-
+ event_json = self._serialize_json_field(event_record["event_json"])
sql = f"""
INSERT INTO {self._events_table} (
- id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
- ) VALUES (
- ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
- )
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES (?, ?, ?, ?, ?)
"""
- timestamp = kwargs.get("timestamp")
- if timestamp is None:
- timestamp = datetime.now(timezone.utc)
-
with self._config.provide_connection() as conn:
cursor = conn.cursor()
try:
cursor.execute(
sql,
(
- event_id,
- session_id,
- app_name,
- user_id,
- kwargs.get("invocation_id"),
- author,
- actions,
- kwargs.get("long_running_tool_ids_json"),
- kwargs.get("branch"),
- timestamp,
- content_json,
- grounding_metadata_json,
- custom_metadata_json,
- partial_int,
- turn_complete_int,
- interrupted_int,
- kwargs.get("error_code"),
- kwargs.get("error_message"),
+ event_record["session_id"],
+ event_record["invocation_id"],
+ event_record["author"],
+ event_record["timestamp"],
+ event_json,
),
)
conn.commit()
finally:
cursor.close() # type: ignore[no-untyped-call]
- events = self.list_events(session_id)
- for event in events:
- if event["id"] == event_id:
- return event
+ def _append_event_and_update_state(
+ self, event_record: "EventRecord", session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ """Atomically insert an event and update the session's durable state.
- msg = f"Failed to retrieve created event {event_id}"
- raise RuntimeError(msg)
+ The event insert and state update are executed within a single
+ connection and committed together. If either statement fails the
+ transaction is rolled back so the two writes remain consistent.
- def list_events(self, session_id: str) -> "list[EventRecord]":
+ Args:
+ event_record: Event record to store.
+ session_id: Session identifier whose state should be updated.
+ state: Post-append durable state snapshot (``temp:`` keys already
+ stripped by the service layer).
+ """
+ insert_sql = f"""
+ INSERT INTO {self._events_table} (
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES (?, ?, ?, ?, ?)
+ """
+ update_sql = f"""
+ UPDATE {self._session_table}
+ SET state = ?, update_time = CURRENT_TIMESTAMP
+ WHERE id = ?
+ """
+ state_json = self._serialize_state(state)
+ event_json = self._serialize_json_field(event_record["event_json"])
+
+ with self._config.provide_connection() as conn:
+ cursor = conn.cursor()
+ try:
+ cursor.execute(
+ insert_sql,
+ (
+ event_record["session_id"],
+ event_record["invocation_id"],
+ event_record["author"],
+ event_record["timestamp"],
+ event_json,
+ ),
+ )
+ cursor.execute(update_sql, (state_json, session_id))
+ conn.commit()
+ except Exception:
+ with contextlib.suppress(Exception):
+ conn.rollback()
+ raise
+ finally:
+ cursor.close() # type: ignore[no-untyped-call]
+
+ async def append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ """Atomically append an event and update the session's durable state."""
+ await async_(self._append_event_and_update_state)(event_record, session_id, state)
+
+ def _get_events(
+ self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
+ ) -> "list[EventRecord]":
"""List events for a session ordered by timestamp.
Args:
session_id: Session identifier.
+ after_timestamp: Only return events after this time.
+ limit: Maximum number of events to return.
Returns:
List of event records ordered by timestamp ASC.
Notes:
Uses index on (session_id, timestamp ASC).
- JSON fields deserialized from JSON strings.
- Converts INTEGER booleans to Python bool.
+ Returns the 5-column EventRecord (session_id, invocation_id,
+ author, timestamp, event_json).
"""
+ where_clauses = ["session_id = ?"]
+ params: list[Any] = [session_id]
+
+ if after_timestamp is not None:
+ where_clauses.append("timestamp > ?")
+ params.append(after_timestamp)
+
+ where_clause = " AND ".join(where_clauses)
+ limit_clause = f" LIMIT {limit}" if limit else ""
sql = f"""
- SELECT id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
+ SELECT session_id, invocation_id, author, timestamp, event_json
FROM {self._events_table}
- WHERE session_id = ?
- ORDER BY timestamp ASC
+ WHERE {where_clause}
+ ORDER BY timestamp ASC{limit_clause}
"""
try:
with self._config.provide_connection() as conn:
cursor = conn.cursor()
try:
- cursor.execute(sql, (session_id,))
+ cursor.execute(sql, params)
rows = cursor.fetchall()
return [
EventRecord(
- id=row[0],
- session_id=row[1],
- app_name=row[2],
- user_id=row[3],
- invocation_id=row[4],
- author=row[5],
- actions=bytes(row[6]) if row[6] is not None else b"",
- long_running_tool_ids_json=row[7],
- branch=row[8],
- timestamp=row[9],
- content=self._deserialize_json_field(row[10]),
- grounding_metadata=self._deserialize_json_field(row[11]),
- custom_metadata=self._deserialize_json_field(row[12]),
- partial=self._from_int_bool(row[13]),
- turn_complete=self._from_int_bool(row[14]),
- interrupted=self._from_int_bool(row[15]),
- error_code=row[16],
- error_message=row[17],
+ session_id=row[0],
+ invocation_id=row[1],
+ author=row[2],
+ timestamp=row[3],
+ event_json=self._deserialize_json_field(row[4]) or {},
)
for row in rows
]
@@ -850,36 +827,22 @@ def list_events(self, session_id: str) -> "list[EventRecord]":
return []
raise
- @staticmethod
- def _to_int_bool(value: "bool | None") -> "int | None":
- """Convert Python boolean to INTEGER (0/1).
+ async def get_events(
+ self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
+ ) -> "list[EventRecord]":
+ """Get events for a session."""
+ return await async_(self._get_events)(session_id, after_timestamp, limit)
- Args:
- value: Python boolean value or None.
-
- Returns:
- 1 for True, 0 for False, None for None.
- """
- if value is None:
- return None
- return 1 if value else 0
+ def _append_event(self, event_record: EventRecord) -> None:
+ """Synchronous implementation of append_event."""
+ self._insert_event(event_record)
- @staticmethod
- def _from_int_bool(value: "int | None") -> "bool | None":
- """Convert INTEGER to Python boolean.
+ async def append_event(self, event_record: EventRecord) -> None:
+ """Append an event to a session."""
+ await async_(self._append_event)(event_record)
- Args:
- value: INTEGER value (0, 1, or None).
- Returns:
- Python boolean or None.
- """
- if value is None:
- return None
- return bool(value)
-
-
-class AdbcADKMemoryStore(BaseSyncADKMemoryStore["AdbcConfig"]):
+class AdbcADKMemoryStore(BaseAsyncADKMemoryStore["AdbcConfig"]):
"""ADBC synchronous ADK memory store for Arrow Database Connectivity."""
__slots__ = ("_dialect",)
@@ -924,7 +887,7 @@ def _decode_timestamp(self, value: Any) -> datetime:
return datetime.fromisoformat(value)
return datetime.fromisoformat(str(value))
- def _get_create_memory_table_sql(self) -> str:
+ async def _get_create_memory_table_sql(self) -> str:
if self._dialect == DIALECT_POSTGRESQL:
return self._get_memory_ddl_postgresql()
if self._dialect == DIALECT_SQLITE:
@@ -1028,14 +991,14 @@ def _get_memory_ddl_generic(self) -> str:
def _get_drop_memory_table_sql(self) -> "list[str]":
return [f"DROP TABLE IF EXISTS {self._memory_table}"]
- def create_tables(self) -> None:
+ def _create_tables(self) -> None:
if not self._enabled:
return
with self._config.provide_connection() as conn:
cursor = conn.cursor()
try:
- cursor.execute(self._get_create_memory_table_sql())
+ cursor.execute(run_(self._get_create_memory_table_sql)())
conn.commit()
idx_app_user = (
@@ -1053,7 +1016,11 @@ def create_tables(self) -> None:
finally:
cursor.close() # type: ignore[no-untyped-call]
- def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
+ async def create_tables(self) -> None:
+ """Create tables if they don't exist."""
+ await async_(self._create_tables)()
+
+ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
if not self._enabled:
msg = "Memory store is disabled"
raise RuntimeError(msg)
@@ -1159,7 +1126,11 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object
return inserted_count
- def search_entries(
+ async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
+ """Bulk insert memory entries with deduplication."""
+ return await async_(self._insert_memory_entries)(entries, owner_id)
+
+ def _search_entries(
self, query: str, app_name: str, user_id: str, limit: "int | None" = None
) -> "list[MemoryRecord]":
if not self._enabled:
@@ -1199,7 +1170,13 @@ def search_entries(
return self._rows_to_records(rows)
- def delete_entries_by_session(self, session_id: str) -> int:
+ async def search_entries(
+ self, query: str, app_name: str, user_id: str, limit: "int | None" = None
+ ) -> "list[MemoryRecord]":
+ """Search memory entries by text query."""
+ return await async_(self._search_entries)(query, app_name, user_id, limit)
+
+ def _delete_entries_by_session(self, session_id: str) -> int:
use_returning = self._dialect in {DIALECT_SQLITE, DIALECT_POSTGRESQL, DIALECT_DUCKDB}
if use_returning:
sql = f"DELETE FROM {self._memory_table} WHERE session_id = ? RETURNING 1"
@@ -1218,7 +1195,11 @@ def delete_entries_by_session(self, session_id: str) -> int:
finally:
cursor.close() # type: ignore[no-untyped-call]
- def delete_entries_older_than(self, days: int) -> int:
+ async def delete_entries_by_session(self, session_id: str) -> int:
+ """Delete all memory entries for a specific session."""
+ return await async_(self._delete_entries_by_session)(session_id)
+
+ def _delete_entries_older_than(self, days: int) -> int:
cutoff = self._encode_timestamp(datetime.now(timezone.utc) - timedelta(days=days))
use_returning = self._dialect in {DIALECT_SQLITE, DIALECT_POSTGRESQL, DIALECT_DUCKDB}
if use_returning:
@@ -1238,6 +1219,10 @@ def delete_entries_older_than(self, days: int) -> int:
finally:
cursor.close() # type: ignore[no-untyped-call]
+ async def delete_entries_older_than(self, days: int) -> int:
+ """Delete memory entries older than specified days."""
+ return await async_(self._delete_entries_older_than)(days)
+
def _rows_to_records(self, rows: "list[Any]") -> "list[MemoryRecord]":
records: list[MemoryRecord] = []
for row in rows:
diff --git a/sqlspec/adapters/aiosqlite/adk/store.py b/sqlspec/adapters/aiosqlite/adk/store.py
index 7a63095f1..f68805297 100644
--- a/sqlspec/adapters/aiosqlite/adk/store.py
+++ b/sqlspec/adapters/aiosqlite/adk/store.py
@@ -52,34 +52,6 @@ def _julian_to_datetime(julian: float) -> datetime:
return datetime.fromtimestamp(timestamp, tz=timezone.utc)
-def _to_sqlite_bool(value: "bool | None") -> "int | None":
- """Convert Python bool to SQLite INTEGER.
-
- Args:
- value: Boolean value or None.
-
- Returns:
- 1 for True, 0 for False, None for None.
- """
- if value is None:
- return None
- return 1 if value else 0
-
-
-def _from_sqlite_bool(value: "int | None") -> "bool | None":
- """Convert SQLite INTEGER to Python bool.
-
- Args:
- value: Integer value (0/1) or None.
-
- Returns:
- True for 1, False for 0, None for None.
- """
- if value is None:
- return None
- return bool(value)
-
-
class AiosqliteADKStore(BaseAsyncADKStore["AiosqliteConfig"]):
"""Aiosqlite ADK store using asynchronous SQLite driver.
@@ -88,10 +60,11 @@ class AiosqliteADKStore(BaseAsyncADKStore["AiosqliteConfig"]):
Provides:
- Session state management with JSON storage (as TEXT)
- - Event history tracking with BLOB-serialized actions
+ - Event history tracking with full-event JSON storage
- Julian Day timestamps (REAL) for efficient date operations
- Foreign key constraints with cascade delete
- - Efficient upserts using INSERT OR REPLACE
+ - Atomic event+state writes via append_event_and_update_state
+ - PRAGMA optimization profile for file-based databases
Args:
config: AiosqliteConfig with extension_config["adk"] settings.
@@ -114,9 +87,8 @@ class AiosqliteADKStore(BaseAsyncADKStore["AiosqliteConfig"]):
Notes:
- JSON stored as TEXT with SQLSpec serializers (msgspec/orjson/stdlib)
- - BOOLEAN as INTEGER (0/1, with None for NULL)
- Timestamps as REAL (Julian day: julianday('now'))
- - BLOB for pre-serialized actions from Google ADK
+ - Full event stored as JSON TEXT in event_data column
- PRAGMA foreign_keys = ON (enable per connection)
- Configuration is read from config.extension_config["adk"]
"""
@@ -136,6 +108,22 @@ def __init__(self, config: "AiosqliteConfig") -> None:
"""
super().__init__(config)
+ async def _apply_pragmas(self, connection: Any) -> None:
+ """Apply PRAGMA optimization profile for this connection.
+
+ Args:
+ connection: Aiosqlite connection.
+
+ Notes:
+ Enables foreign keys and applies performance PRAGMAs.
+ For file-based databases, adds cache_size, mmap_size,
+ and journal_size_limit optimizations.
+ """
+ await connection.execute("PRAGMA foreign_keys = ON")
+ await connection.execute("PRAGMA cache_size = -64000")
+ await connection.execute("PRAGMA mmap_size = 30000000")
+ await connection.execute("PRAGMA journal_size_limit = 67108864")
+
async def _get_create_sessions_table_sql(self) -> str:
"""Get SQLite CREATE TABLE SQL for sessions.
@@ -170,9 +158,8 @@ async def _get_create_events_table_sql(self) -> str:
SQL statement to create adk_events table with indexes.
Notes:
- - TEXT for IDs, strings, and JSON content
- - BLOB for pickled actions
- - INTEGER for booleans (0/1/NULL)
+ - TEXT for IDs and indexed scalars
+ - TEXT for full event JSON (event_data)
- REAL for Julian Day timestamps
- Foreign key to sessions with CASCADE delete
- Index on (session_id, timestamp ASC)
@@ -181,22 +168,10 @@ async def _get_create_events_table_sql(self) -> str:
CREATE TABLE IF NOT EXISTS {self._events_table} (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
- app_name TEXT NOT NULL,
- user_id TEXT NOT NULL,
- invocation_id TEXT NOT NULL,
- author TEXT NOT NULL,
- actions BLOB NOT NULL,
- long_running_tool_ids_json TEXT,
- branch TEXT,
+ invocation_id TEXT,
+ author TEXT,
timestamp REAL NOT NULL,
- content TEXT,
- grounding_metadata TEXT,
- custom_metadata TEXT,
- partial INTEGER,
- turn_complete INTEGER,
- interrupted INTEGER,
- error_code TEXT,
- error_message TEXT,
+ event_data TEXT NOT NULL,
FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session
@@ -215,21 +190,10 @@ def _get_drop_tables_sql(self) -> "list[str]":
"""
return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"]
- async def _enable_foreign_keys(self, connection: Any) -> None:
- """Enable foreign key constraints for this connection.
-
- Args:
- connection: Aiosqlite connection.
-
- Notes:
- SQLite requires PRAGMA foreign_keys = ON per connection.
- """
- await connection.execute("PRAGMA foreign_keys = ON")
-
async def create_tables(self) -> None:
"""Create both sessions and events tables if they don't exist."""
async with self._config.provide_session() as driver:
- await self._enable_foreign_keys(driver.connection)
+ await self._apply_pragmas(driver.connection)
await driver.execute_script(await self._get_create_sessions_table_sql())
await driver.execute_script(await self._get_create_events_table_sql())
@@ -250,11 +214,11 @@ async def create_session(
Notes:
Uses Julian Day for create_time and update_time.
- State is JSON-serialized before insertion.
+ State is always JSON-serialized (empty dict becomes '{}', never NULL).
"""
now = datetime.now(timezone.utc)
now_julian = _datetime_to_julian(now)
- state_json = to_json(state) if state else None
+ state_json = to_json(state)
params: tuple[Any, ...]
if self._owner_id_column_name:
@@ -272,7 +236,7 @@ async def create_session(
params = (session_id, app_name, user_id, state_json, now_julian, now_julian)
async with self._config.provide_connection() as conn:
- await self._enable_foreign_keys(conn)
+ await self._apply_pragmas(conn)
await conn.execute(sql, params)
await conn.commit()
@@ -300,7 +264,7 @@ async def get_session(self, session_id: str) -> "SessionRecord | None":
"""
async with self._config.provide_connection() as conn:
- await self._enable_foreign_keys(conn)
+ await self._apply_pragmas(conn)
cursor = await conn.execute(sql, (session_id,))
row = await cursor.fetchone()
@@ -326,9 +290,10 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") -
Notes:
This replaces the entire state dictionary.
Updates update_time to current Julian Day.
+ Empty dict is serialized as '{}', never NULL.
"""
now_julian = _datetime_to_julian(datetime.now(timezone.utc))
- state_json = to_json(state) if state else None
+ state_json = to_json(state)
sql = f"""
UPDATE {self._session_table}
@@ -337,7 +302,7 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") -
"""
async with self._config.provide_connection() as conn:
- await self._enable_foreign_keys(conn)
+ await self._apply_pragmas(conn)
await conn.execute(sql, (state_json, now_julian, session_id))
await conn.commit()
@@ -372,7 +337,7 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis
params = (app_name, user_id)
async with self._config.provide_connection() as conn:
- await self._enable_foreign_keys(conn)
+ await self._apply_pragmas(conn)
cursor = await conn.execute(sql, params)
rows = await cursor.fetchall()
@@ -400,7 +365,7 @@ async def delete_session(self, session_id: str) -> None:
sql = f"DELETE FROM {self._session_table} WHERE id = ?"
async with self._config.provide_connection() as conn:
- await self._enable_foreign_keys(conn)
+ await self._apply_pragmas(conn)
await conn.execute(sql, (session_id,))
await conn.commit()
@@ -408,63 +373,88 @@ async def append_event(self, event_record: EventRecord) -> None:
"""Append an event to a session.
Args:
- event_record: Event record to store.
+ event_record: Event record with 5 keys: session_id, invocation_id,
+ author, timestamp, event_json.
Notes:
Uses Julian Day for timestamp.
- JSON fields are serialized to TEXT.
- Boolean fields converted to INTEGER (0/1/NULL).
+ event_json dict is serialized to TEXT as event_data column.
"""
- timestamp_julian = _datetime_to_julian(event_record["timestamp"])
+ import uuid
- content_json = to_json(event_record.get("content")) if event_record.get("content") else None
- grounding_metadata_json = (
- to_json(event_record.get("grounding_metadata")) if event_record.get("grounding_metadata") else None
- )
- custom_metadata_json = (
- to_json(event_record.get("custom_metadata")) if event_record.get("custom_metadata") else None
- )
-
- partial_int = _to_sqlite_bool(event_record.get("partial"))
- turn_complete_int = _to_sqlite_bool(event_record.get("turn_complete"))
- interrupted_int = _to_sqlite_bool(event_record.get("interrupted"))
+ timestamp_julian = _datetime_to_julian(event_record["timestamp"])
+ event_data_json = to_json(event_record["event_json"])
+ event_id = str(uuid.uuid4())
sql = f"""
INSERT INTO {self._events_table} (
- id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
- ) VALUES (
- ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
- )
+ id, session_id, invocation_id, author, timestamp, event_data
+ ) VALUES (?, ?, ?, ?, ?, ?)
"""
async with self._config.provide_connection() as conn:
- await self._enable_foreign_keys(conn)
+ await self._apply_pragmas(conn)
await conn.execute(
sql,
(
- event_record["id"],
+ event_id,
+ event_record["session_id"],
+ event_record["invocation_id"],
+ event_record["author"],
+ timestamp_julian,
+ event_data_json,
+ ),
+ )
+ await conn.commit()
+
+ async def append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ """Atomically append an event and update the session's durable state.
+
+ Inserts the event and updates the session state + update_time in a
+ single transaction. Both operations succeed or fail together.
+
+ Args:
+ event_record: Event record to store.
+ session_id: Session identifier whose state should be updated.
+ state: Post-append durable state snapshot (temp: keys already
+ stripped by the service layer).
+ """
+ import uuid
+
+ timestamp_julian = _datetime_to_julian(event_record["timestamp"])
+ event_data_json = to_json(event_record["event_json"])
+ now_julian = _datetime_to_julian(datetime.now(timezone.utc))
+ state_json = to_json(state)
+ event_id = str(uuid.uuid4())
+
+ insert_sql = f"""
+ INSERT INTO {self._events_table} (
+ id, session_id, invocation_id, author, timestamp, event_data
+ ) VALUES (?, ?, ?, ?, ?, ?)
+ """
+
+ update_sql = f"""
+ UPDATE {self._session_table}
+ SET state = ?, update_time = ?
+ WHERE id = ?
+ """
+
+ async with self._config.provide_connection() as conn:
+ await self._apply_pragmas(conn)
+ await conn.execute(
+ insert_sql,
+ (
+ event_id,
event_record["session_id"],
- event_record["app_name"],
- event_record["user_id"],
event_record["invocation_id"],
event_record["author"],
- event_record["actions"],
- event_record.get("long_running_tool_ids_json"),
- event_record.get("branch"),
timestamp_julian,
- content_json,
- grounding_metadata_json,
- custom_metadata_json,
- partial_int,
- turn_complete_int,
- interrupted_int,
- event_record.get("error_code"),
- event_record.get("error_message"),
+ event_data_json,
),
)
+ await conn.execute(update_sql, (state_json, now_julian, session_id))
await conn.commit()
async def get_events(
@@ -482,8 +472,7 @@ async def get_events(
Notes:
Uses index on (session_id, timestamp ASC).
- Parses JSON fields and converts BLOB actions to bytes.
- Converts INTEGER booleans back to bool/None.
+ Parses event_data TEXT back to dict for event_json field.
"""
where_clauses = ["session_id = ?"]
params: list[Any] = [session_id]
@@ -496,40 +485,24 @@ async def get_events(
limit_clause = f" LIMIT {limit}" if limit else ""
sql = f"""
- SELECT id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
+ SELECT id, session_id, invocation_id, author, timestamp, event_data
FROM {self._events_table}
WHERE {where_clause}
ORDER BY timestamp ASC{limit_clause}
"""
async with self._config.provide_connection() as conn:
- await self._enable_foreign_keys(conn)
+ await self._apply_pragmas(conn)
cursor = await conn.execute(sql, params)
rows = await cursor.fetchall()
return [
EventRecord(
- id=row[0],
session_id=row[1],
- app_name=row[2],
- user_id=row[3],
- invocation_id=row[4],
- author=row[5],
- actions=bytes(row[6]),
- long_running_tool_ids_json=row[7],
- branch=row[8],
- timestamp=_julian_to_datetime(row[9]),
- content=from_json(row[10]) if row[10] else None,
- grounding_metadata=from_json(row[11]) if row[11] else None,
- custom_metadata=from_json(row[12]) if row[12] else None,
- partial=_from_sqlite_bool(row[13]),
- turn_complete=_from_sqlite_bool(row[14]),
- interrupted=_from_sqlite_bool(row[15]),
- error_code=row[16],
- error_message=row[17],
+ invocation_id=row[2],
+ author=row[3],
+ timestamp=_julian_to_datetime(row[4]),
+ event_json=from_json(row[5]) if row[5] else {},
)
for row in rows
]
diff --git a/sqlspec/adapters/asyncmy/adk/store.py b/sqlspec/adapters/asyncmy/adk/store.py
index ff74e6851..446defc78 100644
--- a/sqlspec/adapters/asyncmy/adk/store.py
+++ b/sqlspec/adapters/asyncmy/adk/store.py
@@ -27,36 +27,16 @@ class AsyncmyADKStore(BaseAsyncADKStore["AsyncmyConfig"]):
Implements session and event storage for Google Agent Development Kit
using MySQL/MariaDB via the AsyncMy driver. Provides:
- Session state management with JSON storage
- - Event history tracking with BLOB-serialized actions
+ - Full-event JSON storage (single ``event_json`` column)
+ - Atomic event-append + state-update in one transaction
- Microsecond-precision timestamps
- Foreign key constraints with cascade delete
- Efficient upserts using ON DUPLICATE KEY UPDATE
- Args:
- config: AsyncmyConfig with extension_config["adk"] settings.
-
- Example:
- from sqlspec.adapters.asyncmy import AsyncmyConfig
- from sqlspec.adapters.asyncmy.adk import AsyncmyADKStore
-
- config = AsyncmyConfig(
- connection_config={"host": "localhost", ...},
- extension_config={
- "adk": {
- "session_table": "my_sessions",
- "events_table": "my_events",
- "owner_id_column": "tenant_id BIGINT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE"
- }
- }
- )
- store = AsyncmyADKStore(config)
- await store.ensure_tables()
-
Notes:
- MySQL JSON type used (not JSONB) - requires MySQL 5.7.8+
- TIMESTAMP(6) provides microsecond precision
- InnoDB engine required for foreign key support
- - State merging handled at application level
- Configuration is read from config.extension_config["adk"]
"""
@@ -67,12 +47,6 @@ def __init__(self, config: "AsyncmyConfig") -> None:
Args:
config: AsyncmyConfig instance.
-
- Notes:
- Configuration is read from config.extension_config["adk"]:
- - session_table: Sessions table name (default: "adk_sessions")
- - events_table: Events table name (default: "adk_events")
- - owner_id_column: Optional owner FK column DDL (default: None)
"""
super().__init__(config)
@@ -88,10 +62,6 @@ def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]"
Returns:
Tuple of (column_definition, foreign_key_constraint)
-
- Example:
- Input: "tenant_id BIGINT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE"
- Output: ("tenant_id BIGINT NOT NULL", "FOREIGN KEY (tenant_id) REFERENCES tenants(id) ON DELETE CASCADE")
"""
references_match = re.search(r"\s+REFERENCES\s+(.+)", column_ddl, re.IGNORECASE)
@@ -110,16 +80,6 @@ async def _get_create_sessions_table_sql(self) -> str:
Returns:
SQL statement to create adk_sessions table with indexes.
-
- Notes:
- - VARCHAR(128) for IDs and names (sufficient for UUIDs and app names)
- - JSON type for state storage (MySQL 5.7.8+)
- - TIMESTAMP(6) with microsecond precision
- - AUTO-UPDATE on update_time
- - Composite index on (app_name, user_id) for listing
- - Index on update_time DESC for recent session queries
- - Optional owner ID column for multi-tenancy
- - MySQL requires explicit FOREIGN KEY syntax (inline REFERENCES is ignored)
"""
owner_id_col = ""
fk_constraint = ""
@@ -151,34 +111,18 @@ async def _get_create_events_table_sql(self) -> str:
SQL statement to create adk_events table with indexes.
Notes:
- - VARCHAR sizes: id(128), session_id(128), invocation_id(256), author(256),
- branch(256), error_code(256), error_message(1024)
- - BLOB for pickled actions (up to 64KB)
- - JSON for content, grounding_metadata, custom_metadata, long_running_tool_ids_json
- - BOOLEAN for partial, turn_complete, interrupted
- - Foreign key to sessions with CASCADE delete
- - Index on (session_id, timestamp ASC) for ordered event retrieval
+ Post clean-break schema: 5 columns only.
+ - session_id, invocation_id, author: indexed scalars
+ - timestamp: microsecond-precision TIMESTAMP
+ - event_json: full Event as native JSON
"""
return f"""
CREATE TABLE IF NOT EXISTS {self._events_table} (
- id VARCHAR(128) PRIMARY KEY,
session_id VARCHAR(128) NOT NULL,
- app_name VARCHAR(128) NOT NULL,
- user_id VARCHAR(128) NOT NULL,
invocation_id VARCHAR(256) NOT NULL,
- author VARCHAR(256) NOT NULL,
- actions BLOB NOT NULL,
- long_running_tool_ids_json JSON,
- branch VARCHAR(256),
+ author VARCHAR(128) NOT NULL,
timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
- content JSON,
- grounding_metadata JSON,
- custom_metadata JSON,
- partial BOOLEAN,
- turn_complete BOOLEAN,
- interrupted BOOLEAN,
- error_code VARCHAR(256),
- error_message VARCHAR(1024),
+ event_json JSON NOT NULL,
FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE,
INDEX idx_{self._events_table}_session (session_id, timestamp ASC)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
@@ -189,10 +133,6 @@ def _get_drop_tables_sql(self) -> "list[str]":
Returns:
List of SQL statements to drop tables and indexes.
-
- Notes:
- Order matters: drop events table (child) before sessions (parent).
- MySQL automatically drops indexes when dropping tables.
"""
return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"]
@@ -216,11 +156,6 @@ async def create_session(
Returns:
Created session record.
-
- Notes:
- Uses INSERT with UTC_TIMESTAMP(6) for create_time and update_time.
- State is JSON-serialized before insertion.
- If owner_id_column is configured, owner_id must be provided.
"""
state_json = to_json(state)
@@ -252,10 +187,6 @@ async def get_session(self, session_id: str) -> "SessionRecord | None":
Returns:
Session record or None if not found.
-
- Notes:
- MySQL returns datetime objects for TIMESTAMP columns.
- JSON is parsed from database storage.
"""
sql = f"""
SELECT id, app_name, user_id, state, create_time, update_time
@@ -292,10 +223,6 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") -
Args:
session_id: Session identifier.
state: New state dictionary (replaces existing state).
-
- Notes:
- This replaces the entire state dictionary.
- Uses update_time auto-update trigger.
"""
state_json = to_json(state)
@@ -314,9 +241,6 @@ async def delete_session(self, session_id: str) -> None:
Args:
session_id: Session identifier.
-
- Notes:
- Foreign key constraint ensures events are cascade-deleted.
"""
sql = f"DELETE FROM {self._session_table} WHERE id = %s"
@@ -333,9 +257,6 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis
Returns:
List of session records ordered by update_time DESC.
-
- Notes:
- Uses composite index on (app_name, user_id) when user_id is provided.
"""
if user_id is None:
sql = f"""
@@ -379,55 +300,72 @@ async def append_event(self, event_record: EventRecord) -> None:
"""Append an event to a session.
Args:
- event_record: Event record to store.
-
- Notes:
- Uses UTC_TIMESTAMP(6) for timestamp if not provided.
- JSON fields are serialized before insertion.
+ event_record: Event record with 5 keys (session_id, invocation_id,
+ author, timestamp, event_json).
"""
- content_json = to_json(event_record.get("content")) if event_record.get("content") else None
- grounding_metadata_json = (
- to_json(event_record.get("grounding_metadata")) if event_record.get("grounding_metadata") else None
- )
- custom_metadata_json = (
- to_json(event_record.get("custom_metadata")) if event_record.get("custom_metadata") else None
- )
+ event_json = event_record["event_json"]
+ event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json
sql = f"""
INSERT INTO {self._events_table} (
- id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
- ) VALUES (
- %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
- )
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES (%s, %s, %s, %s, %s)
"""
async with self._config.provide_connection() as conn, conn.cursor() as cursor:
await cursor.execute(
sql,
(
- event_record["id"],
event_record["session_id"],
- event_record["app_name"],
- event_record["user_id"],
event_record["invocation_id"],
event_record["author"],
- event_record["actions"],
- event_record.get("long_running_tool_ids_json"),
- event_record.get("branch"),
event_record["timestamp"],
- content_json,
- grounding_metadata_json,
- custom_metadata_json,
- event_record.get("partial"),
- event_record.get("turn_complete"),
- event_record.get("interrupted"),
- event_record.get("error_code"),
- event_record.get("error_message"),
+ event_json_str,
+ ),
+ )
+ await conn.commit()
+
+ async def append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ """Atomically append an event and update the session's durable state.
+
+ The event insert and state update succeed together or fail together
+ within a single transaction.
+
+ Args:
+ event_record: Event record to store.
+ session_id: Session identifier whose state should be updated.
+ state: Post-append durable state snapshot.
+ """
+ event_json = event_record["event_json"]
+ event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json
+ state_json = to_json(state)
+
+ insert_sql = f"""
+ INSERT INTO {self._events_table} (
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES (%s, %s, %s, %s, %s)
+ """
+
+ update_sql = f"""
+ UPDATE {self._session_table}
+ SET state = %s
+ WHERE id = %s
+ """
+
+ async with self._config.provide_connection() as conn, conn.cursor() as cursor:
+ await cursor.execute(
+ insert_sql,
+ (
+ event_record["session_id"],
+ event_record["invocation_id"],
+ event_record["author"],
+ event_record["timestamp"],
+ event_json_str,
),
)
+ await cursor.execute(update_sql, (state_json, session_id))
await conn.commit()
async def get_events(
@@ -442,10 +380,6 @@ async def get_events(
Returns:
List of event records ordered by timestamp ASC.
-
- Notes:
- Uses index on (session_id, timestamp ASC).
- Parses JSON fields and converts BLOB actions to bytes.
"""
where_clauses = ["session_id = %s"]
params: list[Any] = [session_id]
@@ -458,10 +392,7 @@ async def get_events(
limit_clause = f" LIMIT {limit}" if limit else ""
sql = f"""
- SELECT id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
+ SELECT session_id, invocation_id, author, timestamp, event_json
FROM {self._events_table}
WHERE {where_clause}
ORDER BY timestamp ASC{limit_clause}
@@ -474,24 +405,11 @@ async def get_events(
return [
EventRecord(
- id=row[0],
- session_id=row[1],
- app_name=row[2],
- user_id=row[3],
- invocation_id=row[4],
- author=row[5],
- actions=bytes(row[6]),
- long_running_tool_ids_json=row[7],
- branch=row[8],
- timestamp=row[9],
- content=from_json(row[10]) if row[10] and isinstance(row[10], str) else row[10],
- grounding_metadata=from_json(row[11]) if row[11] and isinstance(row[11], str) else row[11],
- custom_metadata=from_json(row[12]) if row[12] and isinstance(row[12], str) else row[12],
- partial=row[13],
- turn_complete=row[14],
- interrupted=row[15],
- error_code=row[16],
- error_message=row[17],
+ session_id=row[0],
+ invocation_id=row[1],
+ author=row[2],
+ timestamp=row[3],
+ event_json=from_json(row[4]) if isinstance(row[4], str) else row[4],
)
for row in rows
]
diff --git a/sqlspec/adapters/asyncpg/adk/store.py b/sqlspec/adapters/asyncpg/adk/store.py
index 4c4624a87..76cffccab 100644
--- a/sqlspec/adapters/asyncpg/adk/store.py
+++ b/sqlspec/adapters/asyncpg/adk/store.py
@@ -1,6 +1,6 @@
"""AsyncPG ADK store for Google Agent Development Kit session/event storage."""
-from typing import TYPE_CHECKING, Any, Final, cast
+from typing import TYPE_CHECKING, Any, Final
import asyncpg
@@ -21,87 +21,32 @@
class AsyncpgADKStore(BaseAsyncADKStore[AsyncConfigT]):
- """PostgreSQL ADK store base class for all PostgreSQL drivers.
+ """PostgreSQL ADK store using asyncpg driver.
Implements session and event storage for Google Agent Development Kit
- using PostgreSQL via any PostgreSQL driver (AsyncPG, Psycopg, Psqlpy).
- All drivers share the same SQL dialect and parameter style ($1, $2, etc).
+ using PostgreSQL via asyncpg. Events are stored as a single JSONB blob
+ (``event_json``) alongside indexed scalar columns for efficient querying.
Provides:
- - Session state management with JSONB storage and merge operations
- - Event history tracking with BYTEA-serialized actions
+ - Session state management with JSONB storage
+ - Full-fidelity event storage via ``event_json`` JSONB column
+ - Atomic ``append_event_and_update_state`` for durable session mutations
- Microsecond-precision timestamps with TIMESTAMPTZ
- Foreign key constraints with cascade delete
- - Efficient upserts using ON CONFLICT
- GIN indexes for JSONB queries
- HOT updates with FILLFACTOR 80
- - Optional user FK column for multi-tenancy
+ - Optional owner ID column for multi-tenancy
Args:
config: PostgreSQL database config with extension_config["adk"] settings.
-
- Example:
- from sqlspec.adapters.asyncpg import AsyncpgConfig
- from sqlspec.adapters.asyncpg.adk import AsyncpgADKStore
-
- config = AsyncpgConfig(
- connection_config={"dsn": "postgresql://..."},
- extension_config={
- "adk": {
- "session_table": "my_sessions",
- "events_table": "my_events",
- "owner_id_column": "tenant_id INTEGER NOT NULL REFERENCES tenants(id) ON DELETE CASCADE"
- }
- }
- )
- store = AsyncpgADKStore(config)
- await store.ensure_tables()
-
- Notes:
- - PostgreSQL JSONB type used for state (more efficient than JSON)
- - AsyncPG automatically converts Python dicts to/from JSONB (no manual serialization)
- - TIMESTAMPTZ provides timezone-aware microsecond precision
- - State merging uses `state || $1::jsonb` operator for efficiency
- - BYTEA for pre-serialized actions from Google ADK (not pickled here)
- - GIN index on state for JSONB queries (partial index)
- - FILLFACTOR 80 leaves space for HOT updates
- - Generic over PostgresConfigT to support all PostgreSQL drivers
- - Owner ID column enables multi-tenant isolation with referential integrity
- - Configuration is read from config.extension_config["adk"]
"""
__slots__ = ()
def __init__(self, config: AsyncConfigT) -> None:
- """Initialize AsyncPG ADK store.
-
- Args:
- config: PostgreSQL database config.
-
- Notes:
- Configuration is read from config.extension_config["adk"]:
- - session_table: Sessions table name (default: "adk_sessions")
- - events_table: Events table name (default: "adk_events")
- - owner_id_column: Optional owner FK column DDL (default: None)
- """
super().__init__(config)
async def _get_create_sessions_table_sql(self) -> str:
- """Get PostgreSQL CREATE TABLE SQL for sessions.
-
- Returns:
- SQL statement to create adk_sessions table with indexes.
-
- Notes:
- - VARCHAR(128) for IDs and names (sufficient for UUIDs and app names)
- - JSONB type for state storage with default empty object
- - TIMESTAMPTZ with microsecond precision
- - FILLFACTOR 80 for HOT updates (reduces table bloat)
- - Composite index on (app_name, user_id) for listing
- - Index on update_time DESC for recent session queries
- - Partial GIN index on state for JSONB queries (only non-empty)
- - Optional owner ID column for multi-tenancy or owner references
- """
owner_id_line = ""
if self._owner_id_column_ddl:
owner_id_line = f",\n {self._owner_id_column_ddl}"
@@ -128,61 +73,24 @@ async def _get_create_sessions_table_sql(self) -> str:
"""
async def _get_create_events_table_sql(self) -> str:
- """Get PostgreSQL CREATE TABLE SQL for events.
-
- Returns:
- SQL statement to create adk_events table with indexes.
-
- Notes:
- - VARCHAR sizes: id(128), session_id(128), invocation_id(256), author(256),
- branch(256), error_code(256), error_message(1024)
- - BYTEA for pickled actions (no size limit)
- - JSONB for content, grounding_metadata, custom_metadata, long_running_tool_ids_json
- - BOOLEAN for partial, turn_complete, interrupted
- - Foreign key to sessions with CASCADE delete
- - Index on (session_id, timestamp ASC) for ordered event retrieval
- """
return f"""
CREATE TABLE IF NOT EXISTS {self._events_table} (
- id VARCHAR(128) PRIMARY KEY,
session_id VARCHAR(128) NOT NULL,
- app_name VARCHAR(128) NOT NULL,
- user_id VARCHAR(128) NOT NULL,
- invocation_id VARCHAR(256),
- author VARCHAR(256),
- actions BYTEA,
- long_running_tool_ids_json JSONB,
- branch VARCHAR(256),
+ invocation_id VARCHAR(256) NOT NULL,
+ author VARCHAR(256) NOT NULL,
timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
- content JSONB,
- grounding_metadata JSONB,
- custom_metadata JSONB,
- partial BOOLEAN,
- turn_complete BOOLEAN,
- interrupted BOOLEAN,
- error_code VARCHAR(256),
- error_message VARCHAR(1024),
+ event_json JSONB NOT NULL,
FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE
- );
+ ) WITH (fillfactor = 80);
CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session
ON {self._events_table}(session_id, timestamp ASC);
"""
def _get_drop_tables_sql(self) -> "list[str]":
- """Get PostgreSQL DROP TABLE SQL statements.
-
- Returns:
- List of SQL statements to drop tables and indexes.
-
- Notes:
- Order matters: drop events table (child) before sessions (parent).
- PostgreSQL automatically drops indexes when dropping tables.
- """
return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"]
async def create_tables(self) -> None:
- """Create both sessions and events tables if they don't exist."""
async with self.config.provide_session() as driver:
await driver.execute_script(await self._get_create_sessions_table_sql())
await driver.execute_script(await self._get_create_events_table_sql())
@@ -190,23 +98,6 @@ async def create_tables(self) -> None:
async def create_session(
self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
) -> SessionRecord:
- """Create a new session.
-
- Args:
- session_id: Unique session identifier.
- app_name: Application name.
- user_id: User identifier.
- state: Initial session state.
- owner_id: Optional owner ID value for owner_id_column (if configured).
-
- Returns:
- Created session record.
-
- Notes:
- Uses CURRENT_TIMESTAMP for create_time and update_time.
- State is passed as dict and asyncpg converts to JSONB automatically.
- If owner_id_column is configured, owner_id value must be provided.
- """
async with self.config.provide_connection() as conn:
if self._owner_id_column_name:
sql = f"""
@@ -225,18 +116,6 @@ async def create_session(
return await self.get_session(session_id) # type: ignore[return-value]
async def get_session(self, session_id: str) -> "SessionRecord | None":
- """Get session by ID.
-
- Args:
- session_id: Session identifier.
-
- Returns:
- Session record or None if not found.
-
- Notes:
- PostgreSQL returns datetime objects for TIMESTAMPTZ columns.
- JSONB is automatically parsed by asyncpg.
- """
sql = f"""
SELECT id, app_name, user_id, state, create_time, update_time
FROM {self._session_table}
@@ -262,16 +141,6 @@ async def get_session(self, session_id: str) -> "SessionRecord | None":
return None
async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
- """Update session state.
-
- Args:
- session_id: Session identifier.
- state: New state dictionary (replaces existing state).
-
- Notes:
- This replaces the entire state dictionary.
- Uses CURRENT_TIMESTAMP for update_time.
- """
sql = f"""
UPDATE {self._session_table}
SET state = $1, update_time = CURRENT_TIMESTAMP
@@ -282,32 +151,12 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") -
await conn.execute(sql, state, session_id)
async def delete_session(self, session_id: str) -> None:
- """Delete session and all associated events (cascade).
-
- Args:
- session_id: Session identifier.
-
- Notes:
- Foreign key constraint ensures events are cascade-deleted.
- """
sql = f"DELETE FROM {self._session_table} WHERE id = $1"
async with self.config.provide_connection() as conn:
await conn.execute(sql, session_id)
async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
- """List sessions for an app, optionally filtered by user.
-
- Args:
- app_name: Application name.
- user_id: User identifier. If None, lists all sessions for the app.
-
- Returns:
- List of session records ordered by update_time DESC.
-
- Notes:
- Uses composite index on (app_name, user_id) when user_id is provided.
- """
if user_id is None:
sql = f"""
SELECT id, app_name, user_id, state, create_time, update_time
@@ -344,70 +193,50 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis
return []
async def append_event(self, event_record: EventRecord) -> None:
- """Append an event to a session.
-
- Args:
- event_record: Event record to store.
-
- Notes:
- Uses CURRENT_TIMESTAMP for timestamp if not provided.
- JSONB fields are passed as dicts and asyncpg converts automatically.
- """
- content_json = event_record.get("content")
- grounding_metadata_json = event_record.get("grounding_metadata")
- custom_metadata_json = event_record.get("custom_metadata")
-
sql = f"""
INSERT INTO {self._events_table} (
- id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
- ) VALUES (
- $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18
- )
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES ($1, $2, $3, $4, $5)
"""
async with self.config.provide_connection() as conn:
await conn.execute(
sql,
- event_record["id"],
event_record["session_id"],
- event_record["app_name"],
- event_record["user_id"],
- event_record.get("invocation_id"),
- event_record.get("author"),
- event_record.get("actions"),
- event_record.get("long_running_tool_ids_json"),
- event_record.get("branch"),
+ event_record["invocation_id"],
+ event_record["author"],
+ event_record["timestamp"],
+ event_record["event_json"],
+ )
+
+ async def append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ insert_sql = f"""
+ INSERT INTO {self._events_table} (
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES ($1, $2, $3, $4, $5)
+ """
+ update_sql = f"""
+ UPDATE {self._session_table}
+ SET state = $1, update_time = CURRENT_TIMESTAMP
+ WHERE id = $2
+ """
+
+ async with self.config.provide_connection() as conn, conn.transaction():
+ await conn.execute(
+ insert_sql,
+ event_record["session_id"],
+ event_record["invocation_id"],
+ event_record["author"],
event_record["timestamp"],
- content_json,
- grounding_metadata_json,
- custom_metadata_json,
- event_record.get("partial"),
- event_record.get("turn_complete"),
- event_record.get("interrupted"),
- event_record.get("error_code"),
- event_record.get("error_message"),
+ event_record["event_json"],
)
+ await conn.execute(update_sql, state, session_id)
async def get_events(
self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
) -> "list[EventRecord]":
- """Get events for a session.
-
- Args:
- session_id: Session identifier.
- after_timestamp: Only return events after this time.
- limit: Maximum number of events to return.
-
- Returns:
- List of event records ordered by timestamp ASC.
-
- Notes:
- Uses index on (session_id, timestamp ASC).
- Parses JSONB fields and converts BYTEA actions to bytes.
- """
where_clauses = ["session_id = $1"]
params: list[Any] = [session_id]
@@ -421,10 +250,7 @@ async def get_events(
params.append(limit)
sql = f"""
- SELECT id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
+ SELECT session_id, invocation_id, author, timestamp, event_json
FROM {self._events_table}
WHERE {where_clause}
ORDER BY timestamp ASC{limit_clause}
@@ -436,24 +262,11 @@ async def get_events(
return [
EventRecord(
- id=row["id"],
session_id=row["session_id"],
- app_name=row["app_name"],
- user_id=row["user_id"],
invocation_id=row["invocation_id"],
author=row["author"],
- actions=bytes(row["actions"]) if row["actions"] else b"",
- long_running_tool_ids_json=row["long_running_tool_ids_json"],
- branch=row["branch"],
timestamp=row["timestamp"],
- content=row["content"],
- grounding_metadata=row["grounding_metadata"],
- custom_metadata=row["custom_metadata"],
- partial=row["partial"],
- turn_complete=row["turn_complete"],
- interrupted=row["interrupted"],
- error_code=row["error_code"],
- error_message=row["error_message"],
+ event_json=row["event_json"],
)
for row in rows
]
@@ -475,66 +288,14 @@ class AsyncpgADKMemoryStore(BaseAsyncADKMemoryStore["AsyncpgConfig"]):
Args:
config: AsyncpgConfig with extension_config["adk"] settings.
-
- Example:
- from sqlspec.adapters.asyncpg import AsyncpgConfig
- from sqlspec.adapters.asyncpg.adk.store import AsyncpgADKMemoryStore
-
- config = AsyncpgConfig(
- connection_config={"dsn": "postgresql://..."},
- extension_config={
- "adk": {
- "memory_table": "adk_memory_entries",
- "memory_use_fts": True,
- "memory_max_results": 20,
- }
- }
- )
- store = AsyncpgADKMemoryStore(config)
- await store.ensure_tables()
-
- Notes:
- - JSONB type for content_json and metadata_json
- - TIMESTAMPTZ with microsecond precision
- - GIN index on content_text tsvector for FTS queries
- - Composite index on (app_name, user_id) for filtering
- - event_id UNIQUE constraint for deduplication
- - Configuration is read from config.extension_config["adk"]
"""
__slots__ = ()
def __init__(self, config: "AsyncpgConfig") -> None:
- """Initialize AsyncPG ADK memory store.
-
- Args:
- config: AsyncpgConfig instance.
-
- Notes:
- Configuration is read from config.extension_config["adk"]:
- - memory_table: Memory table name (default: "adk_memory_entries")
- - memory_use_fts: Enable full-text search when supported (default: False)
- - memory_max_results: Max search results (default: 20)
- - owner_id_column: Optional owner FK column DDL (default: None)
- - enable_memory: Whether memory is enabled (default: True)
- """
super().__init__(config)
async def _get_create_memory_table_sql(self) -> str:
- """Get PostgreSQL CREATE TABLE SQL for memory entries.
-
- Returns:
- SQL statement to create memory table with indexes.
-
- Notes:
- - VARCHAR(128) for IDs and names
- - JSONB for content and metadata storage
- - TIMESTAMPTZ with microsecond precision
- - UNIQUE constraint on event_id for deduplication
- - Composite index on (app_name, user_id, timestamp DESC)
- - GIN index on content_text tsvector for FTS
- - Optional owner ID column for multi-tenancy
- """
owner_id_line = ""
if self._owner_id_column_ddl:
owner_id_line = f",\n {self._owner_id_column_ddl}"
@@ -570,21 +331,9 @@ async def _get_create_memory_table_sql(self) -> str:
"""
def _get_drop_memory_table_sql(self) -> "list[str]":
- """Get PostgreSQL DROP TABLE SQL statements.
-
- Returns:
- List of SQL statements to drop the memory table.
-
- Notes:
- PostgreSQL automatically drops indexes when dropping tables.
- """
return [f"DROP TABLE IF EXISTS {self._memory_table}"]
async def create_tables(self) -> None:
- """Create the memory table and indexes if they don't exist.
-
- Skips table creation if memory store is disabled.
- """
if not self._enabled:
return
@@ -592,21 +341,6 @@ async def create_tables(self) -> None:
await driver.execute_script(await self._get_create_memory_table_sql())
async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
- """Bulk insert memory entries with deduplication.
-
- Uses UPSERT pattern (ON CONFLICT DO NOTHING) to skip duplicates
- based on event_id unique constraint.
-
- Args:
- entries: List of memory records to insert.
- owner_id: Optional owner ID value for owner_id_column (if configured).
-
- Returns:
- Number of entries actually inserted (excludes duplicates).
-
- Raises:
- RuntimeError: If memory store is disabled.
- """
if not self._enabled:
msg = "Memory store is disabled"
raise RuntimeError(msg)
@@ -673,19 +407,6 @@ async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "
async def search_entries(
self, query: str, app_name: str, user_id: str, limit: "int | None" = None
) -> "list[MemoryRecord]":
- """Search memory entries by text query.
-
- Uses the configured search strategy (simple ILIKE or FTS).
-
- Args:
- query: Text query to search for.
- app_name: Application name to filter by.
- user_id: User ID to filter by.
- limit: Maximum number of results (defaults to max_results config).
-
- Returns:
- List of memory records.
- """
if not self._enabled:
msg = "Memory store is disabled"
raise RuntimeError(msg)
@@ -693,6 +414,8 @@ async def search_entries(
if not query:
return []
+ from typing import cast
+
limit_value = limit or self._max_results
if self._use_fts:
sql = f"""
@@ -717,7 +440,6 @@ async def search_entries(
return [cast("MemoryRecord", dict(row)) for row in rows]
async def delete_entries_by_session(self, session_id: str) -> int:
- """Delete all memory entries for a specific session."""
if not self._enabled:
msg = "Memory store is disabled"
raise RuntimeError(msg)
@@ -731,7 +453,6 @@ async def delete_entries_by_session(self, session_id: str) -> int:
return 0
async def delete_entries_older_than(self, days: int) -> int:
- """Delete memory entries older than specified days."""
if not self._enabled:
msg = "Memory store is disabled"
raise RuntimeError(msg)
diff --git a/sqlspec/adapters/bigquery/adk/__init__.py b/sqlspec/adapters/bigquery/adk/__init__.py
deleted file mode 100644
index 6d11c84b8..000000000
--- a/sqlspec/adapters/bigquery/adk/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-"""BigQuery ADK store for Google Agent Development Kit session/event storage."""
-
-from sqlspec.adapters.bigquery.adk.store import BigQueryADKMemoryStore, BigQueryADKStore
-
-__all__ = ("BigQueryADKMemoryStore", "BigQueryADKStore")
diff --git a/sqlspec/adapters/bigquery/adk/store.py b/sqlspec/adapters/bigquery/adk/store.py
deleted file mode 100644
index 61339a7a4..000000000
--- a/sqlspec/adapters/bigquery/adk/store.py
+++ /dev/null
@@ -1,827 +0,0 @@
-"""BigQuery ADK store for Google Agent Development Kit session/event storage."""
-
-from collections.abc import Mapping
-from datetime import datetime, timezone
-from typing import TYPE_CHECKING, Any, cast
-
-from google.api_core.exceptions import NotFound
-from google.cloud.bigquery import QueryJobConfig, ScalarQueryParameter
-
-from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord
-from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore
-from sqlspec.utils.serializers import from_json, to_json
-from sqlspec.utils.sync_tools import async_, run_
-
-if TYPE_CHECKING:
- from sqlspec.adapters.bigquery.config import BigQueryConfig
- from sqlspec.extensions.adk import MemoryRecord
-
-
-__all__ = ("BigQueryADKMemoryStore", "BigQueryADKStore")
-
-
-class BigQueryADKStore(BaseAsyncADKStore["BigQueryConfig"]):
- """BigQuery ADK store using synchronous BigQuery client with async wrapper.
-
- Implements session and event storage for Google Agent Development Kit
- using Google Cloud BigQuery. Uses BigQuery's native JSON type for state/metadata
- storage and async_() wrapper to provide async interface.
-
- Provides:
- - Serverless, scalable session state management with JSON storage
- - Event history tracking optimized for analytics
- - Microsecond-precision timestamps with TIMESTAMP type
- - Cost-optimized queries with partitioning and clustering
- - Efficient JSON handling with BigQuery's JSON type
- - Manual cascade delete pattern (no foreign key support)
-
- Args:
- config: BigQueryConfig with extension_config["adk"] settings.
-
- Example:
- from sqlspec.adapters.bigquery import BigQueryConfig
- from sqlspec.adapters.bigquery.adk import BigQueryADKStore
-
- config = BigQueryConfig(
- connection_config={
- "project": "my-project",
- "dataset_id": "my_dataset",
- },
- extension_config={
- "adk": {
- "session_table": "my_sessions",
- "events_table": "my_events",
- "owner_id_column": "tenant_id INT64 NOT NULL"
- }
- }
- )
- store = BigQueryADKStore(config)
- await store.ensure_tables()
-
- Notes:
- - JSON type for state, content, and metadata (native BigQuery JSON)
- - BYTES for pre-serialized actions from Google ADK
- - TIMESTAMP for timezone-aware microsecond precision
- - Partitioned by DATE(create_time) for cost optimization
- - Clustered by app_name, user_id for query performance
- - Uses to_json/from_json for serialization to JSON columns
- - BigQuery has eventual consistency - handle appropriately
- - No true foreign keys but implements cascade delete pattern
- - Configuration is read from config.extension_config["adk"]
- """
-
- __slots__ = ("_dataset_id",)
-
- def __init__(self, config: "BigQueryConfig") -> None:
- """Initialize BigQuery ADK store.
-
- Args:
- config: BigQueryConfig instance.
-
- Notes:
- Configuration is read from config.extension_config["adk"]:
- - session_table: Sessions table name (default: "adk_sessions")
- - events_table: Events table name (default: "adk_events")
- - owner_id_column: Optional owner FK column DDL (default: None)
- """
- super().__init__(config)
- self._dataset_id = config.connection_config.get("dataset_id")
-
- def _get_full_table_name(self, table_name: str) -> str:
- """Get fully qualified table name for BigQuery.
-
- Args:
- table_name: Base table name.
-
- Returns:
- Fully qualified table name with backticks.
-
- Notes:
- BigQuery requires backtick-quoted identifiers for table names.
- Format: `project.dataset.table` or `dataset.table`
- """
- if self._dataset_id:
- return f"`{self._dataset_id}.{table_name}`"
- return f"`{table_name}`"
-
- async def _get_create_sessions_table_sql(self) -> str:
- """Get BigQuery CREATE TABLE SQL for sessions.
-
- Returns:
- SQL statement to create adk_sessions table.
-
- Notes:
- - STRING for IDs and names
- - JSON type for state storage (native BigQuery JSON)
- - TIMESTAMP for timezone-aware microsecond precision
- - Partitioned by DATE(create_time) for cost optimization
- - Clustered by app_name, user_id for query performance
- - No indexes needed (BigQuery auto-optimizes)
- - Optional owner ID column for multi-tenant scenarios
- - Note: BigQuery doesn't enforce FK constraints
- """
- owner_id_line = ""
- if self._owner_id_column_ddl:
- owner_id_line = f",\n {self._owner_id_column_ddl}"
-
- table_name = self._get_full_table_name(self._session_table)
- return f"""
- CREATE TABLE IF NOT EXISTS {table_name} (
- id STRING NOT NULL,
- app_name STRING NOT NULL,
- user_id STRING NOT NULL{owner_id_line},
- state JSON NOT NULL,
- create_time TIMESTAMP NOT NULL,
- update_time TIMESTAMP NOT NULL
- )
- PARTITION BY DATE(create_time)
- CLUSTER BY app_name, user_id
- """
-
- async def _get_create_events_table_sql(self) -> str:
- """Get BigQuery CREATE TABLE SQL for events.
-
- Returns:
- SQL statement to create adk_events table.
-
- Notes:
- - STRING for IDs and text fields
- - BYTES for pickled actions
- - JSON for content, grounding_metadata, custom_metadata, long_running_tool_ids_json
- - BOOL for boolean flags
- - TIMESTAMP for timezone-aware timestamps
- - Partitioned by DATE(timestamp) for cost optimization
- - Clustered by session_id, timestamp for ordered retrieval
- """
- table_name = self._get_full_table_name(self._events_table)
- return f"""
- CREATE TABLE IF NOT EXISTS {table_name} (
- id STRING NOT NULL,
- session_id STRING NOT NULL,
- app_name STRING NOT NULL,
- user_id STRING NOT NULL,
- invocation_id STRING,
- author STRING,
- actions BYTES,
- long_running_tool_ids_json JSON,
- branch STRING,
- timestamp TIMESTAMP NOT NULL,
- content JSON,
- grounding_metadata JSON,
- custom_metadata JSON,
- partial BOOL,
- turn_complete BOOL,
- interrupted BOOL,
- error_code STRING,
- error_message STRING
- )
- PARTITION BY DATE(timestamp)
- CLUSTER BY session_id, timestamp
- """
-
- def _get_drop_tables_sql(self) -> "list[str]":
- """Get BigQuery DROP TABLE SQL statements.
-
- Returns:
- List of SQL statements to drop tables.
-
- Notes:
- Order matters: drop events table before sessions table.
- BigQuery uses IF EXISTS for idempotent drops.
- """
- events_table = self._get_full_table_name(self._events_table)
- sessions_table = self._get_full_table_name(self._session_table)
- return [f"DROP TABLE IF EXISTS {events_table}", f"DROP TABLE IF EXISTS {sessions_table}"]
-
- def _create_tables(self) -> None:
- """Synchronous implementation of create_tables."""
- with self._config.provide_session() as driver:
- driver.execute_script(run_(self._get_create_sessions_table_sql)())
- driver.execute_script(run_(self._get_create_events_table_sql)())
-
- async def create_tables(self) -> None:
- """Create both sessions and events tables if they don't exist."""
- await async_(self._create_tables)()
-
- def _create_session(
- self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
- ) -> SessionRecord:
- """Synchronous implementation of create_session."""
- now = datetime.now(timezone.utc)
- state_json = to_json(state) if state else "{}"
-
- table_name = self._get_full_table_name(self._session_table)
-
- if self._owner_id_column_name:
- sql = f"""
- INSERT INTO {table_name} (id, app_name, user_id, {self._owner_id_column_name}, state, create_time, update_time)
- VALUES (@id, @app_name, @user_id, @owner_id, JSON(@state), @create_time, @update_time)
- """
-
- params = [
- ScalarQueryParameter("id", "STRING", session_id),
- ScalarQueryParameter("app_name", "STRING", app_name),
- ScalarQueryParameter("user_id", "STRING", user_id),
- ScalarQueryParameter("owner_id", "STRING", str(owner_id) if owner_id is not None else None),
- ScalarQueryParameter("state", "STRING", state_json),
- ScalarQueryParameter("create_time", "TIMESTAMP", now),
- ScalarQueryParameter("update_time", "TIMESTAMP", now),
- ]
- else:
- sql = f"""
- INSERT INTO {table_name} (id, app_name, user_id, state, create_time, update_time)
- VALUES (@id, @app_name, @user_id, JSON(@state), @create_time, @update_time)
- """
-
- params = [
- ScalarQueryParameter("id", "STRING", session_id),
- ScalarQueryParameter("app_name", "STRING", app_name),
- ScalarQueryParameter("user_id", "STRING", user_id),
- ScalarQueryParameter("state", "STRING", state_json),
- ScalarQueryParameter("create_time", "TIMESTAMP", now),
- ScalarQueryParameter("update_time", "TIMESTAMP", now),
- ]
-
- with self._config.provide_connection() as conn:
- job_config = QueryJobConfig(query_parameters=params)
- conn.query(sql, job_config=job_config).result()
-
- return SessionRecord(
- id=session_id, app_name=app_name, user_id=user_id, state=state, create_time=now, update_time=now
- )
-
- async def create_session(
- self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
- ) -> SessionRecord:
- """Create a new session.
-
- Args:
- session_id: Unique session identifier.
- app_name: Application name.
- user_id: User identifier.
- state: Initial session state.
- owner_id: Optional owner ID value for owner_id_column (if configured).
-
- Returns:
- Created session record.
-
- Notes:
- Uses CURRENT_TIMESTAMP() for timestamps.
- State is JSON-serialized then stored in JSON column.
- If owner_id_column is configured, owner_id value must be provided.
- BigQuery doesn't enforce FK constraints, but column is useful for JOINs.
- """
- return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id)
-
- def _get_session(self, session_id: str) -> "SessionRecord | None":
- """Synchronous implementation of get_session."""
- table_name = self._get_full_table_name(self._session_table)
- sql = f"""
- SELECT id, app_name, user_id, JSON_VALUE(state) as state, create_time, update_time
- FROM {table_name}
- WHERE id = @session_id
- """
-
- params = [ScalarQueryParameter("session_id", "STRING", session_id)]
-
- with self._config.provide_connection() as conn:
- job_config = QueryJobConfig(query_parameters=params)
- query_job = conn.query(sql, job_config=job_config)
- results = list(query_job.result())
-
- if not results:
- return None
-
- row = results[0]
- return SessionRecord(
- id=row.id,
- app_name=row.app_name,
- user_id=row.user_id,
- state=from_json(row.state) if row.state else {},
- create_time=row.create_time,
- update_time=row.update_time,
- )
-
- async def get_session(self, session_id: str) -> "SessionRecord | None":
- """Get session by ID.
-
- Args:
- session_id: Session identifier.
-
- Returns:
- Session record or None if not found.
-
- Notes:
- BigQuery returns datetime objects for TIMESTAMP columns.
- JSON_VALUE extracts string representation for parsing.
- """
- return await async_(self._get_session)(session_id)
-
- def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
- """Synchronous implementation of update_session_state."""
- now = datetime.now(timezone.utc)
- state_json = to_json(state) if state else "{}"
-
- table_name = self._get_full_table_name(self._session_table)
- sql = f"""
- UPDATE {table_name}
- SET state = JSON(@state), update_time = @update_time
- WHERE id = @session_id
- """
-
- params = [
- ScalarQueryParameter("state", "STRING", state_json),
- ScalarQueryParameter("update_time", "TIMESTAMP", now),
- ScalarQueryParameter("session_id", "STRING", session_id),
- ]
-
- with self._config.provide_connection() as conn:
- job_config = QueryJobConfig(query_parameters=params)
- conn.query(sql, job_config=job_config).result()
-
- async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
- """Update session state.
-
- Args:
- session_id: Session identifier.
- state: New state dictionary (replaces existing state).
-
- Notes:
- Replaces entire state dictionary.
- Updates update_time to CURRENT_TIMESTAMP().
- """
- await async_(self._update_session_state)(session_id, state)
-
- def _list_sessions(self, app_name: str, user_id: "str | None") -> "list[SessionRecord]":
- """Synchronous implementation of list_sessions."""
- table_name = self._get_full_table_name(self._session_table)
-
- if user_id is None:
- sql = f"""
- SELECT id, app_name, user_id, JSON_VALUE(state) as state, create_time, update_time
- FROM {table_name}
- WHERE app_name = @app_name
- ORDER BY update_time DESC
- """
- params = [ScalarQueryParameter("app_name", "STRING", app_name)]
- else:
- sql = f"""
- SELECT id, app_name, user_id, JSON_VALUE(state) as state, create_time, update_time
- FROM {table_name}
- WHERE app_name = @app_name AND user_id = @user_id
- ORDER BY update_time DESC
- """
- params = [
- ScalarQueryParameter("app_name", "STRING", app_name),
- ScalarQueryParameter("user_id", "STRING", user_id),
- ]
-
- with self._config.provide_connection() as conn:
- job_config = QueryJobConfig(query_parameters=params)
- query_job = conn.query(sql, job_config=job_config)
- results = list(query_job.result())
-
- return [
- SessionRecord(
- id=row.id,
- app_name=row.app_name,
- user_id=row.user_id,
- state=from_json(row.state) if row.state else {},
- create_time=row.create_time,
- update_time=row.update_time,
- )
- for row in results
- ]
-
- async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
- """List sessions for an app, optionally filtered by user.
-
- Args:
- app_name: Application name.
- user_id: User identifier. If None, lists all sessions for the app.
-
- Returns:
- List of session records ordered by update_time DESC.
-
- Notes:
- Uses clustering on (app_name, user_id) when user_id is provided for efficiency.
- """
- return await async_(self._list_sessions)(app_name, user_id)
-
- def _delete_session(self, session_id: str) -> None:
- """Synchronous implementation of delete_session."""
- events_table = self._get_full_table_name(self._events_table)
- sessions_table = self._get_full_table_name(self._session_table)
-
- params = [ScalarQueryParameter("session_id", "STRING", session_id)]
-
- with self._config.provide_connection() as conn:
- job_config = QueryJobConfig(query_parameters=params)
- conn.query(f"DELETE FROM {events_table} WHERE session_id = @session_id", job_config=job_config).result()
- conn.query(f"DELETE FROM {sessions_table} WHERE id = @session_id", job_config=job_config).result()
-
- async def delete_session(self, session_id: str) -> None:
- """Delete session and all associated events.
-
- Args:
- session_id: Session identifier.
-
- Notes:
- BigQuery doesn't support foreign keys, so we manually delete events first.
- Uses two separate DELETE statements in sequence.
- """
- await async_(self._delete_session)(session_id)
-
- def _append_event(self, event_record: EventRecord) -> None:
- """Synchronous implementation of append_event."""
- table_name = self._get_full_table_name(self._events_table)
-
- content_json = to_json(event_record.get("content")) if event_record.get("content") else None
- grounding_metadata_json = (
- to_json(event_record.get("grounding_metadata")) if event_record.get("grounding_metadata") else None
- )
- custom_metadata_json = (
- to_json(event_record.get("custom_metadata")) if event_record.get("custom_metadata") else None
- )
-
- sql = f"""
- INSERT INTO {table_name} (
- id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
- ) VALUES (
- @id, @session_id, @app_name, @user_id, @invocation_id, @author, @actions,
- @long_running_tool_ids_json, @branch, @timestamp,
- {"JSON(@content)" if content_json else "NULL"},
- {"JSON(@grounding_metadata)" if grounding_metadata_json else "NULL"},
- {"JSON(@custom_metadata)" if custom_metadata_json else "NULL"},
- @partial, @turn_complete, @interrupted, @error_code, @error_message
- )
- """
-
- actions_value = event_record.get("actions")
- params = [
- ScalarQueryParameter("id", "STRING", event_record["id"]),
- ScalarQueryParameter("session_id", "STRING", event_record["session_id"]),
- ScalarQueryParameter("app_name", "STRING", event_record["app_name"]),
- ScalarQueryParameter("user_id", "STRING", event_record["user_id"]),
- ScalarQueryParameter("invocation_id", "STRING", event_record.get("invocation_id")),
- ScalarQueryParameter("author", "STRING", event_record.get("author")),
- ScalarQueryParameter(
- "actions",
- "BYTES",
- actions_value.decode("latin1") if isinstance(actions_value, bytes) else actions_value,
- ),
- ScalarQueryParameter(
- "long_running_tool_ids_json", "STRING", event_record.get("long_running_tool_ids_json")
- ),
- ScalarQueryParameter("branch", "STRING", event_record.get("branch")),
- ScalarQueryParameter("timestamp", "TIMESTAMP", event_record["timestamp"]),
- ScalarQueryParameter("partial", "BOOL", event_record.get("partial")),
- ScalarQueryParameter("turn_complete", "BOOL", event_record.get("turn_complete")),
- ScalarQueryParameter("interrupted", "BOOL", event_record.get("interrupted")),
- ScalarQueryParameter("error_code", "STRING", event_record.get("error_code")),
- ScalarQueryParameter("error_message", "STRING", event_record.get("error_message")),
- ]
-
- if content_json:
- params.append(ScalarQueryParameter("content", "STRING", content_json))
- if grounding_metadata_json:
- params.append(ScalarQueryParameter("grounding_metadata", "STRING", grounding_metadata_json))
- if custom_metadata_json:
- params.append(ScalarQueryParameter("custom_metadata", "STRING", custom_metadata_json))
-
- with self._config.provide_connection() as conn:
- job_config = QueryJobConfig(query_parameters=params)
- conn.query(sql, job_config=job_config).result()
-
- async def append_event(self, event_record: EventRecord) -> None:
- """Append an event to a session.
-
- Args:
- event_record: Event record to store.
-
- Notes:
- Uses BigQuery TIMESTAMP for timezone-aware timestamps.
- JSON fields are serialized to STRING then cast to JSON.
- Boolean fields stored natively as BOOL.
- """
- await async_(self._append_event)(event_record)
-
- def _get_events(
- self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
- ) -> "list[EventRecord]":
- """Synchronous implementation of get_events."""
- table_name = self._get_full_table_name(self._events_table)
-
- where_clauses = ["session_id = @session_id"]
- params: list[ScalarQueryParameter] = [ScalarQueryParameter("session_id", "STRING", session_id)]
-
- if after_timestamp is not None:
- where_clauses.append("timestamp > @after_timestamp")
- params.append(ScalarQueryParameter("after_timestamp", "TIMESTAMP", after_timestamp))
-
- where_clause = " AND ".join(where_clauses)
- limit_clause = f" LIMIT {limit}" if limit else ""
-
- sql = f"""
- SELECT id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp,
- JSON_VALUE(content) as content,
- JSON_VALUE(grounding_metadata) as grounding_metadata,
- JSON_VALUE(custom_metadata) as custom_metadata,
- partial, turn_complete, interrupted, error_code, error_message
- FROM {table_name}
- WHERE {where_clause}
- ORDER BY timestamp ASC{limit_clause}
- """
-
- with self._config.provide_connection() as conn:
- job_config = QueryJobConfig(query_parameters=params)
- query_job = conn.query(sql, job_config=job_config)
- results = list(query_job.result())
-
- return [
- EventRecord(
- id=row.id,
- session_id=row.session_id,
- app_name=row.app_name,
- user_id=row.user_id,
- invocation_id=row.invocation_id,
- author=row.author,
- actions=bytes(row.actions) if row.actions else b"",
- long_running_tool_ids_json=row.long_running_tool_ids_json,
- branch=row.branch,
- timestamp=row.timestamp,
- content=from_json(row.content) if row.content else None,
- grounding_metadata=from_json(row.grounding_metadata) if row.grounding_metadata else None,
- custom_metadata=from_json(row.custom_metadata) if row.custom_metadata else None,
- partial=row.partial,
- turn_complete=row.turn_complete,
- interrupted=row.interrupted,
- error_code=row.error_code,
- error_message=row.error_message,
- )
- for row in results
- ]
-
- async def get_events(
- self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
- ) -> "list[EventRecord]":
- """Get events for a session.
-
- Args:
- session_id: Session identifier.
- after_timestamp: Only return events after this time.
- limit: Maximum number of events to return.
-
- Returns:
- List of event records ordered by timestamp ASC.
-
- Notes:
- Uses clustering on (session_id, timestamp) for efficient retrieval.
- Parses JSON fields and converts BYTES actions to bytes.
- """
- return await async_(self._get_events)(session_id, after_timestamp, limit)
-
-
-class BigQueryADKMemoryStore(BaseAsyncADKMemoryStore["BigQueryConfig"]):
- """BigQuery ADK memory store using synchronous BigQuery client with async wrapper."""
-
- __slots__ = ("_dataset_id",)
-
- def __init__(self, config: "BigQueryConfig") -> None:
- """Initialize BigQuery ADK memory store."""
- super().__init__(config)
- self._dataset_id = config.connection_config.get("dataset_id")
-
- def _get_full_table_name(self, table_name: str) -> str:
- """Get fully qualified table name for BigQuery."""
- if self._dataset_id:
- return f"`{self._dataset_id}.{table_name}`"
- return f"`{table_name}`"
-
- async def _get_create_memory_table_sql(self) -> str:
- """Get BigQuery CREATE TABLE SQL for memory entries."""
- owner_id_line = ""
- if self._owner_id_column_ddl:
- owner_id_line = f",\n {self._owner_id_column_ddl}"
-
- table_name = self._get_full_table_name(self._memory_table)
- fts_index = ""
- if self._use_fts:
- fts_index = f"""
- CREATE SEARCH INDEX idx_{self._memory_table}_fts
- ON {table_name}(content_text)
- """
-
- return f"""
- CREATE TABLE IF NOT EXISTS {table_name} (
- id STRING NOT NULL,
- session_id STRING NOT NULL,
- app_name STRING NOT NULL,
- user_id STRING NOT NULL,
- event_id STRING NOT NULL,
- author STRING{owner_id_line},
- timestamp TIMESTAMP NOT NULL,
- content_json JSON NOT NULL,
- content_text STRING NOT NULL,
- metadata_json JSON,
- inserted_at TIMESTAMP NOT NULL
- )
- PARTITION BY DATE(timestamp)
- CLUSTER BY app_name, user_id;
- {fts_index}
- """
-
- def _get_drop_memory_table_sql(self) -> "list[str]":
- """Get BigQuery DROP TABLE SQL statements."""
- table_name = self._get_full_table_name(self._memory_table)
- return [f"DROP TABLE IF EXISTS {table_name}"]
-
- def _create_tables(self) -> None:
- """Synchronous implementation of create_tables."""
- with self._config.provide_session() as driver:
- driver.execute_script(run_(self._get_create_memory_table_sql)())
-
- async def create_tables(self) -> None:
- """Create the memory table if it doesn't exist."""
- if not self._enabled:
- return
- await async_(self._create_tables)()
-
- def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
- """Synchronous implementation of insert_memory_entries."""
- table_name = self._get_full_table_name(self._memory_table)
- inserted_count = 0
-
- with self._config.provide_connection() as conn:
- for entry in entries:
- content_json = to_json(entry["content_json"])
- metadata_json = to_json(entry["metadata_json"]) if entry["metadata_json"] is not None else None
- metadata_expr = "JSON(@metadata_json)" if metadata_json is not None else "NULL"
-
- owner_column = f", {self._owner_id_column_name}" if self._owner_id_column_name else ""
- owner_value = ", @owner_id" if self._owner_id_column_name else ""
-
- sql = f"""
- MERGE {table_name} T
- USING (SELECT @event_id AS event_id) S
- ON T.event_id = S.event_id
- WHEN NOT MATCHED THEN
- INSERT (id, session_id, app_name, user_id, event_id, author{owner_column},
- timestamp, content_json, content_text, metadata_json, inserted_at)
- VALUES (@id, @session_id, @app_name, @user_id, @event_id, @author{owner_value},
- @timestamp, JSON(@content_json), @content_text, {metadata_expr}, @inserted_at)
- """
-
- params = [
- ScalarQueryParameter("id", "STRING", entry["id"]),
- ScalarQueryParameter("session_id", "STRING", entry["session_id"]),
- ScalarQueryParameter("app_name", "STRING", entry["app_name"]),
- ScalarQueryParameter("user_id", "STRING", entry["user_id"]),
- ScalarQueryParameter("event_id", "STRING", entry["event_id"]),
- ScalarQueryParameter("author", "STRING", entry["author"]),
- ScalarQueryParameter("timestamp", "TIMESTAMP", entry["timestamp"]),
- ScalarQueryParameter("content_json", "STRING", content_json),
- ScalarQueryParameter("content_text", "STRING", entry["content_text"]),
- ScalarQueryParameter("inserted_at", "TIMESTAMP", entry["inserted_at"]),
- ]
-
- if self._owner_id_column_name:
- params.append(ScalarQueryParameter("owner_id", "STRING", str(owner_id) if owner_id else None))
- if metadata_json is not None:
- params.append(ScalarQueryParameter("metadata_json", "STRING", metadata_json))
-
- job_config = QueryJobConfig(query_parameters=params)
- job = conn.query(sql, job_config=job_config)
- job.result()
- if job.num_dml_affected_rows:
- inserted_count += int(job.num_dml_affected_rows)
-
- return inserted_count
-
- async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
- """Bulk insert memory entries with deduplication."""
- if not self._enabled:
- msg = "Memory store is disabled"
- raise RuntimeError(msg)
-
- if not entries:
- return 0
-
- return await async_(self._insert_memory_entries)(entries, owner_id)
-
- def _search_entries(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]":
- """Synchronous implementation of search_entries."""
- table_name = self._get_full_table_name(self._memory_table)
- base_params = [
- ScalarQueryParameter("app_name", "STRING", app_name),
- ScalarQueryParameter("user_id", "STRING", user_id),
- ScalarQueryParameter("limit", "INT64", limit),
- ]
-
- if self._use_fts:
- sql = f"""
- SELECT id, session_id, app_name, user_id, event_id, author,
- timestamp, content_json, content_text, metadata_json, inserted_at
- FROM {table_name}
- WHERE app_name = @app_name
- AND user_id = @user_id
- AND SEARCH(content_text, @query)
- ORDER BY timestamp DESC
- LIMIT @limit
- """
- params = [*base_params, ScalarQueryParameter("query", "STRING", query)]
- else:
- sql = f"""
- SELECT id, session_id, app_name, user_id, event_id, author,
- timestamp, content_json, content_text, metadata_json, inserted_at
- FROM {table_name}
- WHERE app_name = @app_name
- AND user_id = @user_id
- AND LOWER(content_text) LIKE LOWER(@pattern)
- ORDER BY timestamp DESC
- LIMIT @limit
- """
- pattern = f"%{query}%"
- params = [*base_params, ScalarQueryParameter("pattern", "STRING", pattern)]
-
- with self._config.provide_connection() as conn:
- job_config = QueryJobConfig(query_parameters=params)
- rows = conn.query(sql, job_config=job_config).result()
- return _rows_to_records(rows)
-
- async def search_entries(
- self, query: str, app_name: str, user_id: str, limit: "int | None" = None
- ) -> "list[MemoryRecord]":
- """Search memory entries by text query."""
- if not self._enabled:
- msg = "Memory store is disabled"
- raise RuntimeError(msg)
-
- effective_limit = limit if limit is not None else self._max_results
-
- try:
- return await async_(self._search_entries)(query, app_name, user_id, effective_limit)
- except NotFound:
- return []
-
- def _delete_entries_by_session(self, session_id: str) -> int:
- table_name = self._get_full_table_name(self._memory_table)
- sql = f"DELETE FROM {table_name} WHERE session_id = @session_id"
- params = [ScalarQueryParameter("session_id", "STRING", session_id)]
- with self._config.provide_connection() as conn:
- job_config = QueryJobConfig(query_parameters=params)
- job = conn.query(sql, job_config=job_config)
- job.result()
- return int(job.num_dml_affected_rows or 0)
-
- async def delete_entries_by_session(self, session_id: str) -> int:
- """Delete all memory entries for a specific session."""
- return await async_(self._delete_entries_by_session)(session_id)
-
- def _delete_entries_older_than(self, days: int) -> int:
- table_name = self._get_full_table_name(self._memory_table)
- sql = f"""
- DELETE FROM {table_name}
- WHERE inserted_at < TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL {days} DAY)
- """
- with self._config.provide_connection() as conn:
- job = conn.query(sql)
- job.result()
- return int(job.num_dml_affected_rows or 0)
-
- async def delete_entries_older_than(self, days: int) -> int:
- """Delete memory entries older than specified days."""
- return await async_(self._delete_entries_older_than)(days)
-
-
-def _decode_json_field(value: Any) -> "dict[str, Any] | None":
- if value is None:
- return None
- if isinstance(value, str):
- return cast("dict[str, Any]", from_json(value))
- if isinstance(value, Mapping):
- return dict(value)
- return None
-
-
-def _rows_to_records(rows: Any) -> "list[MemoryRecord]":
- return [
- {
- "id": row["id"],
- "session_id": row["session_id"],
- "app_name": row["app_name"],
- "user_id": row["user_id"],
- "event_id": row["event_id"],
- "author": row["author"],
- "timestamp": row["timestamp"],
- "content_json": _decode_json_field(row["content_json"]) or {},
- "content_text": row["content_text"],
- "metadata_json": _decode_json_field(row["metadata_json"]),
- "inserted_at": row["inserted_at"],
- }
- for row in rows
- ]
diff --git a/sqlspec/adapters/cockroach_asyncpg/adk/store.py b/sqlspec/adapters/cockroach_asyncpg/adk/store.py
index c81e5f6cc..bedaaa788 100644
--- a/sqlspec/adapters/cockroach_asyncpg/adk/store.py
+++ b/sqlspec/adapters/cockroach_asyncpg/adk/store.py
@@ -7,6 +7,8 @@
from sqlspec.utils.logging import get_logger
if TYPE_CHECKING:
+ from datetime import datetime
+
from sqlspec.adapters.cockroach_asyncpg.config import CockroachAsyncpgConfig
from sqlspec.extensions.adk import MemoryRecord
@@ -17,7 +19,19 @@
class CockroachAsyncpgADKStore(BaseAsyncADKStore["CockroachAsyncpgConfig"]):
- """CockroachDB ADK store using asyncpg driver."""
+ """CockroachDB ADK store using asyncpg driver.
+
+ Implements session and event storage for Google Agent Development Kit
+ using CockroachDB via asyncpg in PostgreSQL compatibility mode.
+ Events are stored as a single JSONB blob (``event_json``) alongside
+ indexed scalar columns for efficient querying.
+
+ CockroachDB-specific differences from native PostgreSQL:
+ - No FILLFACTOR (CockroachDB uses different storage engine)
+ - No BRIN indexes (different physical storage layout)
+ - GIN/Inverted indexes on JSONB are fully supported (v23.1+)
+ - Native tsvector/tsquery FTS with GIN is supported (v23.1+)
+ """
__slots__ = ()
@@ -44,34 +58,28 @@ async def _get_create_sessions_table_sql(self) -> str:
CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time
ON {self._session_table}(update_time DESC);
+
+ CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state
+ ON {self._session_table} USING GIN (state)
+ WHERE state != '{{}}'::jsonb;
"""
async def _get_create_events_table_sql(self) -> str:
return f"""
CREATE TABLE IF NOT EXISTS {self._events_table} (
- id VARCHAR(128) PRIMARY KEY,
session_id VARCHAR(128) NOT NULL,
- app_name VARCHAR(128) NOT NULL,
- user_id VARCHAR(128) NOT NULL,
invocation_id VARCHAR(256) NOT NULL,
author VARCHAR(256) NOT NULL,
- actions BYTEA NOT NULL,
- long_running_tool_ids_json JSONB,
- branch VARCHAR(256),
timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
- content JSONB,
- grounding_metadata JSONB,
- custom_metadata JSONB,
- partial BOOLEAN,
- turn_complete BOOLEAN,
- interrupted BOOLEAN,
- error_code VARCHAR(256),
- error_message VARCHAR(1024),
+ event_json JSONB NOT NULL,
FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session
ON {self._events_table}(session_id, timestamp ASC);
+
+ CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_json
+ ON {self._events_table} USING GIN (event_json);
"""
def _get_drop_tables_sql(self) -> "list[str]":
@@ -181,72 +189,77 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis
async def append_event(self, event_record: EventRecord) -> None:
sql = f"""
INSERT INTO {self._events_table} (
- id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
- ) VALUES (
- $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18
- )
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES ($1, $2, $3, $4, $5)
"""
async with self._config.provide_connection() as conn:
await conn.execute(
sql,
- event_record["id"],
event_record["session_id"],
- event_record["app_name"],
- event_record["user_id"],
event_record["invocation_id"],
event_record["author"],
- event_record["actions"],
- event_record.get("long_running_tool_ids_json"),
- event_record.get("branch"),
event_record["timestamp"],
- event_record.get("content"),
- event_record.get("grounding_metadata"),
- event_record.get("custom_metadata"),
- event_record.get("partial"),
- event_record.get("turn_complete"),
- event_record.get("interrupted"),
- event_record.get("error_code"),
- event_record.get("error_message"),
+ event_record["event_json"],
+ )
+
+ async def append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ insert_sql = f"""
+ INSERT INTO {self._events_table} (
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES ($1, $2, $3, $4, $5)
+ """
+ update_sql = f"""
+ UPDATE {self._session_table}
+ SET state = $1, update_time = CURRENT_TIMESTAMP
+ WHERE id = $2
+ """
+
+ async with self._config.provide_connection() as conn, conn.transaction():
+ await conn.execute(
+ insert_sql,
+ event_record["session_id"],
+ event_record["invocation_id"],
+ event_record["author"],
+ event_record["timestamp"],
+ event_record["event_json"],
)
+ await conn.execute(update_sql, state, session_id)
+
+ async def get_events(
+ self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
+ ) -> "list[EventRecord]":
+ where_clauses = ["session_id = $1"]
+ params: list[Any] = [session_id]
+
+ if after_timestamp is not None:
+ where_clauses.append(f"timestamp > ${len(params) + 1}")
+ params.append(after_timestamp)
+
+ where_clause = " AND ".join(where_clauses)
+ limit_clause = f" LIMIT ${len(params) + 1}" if limit else ""
+ if limit:
+ params.append(limit)
- async def list_events(self, session_id: str) -> "list[EventRecord]":
sql = f"""
- SELECT id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
+ SELECT session_id, invocation_id, author, timestamp, event_json
FROM {self._events_table}
- WHERE session_id = $1
- ORDER BY timestamp ASC
+ WHERE {where_clause}
+ ORDER BY timestamp ASC{limit_clause}
"""
async with self._config.provide_connection() as conn:
- rows = await conn.fetch(sql, session_id)
+ rows = await conn.fetch(sql, *params)
return [
EventRecord(
- id=row["id"],
session_id=row["session_id"],
- app_name=row["app_name"],
- user_id=row["user_id"],
invocation_id=row["invocation_id"],
author=row["author"],
- actions=bytes(row["actions"]),
- long_running_tool_ids_json=row["long_running_tool_ids_json"],
- branch=row["branch"],
timestamp=row["timestamp"],
- content=row["content"],
- grounding_metadata=row["grounding_metadata"],
- custom_metadata=row["custom_metadata"],
- partial=row["partial"],
- turn_complete=row["turn_complete"],
- interrupted=row["interrupted"],
- error_code=row["error_code"],
- error_message=row["error_message"],
+ event_json=row["event_json"],
)
for row in rows
]
@@ -265,6 +278,13 @@ async def _get_create_memory_table_sql(self) -> str:
if self._owner_id_column_ddl:
owner_id_line = f",\n {self._owner_id_column_ddl}"
+ fts_index = ""
+ if self._use_fts:
+ fts_index = f"""
+ CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts
+ ON {self._memory_table} USING GIN (to_tsvector('english', content_text));
+ """
+
return f"""
CREATE TABLE IF NOT EXISTS {self._memory_table} (
id VARCHAR(128) PRIMARY KEY,
@@ -285,6 +305,7 @@ async def _get_create_memory_table_sql(self) -> str:
CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session
ON {self._memory_table}(session_id);
+ {fts_index}
"""
def _get_drop_memory_table_sql(self) -> "list[str]":
@@ -371,18 +392,27 @@ async def search_entries(
return []
effective_limit = limit if limit is not None else self._max_results
- if self._use_fts:
- logger.debug("CockroachDB full-text search not supported; using simple search")
- sql = f"""
- SELECT * FROM {self._memory_table}
- WHERE app_name = $1 AND user_id = $2 AND content_text ILIKE $3
- ORDER BY timestamp DESC
- LIMIT $4
- """
+ if self._use_fts:
+ sql = f"""
+ SELECT * FROM {self._memory_table}
+ WHERE app_name = $1 AND user_id = $2
+ AND to_tsvector('english', content_text) @@ plainto_tsquery('english', $3)
+ ORDER BY timestamp DESC
+ LIMIT $4
+ """
+ params: tuple[Any, ...] = (app_name, user_id, query, effective_limit)
+ else:
+ sql = f"""
+ SELECT * FROM {self._memory_table}
+ WHERE app_name = $1 AND user_id = $2 AND content_text ILIKE $3
+ ORDER BY timestamp DESC
+ LIMIT $4
+ """
+ params = (app_name, user_id, f"%{query}%", effective_limit)
async with self._config.provide_connection() as conn:
- rows = await conn.fetch(sql, app_name, user_id, f"%{query}%", effective_limit)
+ rows = await conn.fetch(sql, *params)
return [cast("MemoryRecord", dict(row)) for row in rows]
@@ -403,8 +433,8 @@ async def delete_entries_older_than(self, days: int) -> int:
sql = f"""
DELETE FROM {self._memory_table}
- WHERE inserted_at < (CURRENT_TIMESTAMP - INTERVAL $1 DAY)
+ WHERE inserted_at < (CURRENT_TIMESTAMP - INTERVAL '{days} days')
"""
async with self._config.provide_connection() as conn:
- result = await conn.execute(sql, days)
+ result = await conn.execute(sql)
return int(result.split()[-1]) if result else 0
diff --git a/sqlspec/adapters/cockroach_psycopg/adk/store.py b/sqlspec/adapters/cockroach_psycopg/adk/store.py
index 97ee76bc6..d2906c7ea 100644
--- a/sqlspec/adapters/cockroach_psycopg/adk/store.py
+++ b/sqlspec/adapters/cockroach_psycopg/adk/store.py
@@ -6,11 +6,14 @@
from psycopg import sql as pg_sql
from psycopg.types.json import Jsonb
-from sqlspec.extensions.adk import BaseAsyncADKStore, BaseSyncADKStore, EventRecord, SessionRecord
-from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore
+from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord
+from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore
from sqlspec.utils.logging import get_logger
+from sqlspec.utils.sync_tools import async_, run_
if TYPE_CHECKING:
+ from datetime import datetime
+
from sqlspec.adapters.cockroach_psycopg.config import CockroachPsycopgAsyncConfig, CockroachPsycopgSyncConfig
from sqlspec.extensions.adk import MemoryRecord
@@ -59,7 +62,19 @@ def _build_insert_params_with_owner(entry: "MemoryRecord", owner_id: "object | N
class CockroachPsycopgAsyncADKStore(BaseAsyncADKStore["CockroachPsycopgAsyncConfig"]):
- """CockroachDB ADK store using psycopg async driver."""
+ """CockroachDB ADK store using psycopg async driver.
+
+ Implements session and event storage for Google Agent Development Kit
+ using CockroachDB via psycopg in PostgreSQL compatibility mode.
+ Events are stored as a single JSONB blob (``event_json``) alongside
+ indexed scalar columns for efficient querying.
+
+ CockroachDB-specific differences from native PostgreSQL:
+ - No FILLFACTOR (CockroachDB uses different storage engine)
+ - SQL strings require ``.encode()`` for cockroach-psycopg driver
+ - GIN/Inverted indexes on JSONB are fully supported (v23.1+)
+ - Native tsvector/tsquery FTS with GIN is supported (v23.1+)
+ """
__slots__ = ()
@@ -67,7 +82,6 @@ def __init__(self, config: "CockroachPsycopgAsyncConfig") -> None:
super().__init__(config)
async def _get_create_sessions_table_sql(self) -> str:
- """Get CockroachDB CREATE TABLE SQL for sessions."""
owner_id_line = ""
if self._owner_id_column_ddl:
owner_id_line = f",\n {self._owner_id_column_ddl}"
@@ -87,35 +101,28 @@ async def _get_create_sessions_table_sql(self) -> str:
CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time
ON {self._session_table}(update_time DESC);
+
+ CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state
+ ON {self._session_table} USING GIN (state)
+ WHERE state != '{{}}'::jsonb;
"""
async def _get_create_events_table_sql(self) -> str:
- """Get CockroachDB CREATE TABLE SQL for events."""
return f"""
CREATE TABLE IF NOT EXISTS {self._events_table} (
- id VARCHAR(128) PRIMARY KEY,
session_id VARCHAR(128) NOT NULL,
- app_name VARCHAR(128) NOT NULL,
- user_id VARCHAR(128) NOT NULL,
invocation_id VARCHAR(256) NOT NULL,
author VARCHAR(256) NOT NULL,
- actions BYTEA NOT NULL,
- long_running_tool_ids_json JSONB,
- branch VARCHAR(256),
timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
- content JSONB,
- grounding_metadata JSONB,
- custom_metadata JSONB,
- partial BOOLEAN,
- turn_complete BOOLEAN,
- interrupted BOOLEAN,
- error_code VARCHAR(256),
- error_message VARCHAR(1024),
+ event_json JSONB NOT NULL,
FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session
ON {self._events_table}(session_id, timestamp ASC);
+
+ CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_json
+ ON {self._events_table} USING GIN (event_json);
"""
def _get_drop_tables_sql(self) -> "list[str]":
@@ -239,77 +246,91 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis
async def append_event(self, event_record: EventRecord) -> None:
sql = f"""
INSERT INTO {self._events_table} (
- id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
- ) VALUES (
- %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
- )
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES (%s, %s, %s, %s, %s)
"""
+ event_json_value = event_record["event_json"]
+ jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value
+
async with self._config.provide_connection() as conn, conn.cursor() as cur:
await cur.execute(
sql.encode(),
(
- event_record["id"],
event_record["session_id"],
- event_record["app_name"],
- event_record["user_id"],
event_record["invocation_id"],
event_record["author"],
- event_record["actions"],
- event_record.get("long_running_tool_ids_json"),
- event_record.get("branch"),
event_record["timestamp"],
- event_record.get("content"),
- event_record.get("grounding_metadata"),
- event_record.get("custom_metadata"),
- event_record.get("partial"),
- event_record.get("turn_complete"),
- event_record.get("interrupted"),
- event_record.get("error_code"),
- event_record.get("error_message"),
+ jsonb_value,
+ ),
+ )
+ await conn.commit()
+
+ async def append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ insert_sql = f"""
+ INSERT INTO {self._events_table} (
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES (%s, %s, %s, %s, %s)
+ """
+ update_sql = f"""
+ UPDATE {self._session_table}
+ SET state = %s, update_time = CURRENT_TIMESTAMP
+ WHERE id = %s
+ """
+
+ event_json_value = event_record["event_json"]
+ jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value
+
+ async with self._config.provide_connection() as conn, conn.cursor() as cur:
+ await cur.execute(
+ insert_sql.encode(),
+ (
+ event_record["session_id"],
+ event_record["invocation_id"],
+ event_record["author"],
+ event_record["timestamp"],
+ jsonb_value,
),
)
+ await cur.execute(update_sql.encode(), (Jsonb(state), session_id))
await conn.commit()
- async def list_events(self, session_id: str) -> "list[EventRecord]":
+ async def get_events(
+ self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
+ ) -> "list[EventRecord]":
+ where_clauses = ["session_id = %s"]
+ params: list[Any] = [session_id]
+
+ if after_timestamp is not None:
+ where_clauses.append("timestamp > %s")
+ params.append(after_timestamp)
+
+ where_clause = " AND ".join(where_clauses)
+ limit_clause = " LIMIT %s" if limit else ""
+ if limit:
+ params.append(limit)
+
sql = f"""
- SELECT id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
+ SELECT session_id, invocation_id, author, timestamp, event_json
FROM {self._events_table}
- WHERE session_id = %s
- ORDER BY timestamp ASC
+ WHERE {where_clause}
+ ORDER BY timestamp ASC{limit_clause}
"""
try:
async with self._config.provide_connection() as conn, conn.cursor() as cur:
- await cur.execute(sql.encode(), (session_id,))
+ await cur.execute(sql.encode(), tuple(params))
rows = await cur.fetchall()
return [
EventRecord(
- id=row["id"],
session_id=row["session_id"],
- app_name=row["app_name"],
- user_id=row["user_id"],
invocation_id=row["invocation_id"],
author=row["author"],
- actions=bytes(row["actions"]),
- long_running_tool_ids_json=row["long_running_tool_ids_json"],
- branch=row["branch"],
timestamp=row["timestamp"],
- content=row["content"],
- grounding_metadata=row["grounding_metadata"],
- custom_metadata=row["custom_metadata"],
- partial=row["partial"],
- turn_complete=row["turn_complete"],
- interrupted=row["interrupted"],
- error_code=row["error_code"],
- error_message=row["error_message"],
+ event_json=row["event_json"],
)
for row in rows
]
@@ -317,15 +338,26 @@ async def list_events(self, session_id: str) -> "list[EventRecord]":
return []
-class CockroachPsycopgSyncADKStore(BaseSyncADKStore["CockroachPsycopgSyncConfig"]):
- """CockroachDB ADK store using psycopg sync driver."""
+class CockroachPsycopgSyncADKStore(BaseAsyncADKStore["CockroachPsycopgSyncConfig"]):
+ """CockroachDB ADK store using psycopg sync driver.
+
+ Implements session and event storage for Google Agent Development Kit
+ using CockroachDB via psycopg in PostgreSQL compatibility mode (sync).
+ Events are stored as a single JSONB blob (``event_json``) alongside
+ indexed scalar columns for efficient querying.
+
+ CockroachDB-specific differences from native PostgreSQL:
+ - No FILLFACTOR (CockroachDB uses different storage engine)
+ - SQL strings require ``.encode()`` for cockroach-psycopg driver
+ - GIN/Inverted indexes on JSONB are fully supported (v23.1+)
+ """
__slots__ = ()
def __init__(self, config: "CockroachPsycopgSyncConfig") -> None:
super().__init__(config)
- def _get_create_sessions_table_sql(self) -> str:
+ async def _get_create_sessions_table_sql(self) -> str:
owner_id_line = ""
if self._owner_id_column_ddl:
owner_id_line = f",\n {self._owner_id_column_ddl}"
@@ -345,45 +377,43 @@ def _get_create_sessions_table_sql(self) -> str:
CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time
ON {self._session_table}(update_time DESC);
+
+ CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state
+ ON {self._session_table} USING GIN (state)
+ WHERE state != '{{}}'::jsonb;
"""
- def _get_create_events_table_sql(self) -> str:
+ async def _get_create_events_table_sql(self) -> str:
return f"""
CREATE TABLE IF NOT EXISTS {self._events_table} (
- id VARCHAR(128) PRIMARY KEY,
session_id VARCHAR(128) NOT NULL,
- app_name VARCHAR(128) NOT NULL,
- user_id VARCHAR(128) NOT NULL,
invocation_id VARCHAR(256) NOT NULL,
author VARCHAR(256) NOT NULL,
- actions BYTEA NOT NULL,
- long_running_tool_ids_json JSONB,
- branch VARCHAR(256),
timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
- content JSONB,
- grounding_metadata JSONB,
- custom_metadata JSONB,
- partial BOOLEAN,
- turn_complete BOOLEAN,
- interrupted BOOLEAN,
- error_code VARCHAR(256),
- error_message VARCHAR(1024),
+ event_json JSONB NOT NULL,
FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session
ON {self._events_table}(session_id, timestamp ASC);
+
+ CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_json
+ ON {self._events_table} USING GIN (event_json);
"""
def _get_drop_tables_sql(self) -> "list[str]":
return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"]
- def create_tables(self) -> None:
+ def _create_tables(self) -> None:
with self._config.provide_session() as driver:
- driver.execute_script(self._get_create_sessions_table_sql())
- driver.execute_script(self._get_create_events_table_sql())
+ driver.execute_script(run_(self._get_create_sessions_table_sql)())
+ driver.execute_script(run_(self._get_create_events_table_sql)())
- def create_session(
+ async def create_tables(self) -> None:
+ """Create tables if they don't exist."""
+ await async_(self._create_tables)()
+
+ def _create_session(
self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
) -> SessionRecord:
state_json = Jsonb(state)
@@ -406,13 +436,19 @@ def create_session(
cur.execute(sql.encode(), params)
conn.commit()
- result = self.get_session(session_id)
+ result = self._get_session(session_id)
if result is None:
msg = "Session creation failed"
raise RuntimeError(msg)
return result
- def get_session(self, session_id: str) -> "SessionRecord | None":
+ async def create_session(
+ self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
+ ) -> SessionRecord:
+ """Create a new session."""
+ return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id)
+
+ def _get_session(self, session_id: str) -> "SessionRecord | None":
sql = f"""
SELECT id, app_name, user_id, state, create_time, update_time
FROM {self._session_table}
@@ -438,7 +474,11 @@ def get_session(self, session_id: str) -> "SessionRecord | None":
except errors.UndefinedTable:
return None
- def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
+ async def get_session(self, session_id: str) -> "SessionRecord | None":
+ """Get session by ID."""
+ return await async_(self._get_session)(session_id)
+
+ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
sql = f"""
UPDATE {self._session_table}
SET state = %s, update_time = CURRENT_TIMESTAMP
@@ -449,14 +489,22 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None
cur.execute(sql.encode(), (Jsonb(state), session_id))
conn.commit()
- def delete_session(self, session_id: str) -> None:
+ async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
+ """Update session state."""
+ await async_(self._update_session_state)(session_id, state)
+
+ def _delete_session(self, session_id: str) -> None:
sql = f"DELETE FROM {self._session_table} WHERE id = %s"
with self._config.provide_connection() as conn, conn.cursor() as cur:
cur.execute(sql.encode(), (session_id,))
conn.commit()
- def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
+ async def delete_session(self, session_id: str) -> None:
+ """Delete session and associated events."""
+ await async_(self._delete_session)(session_id)
+
+ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
if user_id is None:
sql = f"""
SELECT id, app_name, user_id, state, create_time, update_time
@@ -493,86 +541,123 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess
except errors.UndefinedTable:
return []
- def append_event(self, event_record: EventRecord) -> None:
+ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
+ """List sessions for an app."""
+ return await async_(self._list_sessions)(app_name, user_id)
+
+ def _append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ insert_sql = f"""
+ INSERT INTO {self._events_table} (
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES (%s, %s, %s, %s, %s)
+ """
+ update_sql = f"""
+ UPDATE {self._session_table}
+ SET state = %s, update_time = CURRENT_TIMESTAMP
+ WHERE id = %s
+ """
+
+ event_json_value = event_record["event_json"]
+ jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value
+
+ with self._config.provide_connection() as conn, conn.cursor() as cur:
+ cur.execute(
+ insert_sql.encode(),
+ (
+ event_record["session_id"],
+ event_record["invocation_id"],
+ event_record["author"],
+ event_record["timestamp"],
+ jsonb_value,
+ ),
+ )
+ cur.execute(update_sql.encode(), (Jsonb(state), session_id))
+ conn.commit()
+
+ async def append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ """Atomically append an event and update the session's durable state."""
+ await async_(self._append_event_and_update_state)(event_record, session_id, state)
+
+ def _insert_event(self, event_record: EventRecord) -> None:
sql = f"""
INSERT INTO {self._events_table} (
- id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
- ) VALUES (
- %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
- )
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES (%s, %s, %s, %s, %s)
"""
+ event_json_value = event_record["event_json"]
+ jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value
+
with self._config.provide_connection() as conn, conn.cursor() as cur:
cur.execute(
sql.encode(),
(
- event_record["id"],
event_record["session_id"],
- event_record["app_name"],
- event_record["user_id"],
event_record["invocation_id"],
event_record["author"],
- event_record["actions"],
- event_record.get("long_running_tool_ids_json"),
- event_record.get("branch"),
event_record["timestamp"],
- event_record.get("content"),
- event_record.get("grounding_metadata"),
- event_record.get("custom_metadata"),
- event_record.get("partial"),
- event_record.get("turn_complete"),
- event_record.get("interrupted"),
- event_record.get("error_code"),
- event_record.get("error_message"),
+ jsonb_value,
),
)
conn.commit()
- def list_events(self, session_id: str) -> "list[EventRecord]":
+ def _get_events(
+ self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
+ ) -> "list[EventRecord]":
+ where_clauses = ["session_id = %s"]
+ params: list[Any] = [session_id]
+
+ if after_timestamp is not None:
+ where_clauses.append("timestamp > %s")
+ params.append(after_timestamp)
+
+ where_clause = " AND ".join(where_clauses)
+ limit_clause = " LIMIT %s" if limit else ""
sql = f"""
- SELECT id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
+ SELECT session_id, invocation_id, author, timestamp, event_json
FROM {self._events_table}
- WHERE session_id = %s
- ORDER BY timestamp ASC
+ WHERE {where_clause}
+ ORDER BY timestamp ASC{limit_clause}
"""
+ if limit:
+ params.append(limit)
try:
with self._config.provide_connection() as conn, conn.cursor() as cur:
- cur.execute(sql.encode(), (session_id,))
+ cur.execute(sql.encode(), tuple(params))
rows = cur.fetchall()
return [
EventRecord(
- id=row["id"],
session_id=row["session_id"],
- app_name=row["app_name"],
- user_id=row["user_id"],
invocation_id=row["invocation_id"],
author=row["author"],
- actions=bytes(row["actions"]),
- long_running_tool_ids_json=row["long_running_tool_ids_json"],
- branch=row["branch"],
timestamp=row["timestamp"],
- content=row["content"],
- grounding_metadata=row["grounding_metadata"],
- custom_metadata=row["custom_metadata"],
- partial=row["partial"],
- turn_complete=row["turn_complete"],
- interrupted=row["interrupted"],
- error_code=row["error_code"],
- error_message=row["error_message"],
+ event_json=row["event_json"],
)
for row in rows
]
except errors.UndefinedTable:
return []
+ async def get_events(
+ self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
+ ) -> "list[EventRecord]":
+ """Get events for a session."""
+ return await async_(self._get_events)(session_id, after_timestamp, limit)
+
+ def _append_event(self, event_record: EventRecord) -> None:
+ """Synchronous implementation of append_event."""
+ self._insert_event(event_record)
+
+ async def append_event(self, event_record: EventRecord) -> None:
+ """Append an event to a session."""
+ await async_(self._append_event)(event_record)
+
class CockroachPsycopgAsyncADKMemoryStore(BaseAsyncADKMemoryStore["CockroachPsycopgAsyncConfig"]):
"""CockroachDB ADK memory store using psycopg async driver."""
@@ -587,6 +672,13 @@ async def _get_create_memory_table_sql(self) -> str:
if self._owner_id_column_ddl:
owner_id_line = f",\n {self._owner_id_column_ddl}"
+ fts_index = ""
+ if self._use_fts:
+ fts_index = f"""
+ CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts
+ ON {self._memory_table} USING GIN (to_tsvector('english', content_text));
+ """
+
return f"""
CREATE TABLE IF NOT EXISTS {self._memory_table} (
id VARCHAR(128) PRIMARY KEY,
@@ -607,6 +699,7 @@ async def _get_create_memory_table_sql(self) -> str:
CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session
ON {self._memory_table}(session_id);
+ {fts_index}
"""
def _get_drop_memory_table_sql(self) -> "list[str]":
@@ -675,16 +768,25 @@ async def search_entries(
return []
effective_limit = limit if limit is not None else self._max_results
+
if self._use_fts:
- logger.debug("CockroachDB full-text search not supported; using simple search")
+ sql = f"""
+ SELECT * FROM {self._memory_table}
+ WHERE app_name = %s AND user_id = %s
+ AND to_tsvector('english', content_text) @@ plainto_tsquery('english', %s)
+ ORDER BY timestamp DESC
+ LIMIT %s
+ """
+ else:
+ sql = f"""
+ SELECT * FROM {self._memory_table}
+ WHERE app_name = %s AND user_id = %s AND content_text ILIKE %s
+ ORDER BY timestamp DESC
+ LIMIT %s
+ """
- sql = f"""
- SELECT * FROM {self._memory_table}
- WHERE app_name = %s AND user_id = %s AND content_text ILIKE %s
- ORDER BY timestamp DESC
- LIMIT %s
- """
- params = (app_name, user_id, f"%{query}%", effective_limit)
+ search_param = query if self._use_fts else f"%{query}%"
+ params = (app_name, user_id, search_param, effective_limit)
try:
async with self._config.provide_connection() as conn, conn.cursor() as cur:
@@ -714,15 +816,15 @@ async def delete_entries_older_than(self, days: int) -> int:
sql = f"""
DELETE FROM {self._memory_table}
- WHERE inserted_at < (CURRENT_TIMESTAMP - INTERVAL %s DAY)
+ WHERE inserted_at < CURRENT_TIMESTAMP - INTERVAL '{days} days'
"""
async with self._config.provide_connection() as conn, conn.cursor() as cur:
- await cur.execute(sql.encode(), (days,))
+ await cur.execute(sql.encode())
await conn.commit()
return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0
-class CockroachPsycopgSyncADKMemoryStore(BaseSyncADKMemoryStore["CockroachPsycopgSyncConfig"]):
+class CockroachPsycopgSyncADKMemoryStore(BaseAsyncADKMemoryStore["CockroachPsycopgSyncConfig"]):
"""CockroachDB ADK memory store using psycopg sync driver."""
__slots__ = ()
@@ -730,11 +832,18 @@ class CockroachPsycopgSyncADKMemoryStore(BaseSyncADKMemoryStore["CockroachPsycop
def __init__(self, config: "CockroachPsycopgSyncConfig") -> None:
super().__init__(config)
- def _get_create_memory_table_sql(self) -> str:
+ async def _get_create_memory_table_sql(self) -> str:
owner_id_line = ""
if self._owner_id_column_ddl:
owner_id_line = f",\n {self._owner_id_column_ddl}"
+ fts_index = ""
+ if self._use_fts:
+ fts_index = f"""
+ CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts
+ ON {self._memory_table} USING GIN (to_tsvector('english', content_text));
+ """
+
return f"""
CREATE TABLE IF NOT EXISTS {self._memory_table} (
id VARCHAR(128) PRIMARY KEY,
@@ -755,19 +864,24 @@ def _get_create_memory_table_sql(self) -> str:
CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session
ON {self._memory_table}(session_id);
+ {fts_index}
"""
def _get_drop_memory_table_sql(self) -> "list[str]":
return [f"DROP TABLE IF EXISTS {self._memory_table}"]
- def create_tables(self) -> None:
+ def _create_tables(self) -> None:
if not self._enabled:
return
with self._config.provide_session() as driver:
- driver.execute_script(self._get_create_memory_table_sql())
+ driver.execute_script(run_(self._get_create_memory_table_sql)())
+
+ async def create_tables(self) -> None:
+ """Create tables if they don't exist."""
+ await async_(self._create_tables)()
- def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
+ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
if not self._enabled:
msg = "Memory store is disabled"
raise RuntimeError(msg)
@@ -810,7 +924,11 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object
inserted_count += cur.rowcount
return inserted_count
- def search_entries(
+ async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
+ """Bulk insert memory entries with deduplication."""
+ return await async_(self._insert_memory_entries)(entries, owner_id)
+
+ def _search_entries(
self, query: str, app_name: str, user_id: str, limit: "int | None" = None
) -> "list[MemoryRecord]":
if not self._enabled:
@@ -821,16 +939,25 @@ def search_entries(
return []
effective_limit = limit if limit is not None else self._max_results
+
if self._use_fts:
- logger.debug("CockroachDB full-text search not supported; using simple search")
+ sql = f"""
+ SELECT * FROM {self._memory_table}
+ WHERE app_name = %s AND user_id = %s
+ AND to_tsvector('english', content_text) @@ plainto_tsquery('english', %s)
+ ORDER BY timestamp DESC
+ LIMIT %s
+ """
+ else:
+ sql = f"""
+ SELECT * FROM {self._memory_table}
+ WHERE app_name = %s AND user_id = %s AND content_text ILIKE %s
+ ORDER BY timestamp DESC
+ LIMIT %s
+ """
- sql = f"""
- SELECT * FROM {self._memory_table}
- WHERE app_name = %s AND user_id = %s AND content_text ILIKE %s
- ORDER BY timestamp DESC
- LIMIT %s
- """
- params = (app_name, user_id, f"%{query}%", effective_limit)
+ search_param = query if self._use_fts else f"%{query}%"
+ params = (app_name, user_id, search_param, effective_limit)
try:
with self._config.provide_connection() as conn, conn.cursor() as cur:
@@ -842,7 +969,13 @@ def search_entries(
return [cast("MemoryRecord", dict(zip(columns, row, strict=False))) for row in rows]
- def delete_entries_by_session(self, session_id: str) -> int:
+ async def search_entries(
+ self, query: str, app_name: str, user_id: str, limit: "int | None" = None
+ ) -> "list[MemoryRecord]":
+ """Search memory entries by text query."""
+ return await async_(self._search_entries)(query, app_name, user_id, limit)
+
+ def _delete_entries_by_session(self, session_id: str) -> int:
if not self._enabled:
msg = "Memory store is disabled"
raise RuntimeError(msg)
@@ -853,16 +986,24 @@ def delete_entries_by_session(self, session_id: str) -> int:
conn.commit()
return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0
- def delete_entries_older_than(self, days: int) -> int:
+ async def delete_entries_by_session(self, session_id: str) -> int:
+ """Delete all memory entries for a specific session."""
+ return await async_(self._delete_entries_by_session)(session_id)
+
+ def _delete_entries_older_than(self, days: int) -> int:
if not self._enabled:
msg = "Memory store is disabled"
raise RuntimeError(msg)
sql = f"""
DELETE FROM {self._memory_table}
- WHERE inserted_at < (CURRENT_TIMESTAMP - INTERVAL %s DAY)
+ WHERE inserted_at < CURRENT_TIMESTAMP - INTERVAL '{days} days'
"""
with self._config.provide_connection() as conn, conn.cursor() as cur:
- cur.execute(sql.encode(), (days,))
+ cur.execute(sql.encode())
conn.commit()
return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0
+
+ async def delete_entries_older_than(self, days: int) -> int:
+ """Delete memory entries older than specified days."""
+ return await async_(self._delete_entries_older_than)(days)
diff --git a/sqlspec/adapters/duckdb/adk/store.py b/sqlspec/adapters/duckdb/adk/store.py
index 4db753b58..9c8d00715 100644
--- a/sqlspec/adapters/duckdb/adk/store.py
+++ b/sqlspec/adapters/duckdb/adk/store.py
@@ -16,10 +16,11 @@
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Final, cast
-from sqlspec.extensions.adk import BaseSyncADKStore, EventRecord, SessionRecord
-from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore
+from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord
+from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore
from sqlspec.utils.logging import get_logger
from sqlspec.utils.serializers import from_json, to_json
+from sqlspec.utils.sync_tools import async_
if TYPE_CHECKING:
from sqlspec.adapters.duckdb.config import DuckDBConfig
@@ -33,15 +34,16 @@
DUCKDB_TABLE_NOT_FOUND_ERROR: Final = "does not exist"
-class DuckdbADKStore(BaseSyncADKStore["DuckDBConfig"]):
+class DuckdbADKStore(BaseAsyncADKStore["DuckDBConfig"]):
"""DuckDB ADK store for Google Agent Development Kit.
Implements session and event storage for Google Agent Development Kit
- using DuckDB's synchronous driver. Provides:
+ using DuckDB's synchronous driver with async wrappers via ``async_()``.
+ Provides:
- Session state management with native JSON type
- - Event history tracking with BLOB-serialized actions
- - Native TIMESTAMP type support
- - Foreign key constraints (manual cascade in delete_session)
+ - Event history with single JSON blob (event_json) plus indexed scalars
+ - Native TIMESTAMPTZ type support
+ - Manual cascade delete (DuckDB has no FK CASCADE)
- Columnar storage for analytical queries
Args:
@@ -62,20 +64,12 @@ class DuckdbADKStore(BaseSyncADKStore["DuckDBConfig"]):
}
)
store = DuckdbADKStore(config)
- store.ensure_tables()
-
- session = store.create_session(
- session_id="session-123",
- app_name="my-app",
- user_id="user-456",
- state={"context": "conversation"}
- )
+ await store.ensure_tables()
Notes:
- - Uses DuckDB native JSON type (not JSONB)
- - TIMESTAMP for date/time storage with microsecond precision
- - BLOB for binary actions data
- - BOOLEAN native type support
+ - Uses DuckDB native JSON type for event_json and state
+ - TIMESTAMPTZ for date/time storage with microsecond precision
+ - event_json stores the full ADK Event as a single JSON blob
- Columnar storage provides excellent analytical query performance
- DuckDB doesn't support CASCADE in foreign keys (manual cascade required)
- Optimized for OLAP workloads; for high-concurrency writes use PostgreSQL
@@ -98,7 +92,7 @@ def __init__(self, config: "DuckDBConfig") -> None:
"""
super().__init__(config)
- def _get_create_sessions_table_sql(self) -> str:
+ async def _get_create_sessions_table_sql(self) -> str:
"""Get DuckDB CREATE TABLE SQL for sessions.
Returns:
@@ -107,7 +101,7 @@ def _get_create_sessions_table_sql(self) -> str:
Notes:
- VARCHAR for IDs and names
- JSON type for state storage (DuckDB native)
- - TIMESTAMP for create_time and update_time
+ - TIMESTAMPTZ for create_time and update_time
- CURRENT_TIMESTAMP for defaults
- Optional owner ID column for multi-tenant scenarios
- Composite index on (app_name, user_id) for listing
@@ -123,48 +117,34 @@ def _get_create_sessions_table_sql(self) -> str:
app_name VARCHAR NOT NULL,
user_id VARCHAR NOT NULL{owner_id_line},
state JSON NOT NULL,
- create_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
- update_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
+ create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user ON {self._session_table}(app_name, user_id);
CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time ON {self._session_table}(update_time DESC);
"""
- def _get_create_events_table_sql(self) -> str:
+ async def _get_create_events_table_sql(self) -> str:
"""Get DuckDB CREATE TABLE SQL for events.
Returns:
SQL statement to create adk_events table with indexes.
Notes:
- - VARCHAR for string fields
- - BLOB for pickled actions
- - JSON for content, grounding_metadata, custom_metadata, long_running_tool_ids_json
- - BOOLEAN for flags
+ - 5-column schema: session_id, invocation_id, author, timestamp, event_json
+ - event_json stores the full ADK Event as a single JSON blob
+ - No decomposed columns -- eliminates column drift with upstream ADK
- Foreign key constraint (DuckDB doesn't support CASCADE)
- Index on (session_id, timestamp ASC) for ordered event retrieval
- Manual cascade delete required in delete_session method
"""
return f"""
CREATE TABLE IF NOT EXISTS {self._events_table} (
- id VARCHAR PRIMARY KEY,
session_id VARCHAR NOT NULL,
- app_name VARCHAR NOT NULL,
- user_id VARCHAR NOT NULL,
- invocation_id VARCHAR,
- author VARCHAR,
- actions BLOB,
- long_running_tool_ids_json JSON,
- branch VARCHAR,
- timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
- content JSON,
- grounding_metadata JSON,
- custom_metadata JSON,
- partial BOOLEAN,
- turn_complete BOOLEAN,
- interrupted BOOLEAN,
- error_code VARCHAR,
- error_message VARCHAR,
+ invocation_id VARCHAR NOT NULL,
+ author VARCHAR NOT NULL,
+ timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ event_json JSON NOT NULL,
FOREIGN KEY (session_id) REFERENCES {self._session_table}(id)
);
CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC);
@@ -182,31 +162,53 @@ def _get_drop_tables_sql(self) -> "list[str]":
"""
return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"]
- def create_tables(self) -> None:
- """Create both sessions and events tables if they don't exist."""
+ def _create_tables(self) -> None:
+ """Synchronous implementation of create_tables."""
with self._config.provide_connection() as conn:
- conn.execute(self._get_create_sessions_table_sql())
- conn.execute(self._get_create_events_table_sql())
-
- def create_session(
- self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
- ) -> SessionRecord:
- """Create a new session.
+ conn.execute(self.__get_create_sessions_table_sql_sync())
+ conn.execute(self.__get_create_events_table_sql_sync())
- Args:
- session_id: Unique session identifier.
- app_name: Application name.
- user_id: User identifier.
- state: Initial session state.
- owner_id: Optional owner ID value for owner_id_column (if configured).
+ def __get_create_sessions_table_sql_sync(self) -> str:
+ """Synchronous version of DDL generation for use in _create_tables."""
+ owner_id_line = ""
+ if self._owner_id_column_ddl:
+ owner_id_line = f",\n {self._owner_id_column_ddl}"
- Returns:
- Created session record.
+ return f"""
+ CREATE TABLE IF NOT EXISTS {self._session_table} (
+ id VARCHAR PRIMARY KEY,
+ app_name VARCHAR NOT NULL,
+ user_id VARCHAR NOT NULL{owner_id_line},
+ state JSON NOT NULL,
+ create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
+ );
+ CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user ON {self._session_table}(app_name, user_id);
+ CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time ON {self._session_table}(update_time DESC);
+ """
- Notes:
- Uses current UTC timestamp for create_time and update_time.
- State is JSON-serialized using SQLSpec serializers.
+ def __get_create_events_table_sql_sync(self) -> str:
+ """Synchronous version of DDL generation for use in _create_tables."""
+ return f"""
+ CREATE TABLE IF NOT EXISTS {self._events_table} (
+ session_id VARCHAR NOT NULL,
+ invocation_id VARCHAR NOT NULL,
+ author VARCHAR NOT NULL,
+ timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ event_json JSON NOT NULL,
+ FOREIGN KEY (session_id) REFERENCES {self._session_table}(id)
+ );
+ CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC);
"""
+
+ async def create_tables(self) -> None:
+ """Create both sessions and events tables if they don't exist."""
+ await async_(self._create_tables)()
+
+ def _create_session(
+ self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
+ ) -> SessionRecord:
+ """Synchronous implementation of create_session."""
now = datetime.now(timezone.utc)
state_json = to_json(state)
@@ -233,19 +235,29 @@ def create_session(
id=session_id, app_name=app_name, user_id=user_id, state=state, create_time=now, update_time=now
)
- def get_session(self, session_id: str) -> "SessionRecord | None":
- """Get session by ID.
+ async def create_session(
+ self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
+ ) -> SessionRecord:
+ """Create a new session.
Args:
- session_id: Session identifier.
+ session_id: Unique session identifier.
+ app_name: Application name.
+ user_id: User identifier.
+ state: Initial session state.
+ owner_id: Optional owner ID value for owner_id_column (if configured).
Returns:
- Session record or None if not found.
+ Created session record.
Notes:
- DuckDB returns datetime objects for TIMESTAMP columns.
- JSON is parsed from database storage.
+ Uses current UTC timestamp for create_time and update_time.
+ State is JSON-serialized using SQLSpec serializers.
"""
+ return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id)
+
+ def _get_session(self, session_id: str) -> "SessionRecord | None":
+ """Synchronous implementation of get_session."""
sql = f"""
SELECT id, app_name, user_id, state, create_time, update_time
FROM {self._session_table}
@@ -277,17 +289,23 @@ def get_session(self, session_id: str) -> "SessionRecord | None":
return None
raise
- def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
- """Update session state.
+ async def get_session(self, session_id: str) -> "SessionRecord | None":
+ """Get session by ID.
Args:
session_id: Session identifier.
- state: New state dictionary (replaces existing state).
+
+ Returns:
+ Session record or None if not found.
Notes:
- This replaces the entire state dictionary.
- Update time is automatically set to current UTC timestamp.
+ DuckDB returns datetime objects for TIMESTAMPTZ columns.
+ JSON is parsed from database storage.
"""
+ return await async_(self._get_session)(session_id)
+
+ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
+ """Synchronous implementation of update_session_state."""
now = datetime.now(timezone.utc)
state_json = to_json(state)
@@ -301,15 +319,21 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None
conn.execute(sql, (state_json, now, session_id))
conn.commit()
- def delete_session(self, session_id: str) -> None:
- """Delete session and all associated events.
+ async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
+ """Update session state.
Args:
session_id: Session identifier.
+ state: New state dictionary (replaces existing state).
Notes:
- DuckDB doesn't support CASCADE in foreign keys, so we manually delete events first.
+ This replaces the entire state dictionary.
+ Update time is automatically set to current UTC timestamp.
"""
+ await async_(self._update_session_state)(session_id, state)
+
+ def _delete_session(self, session_id: str) -> None:
+ """Synchronous implementation of delete_session."""
delete_events_sql = f"DELETE FROM {self._events_table} WHERE session_id = ?"
delete_session_sql = f"DELETE FROM {self._session_table} WHERE id = ?"
@@ -318,19 +342,19 @@ def delete_session(self, session_id: str) -> None:
conn.execute(delete_session_sql, (session_id,))
conn.commit()
- def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
- """List sessions for an app, optionally filtered by user.
+ async def delete_session(self, session_id: str) -> None:
+ """Delete session and all associated events.
Args:
- app_name: Application name.
- user_id: User identifier. If None, lists all sessions for the app.
-
- Returns:
- List of session records ordered by update_time DESC.
+ session_id: Session identifier.
Notes:
- Uses composite index on (app_name, user_id) when user_id is provided.
+ DuckDB doesn't support CASCADE in foreign keys, so we manually delete events first.
"""
+ await async_(self._delete_session)(session_id)
+
+ def _list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]":
+ """Synchronous implementation of list_sessions."""
if user_id is None:
sql = f"""
SELECT id, app_name, user_id, state, create_time, update_time
@@ -369,194 +393,136 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess
return []
raise
- def create_event(
- self,
- event_id: str,
- session_id: str,
- app_name: str,
- user_id: str,
- author: "str | None" = None,
- actions: "bytes | None" = None,
- content: "dict[str, Any] | None" = None,
- **kwargs: Any,
- ) -> EventRecord:
- """Create a new event.
+ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
+ """List sessions for an app, optionally filtered by user.
Args:
- event_id: Unique event identifier.
- session_id: Session identifier.
app_name: Application name.
- user_id: User identifier.
- author: Event author (user/assistant/system).
- actions: Pickled actions object.
- content: Event content (JSON).
- **kwargs: Additional optional fields.
+ user_id: User identifier. If None, lists all sessions for the app.
Returns:
- Created event record.
+ List of session records ordered by update_time DESC.
Notes:
- Uses current UTC timestamp if not provided in kwargs.
- JSON fields are serialized using SQLSpec serializers.
+ Uses composite index on (app_name, user_id) when user_id is provided.
"""
- timestamp = kwargs.get("timestamp", datetime.now(timezone.utc))
- content_json = to_json(content) if content else None
- grounding_metadata = kwargs.get("grounding_metadata")
- grounding_metadata_json = to_json(grounding_metadata) if grounding_metadata else None
- custom_metadata = kwargs.get("custom_metadata")
- custom_metadata_json = to_json(custom_metadata) if custom_metadata else None
+ return await async_(self._list_sessions)(app_name, user_id)
+
+ def _append_event(self, event_record: EventRecord) -> None:
+ """Synchronous implementation of append_event."""
+ event_json_str = to_json(event_record["event_json"])
sql = f"""
- INSERT INTO {self._events_table} (
- id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
- ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+ INSERT INTO {self._events_table}
+ (session_id, invocation_id, author, timestamp, event_json)
+ VALUES (?, ?, ?, ?, ?)
"""
with self._config.provide_connection() as conn:
conn.execute(
sql,
(
- event_id,
- session_id,
- app_name,
- user_id,
- kwargs.get("invocation_id"),
- author,
- actions,
- kwargs.get("long_running_tool_ids_json"),
- kwargs.get("branch"),
- timestamp,
- content_json,
- grounding_metadata_json,
- custom_metadata_json,
- kwargs.get("partial"),
- kwargs.get("turn_complete"),
- kwargs.get("interrupted"),
- kwargs.get("error_code"),
- kwargs.get("error_message"),
+ event_record["session_id"],
+ event_record["invocation_id"],
+ event_record["author"],
+ event_record["timestamp"],
+ event_json_str,
),
)
conn.commit()
- return EventRecord(
- id=event_id,
- session_id=session_id,
- app_name=app_name,
- user_id=user_id,
- invocation_id=kwargs.get("invocation_id", ""),
- author=author or "",
- actions=actions or b"",
- long_running_tool_ids_json=kwargs.get("long_running_tool_ids_json"),
- branch=kwargs.get("branch"),
- timestamp=timestamp,
- content=content,
- grounding_metadata=grounding_metadata,
- custom_metadata=custom_metadata,
- partial=kwargs.get("partial"),
- turn_complete=kwargs.get("turn_complete"),
- interrupted=kwargs.get("interrupted"),
- error_code=kwargs.get("error_code"),
- error_message=kwargs.get("error_message"),
- )
-
- def get_event(self, event_id: str) -> "EventRecord | None":
- """Get event by ID.
+ async def append_event(self, event_record: EventRecord) -> None:
+ """Append an event to a session.
Args:
- event_id: Event identifier.
+ event_record: Event record with 5 keys (session_id, invocation_id,
+ author, timestamp, event_json).
+ """
+ await async_(self._append_event)(event_record)
- Returns:
- Event record or None if not found.
+ def _append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ """Synchronous implementation of append_event_and_update_state."""
+ now = datetime.now(timezone.utc)
+ state_json = to_json(state)
+ event_json_str = to_json(event_record["event_json"])
+
+ insert_sql = f"""
+ INSERT INTO {self._events_table}
+ (session_id, invocation_id, author, timestamp, event_json)
+ VALUES (?, ?, ?, ?, ?)
"""
- sql = f"""
- SELECT id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
- FROM {self._events_table}
+
+ update_sql = f"""
+ UPDATE {self._session_table}
+ SET state = ?, update_time = ?
WHERE id = ?
"""
- try:
- with self._config.provide_connection() as conn:
- cursor = conn.execute(sql, (event_id,))
- row = cursor.fetchone()
-
- if row is None:
- return None
+ with self._config.provide_connection() as conn:
+ conn.execute(
+ insert_sql,
+ (
+ event_record["session_id"],
+ event_record["invocation_id"],
+ event_record["author"],
+ event_record["timestamp"],
+ event_json_str,
+ ),
+ )
+ conn.execute(update_sql, (state_json, now, session_id))
+ conn.commit()
- return EventRecord(
- id=row[0],
- session_id=row[1],
- app_name=row[2],
- user_id=row[3],
- invocation_id=row[4],
- author=row[5],
- actions=bytes(row[6]) if row[6] else b"",
- long_running_tool_ids_json=row[7],
- branch=row[8],
- timestamp=row[9],
- content=from_json(row[10]) if row[10] else None,
- grounding_metadata=from_json(row[11]) if row[11] else None,
- custom_metadata=from_json(row[12]) if row[12] else None,
- partial=row[13],
- turn_complete=row[14],
- interrupted=row[15],
- error_code=row[16],
- error_message=row[17],
- )
- except Exception as e:
- if DUCKDB_TABLE_NOT_FOUND_ERROR in str(e):
- return None
- raise
+ async def append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ """Atomically append an event and update the session's durable state.
- def list_events(self, session_id: str) -> "list[EventRecord]":
- """List events for a session ordered by timestamp.
+ The event insert and state update succeed together or fail together
+ within a single DuckDB transaction.
Args:
- session_id: Session identifier.
-
- Returns:
- List of event records ordered by timestamp ASC.
+ event_record: Event record to store (5-key shape).
+ session_id: Session identifier whose state should be updated.
+ state: Post-append durable state snapshot (``temp:`` keys already
+ stripped by the service layer).
"""
+ await async_(self._append_event_and_update_state)(event_record, session_id, state)
+
+ def _get_events(
+ self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
+ ) -> "list[EventRecord]":
+ """Synchronous implementation of get_events."""
+ where_clauses = ["session_id = ?"]
+ params: list[Any] = [session_id]
+
+ if after_timestamp is not None:
+ where_clauses.append("timestamp > ?")
+ params.append(after_timestamp)
+
+ where_clause = " AND ".join(where_clauses)
+ limit_clause = f" LIMIT {limit}" if limit else ""
+
sql = f"""
- SELECT id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
+ SELECT session_id, invocation_id, author, timestamp, event_json
FROM {self._events_table}
- WHERE session_id = ?
- ORDER BY timestamp ASC
+ WHERE {where_clause}
+ ORDER BY timestamp ASC{limit_clause}
"""
try:
with self._config.provide_connection() as conn:
- cursor = conn.execute(sql, (session_id,))
+ cursor = conn.execute(sql, params)
rows = cursor.fetchall()
return [
EventRecord(
- id=row[0],
- session_id=row[1],
- app_name=row[2],
- user_id=row[3],
- invocation_id=row[4],
- author=row[5],
- actions=bytes(row[6]) if row[6] else b"",
- long_running_tool_ids_json=row[7],
- branch=row[8],
- timestamp=row[9],
- content=from_json(row[10]) if row[10] else None,
- grounding_metadata=from_json(row[11]) if row[11] else None,
- custom_metadata=from_json(row[12]) if row[12] else None,
- partial=row[13],
- turn_complete=row[14],
- interrupted=row[15],
- error_code=row[16],
- error_message=row[17],
+ session_id=row[0],
+ invocation_id=row[1],
+ author=row[2],
+ timestamp=row[3],
+ event_json=from_json(row[4]) if isinstance(row[4], str) else row[4],
)
for row in rows
]
@@ -565,14 +531,30 @@ def list_events(self, session_id: str) -> "list[EventRecord]":
return []
raise
+ async def get_events(
+ self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
+ ) -> "list[EventRecord]":
+ """Get events for a session.
-class DuckdbADKMemoryStore(BaseSyncADKMemoryStore["DuckDBConfig"]):
- """DuckDB ADK memory store using synchronous DuckDB driver.
+ Args:
+ session_id: Session identifier.
+ after_timestamp: Only return events after this time.
+ limit: Maximum number of events to return.
+
+ Returns:
+ List of event records ordered by timestamp ASC.
+ """
+ return await async_(self._get_events)(session_id, after_timestamp, limit)
+
+
+class DuckdbADKMemoryStore(BaseAsyncADKMemoryStore["DuckDBConfig"]):
+ """DuckDB ADK memory store using synchronous DuckDB driver with async wrappers.
Implements memory entry storage for Google Agent Development Kit
- using DuckDB's synchronous driver. Provides:
+ using DuckDB's synchronous driver with async wrappers via ``async_()``.
+ Provides:
- Session memory storage with native JSON type
- - Simple ILIKE search
+ - Simple ILIKE search or BM25 full-text search via FTS extension
- Native TIMESTAMP type support
- Deduplication via event_id unique constraint
- Efficient upserts using INSERT OR IGNORE
@@ -595,13 +577,15 @@ class DuckdbADKMemoryStore(BaseSyncADKMemoryStore["DuckDBConfig"]):
}
)
store = DuckdbADKMemoryStore(config)
- store.ensure_tables()
+ await store.ensure_tables()
Notes:
- Uses DuckDB native JSON type (not JSONB)
- TIMESTAMP for date/time storage with microsecond precision
- event_id UNIQUE constraint for deduplication
- Composite index on (app_name, user_id, timestamp DESC)
+ - FTS uses match_bm25() for BM25-ranked results (not @@ operator)
+ - FTS index is refreshed after inserts, not on every search
- Columnar storage provides excellent analytical query performance
- Optimized for OLAP workloads; for high-concurrency writes use PostgreSQL
- Configuration is read from config.extension_config["adk"]
@@ -644,12 +628,19 @@ def _create_fts_index(self, conn: Any) -> None:
return
try:
- conn.execute(f"PRAGMA create_fts_index('{self._memory_table}', 'id', 'content_text')")
+ conn.execute(
+ f"PRAGMA create_fts_index('{self._memory_table}', 'id', 'content_text', "
+ f"stemmer='porter', stopwords='english', strip_accents=1, lower=1)"
+ )
except Exception as exc:
logger.debug("Failed to create DuckDB FTS index: %s", exc)
def _refresh_fts_index(self, conn: Any) -> None:
- """Rebuild the FTS index to reflect recent changes."""
+ """Rebuild the FTS index to reflect recent inserts.
+
+ DuckDB FTS indexes do not auto-update. This must be called after
+ insert/update/delete operations, NOT on every search.
+ """
if not self._ensure_fts_extension(conn):
return
@@ -657,11 +648,14 @@ def _refresh_fts_index(self, conn: Any) -> None:
conn.execute(f"PRAGMA drop_fts_index('{self._memory_table}')")
try:
- conn.execute(f"PRAGMA create_fts_index('{self._memory_table}', 'id', 'content_text')")
+ conn.execute(
+ f"PRAGMA create_fts_index('{self._memory_table}', 'id', 'content_text', "
+ f"overwrite=1, stemmer='porter', stopwords='english', strip_accents=1, lower=1)"
+ )
except Exception as exc:
logger.debug("Failed to refresh DuckDB FTS index: %s", exc)
- def _get_create_memory_table_sql(self) -> str:
+ async def _get_create_memory_table_sql(self) -> str:
"""Get DuckDB CREATE TABLE SQL for memory entries.
Returns:
@@ -697,18 +691,51 @@ def _get_drop_memory_table_sql(self) -> "list[str]":
"""Get DuckDB DROP TABLE SQL statements."""
return [f"DROP TABLE IF EXISTS {self._memory_table}"]
- def create_tables(self) -> None:
- """Create the memory table and indexes if they don't exist."""
+ def _create_tables(self) -> None:
+ """Synchronous implementation of create_tables."""
if not self._enabled:
return
+ ddl = self.__get_create_memory_table_sql_sync()
with self._config.provide_connection() as conn:
- conn.execute(self._get_create_memory_table_sql())
+ conn.execute(ddl)
if self._use_fts:
self._create_fts_index(conn)
- def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
- """Bulk insert memory entries with deduplication."""
+ def __get_create_memory_table_sql_sync(self) -> str:
+ """Synchronous version of DDL generation for use in _create_tables."""
+ owner_id_line = ""
+ if self._owner_id_column_ddl:
+ owner_id_line = f",\n {self._owner_id_column_ddl}"
+
+ return f"""
+ CREATE TABLE IF NOT EXISTS {self._memory_table} (
+ id VARCHAR(128) PRIMARY KEY,
+ session_id VARCHAR(128) NOT NULL,
+ app_name VARCHAR(128) NOT NULL,
+ user_id VARCHAR(128) NOT NULL,
+ event_id VARCHAR(128) NOT NULL UNIQUE,
+ author VARCHAR(256){owner_id_line},
+ timestamp TIMESTAMP NOT NULL,
+ content_json JSON NOT NULL,
+ content_text TEXT NOT NULL,
+ metadata_json JSON,
+ inserted_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
+ );
+
+ CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time
+ ON {self._memory_table}(app_name, user_id, timestamp DESC);
+
+ CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session
+ ON {self._memory_table}(session_id);
+ """
+
+ async def create_tables(self) -> None:
+ """Create the memory table and indexes if they don't exist."""
+ await async_(self._create_tables)()
+
+ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
+ """Synchronous implementation of insert_memory_entries."""
if not self._enabled:
msg = "Memory store is disabled"
raise RuntimeError(msg)
@@ -770,12 +797,24 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object
result = conn.execute(sql, params)
inserted_count += len(result.fetchall())
conn.commit()
+
+ # Refresh FTS index after inserts, not on search
+ if self._use_fts and inserted_count > 0:
+ self._refresh_fts_index(conn)
+
return inserted_count
- def search_entries(
+ async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
+ """Bulk insert memory entries with deduplication.
+
+ After successful inserts, refreshes the FTS index if FTS is enabled.
+ """
+ return await async_(self._insert_memory_entries)(entries, owner_id)
+
+ def _search_entries(
self, query: str, app_name: str, user_id: str, limit: "int | None" = None
) -> "list[MemoryRecord]":
- """Search memory entries by text query."""
+ """Synchronous implementation of search_entries."""
if not self._enabled:
msg = "Memory store is disabled"
raise RuntimeError(msg)
@@ -785,13 +824,19 @@ def search_entries(
limit_value = limit or self._max_results
if self._use_fts:
+ # Use match_bm25() -- the correct DuckDB FTS syntax
sql = f"""
- SELECT * FROM {self._memory_table}
- WHERE app_name = ? AND user_id = ? AND content_text @@ ?
- ORDER BY timestamp DESC
+ SELECT m.*
+ FROM {self._memory_table} m
+ JOIN (
+ SELECT id, fts_main_{self._memory_table}.match_bm25(id, ?, fields := 'content_text') AS score
+ FROM {self._memory_table}
+ ) fts ON m.id = fts.id
+ WHERE m.app_name = ? AND m.user_id = ? AND fts.score IS NOT NULL
+ ORDER BY fts.score DESC
LIMIT ?
"""
- params = (app_name, user_id, query, limit_value)
+ params = (query, app_name, user_id, limit_value)
else:
sql = f"""
SELECT * FROM {self._memory_table}
@@ -814,13 +859,20 @@ def search_entries(
if isinstance(metadata_value, (str, bytes)):
record["metadata_json"] = from_json(metadata_value)
records.append(record)
- if self._use_fts:
- with self._config.provide_connection() as conn:
- self._refresh_fts_index(conn)
return records
- def delete_entries_by_session(self, session_id: str) -> int:
- """Delete all memory entries for a specific session."""
+ async def search_entries(
+ self, query: str, app_name: str, user_id: str, limit: "int | None" = None
+ ) -> "list[MemoryRecord]":
+ """Search memory entries by text query.
+
+ When FTS is enabled, uses ``match_bm25()`` for BM25-ranked results.
+ Falls back to ILIKE for simple substring matching.
+ """
+ return await async_(self._search_entries)(query, app_name, user_id, limit)
+
+ def _delete_entries_by_session(self, session_id: str) -> int:
+ """Synchronous implementation of delete_entries_by_session."""
if not self._enabled:
msg = "Memory store is disabled"
raise RuntimeError(msg)
@@ -830,10 +882,16 @@ def delete_entries_by_session(self, session_id: str) -> int:
result = conn.execute(sql, (session_id,))
deleted_count = len(result.fetchall())
conn.commit()
+ if self._use_fts and deleted_count > 0:
+ self._refresh_fts_index(conn)
return deleted_count
- def delete_entries_older_than(self, days: int) -> int:
- """Delete memory entries older than specified days."""
+ async def delete_entries_by_session(self, session_id: str) -> int:
+ """Delete all memory entries for a specific session."""
+ return await async_(self._delete_entries_by_session)(session_id)
+
+ def _delete_entries_older_than(self, days: int) -> int:
+ """Synchronous implementation of delete_entries_older_than."""
if not self._enabled:
msg = "Memory store is disabled"
raise RuntimeError(msg)
@@ -847,4 +905,10 @@ def delete_entries_older_than(self, days: int) -> int:
result = conn.execute(sql)
deleted_count = len(result.fetchall())
conn.commit()
+ if self._use_fts and deleted_count > 0:
+ self._refresh_fts_index(conn)
return deleted_count
+
+ async def delete_entries_older_than(self, days: int) -> int:
+ """Delete memory entries older than specified days."""
+ return await async_(self._delete_entries_older_than)(days)
diff --git a/sqlspec/adapters/mysqlconnector/adk/store.py b/sqlspec/adapters/mysqlconnector/adk/store.py
index 3aed6258e..1a25702f7 100644
--- a/sqlspec/adapters/mysqlconnector/adk/store.py
+++ b/sqlspec/adapters/mysqlconnector/adk/store.py
@@ -5,9 +5,10 @@
import mysql.connector
-from sqlspec.extensions.adk import BaseAsyncADKStore, BaseSyncADKStore, EventRecord, SessionRecord
-from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore
+from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord
+from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore
from sqlspec.utils.serializers import from_json, to_json
+from sqlspec.utils.sync_tools import async_, run_
if TYPE_CHECKING:
from datetime import datetime
@@ -38,8 +39,61 @@ def _parse_owner_id_column_for_mysql(column_ddl: str) -> "tuple[str, str]":
return (col_def, fk_constraint)
+def _mysql_sessions_ddl(session_table: str, owner_id_column_ddl: "str | None") -> str:
+ """Generate shared MySQL sessions CREATE TABLE DDL."""
+ owner_id_col = ""
+ fk_constraint = ""
+
+ if owner_id_column_ddl:
+ col_def, fk_def = _parse_owner_id_column_for_mysql(owner_id_column_ddl)
+ owner_id_col = f"{col_def},"
+ if fk_def:
+ fk_constraint = f",\n {fk_def}"
+
+ return f"""
+ CREATE TABLE IF NOT EXISTS {session_table} (
+ id VARCHAR(128) PRIMARY KEY,
+ app_name VARCHAR(128) NOT NULL,
+ user_id VARCHAR(128) NOT NULL,
+ {owner_id_col}
+ state JSON NOT NULL,
+ create_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
+ update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6),
+ INDEX idx_{session_table}_app_user (app_name, user_id),
+ INDEX idx_{session_table}_update_time (update_time DESC){fk_constraint}
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
+ """
+
+
+def _mysql_events_ddl(events_table: str, session_table: str) -> str:
+ """Generate shared MySQL events CREATE TABLE DDL (post clean-break, 5 columns)."""
+ return f"""
+ CREATE TABLE IF NOT EXISTS {events_table} (
+ session_id VARCHAR(128) NOT NULL,
+ invocation_id VARCHAR(256) NOT NULL,
+ author VARCHAR(128) NOT NULL,
+ timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
+ event_json JSON NOT NULL,
+ FOREIGN KEY (session_id) REFERENCES {session_table}(id) ON DELETE CASCADE,
+ INDEX idx_{events_table}_session (session_id, timestamp ASC)
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
+ """
+
+
class MysqlConnectorAsyncADKStore(BaseAsyncADKStore["MysqlConnectorAsyncConfig"]):
- """MySQL/MariaDB ADK store using mysql-connector async driver."""
+ """MySQL/MariaDB ADK store using mysql-connector async driver.
+
+ Provides:
+ - Session state management with JSON storage
+ - Full-event JSON storage (single ``event_json`` column)
+ - Atomic event-append + state-update in one transaction
+ - Microsecond-precision timestamps
+ - Foreign key constraints with cascade delete
+
+ Notes:
+ - Uses ``cast()`` extensively because mysql-connector returns ``Any`` types
+ - Configuration is read from config.extension_config["adk"]
+ """
__slots__ = ()
@@ -50,54 +104,10 @@ def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]"
return _parse_owner_id_column_for_mysql(column_ddl)
async def _get_create_sessions_table_sql(self) -> str:
- owner_id_col = ""
- fk_constraint = ""
-
- if self._owner_id_column_ddl:
- col_def, fk_def = self._parse_owner_id_column_for_mysql(self._owner_id_column_ddl)
- owner_id_col = f"{col_def},"
- if fk_def:
- fk_constraint = f",\n {fk_def}"
-
- return f"""
- CREATE TABLE IF NOT EXISTS {self._session_table} (
- id VARCHAR(128) PRIMARY KEY,
- app_name VARCHAR(128) NOT NULL,
- user_id VARCHAR(128) NOT NULL,
- {owner_id_col}
- state JSON NOT NULL,
- create_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
- update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6),
- INDEX idx_{self._session_table}_app_user (app_name, user_id),
- INDEX idx_{self._session_table}_update_time (update_time DESC){fk_constraint}
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
- """
+ return _mysql_sessions_ddl(self._session_table, self._owner_id_column_ddl)
async def _get_create_events_table_sql(self) -> str:
- return f"""
- CREATE TABLE IF NOT EXISTS {self._events_table} (
- id VARCHAR(128) PRIMARY KEY,
- session_id VARCHAR(128) NOT NULL,
- app_name VARCHAR(128) NOT NULL,
- user_id VARCHAR(128) NOT NULL,
- invocation_id VARCHAR(256) NOT NULL,
- author VARCHAR(256) NOT NULL,
- actions BLOB NOT NULL,
- long_running_tool_ids_json JSON,
- branch VARCHAR(256),
- timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
- content JSON,
- grounding_metadata JSON,
- custom_metadata JSON,
- partial BOOLEAN,
- turn_complete BOOLEAN,
- interrupted BOOLEAN,
- error_code VARCHAR(256),
- error_message VARCHAR(1024),
- FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE,
- INDEX idx_{self._events_table}_session (session_id, timestamp ASC)
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
- """
+ return _mysql_events_ddl(self._events_table, self._session_table)
def _get_drop_tables_sql(self) -> "list[str]":
return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"]
@@ -242,23 +252,19 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis
raise
async def append_event(self, event_record: EventRecord) -> None:
- content_json = to_json(event_record.get("content")) if event_record.get("content") else None
- grounding_metadata_json = (
- to_json(event_record.get("grounding_metadata")) if event_record.get("grounding_metadata") else None
- )
- custom_metadata_json = (
- to_json(event_record.get("custom_metadata")) if event_record.get("custom_metadata") else None
- )
+ """Append an event to a session.
+
+ Args:
+ event_record: Event record with 5 keys (session_id, invocation_id,
+ author, timestamp, event_json).
+ """
+ event_json = event_record["event_json"]
+ event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json
sql = f"""
INSERT INTO {self._events_table} (
- id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
- ) VALUES (
- %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
- )
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES (%s, %s, %s, %s, %s)
"""
async with self._config.provide_connection() as conn:
@@ -267,26 +273,60 @@ async def append_event(self, event_record: EventRecord) -> None:
await cursor.execute(
sql,
(
- event_record["id"],
event_record["session_id"],
- event_record["app_name"],
- event_record["user_id"],
event_record["invocation_id"],
event_record["author"],
- event_record["actions"],
- event_record.get("long_running_tool_ids_json"),
- event_record.get("branch"),
event_record["timestamp"],
- content_json,
- grounding_metadata_json,
- custom_metadata_json,
- event_record.get("partial"),
- event_record.get("turn_complete"),
- event_record.get("interrupted"),
- event_record.get("error_code"),
- event_record.get("error_message"),
+ event_json_str,
+ ),
+ )
+ finally:
+ await cursor.close()
+ await conn.commit()
+
+ async def append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ """Atomically append an event and update the session's durable state.
+
+ The event insert and state update succeed together or fail together
+ within a single transaction.
+
+ Args:
+ event_record: Event record to store.
+ session_id: Session identifier whose state should be updated.
+ state: Post-append durable state snapshot.
+ """
+ event_json = event_record["event_json"]
+ event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json
+ state_json = to_json(state)
+
+ insert_sql = f"""
+ INSERT INTO {self._events_table} (
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES (%s, %s, %s, %s, %s)
+ """
+
+ update_sql = f"""
+ UPDATE {self._session_table}
+ SET state = %s
+ WHERE id = %s
+ """
+
+ async with self._config.provide_connection() as conn:
+ cursor = await conn.cursor()
+ try:
+ await cursor.execute(
+ insert_sql,
+ (
+ event_record["session_id"],
+ event_record["invocation_id"],
+ event_record["author"],
+ event_record["timestamp"],
+ event_json_str,
),
)
+ await cursor.execute(update_sql, (state_json, session_id))
finally:
await cursor.close()
await conn.commit()
@@ -294,6 +334,16 @@ async def append_event(self, event_record: EventRecord) -> None:
async def get_events(
self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
) -> "list[EventRecord]":
+ """Get events for a session.
+
+ Args:
+ session_id: Session identifier.
+ after_timestamp: Only return events after this time.
+ limit: Maximum number of events to return.
+
+ Returns:
+ List of event records ordered by timestamp ASC.
+ """
where_clauses = ["session_id = %s"]
params: list[Any] = [session_id]
@@ -305,10 +355,7 @@ async def get_events(
limit_clause = f" LIMIT {limit}" if limit else ""
sql = f"""
- SELECT id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
+ SELECT session_id, invocation_id, author, timestamp, event_json
FROM {self._events_table}
WHERE {where_clause}
ORDER BY timestamp ASC{limit_clause}
@@ -325,30 +372,11 @@ async def get_events(
return [
EventRecord(
- id=cast("str", row[0]),
- session_id=cast("str", row[1]),
- app_name=cast("str", row[2]),
- user_id=cast("str", row[3]),
- invocation_id=cast("str", row[4]),
- author=cast("str", row[5]),
- actions=bytes(cast("bytes", row[6])),
- long_running_tool_ids_json=cast("str | None", row[7]),
- branch=cast("str | None", row[8]),
- timestamp=cast("datetime", row[9]),
- content=from_json(row[10])
- if row[10] and isinstance(row[10], str)
- else cast("dict[str, Any] | None", row[10]),
- grounding_metadata=from_json(row[11])
- if row[11] and isinstance(row[11], str)
- else cast("dict[str, Any] | None", row[11]),
- custom_metadata=from_json(row[12])
- if row[12] and isinstance(row[12], str)
- else cast("dict[str, Any] | None", row[12]),
- partial=cast("bool | None", row[13]),
- turn_complete=cast("bool | None", row[14]),
- interrupted=cast("bool | None", row[15]),
- error_code=cast("str | None", row[16]),
- error_message=cast("str | None", row[17]),
+ session_id=cast("str", row[0]),
+ invocation_id=cast("str", row[1]),
+ author=cast("str", row[2]),
+ timestamp=cast("datetime", row[3]),
+ event_json=from_json(row[4]) if isinstance(row[4], str) else cast("dict[str, Any]", row[4]),
)
for row in rows
]
@@ -358,8 +386,20 @@ async def get_events(
raise
-class MysqlConnectorSyncADKStore(BaseSyncADKStore["MysqlConnectorSyncConfig"]):
- """MySQL/MariaDB ADK store using mysql-connector sync driver."""
+class MysqlConnectorSyncADKStore(BaseAsyncADKStore["MysqlConnectorSyncConfig"]):
+ """MySQL/MariaDB ADK store using mysql-connector sync driver.
+
+ Provides:
+ - Session state management with JSON storage
+ - Full-event JSON storage (single ``event_json`` column)
+ - Atomic event-create + state-update in one transaction
+ - Microsecond-precision timestamps
+ - Foreign key constraints with cascade delete
+
+ Notes:
+ - Uses ``cast()`` extensively because mysql-connector returns ``Any`` types
+ - Configuration is read from config.extension_config["adk"]
+ """
__slots__ = ()
@@ -369,65 +409,25 @@ def __init__(self, config: "MysqlConnectorSyncConfig") -> None:
def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]":
return _parse_owner_id_column_for_mysql(column_ddl)
- def _get_create_sessions_table_sql(self) -> str:
- owner_id_col = ""
- fk_constraint = ""
-
- if self._owner_id_column_ddl:
- col_def, fk_def = self._parse_owner_id_column_for_mysql(self._owner_id_column_ddl)
- owner_id_col = f"{col_def},"
- if fk_def:
- fk_constraint = f",\n {fk_def}"
-
- return f"""
- CREATE TABLE IF NOT EXISTS {self._session_table} (
- id VARCHAR(128) PRIMARY KEY,
- app_name VARCHAR(128) NOT NULL,
- user_id VARCHAR(128) NOT NULL,
- {owner_id_col}
- state JSON NOT NULL,
- create_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
- update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6),
- INDEX idx_{self._session_table}_app_user (app_name, user_id),
- INDEX idx_{self._session_table}_update_time (update_time DESC){fk_constraint}
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
- """
+ async def _get_create_sessions_table_sql(self) -> str:
+ return _mysql_sessions_ddl(self._session_table, self._owner_id_column_ddl)
- def _get_create_events_table_sql(self) -> str:
- return f"""
- CREATE TABLE IF NOT EXISTS {self._events_table} (
- id VARCHAR(128) PRIMARY KEY,
- session_id VARCHAR(128) NOT NULL,
- app_name VARCHAR(128) NOT NULL,
- user_id VARCHAR(128) NOT NULL,
- invocation_id VARCHAR(256) NOT NULL,
- author VARCHAR(256) NOT NULL,
- actions BLOB NOT NULL,
- long_running_tool_ids_json JSON,
- branch VARCHAR(256),
- timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
- content JSON,
- grounding_metadata JSON,
- custom_metadata JSON,
- partial BOOLEAN,
- turn_complete BOOLEAN,
- interrupted BOOLEAN,
- error_code VARCHAR(256),
- error_message VARCHAR(1024),
- FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE,
- INDEX idx_{self._events_table}_session (session_id, timestamp ASC)
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
- """
+ async def _get_create_events_table_sql(self) -> str:
+ return _mysql_events_ddl(self._events_table, self._session_table)
def _get_drop_tables_sql(self) -> "list[str]":
return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"]
- def create_tables(self) -> None:
+ def _create_tables(self) -> None:
with self._config.provide_session() as driver:
- driver.execute_script(self._get_create_sessions_table_sql())
- driver.execute_script(self._get_create_events_table_sql())
+ driver.execute_script(run_(self._get_create_sessions_table_sql)())
+ driver.execute_script(run_(self._get_create_events_table_sql)())
+
+ async def create_tables(self) -> None:
+ """Create tables if they don't exist."""
+ await async_(self._create_tables)()
- def create_session(
+ def _create_session(
self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
) -> SessionRecord:
state_json = to_json(state)
@@ -454,13 +454,19 @@ def create_session(
cursor.close()
conn.commit()
- result = self.get_session(session_id)
+ result = self._get_session(session_id)
if result is None:
msg = "Failed to fetch created session"
raise RuntimeError(msg)
return result
- def get_session(self, session_id: str) -> "SessionRecord | None":
+ async def create_session(
+ self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
+ ) -> SessionRecord:
+ """Create a new session."""
+ return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id)
+
+ def _get_session(self, session_id: str) -> "SessionRecord | None":
sql = f"""
SELECT id, app_name, user_id, state, create_time, update_time
FROM {self._session_table}
@@ -494,7 +500,11 @@ def get_session(self, session_id: str) -> "SessionRecord | None":
return None
raise
- def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
+ async def get_session(self, session_id: str) -> "SessionRecord | None":
+ """Get session by ID."""
+ return await async_(self._get_session)(session_id)
+
+ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
state_json = to_json(state)
sql = f"""
@@ -511,7 +521,11 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None
cursor.close()
conn.commit()
- def delete_session(self, session_id: str) -> None:
+ async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
+ """Update session state."""
+ await async_(self._update_session_state)(session_id, state)
+
+ def _delete_session(self, session_id: str) -> None:
sql = f"DELETE FROM {self._session_table} WHERE id = %s"
with self._config.provide_connection() as conn:
@@ -522,7 +536,11 @@ def delete_session(self, session_id: str) -> None:
cursor.close()
conn.commit()
- def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
+ async def delete_session(self, session_id: str) -> None:
+ """Delete session and associated events."""
+ await async_(self._delete_session)(session_id)
+
+ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
if user_id is None:
sql = f"""
SELECT id, app_name, user_id, state, create_time, update_time
@@ -565,24 +583,71 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess
return []
raise
- def append_event(self, event_record: EventRecord) -> None:
- content_json = to_json(event_record.get("content")) if event_record.get("content") else None
- grounding_metadata_json = (
- to_json(event_record.get("grounding_metadata")) if event_record.get("grounding_metadata") else None
- )
- custom_metadata_json = (
- to_json(event_record.get("custom_metadata")) if event_record.get("custom_metadata") else None
- )
+ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
+ """List sessions for an app."""
+ return await async_(self._list_sessions)(app_name, user_id)
+
+ def _append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ """Atomically create an event and update the session's durable state.
+
+ The event insert and state update succeed together or fail together
+ within a single transaction.
+
+ Args:
+ event_record: Event record to store.
+ session_id: Session identifier whose state should be updated.
+ state: Post-append durable state snapshot.
+ """
+ event_json = event_record["event_json"]
+ event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json
+ state_json = to_json(state)
+
+ insert_sql = f"""
+ INSERT INTO {self._events_table} (
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES (%s, %s, %s, %s, %s)
+ """
+
+ update_sql = f"""
+ UPDATE {self._session_table}
+ SET state = %s
+ WHERE id = %s
+ """
+
+ with self._config.provide_connection() as conn:
+ cursor = conn.cursor()
+ try:
+ cursor.execute(
+ insert_sql,
+ (
+ event_record["session_id"],
+ event_record["invocation_id"],
+ event_record["author"],
+ event_record["timestamp"],
+ event_json_str,
+ ),
+ )
+ cursor.execute(update_sql, (state_json, session_id))
+ finally:
+ cursor.close()
+ conn.commit()
+
+ async def append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ """Atomically append an event and update the session's durable state."""
+ await async_(self._append_event_and_update_state)(event_record, session_id, state)
+
+ def _insert_event(self, event_record: EventRecord) -> None:
+ event_json = event_record["event_json"]
+ event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json
sql = f"""
INSERT INTO {self._events_table} (
- id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
- ) VALUES (
- %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
- )
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES (%s, %s, %s, %s, %s)
"""
with self._config.provide_connection() as conn:
@@ -591,33 +656,30 @@ def append_event(self, event_record: EventRecord) -> None:
cursor.execute(
sql,
(
- event_record["id"],
event_record["session_id"],
- event_record["app_name"],
- event_record["user_id"],
event_record["invocation_id"],
event_record["author"],
- event_record["actions"],
- event_record.get("long_running_tool_ids_json"),
- event_record.get("branch"),
event_record["timestamp"],
- content_json,
- grounding_metadata_json,
- custom_metadata_json,
- event_record.get("partial"),
- event_record.get("turn_complete"),
- event_record.get("interrupted"),
- event_record.get("error_code"),
- event_record.get("error_message"),
+ event_json_str,
),
)
finally:
cursor.close()
conn.commit()
- def get_events(
+ def _get_events(
self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
) -> "list[EventRecord]":
+ """List events for a session ordered by timestamp.
+
+ Args:
+ session_id: Session identifier.
+ after_timestamp: Only return events after this time.
+ limit: Maximum number of events to return.
+
+ Returns:
+ List of event records ordered by timestamp ASC.
+ """
where_clauses = ["session_id = %s"]
params: list[Any] = [session_id]
@@ -626,53 +688,32 @@ def get_events(
params.append(after_timestamp)
where_clause = " AND ".join(where_clauses)
- limit_clause = f" LIMIT {limit}" if limit else ""
-
+ limit_clause = " LIMIT %s" if limit else ""
sql = f"""
- SELECT id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
+ SELECT session_id, invocation_id, author, timestamp, event_json
FROM {self._events_table}
WHERE {where_clause}
ORDER BY timestamp ASC{limit_clause}
"""
+ if limit:
+ params.append(limit)
try:
with self._config.provide_connection() as conn:
cursor = conn.cursor()
try:
- cursor.execute(sql, params)
+ cursor.execute(sql, tuple(params))
rows = cursor.fetchall()
finally:
cursor.close()
return [
EventRecord(
- id=cast("str", row[0]),
- session_id=cast("str", row[1]),
- app_name=cast("str", row[2]),
- user_id=cast("str", row[3]),
- invocation_id=cast("str", row[4]),
- author=cast("str", row[5]),
- actions=bytes(cast("bytes", row[6])),
- long_running_tool_ids_json=cast("str | None", row[7]),
- branch=cast("str | None", row[8]),
- timestamp=cast("datetime", row[9]),
- content=from_json(row[10])
- if row[10] and isinstance(row[10], str)
- else cast("dict[str, Any] | None", row[10]),
- grounding_metadata=from_json(row[11])
- if row[11] and isinstance(row[11], str)
- else cast("dict[str, Any] | None", row[11]),
- custom_metadata=from_json(row[12])
- if row[12] and isinstance(row[12], str)
- else cast("dict[str, Any] | None", row[12]),
- partial=cast("bool | None", row[13]),
- turn_complete=cast("bool | None", row[14]),
- interrupted=cast("bool | None", row[15]),
- error_code=cast("str | None", row[16]),
- error_message=cast("str | None", row[17]),
+ session_id=cast("str", row[0]),
+ invocation_id=cast("str", row[1]),
+ author=cast("str", row[2]),
+ timestamp=cast("datetime", row[3]),
+ event_json=from_json(row[4]) if isinstance(row[4], str) else cast("dict[str, Any]", row[4]),
)
for row in rows
]
@@ -681,6 +722,20 @@ def get_events(
return []
raise
+ async def get_events(
+ self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
+ ) -> "list[EventRecord]":
+ """Get events for a session."""
+ return await async_(self._get_events)(session_id, after_timestamp, limit)
+
+ def _append_event(self, event_record: EventRecord) -> None:
+ """Synchronous implementation of append_event."""
+ self._insert_event(event_record)
+
+ async def append_event(self, event_record: EventRecord) -> None:
+ """Append an event to a session."""
+ await async_(self._append_event)(event_record)
+
class MysqlConnectorAsyncADKMemoryStore(BaseAsyncADKMemoryStore["MysqlConnectorAsyncConfig"]):
"""MySQL/MariaDB ADK memory store using mysql-connector async driver."""
@@ -871,7 +926,7 @@ async def delete_entries_older_than(self, days: int) -> int:
await cursor.close()
-class MysqlConnectorSyncADKMemoryStore(BaseSyncADKMemoryStore["MysqlConnectorSyncConfig"]):
+class MysqlConnectorSyncADKMemoryStore(BaseAsyncADKMemoryStore["MysqlConnectorSyncConfig"]):
"""MySQL/MariaDB ADK memory store using mysql-connector sync driver."""
__slots__ = ()
@@ -879,7 +934,7 @@ class MysqlConnectorSyncADKMemoryStore(BaseSyncADKMemoryStore["MysqlConnectorSyn
def __init__(self, config: "MysqlConnectorSyncConfig") -> None:
super().__init__(config)
- def _get_create_memory_table_sql(self) -> str:
+ async def _get_create_memory_table_sql(self) -> str:
owner_id_line = ""
fk_constraint = ""
if self._owner_id_column_ddl:
@@ -913,14 +968,18 @@ def _get_create_memory_table_sql(self) -> str:
def _get_drop_memory_table_sql(self) -> "list[str]":
return [f"DROP TABLE IF EXISTS {self._memory_table}"]
- def create_tables(self) -> None:
+ def _create_tables(self) -> None:
if not self._enabled:
return
with self._config.provide_session() as driver:
- driver.execute_script(self._get_create_memory_table_sql())
+ driver.execute_script(run_(self._get_create_memory_table_sql)())
- def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
+ async def create_tables(self) -> None:
+ """Create tables if they don't exist."""
+ await async_(self._create_tables)()
+
+ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
if not self._enabled:
msg = "Memory store is disabled"
raise RuntimeError(msg)
@@ -986,7 +1045,11 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object
conn.commit()
return inserted_count
- def search_entries(
+ async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
+ """Bulk insert memory entries with deduplication."""
+ return await async_(self._insert_memory_entries)(entries, owner_id)
+
+ def _search_entries(
self, query: str, app_name: str, user_id: str, limit: "int | None" = None
) -> "list[MemoryRecord]":
if not self._enabled:
@@ -1026,7 +1089,13 @@ def search_entries(
return [cast("MemoryRecord", dict(zip(columns, row, strict=False))) for row in rows]
- def delete_entries_by_session(self, session_id: str) -> int:
+ async def search_entries(
+ self, query: str, app_name: str, user_id: str, limit: "int | None" = None
+ ) -> "list[MemoryRecord]":
+ """Search memory entries by text query."""
+ return await async_(self._search_entries)(query, app_name, user_id, limit)
+
+ def _delete_entries_by_session(self, session_id: str) -> int:
if not self._enabled:
msg = "Memory store is disabled"
raise RuntimeError(msg)
@@ -1041,7 +1110,11 @@ def delete_entries_by_session(self, session_id: str) -> int:
finally:
cursor.close()
- def delete_entries_older_than(self, days: int) -> int:
+ async def delete_entries_by_session(self, session_id: str) -> int:
+ """Delete all memory entries for a specific session."""
+ return await async_(self._delete_entries_by_session)(session_id)
+
+ def _delete_entries_older_than(self, days: int) -> int:
if not self._enabled:
msg = "Memory store is disabled"
raise RuntimeError(msg)
@@ -1058,3 +1131,7 @@ def delete_entries_older_than(self, days: int) -> int:
return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0
finally:
cursor.close()
+
+ async def delete_entries_older_than(self, days: int) -> int:
+ """Delete memory entries older than specified days."""
+ return await async_(self._delete_entries_older_than)(days)
diff --git a/sqlspec/adapters/oracledb/adk/store.py b/sqlspec/adapters/oracledb/adk/store.py
index 4aa9b2811..56248882d 100644
--- a/sqlspec/adapters/oracledb/adk/store.py
+++ b/sqlspec/adapters/oracledb/adk/store.py
@@ -12,10 +12,11 @@
OracledbSyncDataDictionary,
OracleVersionInfo,
)
-from sqlspec.extensions.adk import BaseAsyncADKStore, BaseSyncADKStore, EventRecord, SessionRecord
-from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore
+from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord
+from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore
from sqlspec.utils.logging import get_logger
from sqlspec.utils.serializers import from_json, to_json
+from sqlspec.utils.sync_tools import async_, run_
from sqlspec.utils.type_guards import is_async_readable, is_readable
if TYPE_CHECKING:
@@ -93,43 +94,28 @@ def storage_type_from_version(version_info: "OracleVersionInfo | None") -> JSONS
return _storage_type_from_version(version_info)
-def _to_oracle_bool(value: "bool | None") -> "int | None":
- """Convert Python boolean to Oracle NUMBER(1).
+def _event_json_column_ddl(storage_type: JSONStorageType) -> str:
+ """Return the DDL fragment for the event_json column.
- Args:
- value: Python boolean value or None.
-
- Returns:
- 1 for True, 0 for False, None for None.
+ For JSON_NATIVE (Oracle 21c+) we use the native JSON type.
+ For older versions we use BLOB since Oracle recommends BLOB over CLOB for
+ JSON storage. BLOB_JSON gets a CHECK constraint; BLOB_PLAIN does not.
"""
- if value is None:
- return None
- return 1 if value else 0
+ if storage_type == JSONStorageType.JSON_NATIVE:
+ return "event_json JSON NOT NULL"
+ if storage_type == JSONStorageType.BLOB_JSON:
+ return "event_json BLOB CHECK (event_json IS JSON) NOT NULL"
+ return "event_json BLOB NOT NULL"
-def _from_oracle_bool(value: "int | None") -> "bool | None":
- """Convert Oracle NUMBER(1) to Python boolean.
-
- Args:
- value: Oracle NUMBER value (0, 1, or None).
+def _oracle_text_value(value: Any) -> str:
+ """Normalize Oracle VARCHAR2 values back to Python strings.
- Returns:
- Python boolean or None.
+ Oracle stores empty strings as ``NULL``. The ADK event contract allows
+ empty strings for fields like ``invocation_id``, so reads coerce ``NULL``
+ back to ``""``.
"""
- if value is None:
- return None
- return bool(value)
-
-
-def _coerce_bytes_payload(value: Any) -> bytes:
- """Coerce a LOB payload into bytes."""
- if value is None:
- return b""
- if isinstance(value, bytes):
- return value
- if isinstance(value, str):
- return value.encode("utf-8")
- return str(value).encode("utf-8")
+ return "" if value is None else str(value)
class OracleAsyncADKStore(BaseAsyncADKStore["OracleAsyncConfig"]):
@@ -138,7 +124,8 @@ class OracleAsyncADKStore(BaseAsyncADKStore["OracleAsyncConfig"]):
Implements session and event storage for Google Agent Development Kit
using Oracle Database via the python-oracledb async driver. Provides:
- Session state management with version-specific JSON storage
- - Event history tracking with BLOB-serialized actions
+ - Full-fidelity event storage via ``event_json`` column
+ - Atomic ``append_event_and_update_state`` for durable session mutations
- TIMESTAMP WITH TIME ZONE for timezone-aware timestamps
- Foreign key constraints with cascade delete
- Efficient upserts using MERGE statement
@@ -146,28 +133,10 @@ class OracleAsyncADKStore(BaseAsyncADKStore["OracleAsyncConfig"]):
Args:
config: OracleAsyncConfig with extension_config["adk"] settings.
- Example:
- from sqlspec.adapters.oracledb import OracleAsyncConfig
- from sqlspec.adapters.oracledb.adk import OracleAsyncADKStore
-
- config = OracleAsyncConfig(
- connection_config={"dsn": "oracle://..."},
- extension_config={
- "adk": {
- "session_table": "my_sessions",
- "events_table": "my_events",
- "owner_id_column": "tenant_id NUMBER(10) REFERENCES tenants(id)"
- }
- }
- )
- store = OracleAsyncADKStore(config)
- await store.ensure_tables()
-
Notes:
- JSON storage type detected based on Oracle version (21c+, 12c+, legacy)
- - BLOB for pre-serialized actions from Google ADK
+ - event_json stored as JSON (21c+) or BLOB (older versions)
- TIMESTAMP WITH TIME ZONE for timezone-aware timestamps
- - NUMBER(1) for booleans (0/1/NULL)
- Named parameters using :param_name
- State merging handled at application level
- owner_id_column supports NUMBER, VARCHAR2, RAW for Oracle FK types
@@ -223,10 +192,9 @@ async def _detect_json_storage_type(self) -> JSONStorageType:
Notes:
Queries product_component_version to determine Oracle version.
- Oracle 21c+ with compatible >= 20: Native JSON type
- - Oracle 12c+: BLOB with IS JSON constraint (preferred)
- - Oracle 11g and earlier: BLOB without constraint
+ - Oracle 12c+: BLOB with IS JSON constraint
+ - Oracle 11g and earlier: plain BLOB
- BLOB is preferred over CLOB for 12c+ as per Oracle recommendations.
Result is cached in self._json_storage_type.
"""
if self._json_storage_type is not None:
@@ -296,55 +264,40 @@ async def _deserialize_state(self, data: Any) -> "dict[str, Any]":
return from_json(str(data)) # type: ignore[no-any-return]
- async def _serialize_json_field(self, value: Any) -> "str | bytes | None":
- """Serialize optional JSON field for event storage.
-
- Args:
- value: Value to serialize (dict or None).
-
- Returns:
- Serialized JSON or None.
- """
- if value is None:
+ async def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None":
+ """Deserialize JSON payloads from Oracle JSON/BLOB/LOB values."""
+ if data is None:
return None
+ return await self._deserialize_state(data)
+ async def _serialize_event_json(self, event_json: Any) -> "str | bytes":
+ """Serialize event_json to the configured Oracle JSON storage format."""
storage_type = await self._detect_json_storage_type()
-
if storage_type == JSONStorageType.JSON_NATIVE:
- return to_json(value)
+ return to_json(event_json)
+ return to_json(event_json, as_bytes=True)
- return to_json(value, as_bytes=True)
-
- async def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None":
- """Deserialize optional JSON field from database.
+ async def _read_event_json(self, data: Any) -> str:
+ """Read event_json from database, handling LOB types.
Args:
- data: Data from database (may be LOB, str, bytes, dict, or None).
+ data: Data from database (may be LOB, str, or dict).
Returns:
- Deserialized dictionary or None.
-
- Notes:
- Oracle JSON type may return dict directly.
+ JSON string.
"""
- if data is None:
- return None
-
if is_async_readable(data):
data = await data.read()
elif is_readable(data):
data = data.read()
if isinstance(data, dict):
- return cast("dict[str, Any]", _coerce_decimal_values(data))
+ return to_json(data)
if isinstance(data, bytes):
- return from_json(data) # type: ignore[no-any-return]
-
- if isinstance(data, str):
- return from_json(data) # type: ignore[no-any-return]
+ return data.decode("utf-8")
- return from_json(str(data)) # type: ignore[no-any-return]
+ return str(data)
def _get_create_sessions_table_sql_for_type(self, storage_type: JSONStorageType) -> str:
"""Get Oracle CREATE TABLE SQL for sessions with specified storage type.
@@ -406,54 +359,27 @@ def _get_create_sessions_table_sql_for_type(self, storage_type: JSONStorageType)
def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) -> str:
"""Get Oracle CREATE TABLE SQL for events with specified storage type.
+ The events table uses the new 5-column contract: session_id, invocation_id,
+ author, timestamp, and event_json. The event_json column stores the full
+ ADK Event as JSON (21c+) or BLOB (older versions).
+
Args:
storage_type: JSON storage type to use.
Returns:
SQL statement to create adk_events table.
"""
- if storage_type == JSONStorageType.JSON_NATIVE:
- json_columns = """
- content JSON,
- grounding_metadata JSON,
- custom_metadata JSON,
- long_running_tool_ids_json JSON
- """
- elif storage_type == JSONStorageType.BLOB_JSON:
- json_columns = """
- content BLOB CHECK (content IS JSON),
- grounding_metadata BLOB CHECK (grounding_metadata IS JSON),
- custom_metadata BLOB CHECK (custom_metadata IS JSON),
- long_running_tool_ids_json BLOB CHECK (long_running_tool_ids_json IS JSON)
- """
- else:
- json_columns = """
- content BLOB,
- grounding_metadata BLOB,
- custom_metadata BLOB,
- long_running_tool_ids_json BLOB
- """
-
+ event_json_col = _event_json_column_ddl(storage_type)
inmemory_clause = " INMEMORY PRIORITY HIGH" if self._in_memory else ""
return f"""
BEGIN
EXECUTE IMMEDIATE 'CREATE TABLE {self._events_table} (
- id VARCHAR2(128) PRIMARY KEY,
session_id VARCHAR2(128) NOT NULL,
- app_name VARCHAR2(128) NOT NULL,
- user_id VARCHAR2(128) NOT NULL,
invocation_id VARCHAR2(256),
author VARCHAR2(256),
- actions BLOB,
- branch VARCHAR2(256),
timestamp TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL,
- {json_columns},
- partial NUMBER(1),
- turn_complete NUMBER(1),
- interrupted NUMBER(1),
- error_code VARCHAR2(256),
- error_message VARCHAR2(1024),
+ {event_json_col},
CONSTRAINT fk_{self._events_table}_session FOREIGN KEY (session_id)
REFERENCES {self._session_table}(id) ON DELETE CASCADE
){inmemory_clause}';
@@ -753,28 +679,14 @@ async def append_event(self, event_record: EventRecord) -> None:
"""Append an event to a session.
Args:
- event_record: Event record to store.
-
- Notes:
- Uses SYSTIMESTAMP for timestamp if not provided.
- JSON fields are serialized using version-appropriate format.
- Boolean fields are converted to NUMBER(1).
+ event_record: Event record with 5 keys: session_id, invocation_id,
+ author, timestamp, event_json.
"""
- content_data = await self._serialize_json_field(event_record.get("content"))
- grounding_metadata_data = await self._serialize_json_field(event_record.get("grounding_metadata"))
- custom_metadata_data = await self._serialize_json_field(event_record.get("custom_metadata"))
-
sql = f"""
INSERT INTO {self._events_table} (
- id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
+ session_id, invocation_id, author, timestamp, event_json
) VALUES (
- :id, :session_id, :app_name, :user_id, :invocation_id, :author, :actions,
- :long_running_tool_ids_json, :branch, :timestamp, :content,
- :grounding_metadata, :custom_metadata, :partial, :turn_complete,
- :interrupted, :error_code, :error_message
+ :session_id, :invocation_id, :author, :timestamp, :event_json
)
"""
@@ -783,26 +695,58 @@ async def append_event(self, event_record: EventRecord) -> None:
await cursor.execute(
sql,
{
- "id": event_record["id"],
"session_id": event_record["session_id"],
- "app_name": event_record["app_name"],
- "user_id": event_record["user_id"],
"invocation_id": event_record["invocation_id"],
"author": event_record["author"],
- "actions": event_record["actions"],
- "long_running_tool_ids_json": event_record.get("long_running_tool_ids_json"),
- "branch": event_record.get("branch"),
"timestamp": event_record["timestamp"],
- "content": content_data,
- "grounding_metadata": grounding_metadata_data,
- "custom_metadata": custom_metadata_data,
- "partial": _to_oracle_bool(event_record.get("partial")),
- "turn_complete": _to_oracle_bool(event_record.get("turn_complete")),
- "interrupted": _to_oracle_bool(event_record.get("interrupted")),
- "error_code": event_record.get("error_code"),
- "error_message": event_record.get("error_message"),
+ "event_json": await self._serialize_event_json(event_record["event_json"]),
+ },
+ )
+ await conn.commit()
+
+ async def append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ """Atomically append an event and update the session's durable state.
+
+ Both the event insert and session state update are executed within a
+ single transaction so they succeed or fail together.
+
+ Args:
+ event_record: Event record with 5 keys: session_id, invocation_id,
+ author, timestamp, event_json.
+ session_id: Session identifier whose state should be updated.
+ state: Post-append durable state snapshot (``temp:`` keys already
+ stripped by the service layer).
+ """
+ insert_sql = f"""
+ INSERT INTO {self._events_table} (
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES (
+ :session_id, :invocation_id, :author, :timestamp, :event_json
+ )
+ """
+
+ state_data = await self._serialize_state(state)
+ update_sql = f"""
+ UPDATE {self._session_table}
+ SET state = :state, update_time = SYSTIMESTAMP
+ WHERE id = :id
+ """
+
+ async with self._config.provide_connection() as conn:
+ cursor = conn.cursor()
+ await cursor.execute(
+ insert_sql,
+ {
+ "session_id": event_record["session_id"],
+ "invocation_id": event_record["invocation_id"],
+ "author": event_record["author"],
+ "timestamp": event_record["timestamp"],
+ "event_json": await self._serialize_event_json(event_record["event_json"]),
},
)
+ await cursor.execute(update_sql, {"state": state_data, "id": session_id})
await conn.commit()
async def get_events(
@@ -817,11 +761,6 @@ async def get_events(
Returns:
List of event records ordered by timestamp ASC.
-
- Notes:
- Uses index on (session_id, timestamp ASC).
- JSON fields deserialized using version-appropriate format.
- Converts BLOB actions to bytes and NUMBER(1) booleans to Python bool.
"""
where_clauses = ["session_id = :session_id"]
@@ -837,10 +776,7 @@ async def get_events(
limit_clause = f" FETCH FIRST {limit} ROWS ONLY"
sql = f"""
- SELECT id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
+ SELECT session_id, invocation_id, author, timestamp, event_json
FROM {self._events_table}
WHERE {where_clause}
ORDER BY timestamp ASC{limit_clause}
@@ -852,43 +788,16 @@ async def get_events(
await cursor.execute(sql, params)
rows = await cursor.fetchall()
- results = []
- for row in rows:
- actions_blob = row[6]
- if is_async_readable(actions_blob):
- actions_data = await actions_blob.read()
- elif is_readable(actions_blob):
- actions_data = actions_blob.read()
- else:
- actions_data = actions_blob
-
- content = await self._deserialize_json_field(row[10])
- grounding_metadata = await self._deserialize_json_field(row[11])
- custom_metadata = await self._deserialize_json_field(row[12])
-
- results.append(
- EventRecord(
- id=row[0],
- session_id=row[1],
- app_name=row[2],
- user_id=row[3],
- invocation_id=row[4],
- author=row[5],
- actions=_coerce_bytes_payload(actions_data),
- long_running_tool_ids_json=row[7],
- branch=row[8],
- timestamp=row[9],
- content=content,
- grounding_metadata=grounding_metadata,
- custom_metadata=custom_metadata,
- partial=_from_oracle_bool(row[13]),
- turn_complete=_from_oracle_bool(row[14]),
- interrupted=_from_oracle_bool(row[15]),
- error_code=row[16],
- error_message=row[17],
- )
+ return [
+ EventRecord(
+ session_id=row[0],
+ invocation_id=_oracle_text_value(row[1]),
+ author=_oracle_text_value(row[2]),
+ timestamp=row[3],
+ event_json=await self._deserialize_json_field(row[4]) or {},
)
- return results
+ for row in rows
+ ]
except oracledb.DatabaseError as e:
error_obj = e.args[0] if e.args else None
if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR:
@@ -896,13 +805,14 @@ async def get_events(
raise
-class OracleSyncADKStore(BaseSyncADKStore["OracleSyncConfig"]):
+class OracleSyncADKStore(BaseAsyncADKStore["OracleSyncConfig"]):
"""Oracle synchronous ADK store using oracledb sync driver.
Implements session and event storage for Google Agent Development Kit
using Oracle Database via the python-oracledb synchronous driver. Provides:
- Session state management with version-specific JSON storage
- - Event history tracking with BLOB-serialized actions
+ - Full-fidelity event storage via ``event_json`` column
+ - Atomic ``create_event_and_update_state`` for durable session mutations
- TIMESTAMP WITH TIME ZONE for timezone-aware timestamps
- Foreign key constraints with cascade delete
- Efficient upserts using MERGE statement
@@ -910,28 +820,10 @@ class OracleSyncADKStore(BaseSyncADKStore["OracleSyncConfig"]):
Args:
config: OracleSyncConfig with extension_config["adk"] settings.
- Example:
- from sqlspec.adapters.oracledb import OracleSyncConfig
- from sqlspec.adapters.oracledb.adk import OracleSyncADKStore
-
- config = OracleSyncConfig(
- connection_config={"dsn": "oracle://..."},
- extension_config={
- "adk": {
- "session_table": "my_sessions",
- "events_table": "my_events",
- "owner_id_column": "account_id NUMBER(19) REFERENCES accounts(id)"
- }
- }
- )
- store = OracleSyncADKStore(config)
- store.ensure_tables()
-
Notes:
- JSON storage type detected based on Oracle version (21c+, 12c+, legacy)
- - BLOB for pre-serialized actions from Google ADK
+ - event_json stored as JSON (21c+) or BLOB (older versions)
- TIMESTAMP WITH TIME ZONE for timezone-aware timestamps
- - NUMBER(1) for booleans (0/1/NULL)
- Named parameters using :param_name
- State merging handled at application level
- owner_id_column supports NUMBER, VARCHAR2, RAW for Oracle FK types
@@ -960,7 +852,7 @@ def __init__(self, config: "OracleSyncConfig") -> None:
adk_config = config.extension_config.get("adk", {})
self._in_memory: bool = bool(adk_config.get("in_memory", False))
- def _get_create_sessions_table_sql(self) -> str:
+ async def _get_create_sessions_table_sql(self) -> str:
"""Get Oracle CREATE TABLE SQL for sessions table.
Auto-detects optimal JSON storage type based on Oracle version.
@@ -969,7 +861,7 @@ def _get_create_sessions_table_sql(self) -> str:
storage_type = self._detect_json_storage_type()
return self._get_create_sessions_table_sql_for_type(storage_type)
- def _get_create_events_table_sql(self) -> str:
+ async def _get_create_events_table_sql(self) -> str:
"""Get Oracle CREATE TABLE SQL for events table.
Auto-detects optimal JSON storage type based on Oracle version.
@@ -987,10 +879,9 @@ def _detect_json_storage_type(self) -> JSONStorageType:
Notes:
Queries product_component_version to determine Oracle version.
- Oracle 21c+ with compatible >= 20: Native JSON type
- - Oracle 12c+: BLOB with IS JSON constraint (preferred)
- - Oracle 11g and earlier: BLOB without constraint
+ - Oracle 12c+: BLOB with IS JSON constraint
+ - Oracle 11g and earlier: plain BLOB
- BLOB is preferred over CLOB for 12c+ as per Oracle recommendations.
Result is cached in self._json_storage_type.
"""
if self._json_storage_type is not None:
@@ -1058,53 +949,38 @@ def _deserialize_state(self, data: Any) -> "dict[str, Any]":
return from_json(str(data)) # type: ignore[no-any-return]
- def _serialize_json_field(self, value: Any) -> "str | bytes | None":
- """Serialize optional JSON field for event storage.
-
- Args:
- value: Value to serialize (dict or None).
-
- Returns:
- Serialized JSON or None.
- """
- if value is None:
+ def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None":
+ """Deserialize JSON payloads from Oracle JSON/BLOB/LOB values."""
+ if data is None:
return None
+ return self._deserialize_state(data)
+ def _serialize_event_json(self, event_json: Any) -> "str | bytes":
+ """Serialize event_json to the configured Oracle JSON storage format."""
storage_type = self._detect_json_storage_type()
-
if storage_type == JSONStorageType.JSON_NATIVE:
- return to_json(value)
+ return to_json(event_json)
+ return to_json(event_json, as_bytes=True)
- return to_json(value, as_bytes=True)
-
- def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None":
- """Deserialize optional JSON field from database.
+ def _read_event_json(self, data: Any) -> str:
+ """Read event_json from database, handling LOB types.
Args:
- data: Data from database (may be LOB, str, bytes, dict, or None).
+ data: Data from database (may be LOB, str, or dict).
Returns:
- Deserialized dictionary or None.
-
- Notes:
- Oracle JSON type may return dict directly.
+ JSON string.
"""
- if data is None:
- return None
-
if is_readable(data):
data = data.read()
if isinstance(data, dict):
- return cast("dict[str, Any]", _coerce_decimal_values(data))
+ return to_json(data)
if isinstance(data, bytes):
- return from_json(data) # type: ignore[no-any-return]
+ return data.decode("utf-8")
- if isinstance(data, str):
- return from_json(data) # type: ignore[no-any-return]
-
- return from_json(str(data)) # type: ignore[no-any-return]
+ return str(data)
def _get_create_sessions_table_sql_for_type(self, storage_type: JSONStorageType) -> str:
"""Get Oracle CREATE TABLE SQL for sessions with specified storage type.
@@ -1166,54 +1042,27 @@ def _get_create_sessions_table_sql_for_type(self, storage_type: JSONStorageType)
def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) -> str:
"""Get Oracle CREATE TABLE SQL for events with specified storage type.
+ The events table uses the new 5-column contract: session_id, invocation_id,
+ author, timestamp, and event_json. The event_json column stores the full
+ ADK Event as JSON (21c+) or BLOB (older versions).
+
Args:
storage_type: JSON storage type to use.
Returns:
SQL statement to create adk_events table.
"""
- if storage_type == JSONStorageType.JSON_NATIVE:
- json_columns = """
- content JSON,
- grounding_metadata JSON,
- custom_metadata JSON,
- long_running_tool_ids_json JSON
- """
- elif storage_type == JSONStorageType.BLOB_JSON:
- json_columns = """
- content BLOB CHECK (content IS JSON),
- grounding_metadata BLOB CHECK (grounding_metadata IS JSON),
- custom_metadata BLOB CHECK (custom_metadata IS JSON),
- long_running_tool_ids_json BLOB CHECK (long_running_tool_ids_json IS JSON)
- """
- else:
- json_columns = """
- content BLOB,
- grounding_metadata BLOB,
- custom_metadata BLOB,
- long_running_tool_ids_json BLOB
- """
-
+ event_json_col = _event_json_column_ddl(storage_type)
inmemory_clause = " INMEMORY PRIORITY HIGH" if self._in_memory else ""
return f"""
BEGIN
EXECUTE IMMEDIATE 'CREATE TABLE {self._events_table} (
- id VARCHAR2(128) PRIMARY KEY,
session_id VARCHAR2(128) NOT NULL,
- app_name VARCHAR2(128) NOT NULL,
- user_id VARCHAR2(128) NOT NULL,
invocation_id VARCHAR2(256),
author VARCHAR2(256),
- actions BLOB,
- branch VARCHAR2(256),
timestamp TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL,
- {json_columns},
- partial NUMBER(1),
- turn_complete NUMBER(1),
- interrupted NUMBER(1),
- error_code VARCHAR2(256),
- error_message VARCHAR2(1024),
+ {event_json_col},
CONSTRAINT fk_{self._events_table}_session FOREIGN KEY (session_id)
REFERENCES {self._session_table}(id) ON DELETE CASCADE
){inmemory_clause}';
@@ -1298,7 +1147,7 @@ def _get_drop_tables_sql(self) -> "list[str]":
""",
]
- def create_tables(self) -> None:
+ def _create_tables(self) -> None:
"""Create both sessions and events tables if they don't exist.
Notes:
@@ -1315,7 +1164,11 @@ def create_tables(self) -> None:
events_sql = SQL(self._get_create_events_table_sql_for_type(storage_type))
driver.execute_script(events_sql)
- def create_session(
+ async def create_tables(self) -> None:
+ """Create tables if they don't exist."""
+ await async_(self._create_tables)()
+
+ def _create_session(
self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
) -> SessionRecord:
"""Create a new session.
@@ -1361,9 +1214,19 @@ def create_session(
cursor.execute(sql, params)
conn.commit()
- return self.get_session(session_id) # type: ignore[return-value]
+ result = self._get_session(session_id)
+ if result is None:
+ msg = "Failed to fetch created session"
+ raise RuntimeError(msg)
+ return result
- def get_session(self, session_id: str) -> "SessionRecord | None":
+ async def create_session(
+ self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
+ ) -> SessionRecord:
+ """Create a new session."""
+ return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id)
+
+ def _get_session(self, session_id: str) -> "SessionRecord | None":
"""Get session by ID.
Args:
@@ -1410,7 +1273,11 @@ def get_session(self, session_id: str) -> "SessionRecord | None":
return None
raise
- def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
+ async def get_session(self, session_id: str) -> "SessionRecord | None":
+ """Get session by ID."""
+ return await async_(self._get_session)(session_id)
+
+ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
"""Update session state.
Args:
@@ -1435,7 +1302,11 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None
cursor.execute(sql, {"state": state_data, "id": session_id})
conn.commit()
- def delete_session(self, session_id: str) -> None:
+ async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
+ """Update session state."""
+ await async_(self._update_session_state)(session_id, state)
+
+ def _delete_session(self, session_id: str) -> None:
"""Delete session and all associated events (cascade).
Args:
@@ -1451,7 +1322,11 @@ def delete_session(self, session_id: str) -> None:
cursor.execute(sql, {"id": session_id})
conn.commit()
- def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
+ async def delete_session(self, session_id: str) -> None:
+ """Delete session and associated events."""
+ await async_(self._delete_session)(session_id)
+
+ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
"""List sessions for an app, optionally filtered by user.
Args:
@@ -1510,159 +1385,147 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess
return []
raise
- def create_event(
- self,
- event_id: str,
- session_id: str,
- app_name: str,
- user_id: str,
- author: "str | None" = None,
- actions: "bytes | None" = None,
- content: "dict[str, Any] | None" = None,
- **kwargs: Any,
- ) -> "EventRecord":
- """Create a new event.
+ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
+ """List sessions for an app."""
+ return await async_(self._list_sessions)(app_name, user_id)
- Args:
- event_id: Unique event identifier.
- session_id: Session identifier.
- app_name: Application name.
- user_id: User identifier.
- author: Event author (user/assistant/system).
- actions: Pickled actions object.
- content: Event content (JSONB/JSON).
- **kwargs: Additional optional fields.
+ def _append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ """Atomically create an event and update the session's durable state.
- Returns:
- Created event record.
+ Both the event insert and session state update are executed within a
+ single transaction so they succeed or fail together.
- Notes:
- Uses SYSTIMESTAMP for timestamp if not provided.
- JSON fields are serialized using version-appropriate format.
- Boolean fields are converted to NUMBER(1).
+ Args:
+ event_record: Event record with 5 keys: session_id, invocation_id,
+ author, timestamp, event_json.
+ session_id: Session identifier whose state should be updated.
+ state: Post-append durable state snapshot (``temp:`` keys already
+ stripped by the service layer).
"""
- content_data = self._serialize_json_field(content)
- grounding_metadata_data = self._serialize_json_field(kwargs.get("grounding_metadata"))
- custom_metadata_data = self._serialize_json_field(kwargs.get("custom_metadata"))
-
- sql = f"""
+ insert_sql = f"""
INSERT INTO {self._events_table} (
- id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
+ session_id, invocation_id, author, timestamp, event_json
) VALUES (
- :id, :session_id, :app_name, :user_id, :invocation_id, :author, :actions,
- :long_running_tool_ids_json, :branch, :timestamp, :content,
- :grounding_metadata, :custom_metadata, :partial, :turn_complete,
- :interrupted, :error_code, :error_message
+ :session_id, :invocation_id, :author, :timestamp, :event_json
)
"""
+ state_data = self._serialize_state(state)
+ update_sql = f"""
+ UPDATE {self._session_table}
+ SET state = :state, update_time = SYSTIMESTAMP
+ WHERE id = :id
+ """
+
with self._config.provide_connection() as conn:
cursor = conn.cursor()
cursor.execute(
- sql,
+ insert_sql,
{
- "id": event_id,
- "session_id": session_id,
- "app_name": app_name,
- "user_id": user_id,
- "invocation_id": kwargs.get("invocation_id"),
- "author": author,
- "actions": actions,
- "long_running_tool_ids_json": kwargs.get("long_running_tool_ids_json"),
- "branch": kwargs.get("branch"),
- "timestamp": kwargs.get("timestamp"),
- "content": content_data,
- "grounding_metadata": grounding_metadata_data,
- "custom_metadata": custom_metadata_data,
- "partial": _to_oracle_bool(kwargs.get("partial")),
- "turn_complete": _to_oracle_bool(kwargs.get("turn_complete")),
- "interrupted": _to_oracle_bool(kwargs.get("interrupted")),
- "error_code": kwargs.get("error_code"),
- "error_message": kwargs.get("error_message"),
+ "session_id": event_record["session_id"],
+ "invocation_id": event_record["invocation_id"],
+ "author": event_record["author"],
+ "timestamp": event_record["timestamp"],
+ "event_json": self._serialize_event_json(event_record["event_json"]),
},
)
+ cursor.execute(update_sql, {"state": state_data, "id": session_id})
conn.commit()
- events = self.list_events(session_id)
- for event in events:
- if event["id"] == event_id:
- return event
-
- msg = f"Failed to retrieve created event {event_id}"
- raise RuntimeError(msg)
+ async def append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ """Atomically append an event and update the session's durable state."""
+ await async_(self._append_event_and_update_state)(event_record, session_id, state)
- def list_events(self, session_id: str) -> "list[EventRecord]":
+ def _get_events(
+ self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
+ ) -> "list[EventRecord]":
"""List events for a session ordered by timestamp.
Args:
session_id: Session identifier.
+ after_timestamp: Only return events after this time.
+ limit: Maximum number of events to return.
Returns:
List of event records ordered by timestamp ASC.
-
- Notes:
- Uses index on (session_id, timestamp ASC).
- JSON fields deserialized using version-appropriate format.
- Converts BLOB actions to bytes and NUMBER(1) booleans to Python bool.
"""
+ where_clauses = ["session_id = :session_id"]
+ params: dict[str, Any] = {"session_id": session_id}
+
+ if after_timestamp is not None:
+ where_clauses.append("timestamp > :after_timestamp")
+ params["after_timestamp"] = after_timestamp
+
+ where_clause = " AND ".join(where_clauses)
+ limit_clause = f" FETCH FIRST {limit} ROWS ONLY" if limit else ""
sql = f"""
- SELECT id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
+ SELECT session_id, invocation_id, author, timestamp, event_json
FROM {self._events_table}
- WHERE session_id = :session_id
- ORDER BY timestamp ASC
+ WHERE {where_clause}
+ ORDER BY timestamp ASC{limit_clause}
"""
try:
with self._config.provide_connection() as conn:
cursor = conn.cursor()
- cursor.execute(sql, {"session_id": session_id})
+ cursor.execute(sql, params)
rows = cursor.fetchall()
- results = []
- for row in rows:
- actions_blob = row[6]
- actions_data = actions_blob.read() if is_readable(actions_blob) else actions_blob
-
- content = self._deserialize_json_field(row[10])
- grounding_metadata = self._deserialize_json_field(row[11])
- custom_metadata = self._deserialize_json_field(row[12])
-
- results.append(
- EventRecord(
- id=row[0],
- session_id=row[1],
- app_name=row[2],
- user_id=row[3],
- invocation_id=row[4],
- author=row[5],
- actions=_coerce_bytes_payload(actions_data),
- long_running_tool_ids_json=row[7],
- branch=row[8],
- timestamp=row[9],
- content=content,
- grounding_metadata=grounding_metadata,
- custom_metadata=custom_metadata,
- partial=_from_oracle_bool(row[13]),
- turn_complete=_from_oracle_bool(row[14]),
- interrupted=_from_oracle_bool(row[15]),
- error_code=row[16],
- error_message=row[17],
- )
+ return [
+ EventRecord(
+ session_id=row[0],
+ invocation_id=_oracle_text_value(row[1]),
+ author=_oracle_text_value(row[2]),
+ timestamp=row[3],
+ event_json=self._deserialize_json_field(row[4]) or {},
)
- return results
+ for row in rows
+ ]
except oracledb.DatabaseError as e:
error_obj = e.args[0] if e.args else None
if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR:
return []
raise
+ async def get_events(
+ self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
+ ) -> "list[EventRecord]":
+ """Get events for a session."""
+ return await async_(self._get_events)(session_id, after_timestamp, limit)
+
+ def _append_event(self, event_record: EventRecord) -> None:
+ """Synchronous implementation of append_event."""
+ sql = f"""
+ INSERT INTO {self._events_table} (
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES (
+ :session_id, :invocation_id, :author, :timestamp, :event_json
+ )
+ """
+
+ with self._config.provide_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ sql,
+ {
+ "session_id": event_record["session_id"],
+ "invocation_id": event_record["invocation_id"],
+ "author": event_record["author"],
+ "timestamp": event_record["timestamp"],
+ "event_json": self._serialize_event_json(event_record["event_json"]),
+ },
+ )
+ conn.commit()
+
+ async def append_event(self, event_record: EventRecord) -> None:
+ """Append an event to a session."""
+ await async_(self._append_event)(event_record)
+
ORACLE_DUPLICATE_KEY_ERROR: Final = 1
@@ -2030,7 +1893,7 @@ async def _rows_to_records(self, rows: "list[Any]") -> "list[MemoryRecord]":
return records
-class OracleSyncADKMemoryStore(BaseSyncADKMemoryStore["OracleSyncConfig"]):
+class OracleSyncADKMemoryStore(BaseAsyncADKMemoryStore["OracleSyncConfig"]):
"""Oracle ADK memory store using sync oracledb driver."""
__slots__ = ("_in_memory", "_json_storage_type", "_oracle_version_info")
@@ -2081,7 +1944,7 @@ def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None":
return _extract_json_value(data)
- def _get_create_memory_table_sql(self) -> str:
+ async def _get_create_memory_table_sql(self) -> str:
storage_type = self._detect_json_storage_type()
return self._get_create_memory_table_sql_for_type(storage_type)
@@ -2196,12 +2059,16 @@ def _get_drop_memory_table_sql(self) -> "list[str]":
""",
]
- def create_tables(self) -> None:
+ def _create_tables(self) -> None:
if not self._enabled:
return
with self._config.provide_session() as driver:
- driver.execute_script(self._get_create_memory_table_sql())
+ driver.execute_script(run_(self._get_create_memory_table_sql)())
+
+ async def create_tables(self) -> None:
+ """Create tables if they don't exist."""
+ await async_(self._create_tables)()
def _execute_insert_entry(self, cursor: Any, sql: str, params: "dict[str, Any]") -> bool:
"""Execute an insert and skip duplicate key errors."""
@@ -2214,7 +2081,7 @@ def _execute_insert_entry(self, cursor: Any, sql: str, params: "dict[str, Any]")
raise
return True
- def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
+ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
if not self._enabled:
msg = "Memory store is disabled"
raise RuntimeError(msg)
@@ -2261,7 +2128,11 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object
return inserted_count
- def search_entries(
+ async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
+ """Bulk insert memory entries with deduplication."""
+ return await async_(self._insert_memory_entries)(entries, owner_id)
+
+ def _search_entries(
self, query: str, app_name: str, user_id: str, limit: "int | None" = None
) -> "list[MemoryRecord]":
if not self._enabled:
@@ -2280,6 +2151,12 @@ def search_entries(
return []
raise
+ async def search_entries(
+ self, query: str, app_name: str, user_id: str, limit: "int | None" = None
+ ) -> "list[MemoryRecord]":
+ """Search memory entries by text query."""
+ return await async_(self._search_entries)(query, app_name, user_id, limit)
+
def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]":
sql = f"""
SELECT id, session_id, app_name, user_id, event_id, author,
@@ -2326,7 +2203,7 @@ def _search_entries_simple(self, query: str, app_name: str, user_id: str, limit:
rows = cursor.fetchall()
return self._rows_to_records(rows)
- def delete_entries_by_session(self, session_id: str) -> int:
+ def _delete_entries_by_session(self, session_id: str) -> int:
sql = f"DELETE FROM {self._memory_table} WHERE session_id = :session_id"
with self._config.provide_connection() as conn:
cursor = conn.cursor()
@@ -2334,7 +2211,11 @@ def delete_entries_by_session(self, session_id: str) -> int:
conn.commit()
return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0
- def delete_entries_older_than(self, days: int) -> int:
+ async def delete_entries_by_session(self, session_id: str) -> int:
+ """Delete all memory entries for a specific session."""
+ return await async_(self._delete_entries_by_session)(session_id)
+
+ def _delete_entries_older_than(self, days: int) -> int:
sql = f"""
DELETE FROM {self._memory_table}
WHERE inserted_at < SYSTIMESTAMP - NUMTODSINTERVAL(:days, 'DAY')
@@ -2345,6 +2226,10 @@ def delete_entries_older_than(self, days: int) -> int:
conn.commit()
return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0
+ async def delete_entries_older_than(self, days: int) -> int:
+ """Delete memory entries older than specified days."""
+ return await async_(self._delete_entries_older_than)(days)
+
def _rows_to_records(self, rows: "list[Any]") -> "list[MemoryRecord]":
records: list[MemoryRecord] = []
for row in rows:
diff --git a/sqlspec/adapters/psqlpy/adk/store.py b/sqlspec/adapters/psqlpy/adk/store.py
index c1629e05d..9cfdcea27 100644
--- a/sqlspec/adapters/psqlpy/adk/store.py
+++ b/sqlspec/adapters/psqlpy/adk/store.py
@@ -29,79 +29,28 @@ class PsqlpyADKStore(BaseAsyncADKStore["PsqlpyConfig"]):
Implements session and event storage for Google Agent Development Kit
using PostgreSQL via the high-performance Rust-based psqlpy driver.
+ Events are stored as a single JSONB blob (``event_json``) alongside
+ indexed scalar columns for efficient querying.
Provides:
- Session state management with JSONB storage
- - Event history tracking with BYTEA-serialized actions
+ - Full-fidelity event storage via ``event_json`` JSONB column
+ - Atomic ``append_event_and_update_state`` for durable session mutations
- Microsecond-precision timestamps with TIMESTAMPTZ
- Foreign key constraints with cascade delete
- - Efficient upserts using ON CONFLICT
- GIN indexes for JSONB queries
- HOT updates with FILLFACTOR 80
Args:
config: PsqlpyConfig with extension_config["adk"] settings.
-
- Example:
- from sqlspec.adapters.psqlpy import PsqlpyConfig
- from sqlspec.adapters.psqlpy.adk import PsqlpyADKStore
-
- config = PsqlpyConfig(
- connection_config={"dsn": "postgresql://..."},
- extension_config={
- "adk": {
- "session_table": "my_sessions",
- "events_table": "my_events",
- "owner_id_column": "tenant_id INTEGER NOT NULL REFERENCES tenants(id) ON DELETE CASCADE"
- }
- }
- )
- store = PsqlpyADKStore(config)
- await store.ensure_tables()
-
- Notes:
- - PostgreSQL JSONB type used for state (more efficient than JSON)
- - Psqlpy automatically converts Python dicts to/from JSONB
- - TIMESTAMPTZ provides timezone-aware microsecond precision
- - BYTEA for pre-serialized actions from Google ADK
- - GIN index on state for JSONB queries (partial index)
- - FILLFACTOR 80 leaves space for HOT updates
- - Uses PostgreSQL numeric parameter style ($1, $2, $3)
- - Configuration is read from config.extension_config["adk"]
"""
__slots__ = ()
def __init__(self, config: "PsqlpyConfig") -> None:
- """Initialize Psqlpy ADK store.
-
- Args:
- config: PsqlpyConfig instance.
-
- Notes:
- Configuration is read from config.extension_config["adk"]:
- - session_table: Sessions table name (default: "adk_sessions")
- - events_table: Events table name (default: "adk_events")
- - owner_id_column: Optional owner FK column DDL (default: None)
- """
super().__init__(config)
async def _get_create_sessions_table_sql(self) -> str:
- """Get PostgreSQL CREATE TABLE SQL for sessions.
-
- Returns:
- SQL statement to create adk_sessions table with indexes.
-
- Notes:
- - VARCHAR(128) for IDs and names (sufficient for UUIDs and app names)
- - JSONB type for state storage with default empty object
- - TIMESTAMPTZ with microsecond precision
- - FILLFACTOR 80 for HOT updates (reduces table bloat)
- - Composite index on (app_name, user_id) for listing
- - Index on update_time DESC for recent session queries
- - Partial GIN index on state for JSONB queries (only non-empty)
- - Optional owner ID column for multi-tenancy or user references
- """
owner_id_line = ""
if self._owner_id_column_ddl:
owner_id_line = f",\n {self._owner_id_column_ddl}"
@@ -128,66 +77,24 @@ async def _get_create_sessions_table_sql(self) -> str:
"""
async def _get_create_events_table_sql(self) -> str:
- """Get PostgreSQL CREATE TABLE SQL for events.
-
- Returns:
- SQL statement to create adk_events table with indexes.
-
- Notes:
- - VARCHAR sizes: id(128), session_id(128), invocation_id(256), author(256),
- branch(256), error_code(256), error_message(1024)
- - BYTEA for pre-serialized actions (no size limit)
- - JSONB for content, grounding_metadata, custom_metadata, long_running_tool_ids_json
- - BOOLEAN for partial, turn_complete, interrupted
- - Foreign key to sessions with CASCADE delete
- - Index on (session_id, timestamp ASC) for ordered event retrieval
- """
return f"""
CREATE TABLE IF NOT EXISTS {self._events_table} (
- id VARCHAR(128) PRIMARY KEY,
session_id VARCHAR(128) NOT NULL,
- app_name VARCHAR(128) NOT NULL,
- user_id VARCHAR(128) NOT NULL,
- invocation_id VARCHAR(256),
- author VARCHAR(256),
- actions BYTEA,
- long_running_tool_ids_json JSONB,
- branch VARCHAR(256),
+ invocation_id VARCHAR(256) NOT NULL,
+ author VARCHAR(256) NOT NULL,
timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
- content JSONB,
- grounding_metadata JSONB,
- custom_metadata JSONB,
- partial BOOLEAN,
- turn_complete BOOLEAN,
- interrupted BOOLEAN,
- error_code VARCHAR(256),
- error_message VARCHAR(1024),
+ event_json JSONB NOT NULL,
FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE
- );
+ ) WITH (fillfactor = 80);
CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session
ON {self._events_table}(session_id, timestamp ASC);
"""
def _get_drop_tables_sql(self) -> "list[str]":
- """Get PostgreSQL DROP TABLE SQL statements.
-
- Returns:
- List of SQL statements to drop tables and indexes.
-
- Notes:
- Order matters: drop events table (child) before sessions (parent).
- PostgreSQL automatically drops indexes when dropping tables.
- """
return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"]
async def create_tables(self) -> None:
- """Create both sessions and events tables if they don't exist.
-
- Notes:
- Uses driver.execute_script() which handles multiple statements.
- Creates sessions table first, then events table (FK dependency).
- """
async with self._config.provide_session() as driver:
await driver.execute_script(await self._get_create_sessions_table_sql())
await driver.execute_script(await self._get_create_events_table_sql())
@@ -195,23 +102,6 @@ async def create_tables(self) -> None:
async def create_session(
self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
) -> SessionRecord:
- """Create a new session.
-
- Args:
- session_id: Unique session identifier.
- app_name: Application name.
- user_id: User identifier.
- state: Initial session state.
- owner_id: Optional owner ID value for owner_id_column (if configured).
-
- Returns:
- Created session record.
-
- Notes:
- Uses CURRENT_TIMESTAMP for create_time and update_time.
- State is passed as dict and psqlpy converts to JSONB automatically.
- If owner_id_column is configured, owner_id value must be provided.
- """
async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue]
if self._owner_id_column_name:
sql = f"""
@@ -230,19 +120,6 @@ async def create_session(
return await self.get_session(session_id) # type: ignore[return-value]
async def get_session(self, session_id: str) -> "SessionRecord | None":
- """Get session by ID.
-
- Args:
- session_id: Session identifier.
-
- Returns:
- Session record or None if not found.
-
- Notes:
- PostgreSQL returns datetime objects for TIMESTAMPTZ columns.
- JSONB is automatically parsed by psqlpy to Python dicts.
- Returns None if table doesn't exist (catches database errors).
- """
sql = f"""
SELECT id, app_name, user_id, state, create_time, update_time
FROM {self._session_table}
@@ -273,17 +150,6 @@ async def get_session(self, session_id: str) -> "SessionRecord | None":
raise
async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
- """Update session state.
-
- Args:
- session_id: Session identifier.
- state: New state dictionary (replaces existing state).
-
- Notes:
- This replaces the entire state dictionary.
- Uses CURRENT_TIMESTAMP for update_time.
- Psqlpy automatically converts dict to JSONB.
- """
sql = f"""
UPDATE {self._session_table}
SET state = $1, update_time = CURRENT_TIMESTAMP
@@ -294,33 +160,12 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") -
await conn.execute(sql, [state, session_id])
async def delete_session(self, session_id: str) -> None:
- """Delete session and all associated events (cascade).
-
- Args:
- session_id: Session identifier.
-
- Notes:
- Foreign key constraint ensures events are cascade-deleted.
- """
sql = f"DELETE FROM {self._session_table} WHERE id = $1"
async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue]
await conn.execute(sql, [session_id])
async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
- """List sessions for an app, optionally filtered by user.
-
- Args:
- app_name: Application name.
- user_id: User identifier. If None, lists all sessions for the app.
-
- Returns:
- List of session records ordered by update_time DESC.
-
- Notes:
- Uses composite index on (app_name, user_id) when user_id is provided.
- Returns empty list if table doesn't exist.
- """
if user_id is None:
sql = f"""
SELECT id, app_name, user_id, state, create_time, update_time
@@ -361,74 +206,54 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis
raise
async def append_event(self, event_record: EventRecord) -> None:
- """Append an event to a session.
-
- Args:
- event_record: Event record to store.
-
- Notes:
- Uses CURRENT_TIMESTAMP for timestamp if not provided.
- JSONB fields are passed as dicts and psqlpy converts automatically.
- BYTEA actions field stores pre-serialized data from Google ADK.
- """
- content_json = event_record.get("content")
- grounding_metadata_json = event_record.get("grounding_metadata")
- custom_metadata_json = event_record.get("custom_metadata")
-
sql = f"""
INSERT INTO {self._events_table} (
- id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
- ) VALUES (
- $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18
- )
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES ($1, $2, $3, $4, $5)
"""
async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue]
await conn.execute(
sql,
[
- event_record["id"],
event_record["session_id"],
- event_record["app_name"],
- event_record["user_id"],
- event_record.get("invocation_id"),
- event_record.get("author"),
- event_record.get("actions"),
- event_record.get("long_running_tool_ids_json"),
- event_record.get("branch"),
+ event_record["invocation_id"],
+ event_record["author"],
+ event_record["timestamp"],
+ event_record["event_json"],
+ ],
+ )
+
+ async def append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ insert_sql = f"""
+ INSERT INTO {self._events_table} (
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES ($1, $2, $3, $4, $5)
+ """
+ update_sql = f"""
+ UPDATE {self._session_table}
+ SET state = $1, update_time = CURRENT_TIMESTAMP
+ WHERE id = $2
+ """
+
+ async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue]
+ await conn.execute(
+ insert_sql,
+ [
+ event_record["session_id"],
+ event_record["invocation_id"],
+ event_record["author"],
event_record["timestamp"],
- content_json,
- grounding_metadata_json,
- custom_metadata_json,
- event_record.get("partial"),
- event_record.get("turn_complete"),
- event_record.get("interrupted"),
- event_record.get("error_code"),
- event_record.get("error_message"),
+ event_record["event_json"],
],
)
+ await conn.execute(update_sql, [state, session_id])
async def get_events(
self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
) -> "list[EventRecord]":
- """Get events for a session.
-
- Args:
- session_id: Session identifier.
- after_timestamp: Only return events after this time.
- limit: Maximum number of events to return.
-
- Returns:
- List of event records ordered by timestamp ASC.
-
- Notes:
- Uses index on (session_id, timestamp ASC).
- Parses JSONB fields and converts BYTEA actions to bytes.
- Returns empty list if table doesn't exist.
- """
where_clauses = ["session_id = $1"]
params: list[Any] = [session_id]
@@ -442,10 +267,7 @@ async def get_events(
params.append(limit)
sql = f"""
- SELECT id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
+ SELECT session_id, invocation_id, author, timestamp, event_json
FROM {self._events_table}
WHERE {where_clause}
ORDER BY timestamp ASC{limit_clause}
@@ -458,24 +280,11 @@ async def get_events(
return [
EventRecord(
- id=row["id"],
session_id=row["session_id"],
- app_name=row["app_name"],
- user_id=row["user_id"],
invocation_id=row["invocation_id"],
author=row["author"],
- actions=bytes(row["actions"]) if row["actions"] else b"",
- long_running_tool_ids_json=row["long_running_tool_ids_json"],
- branch=row["branch"],
timestamp=row["timestamp"],
- content=row["content"],
- grounding_metadata=row["grounding_metadata"],
- custom_metadata=row["custom_metadata"],
- partial=row["partial"],
- turn_complete=row["turn_complete"],
- interrupted=row["interrupted"],
- error_code=row["error_code"],
- error_message=row["error_message"],
+ event_json=row["event_json"],
)
for row in rows
]
diff --git a/sqlspec/adapters/psycopg/adk/store.py b/sqlspec/adapters/psycopg/adk/store.py
index 6011e6dc2..00d68aade 100644
--- a/sqlspec/adapters/psycopg/adk/store.py
+++ b/sqlspec/adapters/psycopg/adk/store.py
@@ -6,9 +6,10 @@
from psycopg import sql as pg_sql
from psycopg.types.json import Jsonb
-from sqlspec.extensions.adk import BaseAsyncADKStore, BaseSyncADKStore, EventRecord, SessionRecord
-from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore
+from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord
+from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore
from sqlspec.utils.logging import get_logger
+from sqlspec.utils.sync_tools import async_, run_
if TYPE_CHECKING:
from datetime import datetime
@@ -56,84 +57,32 @@ def _build_insert_params_with_owner(entry: "MemoryRecord", owner_id: "object | N
class PsycopgAsyncADKStore(BaseAsyncADKStore["PsycopgAsyncConfig"]):
- """PostgreSQL ADK store using Psycopg3 driver.
+ """PostgreSQL ADK store using Psycopg3 async driver.
Implements session and event storage for Google Agent Development Kit
using PostgreSQL via psycopg3 with native async/await support.
+ Events are stored as a single JSONB blob (``event_json``) alongside
+ indexed scalar columns for efficient querying.
Provides:
- - Session state management with JSONB storage and merge operations
- - Event history tracking with BYTEA-serialized actions
+ - Session state management with JSONB storage
+ - Full-fidelity event storage via ``event_json`` JSONB column
+ - Atomic ``append_event_and_update_state`` for durable session mutations
- Microsecond-precision timestamps with TIMESTAMPTZ
- Foreign key constraints with cascade delete
- - Efficient upserts using ON CONFLICT
- GIN indexes for JSONB queries
- HOT updates with FILLFACTOR 80
Args:
config: PsycopgAsyncConfig with extension_config["adk"] settings.
-
- Example:
- from sqlspec.adapters.psycopg import PsycopgAsyncConfig
- from sqlspec.adapters.psycopg.adk import PsycopgAsyncADKStore
-
- config = PsycopgAsyncConfig(
- connection_config={"conninfo": "postgresql://..."},
- extension_config={
- "adk": {
- "session_table": "my_sessions",
- "events_table": "my_events",
- "owner_id_column": "tenant_id INTEGER NOT NULL REFERENCES tenants(id) ON DELETE CASCADE"
- }
- }
- )
- store = PsycopgAsyncADKStore(config)
- await store.ensure_tables()
-
- Notes:
- - PostgreSQL JSONB type used for state (more efficient than JSON)
- - Psycopg requires wrapping dicts with Jsonb() for type safety
- - TIMESTAMPTZ provides timezone-aware microsecond precision
- - State merging uses `state || $1::jsonb` operator for efficiency
- - BYTEA for pre-serialized actions from Google ADK
- - GIN index on state for JSONB queries (partial index)
- - FILLFACTOR 80 leaves space for HOT updates
- - Parameter style: $1, $2, $3 (PostgreSQL numeric placeholders)
- - Configuration is read from config.extension_config["adk"]
"""
__slots__ = ()
def __init__(self, config: "PsycopgAsyncConfig") -> None:
- """Initialize Psycopg ADK store.
-
- Args:
- config: PsycopgAsyncConfig instance.
-
- Notes:
- Configuration is read from config.extension_config["adk"]:
- - session_table: Sessions table name (default: "adk_sessions")
- - events_table: Events table name (default: "adk_events")
- - owner_id_column: Optional owner FK column DDL (default: None)
- """
super().__init__(config)
async def _get_create_sessions_table_sql(self) -> str:
- """Get PostgreSQL CREATE TABLE SQL for sessions.
-
- Returns:
- SQL statement to create adk_sessions table with indexes.
-
- Notes:
- - VARCHAR(128) for IDs and names (sufficient for UUIDs and app names)
- - JSONB type for state storage with default empty object
- - TIMESTAMPTZ with microsecond precision
- - FILLFACTOR 80 for HOT updates (reduces table bloat)
- - Composite index on (app_name, user_id) for listing
- - Index on update_time DESC for recent session queries
- - Partial GIN index on state for JSONB queries (only non-empty)
- - Optional owner ID column for multi-tenancy or user references
- """
owner_id_line = ""
if self._owner_id_column_ddl:
owner_id_line = f",\n {self._owner_id_column_ddl}"
@@ -160,61 +109,24 @@ async def _get_create_sessions_table_sql(self) -> str:
"""
async def _get_create_events_table_sql(self) -> str:
- """Get PostgreSQL CREATE TABLE SQL for events.
-
- Returns:
- SQL statement to create adk_events table with indexes.
-
- Notes:
- - VARCHAR sizes: id(128), session_id(128), invocation_id(256), author(256),
- branch(256), error_code(256), error_message(1024)
- - BYTEA for pickled actions (no size limit)
- - JSONB for content, grounding_metadata, custom_metadata, long_running_tool_ids_json
- - BOOLEAN for partial, turn_complete, interrupted
- - Foreign key to sessions with CASCADE delete
- - Index on (session_id, timestamp ASC) for ordered event retrieval
- """
return f"""
CREATE TABLE IF NOT EXISTS {self._events_table} (
- id VARCHAR(128) PRIMARY KEY,
session_id VARCHAR(128) NOT NULL,
- app_name VARCHAR(128) NOT NULL,
- user_id VARCHAR(128) NOT NULL,
- invocation_id VARCHAR(256),
- author VARCHAR(256),
- actions BYTEA,
- long_running_tool_ids_json JSONB,
- branch VARCHAR(256),
+ invocation_id VARCHAR(256) NOT NULL,
+ author VARCHAR(256) NOT NULL,
timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
- content JSONB,
- grounding_metadata JSONB,
- custom_metadata JSONB,
- partial BOOLEAN,
- turn_complete BOOLEAN,
- interrupted BOOLEAN,
- error_code VARCHAR(256),
- error_message VARCHAR(1024),
+ event_json JSONB NOT NULL,
FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE
- );
+ ) WITH (fillfactor = 80);
CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session
ON {self._events_table}(session_id, timestamp ASC);
"""
def _get_drop_tables_sql(self) -> "list[str]":
- """Get PostgreSQL DROP TABLE SQL statements.
-
- Returns:
- List of SQL statements to drop tables and indexes.
-
- Notes:
- Order matters: drop events table (child) before sessions (parent).
- PostgreSQL automatically drops indexes when dropping tables.
- """
return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"]
async def create_tables(self) -> None:
- """Create both sessions and events tables if they don't exist."""
async with self._config.provide_session() as driver:
await driver.execute_script(await self._get_create_sessions_table_sql())
await driver.execute_script(await self._get_create_events_table_sql())
@@ -222,23 +134,6 @@ async def create_tables(self) -> None:
async def create_session(
self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
) -> SessionRecord:
- """Create a new session.
-
- Args:
- session_id: Unique session identifier.
- app_name: Application name.
- user_id: User identifier.
- state: Initial session state.
- owner_id: Optional owner ID value for owner_id_column (if configured).
-
- Returns:
- Created session record.
-
- Notes:
- Uses CURRENT_TIMESTAMP for create_time and update_time.
- State is wrapped with Jsonb() for PostgreSQL type safety.
- If owner_id_column is configured, owner_id value must be provided.
- """
params: tuple[Any, ...]
if self._owner_id_column_name:
query = pg_sql.SQL("""
@@ -261,18 +156,6 @@ async def create_session(
return await self.get_session(session_id) # type: ignore[return-value]
async def get_session(self, session_id: str) -> "SessionRecord | None":
- """Get session by ID.
-
- Args:
- session_id: Session identifier.
-
- Returns:
- Session record or None if not found.
-
- Notes:
- PostgreSQL returns datetime objects for TIMESTAMPTZ columns.
- JSONB is automatically deserialized by psycopg to Python dict.
- """
query = pg_sql.SQL("""
SELECT id, app_name, user_id, state, create_time, update_time
FROM {table}
@@ -299,17 +182,6 @@ async def get_session(self, session_id: str) -> "SessionRecord | None":
return None
async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
- """Update session state.
-
- Args:
- session_id: Session identifier.
- state: New state dictionary (replaces existing state).
-
- Notes:
- This replaces the entire state dictionary.
- Uses CURRENT_TIMESTAMP for update_time.
- State is wrapped with Jsonb() for PostgreSQL type safety.
- """
query = pg_sql.SQL("""
UPDATE {table}
SET state = %s, update_time = CURRENT_TIMESTAMP
@@ -320,32 +192,12 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") -
await cur.execute(query, (Jsonb(state), session_id))
async def delete_session(self, session_id: str) -> None:
- """Delete session and all associated events (cascade).
-
- Args:
- session_id: Session identifier.
-
- Notes:
- Foreign key constraint ensures events are cascade-deleted.
- """
query = pg_sql.SQL("DELETE FROM {table} WHERE id = %s").format(table=pg_sql.Identifier(self._session_table))
async with self._config.provide_connection() as conn, conn.cursor() as cur:
await cur.execute(query, (session_id,))
async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
- """List sessions for an app, optionally filtered by user.
-
- Args:
- app_name: Application name.
- user_id: User identifier. If None, lists all sessions for the app.
-
- Returns:
- List of session records ordered by update_time DESC.
-
- Notes:
- Uses composite index on (app_name, user_id) when user_id is provided.
- """
if user_id is None:
query = pg_sql.SQL("""
SELECT id, app_name, user_id, state, create_time, update_time
@@ -383,73 +235,62 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis
return []
async def append_event(self, event_record: EventRecord) -> None:
- """Append an event to a session.
+ query = pg_sql.SQL("""
+ INSERT INTO {table} (
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES (%s, %s, %s, %s, %s)
+ """).format(table=pg_sql.Identifier(self._events_table))
- Args:
- event_record: Event record to store.
+ event_json_value = event_record["event_json"]
+ jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value
- Notes:
- Uses CURRENT_TIMESTAMP for timestamp if not provided.
- JSONB fields are wrapped with Jsonb() for PostgreSQL type safety.
- """
- content_json = event_record.get("content")
- grounding_metadata_json = event_record.get("grounding_metadata")
- custom_metadata_json = event_record.get("custom_metadata")
+ async with self._config.provide_connection() as conn, conn.cursor() as cur:
+ await cur.execute(
+ query,
+ (
+ event_record["session_id"],
+ event_record["invocation_id"],
+ event_record["author"],
+ event_record["timestamp"],
+ jsonb_value,
+ ),
+ )
- query = pg_sql.SQL("""
+ async def append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ insert_query = pg_sql.SQL("""
INSERT INTO {table} (
- id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
- ) VALUES (
- %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
- )
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES (%s, %s, %s, %s, %s)
""").format(table=pg_sql.Identifier(self._events_table))
+ update_query = pg_sql.SQL("""
+ UPDATE {table}
+ SET state = %s, update_time = CURRENT_TIMESTAMP
+ WHERE id = %s
+ """).format(table=pg_sql.Identifier(self._session_table))
+
+ event_json_value = event_record["event_json"]
+ jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value
+
async with self._config.provide_connection() as conn, conn.cursor() as cur:
await cur.execute(
- query,
+ insert_query,
(
- event_record["id"],
event_record["session_id"],
- event_record["app_name"],
- event_record["user_id"],
- event_record.get("invocation_id"),
- event_record.get("author"),
- event_record.get("actions"),
- event_record.get("long_running_tool_ids_json"),
- event_record.get("branch"),
+ event_record["invocation_id"],
+ event_record["author"],
event_record["timestamp"],
- Jsonb(content_json) if content_json is not None else None,
- Jsonb(grounding_metadata_json) if grounding_metadata_json is not None else None,
- Jsonb(custom_metadata_json) if custom_metadata_json is not None else None,
- event_record.get("partial"),
- event_record.get("turn_complete"),
- event_record.get("interrupted"),
- event_record.get("error_code"),
- event_record.get("error_message"),
+ jsonb_value,
),
)
+ await cur.execute(update_query, (Jsonb(state), session_id))
+ await conn.commit()
async def get_events(
self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
) -> "list[EventRecord]":
- """Get events for a session.
-
- Args:
- session_id: Session identifier.
- after_timestamp: Only return events after this time.
- limit: Maximum number of events to return.
-
- Returns:
- List of event records ordered by timestamp ASC.
-
- Notes:
- Uses index on (session_id, timestamp ASC).
- JSONB fields are automatically deserialized by psycopg.
- BYTEA actions are converted to bytes.
- """
where_clauses = ["session_id = %s"]
params: list[Any] = [session_id]
@@ -463,10 +304,7 @@ async def get_events(
query = pg_sql.SQL(
"""
- SELECT id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
+ SELECT session_id, invocation_id, author, timestamp, event_json
FROM {table}
WHERE {where_clause}
ORDER BY timestamp ASC{limit_clause}
@@ -484,24 +322,11 @@ async def get_events(
return [
EventRecord(
- id=row["id"],
session_id=row["session_id"],
- app_name=row["app_name"],
- user_id=row["user_id"],
invocation_id=row["invocation_id"],
author=row["author"],
- actions=bytes(row["actions"]) if row["actions"] else b"",
- long_running_tool_ids_json=row["long_running_tool_ids_json"],
- branch=row["branch"],
timestamp=row["timestamp"],
- content=row["content"],
- grounding_metadata=row["grounding_metadata"],
- custom_metadata=row["custom_metadata"],
- partial=row["partial"],
- turn_complete=row["turn_complete"],
- interrupted=row["interrupted"],
- error_code=row["error_code"],
- error_message=row["error_message"],
+ event_json=row["event_json"],
)
for row in rows
]
@@ -509,85 +334,33 @@ async def get_events(
return []
-class PsycopgSyncADKStore(BaseSyncADKStore["PsycopgSyncConfig"]):
+class PsycopgSyncADKStore(BaseAsyncADKStore["PsycopgSyncConfig"]):
"""PostgreSQL synchronous ADK store using Psycopg3 driver.
Implements session and event storage for Google Agent Development Kit
using PostgreSQL via psycopg3 with synchronous execution.
+ Events are stored as a single JSONB blob (``event_json``) alongside
+ indexed scalar columns for efficient querying.
Provides:
- - Session state management with JSONB storage and merge operations
- - Event history tracking with BYTEA-serialized actions
+ - Session state management with JSONB storage
+ - Full-fidelity event storage via ``event_json`` JSONB column
+ - Atomic ``create_event_and_update_state`` for durable session mutations
- Microsecond-precision timestamps with TIMESTAMPTZ
- Foreign key constraints with cascade delete
- - Efficient upserts using ON CONFLICT
- GIN indexes for JSONB queries
- HOT updates with FILLFACTOR 80
Args:
config: PsycopgSyncConfig with extension_config["adk"] settings.
-
- Example:
- from sqlspec.adapters.psycopg import PsycopgSyncConfig
- from sqlspec.adapters.psycopg.adk import PsycopgSyncADKStore
-
- config = PsycopgSyncConfig(
- connection_config={"conninfo": "postgresql://..."},
- extension_config={
- "adk": {
- "session_table": "my_sessions",
- "events_table": "my_events",
- "owner_id_column": "tenant_id INTEGER NOT NULL REFERENCES tenants(id) ON DELETE CASCADE"
- }
- }
- )
- store = PsycopgSyncADKStore(config)
- store.ensure_tables()
-
- Notes:
- - PostgreSQL JSONB type used for state (more efficient than JSON)
- - Psycopg requires wrapping dicts with Jsonb() for type safety
- - TIMESTAMPTZ provides timezone-aware microsecond precision
- - State merging uses `state || $1::jsonb` operator for efficiency
- - BYTEA for pre-serialized actions from Google ADK
- - GIN index on state for JSONB queries (partial index)
- - FILLFACTOR 80 leaves space for HOT updates
- - Parameter style: $1, $2, $3 (PostgreSQL numeric placeholders)
- - Configuration is read from config.extension_config["adk"]
"""
__slots__ = ()
def __init__(self, config: "PsycopgSyncConfig") -> None:
- """Initialize Psycopg synchronous ADK store.
-
- Args:
- config: PsycopgSyncConfig instance.
-
- Notes:
- Configuration is read from config.extension_config["adk"]:
- - session_table: Sessions table name (default: "adk_sessions")
- - events_table: Events table name (default: "adk_events")
- - owner_id_column: Optional owner FK column DDL (default: None)
- """
super().__init__(config)
- def _get_create_sessions_table_sql(self) -> str:
- """Get PostgreSQL CREATE TABLE SQL for sessions.
-
- Returns:
- SQL statement to create adk_sessions table with indexes.
-
- Notes:
- - VARCHAR(128) for IDs and names (sufficient for UUIDs and app names)
- - JSONB type for state storage with default empty object
- - TIMESTAMPTZ with microsecond precision
- - FILLFACTOR 80 for HOT updates (reduces table bloat)
- - Composite index on (app_name, user_id) for listing
- - Index on update_time DESC for recent session queries
- - Partial GIN index on state for JSONB queries (only non-empty)
- - Optional owner ID column for multi-tenancy or user references
- """
+ async def _get_create_sessions_table_sql(self) -> str:
owner_id_line = ""
if self._owner_id_column_ddl:
owner_id_line = f",\n {self._owner_id_column_ddl}"
@@ -613,86 +386,36 @@ def _get_create_sessions_table_sql(self) -> str:
WHERE state != '{{}}'::jsonb;
"""
- def _get_create_events_table_sql(self) -> str:
- """Get PostgreSQL CREATE TABLE SQL for events.
-
- Returns:
- SQL statement to create adk_events table with indexes.
-
- Notes:
- - VARCHAR sizes: id(128), session_id(128), invocation_id(256), author(256),
- branch(256), error_code(256), error_message(1024)
- - BYTEA for pickled actions (no size limit)
- - JSONB for content, grounding_metadata, custom_metadata, long_running_tool_ids_json
- - BOOLEAN for partial, turn_complete, interrupted
- - Foreign key to sessions with CASCADE delete
- - Index on (session_id, timestamp ASC) for ordered event retrieval
- """
+ async def _get_create_events_table_sql(self) -> str:
return f"""
CREATE TABLE IF NOT EXISTS {self._events_table} (
- id VARCHAR(128) PRIMARY KEY,
session_id VARCHAR(128) NOT NULL,
- app_name VARCHAR(128) NOT NULL,
- user_id VARCHAR(128) NOT NULL,
- invocation_id VARCHAR(256),
- author VARCHAR(256),
- actions BYTEA,
- long_running_tool_ids_json JSONB,
- branch VARCHAR(256),
+ invocation_id VARCHAR(256) NOT NULL,
+ author VARCHAR(256) NOT NULL,
timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
- content JSONB,
- grounding_metadata JSONB,
- custom_metadata JSONB,
- partial BOOLEAN,
- turn_complete BOOLEAN,
- interrupted BOOLEAN,
- error_code VARCHAR(256),
- error_message VARCHAR(1024),
+ event_json JSONB NOT NULL,
FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE
- );
+ ) WITH (fillfactor = 80);
CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session
ON {self._events_table}(session_id, timestamp ASC);
"""
def _get_drop_tables_sql(self) -> "list[str]":
- """Get PostgreSQL DROP TABLE SQL statements.
-
- Returns:
- List of SQL statements to drop tables and indexes.
-
- Notes:
- Order matters: drop events table (child) before sessions (parent).
- PostgreSQL automatically drops indexes when dropping tables.
- """
return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"]
- def create_tables(self) -> None:
- """Create both sessions and events tables if they don't exist."""
+ def _create_tables(self) -> None:
with self._config.provide_session() as driver:
- driver.execute_script(self._get_create_sessions_table_sql())
- driver.execute_script(self._get_create_events_table_sql())
+ driver.execute_script(run_(self._get_create_sessions_table_sql)())
+ driver.execute_script(run_(self._get_create_events_table_sql)())
- def create_session(
+ async def create_tables(self) -> None:
+ """Create tables if they don't exist."""
+ await async_(self._create_tables)()
+
+ def _create_session(
self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
) -> SessionRecord:
- """Create a new session.
-
- Args:
- session_id: Unique session identifier.
- app_name: Application name.
- user_id: User identifier.
- state: Initial session state.
- owner_id: Optional owner ID value for owner_id_column (if configured).
-
- Returns:
- Created session record.
-
- Notes:
- Uses CURRENT_TIMESTAMP for create_time and update_time.
- State is wrapped with Jsonb() for PostgreSQL type safety.
- If owner_id_column is configured, owner_id value must be provided.
- """
params: tuple[Any, ...]
if self._owner_id_column_name:
query = pg_sql.SQL("""
@@ -712,21 +435,19 @@ def create_session(
with self._config.provide_connection() as conn, conn.cursor() as cur:
cur.execute(query, params)
- return self.get_session(session_id) # type: ignore[return-value]
-
- def get_session(self, session_id: str) -> "SessionRecord | None":
- """Get session by ID.
-
- Args:
- session_id: Session identifier.
+ result = self._get_session(session_id)
+ if result is None:
+ msg = "Failed to fetch created session"
+ raise RuntimeError(msg)
+ return result
- Returns:
- Session record or None if not found.
+ async def create_session(
+ self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
+ ) -> SessionRecord:
+ """Create a new session."""
+ return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id)
- Notes:
- PostgreSQL returns datetime objects for TIMESTAMPTZ columns.
- JSONB is automatically deserialized by psycopg to Python dict.
- """
+ def _get_session(self, session_id: str) -> "SessionRecord | None":
query = pg_sql.SQL("""
SELECT id, app_name, user_id, state, create_time, update_time
FROM {table}
@@ -752,18 +473,11 @@ def get_session(self, session_id: str) -> "SessionRecord | None":
except errors.UndefinedTable:
return None
- def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
- """Update session state.
-
- Args:
- session_id: Session identifier.
- state: New state dictionary (replaces existing state).
+ async def get_session(self, session_id: str) -> "SessionRecord | None":
+ """Get session by ID."""
+ return await async_(self._get_session)(session_id)
- Notes:
- This replaces the entire state dictionary.
- Uses CURRENT_TIMESTAMP for update_time.
- State is wrapped with Jsonb() for PostgreSQL type safety.
- """
+ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
query = pg_sql.SQL("""
UPDATE {table}
SET state = %s, update_time = CURRENT_TIMESTAMP
@@ -773,33 +487,21 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None
with self._config.provide_connection() as conn, conn.cursor() as cur:
cur.execute(query, (Jsonb(state), session_id))
- def delete_session(self, session_id: str) -> None:
- """Delete session and all associated events (cascade).
-
- Args:
- session_id: Session identifier.
+ async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
+ """Update session state."""
+ await async_(self._update_session_state)(session_id, state)
- Notes:
- Foreign key constraint ensures events are cascade-deleted.
- """
+ def _delete_session(self, session_id: str) -> None:
query = pg_sql.SQL("DELETE FROM {table} WHERE id = %s").format(table=pg_sql.Identifier(self._session_table))
with self._config.provide_connection() as conn, conn.cursor() as cur:
cur.execute(query, (session_id,))
- def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
- """List sessions for an app, optionally filtered by user.
-
- Args:
- app_name: Application name.
- user_id: User identifier. If None, lists all sessions for the app.
-
- Returns:
- List of session records ordered by update_time DESC.
+ async def delete_session(self, session_id: str) -> None:
+ """Delete session and associated events."""
+ await async_(self._delete_session)(session_id)
- Notes:
- Uses composite index on (app_name, user_id) when user_id is provided.
- """
+ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
if user_id is None:
query = pg_sql.SQL("""
SELECT id, app_name, user_id, state, create_time, update_time
@@ -836,165 +538,130 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess
except errors.UndefinedTable:
return []
- def create_event(
- self,
- event_id: str,
- session_id: str,
- app_name: str,
- user_id: str,
- author: "str | None" = None,
- actions: "bytes | None" = None,
- content: "dict[str, Any] | None" = None,
- **kwargs: Any,
- ) -> EventRecord:
- """Create a new event.
-
- Args:
- event_id: Unique event identifier.
- session_id: Session identifier.
- app_name: Application name.
- user_id: User identifier.
- author: Event author (user/assistant/system).
- actions: Pickled actions object.
- content: Event content (JSONB).
- **kwargs: Additional optional fields (invocation_id, branch, timestamp,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message, long_running_tool_ids_json).
-
- Returns:
- Created event record.
-
- Notes:
- Uses CURRENT_TIMESTAMP for timestamp if not provided in kwargs.
- JSONB fields are wrapped with Jsonb() for PostgreSQL type safety.
- """
- content_json = Jsonb(content) if content is not None else None
- grounding_metadata = kwargs.get("grounding_metadata")
- grounding_metadata_json = Jsonb(grounding_metadata) if grounding_metadata is not None else None
- custom_metadata = kwargs.get("custom_metadata")
- custom_metadata_json = Jsonb(custom_metadata) if custom_metadata is not None else None
+ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
+ """List sessions for an app."""
+ return await async_(self._list_sessions)(app_name, user_id)
- query = pg_sql.SQL("""
+ def _insert_event(self, event_record: EventRecord) -> None:
+ insert_query = pg_sql.SQL("""
INSERT INTO {table} (
- id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
- ) VALUES (
- %s, %s, %s, %s, %s, %s, %s, %s, %s, COALESCE(%s, CURRENT_TIMESTAMP), %s, %s, %s, %s, %s, %s, %s, %s
- )
- RETURNING id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES (%s, %s, %s, %s, %s)
""").format(table=pg_sql.Identifier(self._events_table))
+ event_json_value = event_record["event_json"]
+ jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value
+
with self._config.provide_connection() as conn, conn.cursor() as cur:
cur.execute(
- query,
+ insert_query,
(
- event_id,
- session_id,
- app_name,
- user_id,
- kwargs.get("invocation_id"),
- author,
- actions,
- kwargs.get("long_running_tool_ids_json"),
- kwargs.get("branch"),
- kwargs.get("timestamp"),
- content_json,
- grounding_metadata_json,
- custom_metadata_json,
- kwargs.get("partial"),
- kwargs.get("turn_complete"),
- kwargs.get("interrupted"),
- kwargs.get("error_code"),
- kwargs.get("error_message"),
+ event_record["session_id"],
+ event_record["invocation_id"],
+ event_record["author"],
+ event_record["timestamp"],
+ jsonb_value,
),
)
- row = cur.fetchone()
-
- if row is None:
- msg = f"Failed to create event {event_id}"
- raise RuntimeError(msg)
-
- return EventRecord(
- id=row["id"],
- session_id=row["session_id"],
- app_name=row["app_name"],
- user_id=row["user_id"],
- invocation_id=row["invocation_id"],
- author=row["author"],
- actions=bytes(row["actions"]) if row["actions"] else b"",
- long_running_tool_ids_json=row["long_running_tool_ids_json"],
- branch=row["branch"],
- timestamp=row["timestamp"],
- content=row["content"],
- grounding_metadata=row["grounding_metadata"],
- custom_metadata=row["custom_metadata"],
- partial=row["partial"],
- turn_complete=row["turn_complete"],
- interrupted=row["interrupted"],
- error_code=row["error_code"],
- error_message=row["error_message"],
+ conn.commit()
+
+ def _append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ insert_query = pg_sql.SQL("""
+ INSERT INTO {table} (
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES (%s, %s, %s, %s, %s)
+ """).format(table=pg_sql.Identifier(self._events_table))
+
+ update_query = pg_sql.SQL("""
+ UPDATE {table}
+ SET state = %s, update_time = CURRENT_TIMESTAMP
+ WHERE id = %s
+ """).format(table=pg_sql.Identifier(self._session_table))
+
+ event_json_value = event_record["event_json"]
+ jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value
+
+ with self._config.provide_connection() as conn, conn.cursor() as cur:
+ cur.execute(
+ insert_query,
+ (
+ event_record["session_id"],
+ event_record["invocation_id"],
+ event_record["author"],
+ event_record["timestamp"],
+ jsonb_value,
+ ),
)
+ cur.execute(update_query, (Jsonb(state), session_id))
+ conn.commit()
- def list_events(self, session_id: str) -> "list[EventRecord]":
- """List events for a session ordered by timestamp.
+ async def append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ """Atomically append an event and update the session's durable state."""
+ await async_(self._append_event_and_update_state)(event_record, session_id, state)
- Args:
- session_id: Session identifier.
+ def _get_events(
+ self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
+ ) -> "list[EventRecord]":
+ where_clauses = ["session_id = %s"]
+ params: list[Any] = [session_id]
- Returns:
- List of event records ordered by timestamp ASC.
+ if after_timestamp is not None:
+ where_clauses.append("timestamp > %s")
+ params.append(after_timestamp)
- Notes:
- Uses index on (session_id, timestamp ASC).
- JSONB fields are automatically deserialized by psycopg.
- BYTEA actions are converted to bytes.
- """
- query = pg_sql.SQL("""
- SELECT id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
+ where_clause = " AND ".join(where_clauses)
+ if limit:
+ params.append(limit)
+
+ query = pg_sql.SQL(
+ """
+ SELECT session_id, invocation_id, author, timestamp, event_json
FROM {table}
- WHERE session_id = %s
- ORDER BY timestamp ASC
- """).format(table=pg_sql.Identifier(self._events_table))
+ WHERE {where_clause}
+ ORDER BY timestamp ASC{limit_clause}
+ """
+ ).format(
+ table=pg_sql.Identifier(self._events_table),
+ where_clause=pg_sql.SQL(where_clause), # pyright: ignore[reportArgumentType]
+ limit_clause=pg_sql.SQL(" LIMIT %s" if limit else ""), # pyright: ignore[reportArgumentType]
+ )
try:
with self._config.provide_connection() as conn, conn.cursor() as cur:
- cur.execute(query, (session_id,))
+ cur.execute(query, tuple(params))
rows = cur.fetchall()
return [
EventRecord(
- id=row["id"],
session_id=row["session_id"],
- app_name=row["app_name"],
- user_id=row["user_id"],
invocation_id=row["invocation_id"],
author=row["author"],
- actions=bytes(row["actions"]) if row["actions"] else b"",
- long_running_tool_ids_json=row["long_running_tool_ids_json"],
- branch=row["branch"],
timestamp=row["timestamp"],
- content=row["content"],
- grounding_metadata=row["grounding_metadata"],
- custom_metadata=row["custom_metadata"],
- partial=row["partial"],
- turn_complete=row["turn_complete"],
- interrupted=row["interrupted"],
- error_code=row["error_code"],
- error_message=row["error_message"],
+ event_json=row["event_json"],
)
for row in rows
]
except errors.UndefinedTable:
return []
+ async def get_events(
+ self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
+ ) -> "list[EventRecord]":
+ """Get events for a session."""
+ return await async_(self._get_events)(session_id, after_timestamp, limit)
+
+ def _append_event(self, event_record: EventRecord) -> None:
+ """Synchronous implementation of append_event."""
+ self._insert_event(event_record)
+
+ async def append_event(self, event_record: EventRecord) -> None:
+ """Append an event to a session."""
+ await async_(self._append_event)(event_record)
+
class PsycopgAsyncADKMemoryStore(BaseAsyncADKMemoryStore["PsycopgAsyncConfig"]):
"""PostgreSQL ADK memory store using Psycopg3 async driver."""
@@ -1182,7 +849,7 @@ async def delete_entries_older_than(self, days: int) -> int:
return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0
-class PsycopgSyncADKMemoryStore(BaseSyncADKMemoryStore["PsycopgSyncConfig"]):
+class PsycopgSyncADKMemoryStore(BaseAsyncADKMemoryStore["PsycopgSyncConfig"]):
"""PostgreSQL ADK memory store using Psycopg3 sync driver."""
__slots__ = ()
@@ -1191,7 +858,7 @@ def __init__(self, config: "PsycopgSyncConfig") -> None:
"""Initialize Psycopg sync memory store."""
super().__init__(config)
- def _get_create_memory_table_sql(self) -> str:
+ async def _get_create_memory_table_sql(self) -> str:
"""Get PostgreSQL CREATE TABLE SQL for memory entries."""
owner_id_line = ""
if self._owner_id_column_ddl:
@@ -1231,15 +898,19 @@ def _get_drop_memory_table_sql(self) -> "list[str]":
"""Get PostgreSQL DROP TABLE SQL statements."""
return [f"DROP TABLE IF EXISTS {self._memory_table}"]
- def create_tables(self) -> None:
+ def _create_tables(self) -> None:
"""Create the memory table and indexes if they don't exist."""
if not self._enabled:
return
with self._config.provide_session() as driver:
- driver.execute_script(self._get_create_memory_table_sql())
+ driver.execute_script(run_(self._get_create_memory_table_sql)())
- def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
+ async def create_tables(self) -> None:
+ """Create tables if they don't exist."""
+ await async_(self._create_tables)()
+
+ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
"""Bulk insert memory entries with deduplication."""
if not self._enabled:
msg = "Memory store is disabled"
@@ -1284,7 +955,11 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object
return inserted_count
- def search_entries(
+ async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
+ """Bulk insert memory entries with deduplication."""
+ return await async_(self._insert_memory_entries)(entries, owner_id)
+
+ def _search_entries(
self, query: str, app_name: str, user_id: str, limit: "int | None" = None
) -> "list[MemoryRecord]":
"""Search memory entries by text query."""
@@ -1304,6 +979,12 @@ def search_entries(
except errors.UndefinedTable:
return []
+ async def search_entries(
+ self, query: str, app_name: str, user_id: str, limit: "int | None" = None
+ ) -> "list[MemoryRecord]":
+ """Search memory entries by text query."""
+ return await async_(self._search_entries)(query, app_name, user_id, limit)
+
def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]":
sql = pg_sql.SQL(
"""
@@ -1344,7 +1025,7 @@ def _search_entries_simple(self, query: str, app_name: str, user_id: str, limit:
rows = cur.fetchall()
return _rows_to_records(rows)
- def delete_entries_by_session(self, session_id: str) -> int:
+ def _delete_entries_by_session(self, session_id: str) -> int:
"""Delete all memory entries for a specific session."""
sql = pg_sql.SQL("DELETE FROM {table} WHERE session_id = %s").format(
table=pg_sql.Identifier(self._memory_table)
@@ -1354,7 +1035,11 @@ def delete_entries_by_session(self, session_id: str) -> int:
cur.execute(sql, (session_id,))
return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0
- def delete_entries_older_than(self, days: int) -> int:
+ async def delete_entries_by_session(self, session_id: str) -> int:
+ """Delete all memory entries for a specific session."""
+ return await async_(self._delete_entries_by_session)(session_id)
+
+ def _delete_entries_older_than(self, days: int) -> int:
"""Delete memory entries older than specified days."""
sql = pg_sql.SQL(
"""
@@ -1367,6 +1052,10 @@ def delete_entries_older_than(self, days: int) -> int:
cur.execute(sql)
return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0
+ async def delete_entries_older_than(self, days: int) -> int:
+ """Delete memory entries older than specified days."""
+ return await async_(self._delete_entries_older_than)(days)
+
def _rows_to_records(rows: "list[Any]") -> "list[MemoryRecord]":
return [
diff --git a/sqlspec/adapters/pymysql/adk/store.py b/sqlspec/adapters/pymysql/adk/store.py
index e30765c7f..5e7ea8513 100644
--- a/sqlspec/adapters/pymysql/adk/store.py
+++ b/sqlspec/adapters/pymysql/adk/store.py
@@ -5,9 +5,10 @@
import pymysql
-from sqlspec.extensions.adk import BaseSyncADKStore, EventRecord, SessionRecord
-from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore
+from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord
+from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore
from sqlspec.utils.serializers import from_json, to_json
+from sqlspec.utils.sync_tools import async_, run_
if TYPE_CHECKING:
from datetime import datetime
@@ -33,8 +34,23 @@ def _parse_owner_id_column_for_mysql(column_ddl: str) -> "tuple[str, str]":
return (col_def, fk_constraint)
-class PyMysqlADKStore(BaseSyncADKStore["PyMysqlConfig"]):
- """MySQL/MariaDB ADK store using PyMySQL."""
+class PyMysqlADKStore(BaseAsyncADKStore["PyMysqlConfig"]):
+ """MySQL/MariaDB ADK store using PyMySQL.
+
+ Implements session and event storage for Google Agent Development Kit
+ using MySQL/MariaDB via the PyMySQL sync driver. Provides:
+ - Session state management with JSON storage
+ - Full-event JSON storage (single ``event_json`` column)
+ - Atomic event-create + state-update in one transaction
+ - Microsecond-precision timestamps
+ - Foreign key constraints with cascade delete
+
+ Notes:
+ - MySQL JSON type used - requires MySQL 5.7.8+
+ - TIMESTAMP(6) provides microsecond precision
+ - InnoDB engine required for foreign key support
+ - Configuration is read from config.extension_config["adk"]
+ """
__slots__ = ()
@@ -44,7 +60,7 @@ def __init__(self, config: "PyMysqlConfig") -> None:
def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]":
return _parse_owner_id_column_for_mysql(column_ddl)
- def _get_create_sessions_table_sql(self) -> str:
+ async def _get_create_sessions_table_sql(self) -> str:
owner_id_col = ""
fk_constraint = ""
@@ -68,27 +84,18 @@ def _get_create_sessions_table_sql(self) -> str:
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
"""
- def _get_create_events_table_sql(self) -> str:
+ async def _get_create_events_table_sql(self) -> str:
+ """Get MySQL CREATE TABLE SQL for events.
+
+ Post clean-break schema: 5 columns only.
+ """
return f"""
CREATE TABLE IF NOT EXISTS {self._events_table} (
- id VARCHAR(128) PRIMARY KEY,
session_id VARCHAR(128) NOT NULL,
- app_name VARCHAR(128) NOT NULL,
- user_id VARCHAR(128) NOT NULL,
invocation_id VARCHAR(256) NOT NULL,
- author VARCHAR(256) NOT NULL,
- actions BLOB NOT NULL,
- long_running_tool_ids_json JSON,
- branch VARCHAR(256),
+ author VARCHAR(128) NOT NULL,
timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
- content JSON,
- grounding_metadata JSON,
- custom_metadata JSON,
- partial BOOLEAN,
- turn_complete BOOLEAN,
- interrupted BOOLEAN,
- error_code VARCHAR(256),
- error_message VARCHAR(1024),
+ event_json JSON NOT NULL,
FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE,
INDEX idx_{self._events_table}_session (session_id, timestamp ASC)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
@@ -97,12 +104,16 @@ def _get_create_events_table_sql(self) -> str:
def _get_drop_tables_sql(self) -> "list[str]":
return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"]
- def create_tables(self) -> None:
+ def _create_tables(self) -> None:
with self._config.provide_session() as driver:
- driver.execute_script(self._get_create_sessions_table_sql())
- driver.execute_script(self._get_create_events_table_sql())
+ driver.execute_script(run_(self._get_create_sessions_table_sql)())
+ driver.execute_script(run_(self._get_create_events_table_sql)())
+
+ async def create_tables(self) -> None:
+ """Create tables if they don't exist."""
+ await async_(self._create_tables)()
- def create_session(
+ def _create_session(
self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
) -> SessionRecord:
state_json = to_json(state)
@@ -129,13 +140,19 @@ def create_session(
cursor.close()
conn.commit()
- result = self.get_session(session_id)
+ result = self._get_session(session_id)
if result is None:
msg = "Failed to fetch created session"
raise RuntimeError(msg)
return result
- def get_session(self, session_id: str) -> "SessionRecord | None":
+ async def create_session(
+ self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
+ ) -> SessionRecord:
+ """Create a new session."""
+ return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id)
+
+ def _get_session(self, session_id: str) -> "SessionRecord | None":
sql = f"""
SELECT id, app_name, user_id, state, create_time, update_time
FROM {self._session_table}
@@ -169,7 +186,11 @@ def get_session(self, session_id: str) -> "SessionRecord | None":
return None
raise
- def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
+ async def get_session(self, session_id: str) -> "SessionRecord | None":
+ """Get session by ID."""
+ return await async_(self._get_session)(session_id)
+
+ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
state_json = to_json(state)
sql = f"""
@@ -186,7 +207,11 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None
cursor.close()
conn.commit()
- def delete_session(self, session_id: str) -> None:
+ async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
+ """Update session state."""
+ await async_(self._update_session_state)(session_id, state)
+
+ def _delete_session(self, session_id: str) -> None:
sql = f"DELETE FROM {self._session_table} WHERE id = %s"
with self._config.provide_connection() as conn:
@@ -197,7 +222,11 @@ def delete_session(self, session_id: str) -> None:
cursor.close()
conn.commit()
- def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
+ async def delete_session(self, session_id: str) -> None:
+ """Delete session and associated events."""
+ await async_(self._delete_session)(session_id)
+
+ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
if user_id is None:
sql = f"""
SELECT id, app_name, user_id, state, create_time, update_time
@@ -240,24 +269,71 @@ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Sess
return []
raise
- def append_event(self, event_record: EventRecord) -> None:
- content_json = to_json(event_record.get("content")) if event_record.get("content") else None
- grounding_metadata_json = (
- to_json(event_record.get("grounding_metadata")) if event_record.get("grounding_metadata") else None
- )
- custom_metadata_json = (
- to_json(event_record.get("custom_metadata")) if event_record.get("custom_metadata") else None
- )
+ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
+ """List sessions for an app."""
+ return await async_(self._list_sessions)(app_name, user_id)
+
+ def _append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ """Atomically create an event and update the session's durable state.
+
+ The event insert and state update succeed together or fail together
+ within a single transaction.
+
+ Args:
+ event_record: Event record to store.
+ session_id: Session identifier whose state should be updated.
+ state: Post-append durable state snapshot.
+ """
+ event_json = event_record["event_json"]
+ event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json
+ state_json = to_json(state)
+
+ insert_sql = f"""
+ INSERT INTO {self._events_table} (
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES (%s, %s, %s, %s, %s)
+ """
+
+ update_sql = f"""
+ UPDATE {self._session_table}
+ SET state = %s
+ WHERE id = %s
+ """
+
+ with self._config.provide_connection() as conn:
+ cursor = conn.cursor()
+ try:
+ cursor.execute(
+ insert_sql,
+ (
+ event_record["session_id"],
+ event_record["invocation_id"],
+ event_record["author"],
+ event_record["timestamp"],
+ event_json_str,
+ ),
+ )
+ cursor.execute(update_sql, (state_json, session_id))
+ finally:
+ cursor.close()
+ conn.commit()
+
+ async def append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ """Atomically append an event and update the session's durable state."""
+ await async_(self._append_event_and_update_state)(event_record, session_id, state)
+
+ def _insert_event(self, event_record: EventRecord) -> None:
+ event_json = event_record["event_json"]
+ event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json
sql = f"""
INSERT INTO {self._events_table} (
- id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
- ) VALUES (
- %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
- )
+ session_id, invocation_id, author, timestamp, event_json
+ ) VALUES (%s, %s, %s, %s, %s)
"""
with self._config.provide_connection() as conn:
@@ -266,33 +342,30 @@ def append_event(self, event_record: EventRecord) -> None:
cursor.execute(
sql,
(
- event_record["id"],
event_record["session_id"],
- event_record["app_name"],
- event_record["user_id"],
event_record["invocation_id"],
event_record["author"],
- event_record["actions"],
- event_record.get("long_running_tool_ids_json"),
- event_record.get("branch"),
event_record["timestamp"],
- content_json,
- grounding_metadata_json,
- custom_metadata_json,
- event_record.get("partial"),
- event_record.get("turn_complete"),
- event_record.get("interrupted"),
- event_record.get("error_code"),
- event_record.get("error_message"),
+ event_json_str,
),
)
finally:
cursor.close()
conn.commit()
- def get_events(
+ def _get_events(
self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
) -> "list[EventRecord]":
+ """List events for a session ordered by timestamp.
+
+ Args:
+ session_id: Session identifier.
+ after_timestamp: Only return events after this time.
+ limit: Maximum number of events to return.
+
+ Returns:
+ List of event records ordered by timestamp ASC.
+ """
where_clauses = ["session_id = %s"]
params: list[Any] = [session_id]
@@ -301,47 +374,32 @@ def get_events(
params.append(after_timestamp)
where_clause = " AND ".join(where_clauses)
- limit_clause = f" LIMIT {limit}" if limit else ""
-
+ limit_clause = " LIMIT %s" if limit else ""
sql = f"""
- SELECT id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
+ SELECT session_id, invocation_id, author, timestamp, event_json
FROM {self._events_table}
WHERE {where_clause}
ORDER BY timestamp ASC{limit_clause}
"""
+ if limit:
+ params.append(limit)
try:
with self._config.provide_connection() as conn:
cursor = conn.cursor()
try:
- cursor.execute(sql, params)
+ cursor.execute(sql, tuple(params))
rows = cursor.fetchall()
finally:
cursor.close()
return [
EventRecord(
- id=row[0],
- session_id=row[1],
- app_name=row[2],
- user_id=row[3],
- invocation_id=row[4],
- author=row[5],
- actions=bytes(row[6]),
- long_running_tool_ids_json=row[7],
- branch=row[8],
- timestamp=row[9],
- content=from_json(row[10]) if row[10] and isinstance(row[10], str) else row[10],
- grounding_metadata=from_json(row[11]) if row[11] and isinstance(row[11], str) else row[11],
- custom_metadata=from_json(row[12]) if row[12] and isinstance(row[12], str) else row[12],
- partial=row[13],
- turn_complete=row[14],
- interrupted=row[15],
- error_code=row[16],
- error_message=row[17],
+ session_id=row[0],
+ invocation_id=row[1],
+ author=row[2],
+ timestamp=row[3],
+ event_json=from_json(row[4]) if isinstance(row[4], str) else row[4],
)
for row in rows
]
@@ -350,8 +408,22 @@ def get_events(
return []
raise
+ async def get_events(
+ self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
+ ) -> "list[EventRecord]":
+ """Get events for a session."""
+ return await async_(self._get_events)(session_id, after_timestamp, limit)
+
+ def _append_event(self, event_record: EventRecord) -> None:
+ """Synchronous implementation of append_event."""
+ self._insert_event(event_record)
+
+ async def append_event(self, event_record: EventRecord) -> None:
+ """Append an event to a session."""
+ await async_(self._append_event)(event_record)
+
-class PyMysqlADKMemoryStore(BaseSyncADKMemoryStore["PyMysqlConfig"]):
+class PyMysqlADKMemoryStore(BaseAsyncADKMemoryStore["PyMysqlConfig"]):
"""MySQL/MariaDB ADK memory store using PyMySQL."""
__slots__ = ()
@@ -359,7 +431,7 @@ class PyMysqlADKMemoryStore(BaseSyncADKMemoryStore["PyMysqlConfig"]):
def __init__(self, config: "PyMysqlConfig") -> None:
super().__init__(config)
- def _get_create_memory_table_sql(self) -> str:
+ async def _get_create_memory_table_sql(self) -> str:
owner_id_line = ""
fk_constraint = ""
if self._owner_id_column_ddl:
@@ -393,14 +465,18 @@ def _get_create_memory_table_sql(self) -> str:
def _get_drop_memory_table_sql(self) -> "list[str]":
return [f"DROP TABLE IF EXISTS {self._memory_table}"]
- def create_tables(self) -> None:
+ def _create_tables(self) -> None:
if not self._enabled:
return
with self._config.provide_session() as driver:
- driver.execute_script(self._get_create_memory_table_sql())
+ driver.execute_script(run_(self._get_create_memory_table_sql)())
- def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
+ async def create_tables(self) -> None:
+ """Create tables if they don't exist."""
+ await async_(self._create_tables)()
+
+ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
if not self._enabled:
msg = "Memory store is disabled"
raise RuntimeError(msg)
@@ -466,7 +542,11 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object
conn.commit()
return inserted_count
- def search_entries(
+ async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
+ """Bulk insert memory entries with deduplication."""
+ return await async_(self._insert_memory_entries)(entries, owner_id)
+
+ def _search_entries(
self, query: str, app_name: str, user_id: str, limit: "int | None" = None
) -> "list[MemoryRecord]":
if not self._enabled:
@@ -506,7 +586,13 @@ def search_entries(
return [cast("MemoryRecord", dict(zip(columns, row, strict=False))) for row in rows]
- def delete_entries_by_session(self, session_id: str) -> int:
+ async def search_entries(
+ self, query: str, app_name: str, user_id: str, limit: "int | None" = None
+ ) -> "list[MemoryRecord]":
+ """Search memory entries by text query."""
+ return await async_(self._search_entries)(query, app_name, user_id, limit)
+
+ def _delete_entries_by_session(self, session_id: str) -> int:
if not self._enabled:
msg = "Memory store is disabled"
raise RuntimeError(msg)
@@ -521,7 +607,11 @@ def delete_entries_by_session(self, session_id: str) -> int:
finally:
cursor.close()
- def delete_entries_older_than(self, days: int) -> int:
+ async def delete_entries_by_session(self, session_id: str) -> int:
+ """Delete all memory entries for a specific session."""
+ return await async_(self._delete_entries_by_session)(session_id)
+
+ def _delete_entries_older_than(self, days: int) -> int:
if not self._enabled:
msg = "Memory store is disabled"
raise RuntimeError(msg)
@@ -538,3 +628,7 @@ def delete_entries_older_than(self, days: int) -> int:
return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0
finally:
cursor.close()
+
+ async def delete_entries_older_than(self, days: int) -> int:
+ """Delete memory entries older than specified days."""
+ return await async_(self._delete_entries_older_than)(days)
diff --git a/sqlspec/adapters/spanner/adk/store.py b/sqlspec/adapters/spanner/adk/store.py
index 22ebebc7e..a180de7e5 100644
--- a/sqlspec/adapters/spanner/adk/store.py
+++ b/sqlspec/adapters/spanner/adk/store.py
@@ -7,11 +7,11 @@
from google.cloud.spanner_v1 import param_types
from sqlspec.adapters.spanner.config import SpannerSyncConfig
-from sqlspec.adapters.spanner.type_converter import bytes_to_spanner, spanner_to_bytes
-from sqlspec.extensions.adk import BaseSyncADKStore, EventRecord, SessionRecord
-from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore
+from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord
+from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore
from sqlspec.protocols import SpannerParamTypesProtocol
from sqlspec.utils.serializers import from_json, to_json
+from sqlspec.utils.sync_tools import async_, run_
if TYPE_CHECKING:
from google.cloud.spanner_v1.database import Database
@@ -42,7 +42,7 @@ def __call__(self, transaction: "Transaction") -> None:
transaction.execute_update(sql, params=params, param_types=types) # type: ignore[no-untyped-call]
-class SpannerSyncADKStore(BaseSyncADKStore[SpannerSyncConfig]):
+class SpannerSyncADKStore(BaseAsyncADKStore[SpannerSyncConfig]):
"""Spanner ADK store backed by synchronous Spanner client."""
connector_name: ClassVar[str] = "spanner"
@@ -80,30 +80,15 @@ def _session_param_types(self, include_owner: bool) -> "dict[str, Any]":
types["owner_id"] = SPANNER_PARAM_TYPES.STRING
return types
- def _event_param_types(self, has_branch: bool) -> "dict[str, Any]":
+ def _event_param_types(self) -> "dict[str, Any]":
json_type = _json_param_type()
- types: dict[str, Any] = {
- "id": SPANNER_PARAM_TYPES.STRING,
+ return {
"session_id": SPANNER_PARAM_TYPES.STRING,
- "app_name": SPANNER_PARAM_TYPES.STRING,
- "user_id": SPANNER_PARAM_TYPES.STRING,
- "author": SPANNER_PARAM_TYPES.STRING,
- "actions": SPANNER_PARAM_TYPES.BYTES,
- "long_running_tool_ids_json": json_type,
"invocation_id": SPANNER_PARAM_TYPES.STRING,
+ "author": SPANNER_PARAM_TYPES.STRING,
"timestamp": SPANNER_PARAM_TYPES.TIMESTAMP,
- "content": json_type,
- "grounding_metadata": json_type,
- "custom_metadata": json_type,
- "partial": SPANNER_PARAM_TYPES.BOOL,
- "turn_complete": SPANNER_PARAM_TYPES.BOOL,
- "interrupted": SPANNER_PARAM_TYPES.BOOL,
- "error_code": SPANNER_PARAM_TYPES.STRING,
- "error_message": SPANNER_PARAM_TYPES.STRING,
+ "event_json": json_type,
}
- if has_branch:
- types["branch"] = SPANNER_PARAM_TYPES.STRING
- return types
def _decode_state(self, raw: Any) -> Any:
if isinstance(raw, str):
@@ -117,7 +102,7 @@ def _decode_json(self, raw: Any) -> Any:
return from_json(raw)
return raw
- def create_session(
+ def _create_session(
self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
) -> SessionRecord:
state_json = to_json(state)
@@ -146,7 +131,13 @@ def create_session(
"update_time": datetime.now(timezone.utc),
}
- def get_session(self, session_id: str) -> "SessionRecord | None":
+ async def create_session(
+ self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
+ ) -> SessionRecord:
+ """Create a new session."""
+ return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id)
+
+ def _get_session(self, session_id: str) -> "SessionRecord | None":
sql = f"""
SELECT id, app_name, user_id, state, create_time, update_time{", " + self._owner_id_column_name if self._owner_id_column_name else ""}
FROM {self._session_table}
@@ -172,7 +163,11 @@ def get_session(self, session_id: str) -> "SessionRecord | None":
}
return record
- def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
+ async def get_session(self, session_id: str) -> "SessionRecord | None":
+ """Get session by ID."""
+ return await async_(self._get_session)(session_id)
+
+ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
params = {"id": session_id, "state": to_json(state)}
json_type = _json_param_type()
sql = f"""
@@ -184,7 +179,11 @@ def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None
sql = f"{sql} AND shard_id = MOD(FARM_FINGERPRINT(@id), {self._shard_count})"
self._run_write([(sql, params, {"id": SPANNER_PARAM_TYPES.STRING, "state": json_type})])
- def list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]":
+ async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
+ """Update session state."""
+ await async_(self._update_session_state)(session_id, state)
+
+ def _list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]":
sql = f"""
SELECT id, app_name, user_id, state, create_time, update_time{", " + self._owner_id_column_name if self._owner_id_column_name else ""}
FROM {self._session_table}
@@ -198,6 +197,7 @@ def list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[Se
types["user_id"] = SPANNER_PARAM_TYPES.STRING
if self._shard_count > 1:
sql = f"{sql} AND shard_id = MOD(FARM_FINGERPRINT(id), {self._shard_count})"
+ sql = f"{sql} ORDER BY update_time DESC"
rows = self._run_read(sql, params, types)
records: list[SessionRecord] = []
@@ -214,7 +214,11 @@ def list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[Se
records.append(record)
return records
- def delete_session(self, session_id: str) -> None:
+ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
+ """List sessions for an app."""
+ return await async_(self._list_sessions)(app_name, user_id)
+
+ def _delete_session(self, session_id: str) -> None:
shard_clause = (
f" AND shard_id = MOD(FARM_FINGERPRINT(@session_id), {self._shard_count})" if self._shard_count > 1 else ""
)
@@ -224,174 +228,135 @@ def delete_session(self, session_id: str) -> None:
types = {"session_id": SPANNER_PARAM_TYPES.STRING}
self._run_write([(delete_events_sql, params, types), (delete_session_sql, params, types)])
- def create_event(
- self,
- event_id: str,
- session_id: str,
- app_name: str,
- user_id: str,
- author: "str | None" = None,
- actions: "bytes | None" = None,
- content: "dict[str, Any] | None" = None,
- **kwargs: Any,
- ) -> EventRecord:
- branch = kwargs.get("branch")
- long_running_serialized = (
- to_json(kwargs.get("long_running_tool_ids_json"))
- if kwargs.get("long_running_tool_ids_json") is not None
- else None
- )
- content_serialized = to_json(content) if content is not None else None
- grounding_serialized = (
- to_json(kwargs.get("grounding_metadata")) if kwargs.get("grounding_metadata") is not None else None
- )
- custom_serialized = (
- to_json(kwargs.get("custom_metadata")) if kwargs.get("custom_metadata") is not None else None
- )
- params: dict[str, Any] = {
- "id": event_id,
- "session_id": session_id,
- "app_name": app_name,
- "user_id": user_id,
- "author": author,
- "actions": bytes_to_spanner(actions),
- "long_running_tool_ids_json": long_running_serialized,
- "timestamp": datetime.now(timezone.utc),
- "content": content_serialized,
- "grounding_metadata": grounding_serialized,
- "custom_metadata": custom_serialized,
- "invocation_id": kwargs.get("invocation_id"),
- "partial": kwargs.get("partial"),
- "turn_complete": kwargs.get("turn_complete"),
- "interrupted": kwargs.get("interrupted"),
- "error_code": kwargs.get("error_code"),
- "error_message": kwargs.get("error_message"),
- }
- branch = kwargs.get("branch")
- columns = [
- "id",
- "session_id",
- "app_name",
- "user_id",
- "author",
- "actions",
- "long_running_tool_ids_json",
- "timestamp",
- "content",
- "grounding_metadata",
- "custom_metadata",
- "invocation_id",
- "partial",
- "turn_complete",
- "interrupted",
- "error_code",
- "error_message",
- ]
- values = [
- "@id",
- "@session_id",
- "@app_name",
- "@user_id",
- "@author",
- "@actions",
- "@long_running_tool_ids_json",
- "PENDING_COMMIT_TIMESTAMP()",
- "@content",
- "@grounding_metadata",
- "@custom_metadata",
- "@invocation_id",
- "@partial",
- "@turn_complete",
- "@interrupted",
- "@error_code",
- "@error_message",
- ]
- has_branch = branch is not None
- if has_branch:
- params["branch"] = branch
- columns.append("branch")
- values.append("@branch")
+ async def delete_session(self, session_id: str) -> None:
+ """Delete session and associated events."""
+ await async_(self._delete_session)(session_id)
- sql = f"""
- INSERT INTO {self._events_table} ({", ".join(columns)})
- VALUES ({", ".join(values)})
+ def _append_event_and_update_state(
+ self, event_record: "EventRecord", session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ """Atomically insert an event and update session state in one transaction.
+
+ Both the event INSERT and the session state UPDATE execute within a single
+ Spanner transaction so they succeed or fail together.
+
+ Args:
+ event_record: Event record to store.
+ session_id: Session whose state should be updated.
+ state: Post-append durable state snapshot.
+ """
+ event_params: dict[str, Any] = {
+ "session_id": event_record["session_id"],
+ "invocation_id": event_record["invocation_id"],
+ "author": event_record["author"],
+ "timestamp": event_record["timestamp"],
+ "event_json": to_json(event_record["event_json"]),
+ }
+ insert_sql = f"""
+ INSERT INTO {self._events_table} (session_id, invocation_id, author, timestamp, event_json)
+ VALUES (@session_id, @invocation_id, @author, PENDING_COMMIT_TIMESTAMP(), @event_json)
"""
- self._run_write([(sql, params, self._event_param_types(has_branch))])
- record: EventRecord = {
- "id": event_id,
- "session_id": session_id,
- "app_name": app_name,
- "user_id": user_id,
- "author": author or "",
- "actions": actions or b"",
- "long_running_tool_ids_json": long_running_serialized,
- "branch": branch,
- "timestamp": params["timestamp"],
- "content": from_json(content_serialized) if content_serialized else None,
- "grounding_metadata": from_json(grounding_serialized) if grounding_serialized else None,
- "custom_metadata": from_json(custom_serialized) if custom_serialized else None,
- "invocation_id": kwargs.get("invocation_id", ""),
- "partial": kwargs.get("partial"),
- "turn_complete": kwargs.get("turn_complete"),
- "interrupted": kwargs.get("interrupted"),
- "error_code": kwargs.get("error_code"),
- "error_message": kwargs.get("error_message"),
+ json_type = _json_param_type()
+ state_params: dict[str, Any] = {"id": session_id, "state": to_json(state)}
+ update_sql = f"""
+ UPDATE {self._session_table}
+ SET state = @state, update_time = PENDING_COMMIT_TIMESTAMP()
+ WHERE id = @id
+ """
+ if self._shard_count > 1:
+ update_sql = f"{update_sql} AND shard_id = MOD(FARM_FINGERPRINT(@id), {self._shard_count})"
+
+ self._run_write([
+ (insert_sql, event_params, self._event_param_types()),
+ (update_sql, state_params, {"id": SPANNER_PARAM_TYPES.STRING, "state": json_type}),
+ ])
+
+ async def append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ """Atomically append an event and update the session's durable state."""
+ await async_(self._append_event_and_update_state)(event_record, session_id, state)
+
+ def _insert_event(self, event_record: "EventRecord") -> None:
+ event_params: dict[str, Any] = {
+ "session_id": event_record["session_id"],
+ "invocation_id": event_record["invocation_id"],
+ "author": event_record["author"],
+ "timestamp": event_record["timestamp"],
+ "event_json": to_json(event_record["event_json"]),
}
- return record
+ insert_sql = f"""
+ INSERT INTO {self._events_table} (session_id, invocation_id, author, timestamp, event_json)
+ VALUES (@session_id, @invocation_id, @author, PENDING_COMMIT_TIMESTAMP(), @event_json)
+ """
+ self._run_write([(insert_sql, event_params, self._event_param_types())])
- def list_events(self, session_id: str) -> "list[EventRecord]":
+ def _get_events(
+ self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
+ ) -> "list[EventRecord]":
sql = f"""
- SELECT id, session_id, app_name, user_id, author, actions, long_running_tool_ids_json, branch,
- timestamp, content, grounding_metadata, custom_metadata, invocation_id, partial,
- turn_complete, interrupted, error_code, error_message
+ SELECT session_id, invocation_id, author, timestamp, event_json
FROM {self._events_table}
WHERE session_id = @session_id
"""
if self._shard_count > 1:
sql = f"{sql} AND shard_id = MOD(FARM_FINGERPRINT(@session_id), {self._shard_count})"
+ params: dict[str, Any] = {"session_id": session_id}
+ types: dict[str, Any] = {"session_id": SPANNER_PARAM_TYPES.STRING}
+ if after_timestamp is not None:
+ sql = f"{sql} AND timestamp > @after_timestamp"
+ params["after_timestamp"] = after_timestamp
+ types["after_timestamp"] = SPANNER_PARAM_TYPES.TIMESTAMP
sql = f"{sql} ORDER BY timestamp ASC"
- params = {"session_id": session_id}
- types = {"session_id": SPANNER_PARAM_TYPES.STRING}
+ if limit is not None:
+ sql = f"{sql} LIMIT @limit"
+ params["limit"] = limit
+ types["limit"] = SPANNER_PARAM_TYPES.INT64
rows = self._run_read(sql, params, types)
return [
{
- "id": row[0],
- "session_id": row[1],
- "app_name": row[2],
- "user_id": row[3],
- "invocation_id": row[12] or "",
- "author": row[4] or "",
- "actions": spanner_to_bytes(row[5]) or b"",
- "long_running_tool_ids_json": row[6],
- "branch": row[7],
- "timestamp": row[8],
- "content": self._decode_json(row[9]),
- "grounding_metadata": self._decode_json(row[10]),
- "custom_metadata": self._decode_json(row[11]),
- "partial": row[13],
- "turn_complete": row[14],
- "interrupted": row[15],
- "error_code": row[16],
- "error_message": row[17],
+ "session_id": row[0],
+ "invocation_id": row[1] or "",
+ "author": row[2] or "",
+ "timestamp": row[3],
+ "event_json": row[4],
}
for row in rows
]
- def create_tables(self) -> None:
+ async def get_events(
+ self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
+ ) -> "list[EventRecord]":
+ """Get events for a session."""
+ return await async_(self._get_events)(session_id, after_timestamp, limit)
+
+ def _append_event(self, event_record: EventRecord) -> None:
+ """Synchronous implementation of append_event."""
+ self._insert_event(event_record)
+
+ async def append_event(self, event_record: EventRecord) -> None:
+ """Append an event to a session."""
+ await async_(self._append_event)(event_record)
+
+ def _create_tables(self) -> None:
database = self._database()
existing_tables = {t.table_id for t in database.list_tables()} # type: ignore[no-untyped-call]
ddl_statements: list[str] = []
if self._session_table not in existing_tables:
- ddl_statements.append(self._get_create_sessions_table_sql())
+ ddl_statements.append(run_(self._get_create_sessions_table_sql)())
if self._events_table not in existing_tables:
- ddl_statements.append(self._get_create_events_table_sql())
+ ddl_statements.append(run_(self._get_create_events_table_sql)())
if ddl_statements:
database.update_ddl(ddl_statements).result(300) # type: ignore[no-untyped-call]
- def _get_create_sessions_table_sql(self) -> str:
+ async def create_tables(self) -> None:
+ """Create tables if they don't exist."""
+ await async_(self._create_tables)()
+
+ async def _get_create_sessions_table_sql(self) -> str:
owner_line = ""
if self._owner_id_column_ddl:
owner_line = f",\n {self._owner_id_column_ddl}"
@@ -414,35 +379,22 @@ def _get_create_sessions_table_sql(self) -> str:
) {pk}{options}
"""
- def _get_create_events_table_sql(self) -> str:
+ async def _get_create_events_table_sql(self) -> str:
shard_column = ""
- pk = "PRIMARY KEY (session_id, timestamp, id)"
+ pk = "PRIMARY KEY (session_id, timestamp)"
if self._shard_count > 1:
shard_column = f",\n shard_id INT64 AS (MOD(FARM_FINGERPRINT(session_id), {self._shard_count})) STORED"
- pk = "PRIMARY KEY (shard_id, session_id, timestamp, id)"
+ pk = "PRIMARY KEY (shard_id, session_id, timestamp)"
options = ""
if self._events_table_options:
options = f"\nOPTIONS ({self._events_table_options})"
return f"""
CREATE TABLE {self._events_table} (
- id STRING(128) NOT NULL,
session_id STRING(128) NOT NULL,
- app_name STRING(128) NOT NULL,
- user_id STRING(128) NOT NULL,
- invocation_id STRING(128),
- author STRING(64),
- actions BYTES(MAX),
- long_running_tool_ids_json JSON,
- branch STRING(64),
+ invocation_id STRING(256) NOT NULL,
+ author STRING(128) NOT NULL,
timestamp TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true),
- content JSON,
- grounding_metadata JSON,
- custom_metadata JSON,
- partial BOOL,
- turn_complete BOOL,
- interrupted BOOL,
- error_code STRING(64),
- error_message STRING(255){shard_column}
+ event_json JSON NOT NULL{shard_column}
) {pk}{options}
"""
@@ -479,7 +431,7 @@ def execute_sql(
) -> Iterable[Any]: ...
-class SpannerSyncADKMemoryStore(BaseSyncADKMemoryStore[SpannerSyncConfig]):
+class SpannerSyncADKMemoryStore(BaseAsyncADKMemoryStore[SpannerSyncConfig]):
"""Spanner ADK memory store backed by synchronous Spanner client."""
connector_name: ClassVar[str] = "spanner"
@@ -532,7 +484,7 @@ def _decode_json(self, raw: Any) -> Any:
return from_json(raw)
return raw
- def create_tables(self) -> None:
+ def _create_tables(self) -> None:
if not self._enabled:
return
@@ -541,12 +493,16 @@ def create_tables(self) -> None:
ddl_statements: list[str] = []
if self._memory_table not in existing_tables:
- ddl_statements.extend(self._get_create_memory_table_sql())
+ ddl_statements.extend(run_(self._get_create_memory_table_sql)())
if ddl_statements:
database.update_ddl(ddl_statements).result(300) # type: ignore[no-untyped-call]
- def _get_create_memory_table_sql(self) -> "list[str]":
+ async def create_tables(self) -> None:
+ """Create tables if they don't exist."""
+ await async_(self._create_tables)()
+
+ async def _get_create_memory_table_sql(self) -> "list[str]":
owner_line = ""
if self._owner_id_column_ddl:
owner_line = f",\n {self._owner_id_column_ddl}"
@@ -590,7 +546,23 @@ def _get_create_memory_table_sql(self) -> "list[str]":
statements.append(fts_index)
return statements
- def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
+ def _get_drop_memory_table_sql(self) -> "list[str]":
+ """Get SQL to drop the memory table and its indexes.
+
+ Returns:
+ List of SQL statements to drop the memory table and associated indexes.
+ """
+ statements: list[str] = []
+ if self._use_fts:
+ statements.append(f"DROP SEARCH INDEX idx_{self._memory_table}_fts")
+ statements.extend([
+ f"DROP INDEX idx_{self._memory_table}_session",
+ f"DROP INDEX idx_{self._memory_table}_app_user_time",
+ f"DROP TABLE {self._memory_table}",
+ ])
+ return statements
+
+ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
if not self._enabled:
msg = "Memory store is disabled"
raise RuntimeError(msg)
@@ -639,12 +611,16 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object
self._run_write(statements)
return inserted_count
+ async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
+ """Bulk insert memory entries with deduplication."""
+ return await async_(self._insert_memory_entries)(entries, owner_id)
+
def _event_exists(self, event_id: str) -> bool:
sql = f"SELECT event_id FROM {self._memory_table} WHERE event_id = @event_id LIMIT 1"
rows = self._run_read(sql, {"event_id": event_id}, {"event_id": SPANNER_PARAM_TYPES.STRING})
return bool(rows)
- def search_entries(
+ def _search_entries(
self, query: str, app_name: str, user_id: str, limit: "int | None" = None
) -> "list[MemoryRecord]":
if not self._enabled:
@@ -657,6 +633,12 @@ def search_entries(
return self._search_entries_fts(query, app_name, user_id, effective_limit)
return self._search_entries_simple(query, app_name, user_id, effective_limit)
+ async def search_entries(
+ self, query: str, app_name: str, user_id: str, limit: "int | None" = None
+ ) -> "list[MemoryRecord]":
+ """Search memory entries by text query."""
+ return await async_(self._search_entries)(query, app_name, user_id, limit)
+
def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]":
sql = f"""
SELECT id, session_id, app_name, user_id, event_id, author,
@@ -700,19 +682,27 @@ def _search_entries_simple(self, query: str, app_name: str, user_id: str, limit:
rows = self._run_read(sql, params, types)
return self._rows_to_records(rows)
- def delete_entries_by_session(self, session_id: str) -> int:
+ def _delete_entries_by_session(self, session_id: str) -> int:
sql = f"DELETE FROM {self._memory_table} WHERE session_id = @session_id"
params = {"session_id": session_id}
types = {"session_id": SPANNER_PARAM_TYPES.STRING}
return self._execute_update(sql, params, types)
- def delete_entries_older_than(self, days: int) -> int:
+ async def delete_entries_by_session(self, session_id: str) -> int:
+ """Delete all memory entries for a specific session."""
+ return await async_(self._delete_entries_by_session)(session_id)
+
+ def _delete_entries_older_than(self, days: int) -> int:
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
sql = f"DELETE FROM {self._memory_table} WHERE inserted_at < @cutoff"
params = {"cutoff": cutoff}
types = {"cutoff": SPANNER_PARAM_TYPES.TIMESTAMP}
return self._execute_update(sql, params, types)
+ async def delete_entries_older_than(self, days: int) -> int:
+ """Delete memory entries older than specified days."""
+ return await async_(self._delete_entries_older_than)(days)
+
def _rows_to_records(self, rows: "list[Any]") -> "list[MemoryRecord]":
return [
{
diff --git a/sqlspec/adapters/sqlite/adk/store.py b/sqlspec/adapters/sqlite/adk/store.py
index ee3376d9e..bf4e51e52 100644
--- a/sqlspec/adapters/sqlite/adk/store.py
+++ b/sqlspec/adapters/sqlite/adk/store.py
@@ -4,7 +4,7 @@
from typing import TYPE_CHECKING, Any
from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord
-from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore
+from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore
from sqlspec.utils.logging import get_logger
from sqlspec.utils.serializers import from_json, to_json
from sqlspec.utils.sync_tools import async_, run_
@@ -58,34 +58,6 @@ def _julian_to_datetime(julian: float) -> datetime:
return datetime.fromtimestamp(timestamp, tz=timezone.utc)
-def _to_sqlite_bool(value: "bool | None") -> "int | None":
- """Convert Python bool to SQLite INTEGER.
-
- Args:
- value: Boolean value or None.
-
- Returns:
- 1 for True, 0 for False, None for None.
- """
- if value is None:
- return None
- return 1 if value else 0
-
-
-def _from_sqlite_bool(value: "int | None") -> "bool | None":
- """Convert SQLite INTEGER to Python bool.
-
- Args:
- value: Integer value (0/1) or None.
-
- Returns:
- True for 1, False for 0, None for None.
- """
- if value is None:
- return None
- return bool(value)
-
-
class SqliteADKStore(BaseAsyncADKStore["SqliteConfig"]):
"""SQLite ADK store using synchronous SQLite driver.
@@ -95,10 +67,11 @@ class SqliteADKStore(BaseAsyncADKStore["SqliteConfig"]):
Provides:
- Session state management with JSON storage (as TEXT)
- - Event history tracking with BLOB-serialized actions
+ - Event history tracking with full-event JSON storage
- Julian Day timestamps (REAL) for efficient date operations
- Foreign key constraints with cascade delete
- - Efficient upserts using INSERT OR REPLACE
+ - Atomic event+state writes via append_event_and_update_state
+ - PRAGMA optimization profile for file-based databases
Args:
config: SqliteConfig instance with extension_config["adk"] settings.
@@ -122,9 +95,8 @@ class SqliteADKStore(BaseAsyncADKStore["SqliteConfig"]):
Notes:
- JSON stored as TEXT with SQLSpec serializers (msgspec/orjson/stdlib)
- - BOOLEAN as INTEGER (0/1, with None for NULL)
- Timestamps as REAL (Julian day: julianday('now'))
- - BLOB for pre-serialized actions from Google ADK
+ - Full event stored as JSON TEXT in event_data column
- PRAGMA foreign_keys = ON (enable per connection)
- Configuration is read from config.extension_config["adk"]
"""
@@ -145,6 +117,22 @@ def __init__(self, config: "SqliteConfig") -> None:
"""
super().__init__(config)
+ def _apply_pragmas(self, connection: Any) -> None:
+ """Apply PRAGMA optimization profile for this connection.
+
+ Args:
+ connection: SQLite connection.
+
+ Notes:
+ Enables foreign keys and applies performance PRAGMAs.
+ For file-based databases, adds cache_size, mmap_size,
+ and journal_size_limit optimizations.
+ """
+ connection.execute("PRAGMA foreign_keys = ON")
+ connection.execute("PRAGMA cache_size = -64000")
+ connection.execute("PRAGMA mmap_size = 30000000")
+ connection.execute("PRAGMA journal_size_limit = 67108864")
+
async def _get_create_sessions_table_sql(self) -> str:
"""Get SQLite CREATE TABLE SQL for sessions.
@@ -184,9 +172,8 @@ async def _get_create_events_table_sql(self) -> str:
SQL statement to create adk_events table with indexes.
Notes:
- - TEXT for IDs, strings, and JSON content
- - BLOB for pickled actions
- - INTEGER for booleans (0/1/NULL)
+ - TEXT for IDs and indexed scalars
+ - TEXT for full event JSON (event_data)
- REAL for Julian Day timestamps
- Foreign key to sessions with CASCADE delete
- Index on (session_id, timestamp ASC)
@@ -195,22 +182,10 @@ async def _get_create_events_table_sql(self) -> str:
CREATE TABLE IF NOT EXISTS {self._events_table} (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
- app_name TEXT NOT NULL,
- user_id TEXT NOT NULL,
- invocation_id TEXT NOT NULL,
- author TEXT NOT NULL,
- actions BLOB NOT NULL,
- long_running_tool_ids_json TEXT,
- branch TEXT,
+ invocation_id TEXT,
+ author TEXT,
timestamp REAL NOT NULL,
- content TEXT,
- grounding_metadata TEXT,
- custom_metadata TEXT,
- partial INTEGER,
- turn_complete INTEGER,
- interrupted INTEGER,
- error_code TEXT,
- error_message TEXT,
+ event_data TEXT NOT NULL,
FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session
@@ -229,21 +204,10 @@ def _get_drop_tables_sql(self) -> "list[str]":
"""
return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"]
- def _enable_foreign_keys(self, connection: Any) -> None:
- """Enable foreign key constraints for this connection.
-
- Args:
- connection: SQLite connection.
-
- Notes:
- SQLite requires PRAGMA foreign_keys = ON per connection.
- """
- connection.execute("PRAGMA foreign_keys = ON")
-
def _create_tables(self) -> None:
"""Synchronous implementation of create_tables."""
with self._config.provide_session() as driver:
- driver.connection.execute("PRAGMA foreign_keys = ON")
+ self._apply_pragmas(driver.connection)
driver.execute_script(run_(self._get_create_sessions_table_sql)())
driver.execute_script(run_(self._get_create_events_table_sql)())
@@ -257,7 +221,7 @@ def _create_session(
"""Synchronous implementation of create_session."""
now = datetime.now(timezone.utc)
now_julian = _datetime_to_julian(now)
- state_json = to_json(state) if state else None
+ state_json = to_json(state)
params: tuple[Any, ...]
if self._owner_id_column_name:
@@ -275,7 +239,7 @@ def _create_session(
params = (session_id, app_name, user_id, state_json, now_julian, now_julian)
with self._config.provide_connection() as conn:
- self._enable_foreign_keys(conn)
+ self._apply_pragmas(conn)
conn.execute(sql, params)
conn.commit()
@@ -300,7 +264,7 @@ async def create_session(
Notes:
Uses Julian Day for create_time and update_time.
- State is JSON-serialized before insertion.
+ State is always JSON-serialized (empty dict becomes '{}', never NULL).
If owner_id_column is configured, owner_id is inserted into that column.
"""
return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id)
@@ -314,7 +278,7 @@ def _get_session(self, session_id: str) -> "SessionRecord | None":
"""
with self._config.provide_connection() as conn:
- self._enable_foreign_keys(conn)
+ self._apply_pragmas(conn)
cursor = conn.execute(sql, (session_id,))
row = cursor.fetchone()
@@ -348,7 +312,7 @@ async def get_session(self, session_id: str) -> "SessionRecord | None":
def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
"""Synchronous implementation of update_session_state."""
now_julian = _datetime_to_julian(datetime.now(timezone.utc))
- state_json = to_json(state) if state else None
+ state_json = to_json(state)
sql = f"""
UPDATE {self._session_table}
@@ -357,7 +321,7 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non
"""
with self._config.provide_connection() as conn:
- self._enable_foreign_keys(conn)
+ self._apply_pragmas(conn)
conn.execute(sql, (state_json, now_julian, session_id))
conn.commit()
@@ -371,6 +335,7 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") -
Notes:
This replaces the entire state dictionary.
Updates update_time to current Julian Day.
+ Empty dict is serialized as '{}', never NULL.
"""
await async_(self._update_session_state)(session_id, state)
@@ -394,7 +359,7 @@ def _list_sessions(self, app_name: str, user_id: "str | None") -> "list[SessionR
params = (app_name, user_id)
with self._config.provide_connection() as conn:
- self._enable_foreign_keys(conn)
+ self._apply_pragmas(conn)
cursor = conn.execute(sql, params)
rows = cursor.fetchall()
@@ -430,7 +395,7 @@ def _delete_session(self, session_id: str) -> None:
sql = f"DELETE FROM {self._session_table} WHERE id = ?"
with self._config.provide_connection() as conn:
- self._enable_foreign_keys(conn)
+ self._apply_pragmas(conn)
conn.execute(sql, (session_id,))
conn.commit()
@@ -448,53 +413,29 @@ async def delete_session(self, session_id: str) -> None:
def _append_event(self, event_record: EventRecord) -> None:
"""Synchronous implementation of append_event."""
timestamp_julian = _datetime_to_julian(event_record["timestamp"])
-
- content_json = to_json(event_record.get("content")) if event_record.get("content") else None
- grounding_metadata_json = (
- to_json(event_record.get("grounding_metadata")) if event_record.get("grounding_metadata") else None
- )
- custom_metadata_json = (
- to_json(event_record.get("custom_metadata")) if event_record.get("custom_metadata") else None
- )
-
- partial_int = _to_sqlite_bool(event_record.get("partial"))
- turn_complete_int = _to_sqlite_bool(event_record.get("turn_complete"))
- interrupted_int = _to_sqlite_bool(event_record.get("interrupted"))
+ event_data_json = to_json(event_record["event_json"])
sql = f"""
INSERT INTO {self._events_table} (
- id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
- ) VALUES (
- ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
- )
+ id, session_id, invocation_id, author, timestamp, event_data
+ ) VALUES (?, ?, ?, ?, ?, ?)
"""
+ import uuid
+
+ event_id = str(uuid.uuid4())
+
with self._config.provide_connection() as conn:
- self._enable_foreign_keys(conn)
+ self._apply_pragmas(conn)
conn.execute(
sql,
(
- event_record["id"],
+ event_id,
event_record["session_id"],
- event_record["app_name"],
- event_record["user_id"],
event_record["invocation_id"],
event_record["author"],
- event_record["actions"],
- event_record.get("long_running_tool_ids_json"),
- event_record.get("branch"),
timestamp_julian,
- content_json,
- grounding_metadata_json,
- custom_metadata_json,
- partial_int,
- turn_complete_int,
- interrupted_int,
- event_record.get("error_code"),
- event_record.get("error_message"),
+ event_data_json,
),
)
conn.commit()
@@ -503,15 +444,71 @@ async def append_event(self, event_record: EventRecord) -> None:
"""Append an event to a session.
Args:
- event_record: Event record to store.
+ event_record: Event record with 5 keys: session_id, invocation_id,
+ author, timestamp, event_json.
Notes:
Uses Julian Day for timestamp.
- JSON fields are serialized to TEXT.
- Boolean fields converted to INTEGER (0/1/NULL).
+ event_json dict is serialized to TEXT as event_data column.
"""
await async_(self._append_event)(event_record)
+ def _append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ """Synchronous implementation of append_event_and_update_state."""
+ import uuid
+
+ timestamp_julian = _datetime_to_julian(event_record["timestamp"])
+ event_data_json = to_json(event_record["event_json"])
+ now_julian = _datetime_to_julian(datetime.now(timezone.utc))
+ state_json = to_json(state)
+ event_id = str(uuid.uuid4())
+
+ insert_sql = f"""
+ INSERT INTO {self._events_table} (
+ id, session_id, invocation_id, author, timestamp, event_data
+ ) VALUES (?, ?, ?, ?, ?, ?)
+ """
+
+ update_sql = f"""
+ UPDATE {self._session_table}
+ SET state = ?, update_time = ?
+ WHERE id = ?
+ """
+
+ with self._config.provide_connection() as conn:
+ self._apply_pragmas(conn)
+ conn.execute(
+ insert_sql,
+ (
+ event_id,
+ event_record["session_id"],
+ event_record["invocation_id"],
+ event_record["author"],
+ timestamp_julian,
+ event_data_json,
+ ),
+ )
+ conn.execute(update_sql, (state_json, now_julian, session_id))
+ conn.commit()
+
+ async def append_event_and_update_state(
+ self, event_record: EventRecord, session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ """Atomically append an event and update the session's durable state.
+
+ Inserts the event and updates the session state + update_time in a
+ single transaction. Both operations succeed or fail together.
+
+ Args:
+ event_record: Event record to store.
+ session_id: Session identifier whose state should be updated.
+ state: Post-append durable state snapshot (temp: keys already
+ stripped by the service layer).
+ """
+ await async_(self._append_event_and_update_state)(event_record, session_id, state)
+
def _get_events(
self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
) -> "list[EventRecord]":
@@ -527,40 +524,24 @@ def _get_events(
limit_clause = f" LIMIT {limit}" if limit else ""
sql = f"""
- SELECT id, session_id, app_name, user_id, invocation_id, author, actions,
- long_running_tool_ids_json, branch, timestamp, content,
- grounding_metadata, custom_metadata, partial, turn_complete,
- interrupted, error_code, error_message
+ SELECT id, session_id, invocation_id, author, timestamp, event_data
FROM {self._events_table}
WHERE {where_clause}
ORDER BY timestamp ASC{limit_clause}
"""
with self._config.provide_connection() as conn:
- self._enable_foreign_keys(conn)
+ self._apply_pragmas(conn)
cursor = conn.execute(sql, params)
rows = cursor.fetchall()
return [
EventRecord(
- id=row[0],
session_id=row[1],
- app_name=row[2],
- user_id=row[3],
- invocation_id=row[4],
- author=row[5],
- actions=bytes(row[6]),
- long_running_tool_ids_json=row[7],
- branch=row[8],
- timestamp=_julian_to_datetime(row[9]),
- content=from_json(row[10]) if row[10] else None,
- grounding_metadata=from_json(row[11]) if row[11] else None,
- custom_metadata=from_json(row[12]) if row[12] else None,
- partial=_from_sqlite_bool(row[13]),
- turn_complete=_from_sqlite_bool(row[14]),
- interrupted=_from_sqlite_bool(row[15]),
- error_code=row[16],
- error_message=row[17],
+ invocation_id=row[2],
+ author=row[3],
+ timestamp=_julian_to_datetime(row[4]),
+ event_json=from_json(row[5]) if row[5] else {},
)
for row in rows
]
@@ -580,13 +561,12 @@ async def get_events(
Notes:
Uses index on (session_id, timestamp ASC).
- Parses JSON fields and converts BLOB actions to bytes.
- Converts INTEGER booleans back to bool/None.
+ Parses event_data TEXT back to dict for event_json field.
"""
return await async_(self._get_events)(session_id, after_timestamp, limit)
-class SqliteADKMemoryStore(BaseSyncADKMemoryStore["SqliteConfig"]):
+class SqliteADKMemoryStore(BaseAsyncADKMemoryStore["SqliteConfig"]):
"""SQLite ADK memory store using synchronous SQLite driver.
Implements memory entry storage for Google Agent Development Kit
@@ -645,7 +625,7 @@ def __init__(self, config: "SqliteConfig") -> None:
"""
super().__init__(config)
- def _get_create_memory_table_sql(self) -> str:
+ async def _get_create_memory_table_sql(self) -> str:
"""Get SQLite CREATE TABLE SQL for memory entries.
Returns:
@@ -737,7 +717,7 @@ def _enable_foreign_keys(self, connection: Any) -> None:
"""
connection.execute("PRAGMA foreign_keys = ON")
- def create_tables(self) -> None:
+ def _create_tables(self) -> None:
"""Create the memory table and indexes if they don't exist.
Skips table creation if memory store is disabled.
@@ -747,9 +727,13 @@ def create_tables(self) -> None:
with self._config.provide_session() as driver:
self._enable_foreign_keys(driver.connection)
- driver.execute_script(self._get_create_memory_table_sql())
+ driver.execute_script(run_(self._get_create_memory_table_sql)())
+
+ async def create_tables(self) -> None:
+ """Create tables if they don't exist."""
+ await async_(self._create_tables)()
- def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
+ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
"""Bulk insert memory entries with deduplication.
Uses INSERT OR IGNORE to skip duplicates based on event_id
@@ -833,7 +817,11 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object
return inserted_count
- def search_entries(
+ async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int:
+ """Bulk insert memory entries with deduplication."""
+ return await async_(self._insert_memory_entries)(entries, owner_id)
+
+ def _search_entries(
self, query: str, app_name: str, user_id: str, limit: "int | None" = None
) -> "list[MemoryRecord]":
"""Search memory entries by text query.
@@ -863,6 +851,12 @@ def search_entries(
logger.warning("FTS search failed; falling back to simple search: %s", exc)
return self._search_entries_simple(query, app_name, user_id, effective_limit)
+ async def search_entries(
+ self, query: str, app_name: str, user_id: str, limit: "int | None" = None
+ ) -> "list[MemoryRecord]":
+ """Search memory entries by text query."""
+ return await async_(self._search_entries)(query, app_name, user_id, limit)
+
def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]":
sql = f"""
SELECT m.id, m.session_id, m.app_name, m.user_id, m.event_id, m.author,
@@ -915,7 +909,7 @@ def _fetch_records(self, sql: str, params: "tuple[Any, ...]") -> "list[MemoryRec
for row in rows
]
- def delete_entries_by_session(self, session_id: str) -> int:
+ def _delete_entries_by_session(self, session_id: str) -> int:
"""Delete all memory entries for a specific session.
Args:
@@ -934,7 +928,11 @@ def delete_entries_by_session(self, session_id: str) -> int:
return deleted_count
- def delete_entries_older_than(self, days: int) -> int:
+ async def delete_entries_by_session(self, session_id: str) -> int:
+ """Delete all memory entries for a specific session."""
+ return await async_(self._delete_entries_by_session)(session_id)
+
+ def _delete_entries_older_than(self, days: int) -> int:
"""Delete memory entries older than specified days.
Used for TTL cleanup operations.
@@ -956,3 +954,7 @@ def delete_entries_older_than(self, days: int) -> int:
conn.commit()
return deleted_count
+
+ async def delete_entries_older_than(self, days: int) -> int:
+ """Delete memory entries older than specified days."""
+ return await async_(self._delete_entries_older_than)(days)
diff --git a/sqlspec/config.py b/sqlspec/config.py
index e3a96c35e..63d1ae2dd 100644
--- a/sqlspec/config.py
+++ b/sqlspec/config.py
@@ -35,7 +35,11 @@
__all__ = (
+ "ADKCompressionConfig",
"ADKConfig",
+ "ADKPartitionConfig",
+ "ADKRetentionConfig",
+ "ADKSqliteOptimizationConfig",
"AsyncConfigT",
"AsyncDatabaseConfig",
"ConfigT",
@@ -374,6 +378,178 @@ class FastAPIConfig(StarletteConfig):
"""
+class ADKPartitionConfig(TypedDict):
+ """Configuration for table partitioning and sharding strategies.
+
+ Controls how ADK tables are partitioned across backends that support it.
+ Backends without native partitioning support ignore these settings.
+
+ Example:
+ extension_config={
+ "adk": {
+ "partitioning": {
+ "strategy": "range",
+ "partition_key": "created_at",
+ "interval": "month",
+ }
+ }
+ }
+ """
+
+ strategy: NotRequired[Literal["range", "list", "hash"]]
+ """Partitioning strategy. Default: None (no partitioning).
+
+ - range: Partition by range of values (e.g., time-based)
+ - list: Partition by discrete value lists
+ - hash: Partition by hash of the partition key
+
+ Supported by: PostgreSQL, MySQL 8+, Oracle, BigQuery, Spanner.
+ Ignored by: SQLite, DuckDB.
+ """
+
+ partition_key: NotRequired[str]
+ """Column name used as the partition key.
+
+ For range partitioning with time-based data, this is typically a timestamp column
+ like 'created_at'. For hash partitioning, this is typically the primary key.
+ """
+
+ interval: NotRequired[str]
+ """Partition interval for range partitioning.
+
+ Examples: 'day', 'week', 'month', 'year'.
+ Only meaningful when strategy is 'range'.
+ """
+
+
+class ADKRetentionConfig(TypedDict):
+ """Configuration for data retention and TTL policies.
+
+ Controls automatic cleanup of expired data. Backends with native TTL support
+ (CockroachDB Row-Level TTL, Spanner Row Deletion Policy) use database-level
+ enforcement. Others fall back to application-level sweep queries.
+
+ Example:
+ extension_config={
+ "adk": {
+ "retention": {
+ "session_ttl_seconds": 86400,
+ "event_ttl_seconds": 604800,
+ "memory_ttl_seconds": 0,
+ }
+ }
+ }
+ """
+
+ session_ttl_seconds: NotRequired[int]
+ """TTL for session records in seconds. Default: 0 (no expiry).
+
+ When set, sessions older than this threshold are eligible for cleanup.
+ Backends with native TTL (CockroachDB, Spanner) enforce this at the database level.
+ Others require application-level cleanup via periodic sweep.
+ """
+
+ event_ttl_seconds: NotRequired[int]
+ """TTL for event records in seconds. Default: 0 (no expiry).
+
+ When set, events older than this threshold are eligible for cleanup.
+ """
+
+ memory_ttl_seconds: NotRequired[int]
+ """TTL for memory entries in seconds. Default: 0 (no expiry).
+
+ When set, memory entries older than this threshold are eligible for cleanup.
+ """
+
+ sweep_interval_seconds: NotRequired[int]
+ """Interval between application-level cleanup sweeps in seconds. Default: 3600 (1 hour).
+
+ Only used when the backend does not support native TTL enforcement.
+ Set to 0 to disable automatic sweeps (manual cleanup only).
+ """
+
+
+class ADKCompressionConfig(TypedDict):
+ """Configuration for table-level compression.
+
+ Controls compression of ADK table storage. Support and algorithms vary by backend.
+
+ Example:
+ extension_config={
+ "adk": {
+ "compression": {
+ "enabled": True,
+ "algorithm": "zstd",
+ }
+ }
+ }
+ """
+
+ enabled: NotRequired[bool]
+ """Enable table compression. Default: False.
+
+ When True, adapters that support table-level compression will apply it
+ during table creation.
+ """
+
+ algorithm: NotRequired[str]
+ """Compression algorithm name. Backend-specific.
+
+ Examples:
+ - PostgreSQL (with TOAST): 'pglz', 'lz4' (PG14+)
+ - MySQL/InnoDB: 'zlib'
+ - Oracle: 'basic', 'oltp', 'query_high', 'archive_high'
+ - DuckDB: 'zstd', 'snappy'
+
+ When omitted, the backend default is used.
+ """
+
+ level: NotRequired[int]
+ """Compression level (where supported). Higher levels trade CPU for space savings.
+
+ Valid ranges depend on the algorithm and backend.
+ """
+
+
+class ADKSqliteOptimizationConfig(TypedDict):
+ """SQLite-specific PRAGMA optimization settings.
+
+ Controls SQLite performance tuning parameters applied at connection time.
+ These settings are ignored by non-SQLite adapters.
+
+ Example:
+ extension_config={
+ "adk": {
+ "sqlite_optimization": {
+ "cache_size": -64000,
+ "mmap_size": 31457280,
+ "journal_size_limit": 67108864,
+ }
+ }
+ }
+ """
+
+ cache_size: NotRequired[int]
+ """SQLite page cache size. Default: -64000 (64 MB, negative means KiB).
+
+ Larger caches reduce disk I/O for read-heavy workloads.
+ Negative values specify size in KiB; positive values specify page count.
+ """
+
+ mmap_size: NotRequired[int]
+ """SQLite memory-mapped I/O size in bytes. Default: 31457280 (30 MB).
+
+ Enables memory-mapped I/O for faster reads. Set to 0 to disable.
+ """
+
+ journal_size_limit: NotRequired[int]
+ """SQLite journal file size limit in bytes. Default: 67108864 (64 MB).
+
+ Limits the size of the WAL or rollback journal file.
+ Prevents unbounded journal growth in write-heavy workloads.
+ """
+
+
class ADKConfig(TypedDict):
"""Configuration options for ADK session and memory store extension.
@@ -460,6 +636,27 @@ class ADKConfig(TypedDict):
"tenant_acme_memories"
"""
+ artifact_table: NotRequired[str]
+ """Name of the artifact versions table. Default: 'adk_artifact_versions'
+
+ Examples:
+ "agent_artifacts"
+ "my_app_artifact_versions"
+ """
+
+ artifact_storage_uri: NotRequired[str]
+ """Base URI for artifact content storage.
+
+ Points to a ``sqlspec/storage/`` backend where artifact binary content
+ is stored. Can be a direct URI (``s3://bucket/path``, ``file:///path``)
+ or a registered alias in the storage registry.
+
+ Examples:
+ "s3://my-bucket/adk-artifacts/"
+ "file:///var/data/artifacts/"
+ "gcs://my-gcs-bucket/artifacts/"
+ """
+
memory_use_fts: NotRequired[bool]
"""Enable full-text search when supported. Default: False.
@@ -585,6 +782,81 @@ class ADKConfig(TypedDict):
expires_index_options: NotRequired[str]
"""Adapter-specific options for the expires/index used in ADK stores."""
+ # --- Capability-based configuration (Chapter 2: schema-capability-config) ---
+
+ fts_language: NotRequired[str]
+ """Language configuration for full-text search indexing. Default: 'english'.
+
+ Controls the language dictionary/stemmer used by FTS implementations:
+ - PostgreSQL: to_tsvector/to_tsquery language parameter
+ - SQLite FTS5: tokenizer language for unicode61/porter
+ - MySQL: FULLTEXT parser language (with ngram for CJK on 5.7.6+)
+ - Oracle: CTXSYS.CONTEXT lexer language
+ - Spanner: TOKENIZE_FULLTEXT language parameter
+ - DuckDB: FTS stemmer language
+
+ Only takes effect when ``memory_use_fts`` is True.
+
+ Common values: 'english', 'simple', 'german', 'french', 'spanish',
+ 'portuguese', 'italian', 'dutch', 'russian', 'chinese', 'japanese', 'korean'.
+
+ Notes:
+ Available languages vary by backend. Backends that do not support the
+ specified language will fall back to 'simple' or 'english'.
+ """
+
+ schema_version: NotRequired[int]
+ """Explicit schema version for ADK tables. Default: None (auto-detect).
+
+ When set, locks the ADK schema to a specific version. This is useful for:
+ - Preventing automatic schema upgrades in production
+ - Pinning to a known-good schema during testing
+ - Coordinating schema changes across multiple application instances
+
+ When None, the ADK extension auto-detects the current schema version
+ and applies any pending upgrades during initialization.
+
+ Notes:
+ Schema versions are monotonically increasing integers managed by
+ the ADK extension migration system. Setting this to a version
+ lower than the current database schema will raise a configuration
+ error at startup.
+ """
+
+ partitioning: NotRequired[ADKPartitionConfig]
+ """Table partitioning configuration. Default: None (no partitioning).
+
+ Controls how ADK tables are partitioned for improved query performance
+ and data management at scale. See ``ADKPartitionConfig`` for options.
+
+ Supported by: PostgreSQL, MySQL 8+, Oracle, BigQuery, Spanner.
+ Ignored by: SQLite, DuckDB.
+ """
+
+ retention: NotRequired[ADKRetentionConfig]
+ """Data retention and TTL configuration. Default: None (no automatic cleanup).
+
+ Controls automatic expiry and cleanup of old session, event, and memory data.
+ See ``ADKRetentionConfig`` for options.
+
+ Backends with native TTL (CockroachDB, Spanner) use database-level enforcement.
+ Others fall back to application-level sweep queries.
+ """
+
+ compression: NotRequired[ADKCompressionConfig]
+ """Table compression configuration. Default: None (no compression).
+
+ Controls table-level compression for ADK tables.
+ See ``ADKCompressionConfig`` for options.
+ """
+
+ sqlite_optimization: NotRequired[ADKSqliteOptimizationConfig]
+ """SQLite-specific PRAGMA optimization settings. Default: None (SQLite defaults).
+
+ Controls SQLite performance tuning parameters. Ignored by non-SQLite adapters.
+ See ``ADKSqliteOptimizationConfig`` for options.
+ """
+
class EventsConfig(TypedDict):
"""Configuration options for the events extension.
diff --git a/sqlspec/core/splitter.py b/sqlspec/core/splitter.py
index 45908cf27..0f5414f62 100644
--- a/sqlspec/core/splitter.py
+++ b/sqlspec/core/splitter.py
@@ -623,6 +623,8 @@ def statement_terminators(self) -> "set[str]":
_pattern_cache: LRUCache | None = None
_result_cache: LRUCache | None = None
_cache_lock = threading.Lock()
+_unknown_dialect_warning_lock = threading.Lock()
+_warned_unknown_dialects: set[str] = set()
def _get_pattern_cache() -> LRUCache:
@@ -653,6 +655,16 @@ def _get_result_cache() -> LRUCache:
return _result_cache
+def _warn_unknown_dialect_once(dialect: "str | None") -> None:
+ """Emit the generic splitter fallback warning once per dialect."""
+ key = "" if dialect is None else dialect.lower()
+ with _unknown_dialect_warning_lock:
+ if key in _warned_unknown_dialects:
+ return
+ _warned_unknown_dialects.add(key)
+ logger.warning("Unknown dialect '%s', using generic SQL splitter", dialect)
+
+
@mypyc_attr(allow_interpreted_subclasses=False)
class StatementSplitter:
"""SQL script splitter with caching and dialect support."""
@@ -933,7 +945,7 @@ def split_sql_script(script: str, dialect: str | None = None, strip_trailing_ter
config = dialect_configs.get(dialect.lower())
if not config:
- logger.warning("Unknown dialect '%s', using generic SQL splitter", dialect)
+ _warn_unknown_dialect_once(dialect)
config = GenericDialectConfig()
splitter = StatementSplitter(config, strip_trailing_semicolon=strip_trailing_terminator)
@@ -949,6 +961,8 @@ def clear_splitter_caches() -> None:
result_cache = _get_result_cache()
pattern_cache.clear()
result_cache.clear()
+ with _unknown_dialect_warning_lock:
+ _warned_unknown_dialects.clear()
def get_splitter_cache_stats() -> "dict[str, Any]":
diff --git a/sqlspec/extensions/adk/__init__.py b/sqlspec/extensions/adk/__init__.py
index c7877b1a5..7f6cfc586 100644
--- a/sqlspec/extensions/adk/__init__.py
+++ b/sqlspec/extensions/adk/__init__.py
@@ -1,20 +1,24 @@
-"""Google ADK session backend extension for SQLSpec.
+"""Google ADK session, memory, and artifact backend extension for SQLSpec.
-Provides session, event, and memory storage for Google Agent Development Kit using
-SQLSpec database adapters.
+Provides session, event, memory, and artifact storage for Google Agent Development Kit
+using SQLSpec database adapters.
Public API exports:
- ADKConfig: TypedDict for extension config (type-safe configuration)
- SQLSpecSessionService: Main service class implementing BaseSessionService
- SQLSpecMemoryService: Main async service class implementing BaseMemoryService
- SQLSpecSyncMemoryService: Sync memory service for sync adapters
+ - SQLSpecArtifactService: Artifact service implementing BaseArtifactService
- BaseAsyncADKStore: Base class for async database store implementations
- BaseSyncADKStore: Base class for sync database store implementations
- BaseAsyncADKMemoryStore: Base class for async memory store implementations
- BaseSyncADKMemoryStore: Base class for sync memory store implementations
+ - BaseAsyncADKArtifactStore: Base class for async artifact metadata stores
+ - BaseSyncADKArtifactStore: Base class for sync artifact metadata stores
- SessionRecord: TypedDict for session database records
- EventRecord: TypedDict for event database records
- MemoryRecord: TypedDict for memory database records
+ - ArtifactRecord: TypedDict for artifact metadata database records
Example (with extension_config):
from sqlspec.adapters.asyncpg import AsyncpgConfig
@@ -45,6 +49,12 @@
from sqlspec.config import ADKConfig
from sqlspec.extensions.adk._types import EventRecord, SessionRecord
+from sqlspec.extensions.adk.artifact import (
+ ArtifactRecord,
+ BaseAsyncADKArtifactStore,
+ BaseSyncADKArtifactStore,
+ SQLSpecArtifactService,
+)
from sqlspec.extensions.adk.memory import (
BaseAsyncADKMemoryStore,
BaseSyncADKMemoryStore,
@@ -57,12 +67,16 @@
__all__ = (
"ADKConfig",
+ "ArtifactRecord",
+ "BaseAsyncADKArtifactStore",
"BaseAsyncADKMemoryStore",
"BaseAsyncADKStore",
+ "BaseSyncADKArtifactStore",
"BaseSyncADKMemoryStore",
"BaseSyncADKStore",
"EventRecord",
"MemoryRecord",
+ "SQLSpecArtifactService",
"SQLSpecMemoryService",
"SQLSpecSessionService",
"SQLSpecSyncMemoryService",
diff --git a/sqlspec/extensions/adk/_types.py b/sqlspec/extensions/adk/_types.py
index 651431651..3f11b62f0 100644
--- a/sqlspec/extensions/adk/_types.py
+++ b/sqlspec/extensions/adk/_types.py
@@ -27,25 +27,15 @@ class SessionRecord(TypedDict):
class EventRecord(TypedDict):
"""Database record for an event.
- Represents the schema for events stored in the database.
- Follows the ADK Event model plus session metadata.
+ Stores the full ADK Event as a single JSON blob (``event_json``) alongside
+ a small number of indexed scalar columns used for query filtering.
+
+ This design eliminates column drift with upstream ADK: new Event fields are
+ automatically captured in ``event_json`` without schema changes.
"""
- id: str
- app_name: str
- user_id: str
session_id: str
invocation_id: str
author: str
- branch: "str | None"
- actions: bytes
- long_running_tool_ids_json: "str | None"
timestamp: datetime
- content: "dict[str, Any] | None"
- grounding_metadata: "dict[str, Any] | None"
- custom_metadata: "dict[str, Any] | None"
- partial: "bool | None"
- turn_complete: "bool | None"
- interrupted: "bool | None"
- error_code: "str | None"
- error_message: "str | None"
+ event_json: "dict[str, Any]"
diff --git a/sqlspec/extensions/adk/artifact/__init__.py b/sqlspec/extensions/adk/artifact/__init__.py
new file mode 100644
index 000000000..36c3f8478
--- /dev/null
+++ b/sqlspec/extensions/adk/artifact/__init__.py
@@ -0,0 +1,57 @@
+"""Google ADK artifact service extension for SQLSpec.
+
+Provides artifact versioning and storage for Google Agent Development Kit
+using SQLSpec database adapters for metadata and ``sqlspec/storage/`` backends
+for content.
+
+Public API exports:
+ - SQLSpecArtifactService: Main service implementing BaseArtifactService
+ - BaseAsyncADKArtifactStore: Base class for async artifact metadata stores
+ - BaseSyncADKArtifactStore: Base class for sync artifact metadata stores
+ - ArtifactRecord: TypedDict for artifact metadata database records
+
+Example:
+ from sqlspec.adapters.asyncpg import AsyncpgConfig
+ from sqlspec.extensions.adk.artifact import SQLSpecArtifactService
+
+ config = AsyncpgConfig(
+ connection_config={"dsn": "postgresql://..."},
+ extension_config={
+ "adk": {
+ "artifact_table": "adk_artifact_versions",
+ }
+ }
+ )
+
+ # Create an adapter-specific artifact store (e.g., AsyncpgADKArtifactStore)
+ # and ensure tables exist:
+ artifact_store = AsyncpgADKArtifactStore(config)
+ await artifact_store.ensure_table()
+
+ # Create the service with a storage backend URI:
+ service = SQLSpecArtifactService(
+ store=artifact_store,
+ artifact_storage_uri="s3://my-bucket/adk-artifacts/",
+ )
+
+ # Save an artifact (returns version number starting from 0):
+ version = await service.save_artifact(
+ app_name="my_app",
+ user_id="user123",
+ filename="report.pdf",
+ artifact=part,
+ )
+
+ # Load artifact content:
+ loaded = await service.load_artifact(
+ app_name="my_app",
+ user_id="user123",
+ filename="report.pdf",
+ )
+"""
+
+from sqlspec.extensions.adk.artifact._types import ArtifactRecord
+from sqlspec.extensions.adk.artifact.service import SQLSpecArtifactService
+from sqlspec.extensions.adk.artifact.store import BaseAsyncADKArtifactStore, BaseSyncADKArtifactStore
+
+__all__ = ("ArtifactRecord", "BaseAsyncADKArtifactStore", "BaseSyncADKArtifactStore", "SQLSpecArtifactService")
diff --git a/sqlspec/extensions/adk/artifact/_types.py b/sqlspec/extensions/adk/artifact/_types.py
new file mode 100644
index 000000000..dcedffcf6
--- /dev/null
+++ b/sqlspec/extensions/adk/artifact/_types.py
@@ -0,0 +1,32 @@
+"""Type definitions for ADK artifact extension.
+
+These types define the database record structures for storing artifact metadata.
+They are separate from the Pydantic models to keep mypyc compilation working.
+"""
+
+from datetime import datetime
+from typing import Any, TypedDict
+
+__all__ = ("ArtifactRecord",)
+
+
+class ArtifactRecord(TypedDict):
+ """Database record for an artifact version.
+
+ Represents the schema for artifact metadata stored in the database.
+ Content is stored separately in object storage; this record tracks
+ versioning, ownership, and the canonical URI pointing to the content.
+
+ The composite key is (app_name, user_id, session_id, filename, version),
+ where session_id may be NULL for user-scoped artifacts.
+ """
+
+ app_name: str
+ user_id: str
+ session_id: "str | None"
+ filename: str
+ version: int
+ mime_type: "str | None"
+ canonical_uri: str
+ custom_metadata: "dict[str, Any] | None"
+ created_at: datetime
diff --git a/sqlspec/extensions/adk/artifact/service.py b/sqlspec/extensions/adk/artifact/service.py
new file mode 100644
index 000000000..cecb8f076
--- /dev/null
+++ b/sqlspec/extensions/adk/artifact/service.py
@@ -0,0 +1,509 @@
+"""SQLSpec-backed artifact service for Google ADK.
+
+Implements ``BaseArtifactService`` by composing SQL-backed metadata storage
+(via :class:`BaseAsyncADKArtifactStore`) with ``sqlspec/storage/`` content
+backends (via :class:`StorageRegistry`).
+
+Metadata (version, filename, MIME type, custom metadata, canonical URI) lives
+in a SQL table. Content bytes live in object storage addressed by canonical
+URI. Versioning is append-only with monotonically increasing version numbers
+starting from 0.
+"""
+
+import json
+import logging
+import re
+from typing import TYPE_CHECKING, Any
+
+from google.adk.artifacts.base_artifact_service import BaseArtifactService
+
+from sqlspec.extensions.adk.artifact._types import ArtifactRecord
+from sqlspec.storage.registry import StorageRegistry, storage_registry
+from sqlspec.utils.logging import get_logger, log_with_context
+
+if TYPE_CHECKING:
+ from google.adk.artifacts.base_artifact_service import ArtifactVersion
+ from google.genai import types
+
+ from sqlspec.extensions.adk.artifact.store import BaseAsyncADKArtifactStore
+
+logger = get_logger("sqlspec.extensions.adk.artifact.service")
+
+__all__ = ("SQLSpecArtifactService",)
+
+# Matches path traversal and absolute path components
+_UNSAFE_PATH_CHARS = re.compile(r"(?:^|/)\.\.(?:/|$)|[\x00]")
+
+
+def _sanitize_path_component(value: str) -> str:
+ """Sanitize a path component to prevent directory traversal.
+
+ Removes leading/trailing slashes, rejects ``..`` traversals, and
+ replaces NUL bytes.
+
+ Args:
+ value: Raw path component.
+
+ Returns:
+ Sanitized path component.
+
+ Raises:
+ ValueError: If the value contains path traversal sequences.
+ """
+ value = value.strip("/")
+ if _UNSAFE_PATH_CHARS.search(value):
+ msg = f"Unsafe path component: {value!r}"
+ raise ValueError(msg)
+ return value
+
+
+def _build_content_path(
+ app_name: str, user_id: str, filename: str, version: int, session_id: "str | None" = None
+) -> str:
+ """Build the storage path for artifact content.
+
+ Pattern:
+ ``apps/{app_name}/users/{user_id}/[sessions/{session_id}/]artifacts/{filename}/v{version}``
+
+ All path components are sanitized to prevent directory traversal.
+
+ Args:
+ app_name: Application name.
+ user_id: User identifier.
+ filename: Artifact filename.
+ version: Version number.
+ session_id: Optional session identifier.
+
+ Returns:
+ Sanitized storage path.
+ """
+ parts = ["apps", _sanitize_path_component(app_name), "users", _sanitize_path_component(user_id)]
+ if session_id is not None:
+ parts.extend(["sessions", _sanitize_path_component(session_id)])
+ parts.extend(["artifacts", _sanitize_path_component(filename), f"v{version}"])
+ return "/".join(parts)
+
+
+def _extract_mime_type(artifact: "types.Part | dict[str, Any]") -> "str | None":
+ """Extract MIME type from an artifact Part.
+
+ Checks ``inline_data.mime_type`` and ``file_data.mime_type`` on the Part.
+
+ Args:
+ artifact: ADK Part or dict representation.
+
+ Returns:
+ MIME type string, or None if not determinable.
+ """
+ if isinstance(artifact, dict):
+ # Handle camelCase and snake_case keys
+ inline = artifact.get("inline_data") or artifact.get("inlineData")
+ if isinstance(inline, dict):
+ return inline.get("mime_type") or inline.get("mimeType")
+ file_data = artifact.get("file_data") or artifact.get("fileData")
+ if isinstance(file_data, dict):
+ return file_data.get("mime_type") or file_data.get("mimeType")
+ return None
+
+ # types.Part object
+ if hasattr(artifact, "inline_data") and artifact.inline_data is not None:
+ return getattr(artifact.inline_data, "mime_type", None)
+ if hasattr(artifact, "file_data") and artifact.file_data is not None:
+ return getattr(artifact.file_data, "mime_type", None)
+ return None
+
+
+def _serialize_artifact(artifact: "types.Part | dict[str, Any]") -> bytes:
+ """Serialize an artifact Part to bytes for content storage.
+
+ The artifact is serialized as JSON via ``model_dump(exclude_none=True)``.
+ This preserves the full Part structure including text, inline_data,
+ file_data, and any future Part fields.
+
+ Args:
+ artifact: ADK Part or dict representation.
+
+ Returns:
+ JSON-encoded bytes.
+ """
+ if isinstance(artifact, dict):
+ return json.dumps(artifact, default=str).encode("utf-8")
+
+ # Use Pydantic model serialization
+ if hasattr(artifact, "model_dump"):
+ data = artifact.model_dump(exclude_none=True)
+ return json.dumps(data, default=str).encode("utf-8")
+
+ # Fallback for unexpected types
+ return json.dumps({"text": str(artifact)}).encode("utf-8")
+
+
+def _deserialize_artifact(data: bytes) -> "types.Part":
+ """Deserialize bytes back into an ADK Part.
+
+ Args:
+ data: JSON-encoded bytes from content storage.
+
+ Returns:
+ Reconstructed Part object.
+ """
+ from google.genai import types
+
+ parsed = json.loads(data.decode("utf-8"))
+ return types.Part.model_validate(parsed)
+
+
+def _record_to_artifact_version(record: "ArtifactRecord") -> "ArtifactVersion":
+ """Convert a database artifact record to an ADK ArtifactVersion.
+
+ Args:
+ record: Database artifact record.
+
+ Returns:
+ ArtifactVersion model instance.
+ """
+ from google.adk.artifacts.base_artifact_service import ArtifactVersion
+
+ return ArtifactVersion(
+ version=record["version"],
+ canonical_uri=record["canonical_uri"],
+ custom_metadata=record["custom_metadata"] or {},
+ create_time=record["created_at"].timestamp(),
+ mime_type=record["mime_type"],
+ )
+
+
+class SQLSpecArtifactService(BaseArtifactService):
+ """SQLSpec-backed implementation of BaseArtifactService.
+
+ Composes SQL metadata storage with ``sqlspec/storage/`` content backends
+ to provide versioned artifact management for Google ADK.
+
+ Metadata (version number, filename, MIME type, custom metadata, canonical
+ URI) is stored in a SQL table managed by the artifact store. Content
+ bytes are stored in object storage (S3, GCS, Azure, local filesystem)
+ via the storage registry.
+
+ Args:
+ store: Artifact metadata store implementation.
+ artifact_storage_uri: Base URI for content storage (e.g.,
+ ``"s3://my-bucket/adk-artifacts/"``, ``"file:///var/data/artifacts/"``).
+ Can also be a registered alias in the storage registry.
+ registry: Storage registry to use. Defaults to the global singleton.
+
+ Example:
+ from sqlspec.adapters.asyncpg.adk.artifact_store import AsyncpgADKArtifactStore
+ from sqlspec.extensions.adk.artifact import SQLSpecArtifactService
+
+ artifact_store = AsyncpgADKArtifactStore(config)
+ await artifact_store.ensure_table()
+
+ service = SQLSpecArtifactService(
+ store=artifact_store,
+ artifact_storage_uri="s3://my-bucket/adk-artifacts/",
+ )
+
+ version = await service.save_artifact(
+ app_name="my_app",
+ user_id="user123",
+ filename="output.png",
+ artifact=part,
+ )
+ """
+
+ def __init__(
+ self, store: "BaseAsyncADKArtifactStore", artifact_storage_uri: str, registry: "StorageRegistry | None" = None
+ ) -> None:
+ self._store = store
+ self._artifact_storage_uri = artifact_storage_uri.rstrip("/")
+ self._registry = registry or storage_registry
+
+ @property
+ def store(self) -> "BaseAsyncADKArtifactStore":
+ """Return the artifact metadata store."""
+ return self._store
+
+ @property
+ def artifact_storage_uri(self) -> str:
+ """Return the base URI for content storage."""
+ return self._artifact_storage_uri
+
+ async def save_artifact(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ filename: str,
+ artifact: "types.Part | dict[str, Any]",
+ session_id: "str | None" = None,
+ custom_metadata: "dict[str, Any] | None" = None,
+ ) -> int:
+ """Save an artifact, returning the new version number.
+
+ Writes content to object storage first, then inserts the metadata
+ row. If content write succeeds but metadata insert fails, the
+ orphaned content blob is logged but not automatically cleaned up
+ (eventual consistency is acceptable; orphan sweep can be added later).
+
+ Args:
+ app_name: Application name.
+ user_id: User identifier.
+ filename: Artifact filename.
+ artifact: ADK Part or dict to save.
+ session_id: Session identifier (None for user-scoped).
+ custom_metadata: Optional per-version metadata dict.
+
+ Returns:
+ The version number (0-based, incrementing).
+ """
+ from google.adk.artifacts.base_artifact_service import ensure_part
+
+ # Normalize artifact to Part
+ artifact_part: types.Part = ensure_part(artifact)
+
+ # Determine the next version
+ version = await self._store.get_next_version(
+ app_name=app_name, user_id=user_id, filename=filename, session_id=session_id
+ )
+
+ # Build the content path and canonical URI
+ content_path = _build_content_path(
+ app_name=app_name, user_id=user_id, filename=filename, version=version, session_id=session_id
+ )
+ canonical_uri = f"{self._artifact_storage_uri}/{content_path}"
+
+ # Serialize content
+ content_bytes = _serialize_artifact(artifact_part)
+
+ # Extract MIME type
+ mime_type = _extract_mime_type(artifact_part)
+
+ # Write content first (fail-fast before metadata)
+ backend = self._registry.get(self._artifact_storage_uri)
+ if hasattr(backend, "write_bytes_async"):
+ await backend.write_bytes_async(content_path, content_bytes)
+ else:
+ backend.write_bytes_sync(content_path, content_bytes)
+
+ # Insert metadata row
+ from datetime import datetime, timezone
+
+ record = ArtifactRecord(
+ app_name=app_name,
+ user_id=user_id,
+ session_id=session_id,
+ filename=filename,
+ version=version,
+ mime_type=mime_type,
+ canonical_uri=canonical_uri,
+ custom_metadata=custom_metadata,
+ created_at=datetime.now(tz=timezone.utc),
+ )
+ await self._store.insert_artifact(record)
+
+ log_with_context(
+ logger,
+ logging.DEBUG,
+ "adk.artifact.save",
+ app_name=app_name,
+ user_id=user_id,
+ filename=filename,
+ version=version,
+ session_id=session_id,
+ mime_type=mime_type,
+ )
+ return version
+
+ async def load_artifact(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ filename: str,
+ session_id: "str | None" = None,
+ version: "int | None" = None,
+ ) -> "types.Part | None":
+ """Load an artifact by reading metadata then fetching content.
+
+ Args:
+ app_name: Application name.
+ user_id: User identifier.
+ filename: Artifact filename.
+ session_id: Session identifier (None for user-scoped).
+ version: Specific version, or None for latest.
+
+ Returns:
+ Deserialized Part, or None if not found.
+ """
+ record = await self._store.get_artifact(
+ app_name=app_name, user_id=user_id, filename=filename, session_id=session_id, version=version
+ )
+ if record is None:
+ log_with_context(
+ logger,
+ logging.DEBUG,
+ "adk.artifact.load",
+ app_name=app_name,
+ filename=filename,
+ version=version,
+ found=False,
+ )
+ return None
+
+ # Derive content path from canonical URI
+ content_path = record["canonical_uri"].removeprefix(self._artifact_storage_uri + "/")
+
+ backend = self._registry.get(self._artifact_storage_uri)
+ if hasattr(backend, "read_bytes_async"):
+ content_bytes = await backend.read_bytes_async(content_path)
+ else:
+ content_bytes = backend.read_bytes_sync(content_path)
+
+ log_with_context(
+ logger,
+ logging.DEBUG,
+ "adk.artifact.load",
+ app_name=app_name,
+ filename=filename,
+ version=record["version"],
+ found=True,
+ )
+ return _deserialize_artifact(content_bytes)
+
+ async def list_artifact_keys(self, *, app_name: str, user_id: str, session_id: "str | None" = None) -> "list[str]":
+ """List distinct artifact filenames.
+
+ When ``session_id`` is provided, returns both session-scoped and
+ user-scoped filenames. When None, returns only user-scoped filenames.
+
+ Args:
+ app_name: Application name.
+ user_id: User identifier.
+ session_id: Session identifier.
+
+ Returns:
+ List of artifact filenames.
+ """
+ keys = await self._store.list_artifact_keys(app_name=app_name, user_id=user_id, session_id=session_id)
+ log_with_context(
+ logger,
+ logging.DEBUG,
+ "adk.artifact.list_keys",
+ app_name=app_name,
+ user_id=user_id,
+ session_id=session_id,
+ count=len(keys),
+ )
+ return keys
+
+ async def delete_artifact(
+ self, *, app_name: str, user_id: str, filename: str, session_id: "str | None" = None
+ ) -> None:
+ """Delete an artifact and all its versions.
+
+ Deletes metadata rows first (fail-fast), then removes content
+ objects from storage (best-effort).
+
+ Args:
+ app_name: Application name.
+ user_id: User identifier.
+ filename: Artifact filename.
+ session_id: Session identifier (None for user-scoped).
+ """
+ deleted_records = await self._store.delete_artifact(
+ app_name=app_name, user_id=user_id, filename=filename, session_id=session_id
+ )
+
+ # Best-effort content cleanup
+ backend = self._registry.get(self._artifact_storage_uri)
+ for record in deleted_records:
+ content_path = record["canonical_uri"].removeprefix(self._artifact_storage_uri + "/")
+ try:
+ if hasattr(backend, "delete_async"):
+ await backend.delete_async(content_path)
+ else:
+ backend.delete_sync(content_path)
+ except Exception:
+ log_with_context(
+ logger,
+ logging.WARNING,
+ "adk.artifact.delete.content_cleanup_failed",
+ canonical_uri=record["canonical_uri"],
+ version=record["version"],
+ )
+
+ log_with_context(
+ logger,
+ logging.DEBUG,
+ "adk.artifact.delete",
+ app_name=app_name,
+ filename=filename,
+ session_id=session_id,
+ versions_deleted=len(deleted_records),
+ )
+
+ async def list_versions(
+ self, *, app_name: str, user_id: str, filename: str, session_id: "str | None" = None
+ ) -> "list[int]":
+ """List all version numbers for an artifact.
+
+ Args:
+ app_name: Application name.
+ user_id: User identifier.
+ filename: Artifact filename.
+ session_id: Session identifier (None for user-scoped).
+
+ Returns:
+ Sorted list of version numbers.
+ """
+ records = await self._store.list_artifact_versions(
+ app_name=app_name, user_id=user_id, filename=filename, session_id=session_id
+ )
+ return [r["version"] for r in records]
+
+ async def list_artifact_versions(
+ self, *, app_name: str, user_id: str, filename: str, session_id: "str | None" = None
+ ) -> "list[ArtifactVersion]":
+ """List all versions with full metadata for an artifact.
+
+ Args:
+ app_name: Application name.
+ user_id: User identifier.
+ filename: Artifact filename.
+ session_id: Session identifier (None for user-scoped).
+
+ Returns:
+ List of ArtifactVersion objects ordered by version ascending.
+ """
+ records = await self._store.list_artifact_versions(
+ app_name=app_name, user_id=user_id, filename=filename, session_id=session_id
+ )
+ return [_record_to_artifact_version(r) for r in records]
+
+ async def get_artifact_version(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ filename: str,
+ session_id: "str | None" = None,
+ version: "int | None" = None,
+ ) -> "ArtifactVersion | None":
+ """Get metadata for a specific artifact version.
+
+ Args:
+ app_name: Application name.
+ user_id: User identifier.
+ filename: Artifact filename.
+ session_id: Session identifier (None for user-scoped).
+ version: Version number, or None for latest.
+
+ Returns:
+ ArtifactVersion if found, None otherwise.
+ """
+ record = await self._store.get_artifact(
+ app_name=app_name, user_id=user_id, filename=filename, session_id=session_id, version=version
+ )
+ if record is None:
+ return None
+ return _record_to_artifact_version(record)
diff --git a/sqlspec/extensions/adk/artifact/store.py b/sqlspec/extensions/adk/artifact/store.py
new file mode 100644
index 000000000..ec9c08a33
--- /dev/null
+++ b/sqlspec/extensions/adk/artifact/store.py
@@ -0,0 +1,363 @@
+"""Base store classes for ADK artifact metadata backend (sync and async).
+
+These abstract base classes define the database operations needed to manage
+artifact version metadata. Content storage is handled separately by
+``sqlspec/storage/`` backends; these stores only manage the relational
+metadata rows.
+
+Adapter-specific subclasses (e.g., ``AsyncpgADKArtifactStore``) implement
+the abstract methods with dialect-specific SQL.
+"""
+
+import logging
+import re
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any, Final, Generic, TypeVar, cast
+
+from sqlspec.observability import resolve_db_system
+from sqlspec.utils.logging import get_logger, log_with_context
+
+if TYPE_CHECKING:
+ from sqlspec.config import ADKConfig, DatabaseConfigProtocol
+ from sqlspec.extensions.adk.artifact._types import ArtifactRecord
+
+ConfigT = TypeVar("ConfigT", bound="DatabaseConfigProtocol[Any, Any, Any]")
+
+logger = get_logger("sqlspec.extensions.adk.artifact.store")
+
+__all__ = ("BaseAsyncADKArtifactStore", "BaseSyncADKArtifactStore")
+
+VALID_TABLE_NAME_PATTERN: Final = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
+MAX_TABLE_NAME_LENGTH: Final = 63
+
+
+def _validate_table_name(table_name: str) -> None:
+ """Validate table name for SQL safety.
+
+ Args:
+ table_name: Table name to validate.
+
+ Raises:
+ ValueError: If table name is invalid.
+ """
+ if not table_name:
+ msg = "Table name cannot be empty"
+ raise ValueError(msg)
+
+ if len(table_name) > MAX_TABLE_NAME_LENGTH:
+ msg = f"Table name too long: {len(table_name)} chars (max {MAX_TABLE_NAME_LENGTH})"
+ raise ValueError(msg)
+
+ if not VALID_TABLE_NAME_PATTERN.match(table_name):
+ msg = (
+ f"Invalid table name: {table_name!r}. "
+ "Must start with letter/underscore and contain only alphanumeric characters and underscores"
+ )
+ raise ValueError(msg)
+
+
+class BaseAsyncADKArtifactStore(ABC, Generic[ConfigT]):
+ """Base class for async SQLSpec-backed ADK artifact metadata stores.
+
+ Manages artifact version metadata in a SQL table. Content bytes are
+ stored externally via ``sqlspec/storage/`` backends and referenced
+ by canonical URI in each metadata row.
+
+ Subclasses must implement dialect-specific SQL queries.
+
+ Args:
+ config: SQLSpec database configuration with extension_config["adk"] settings.
+
+ Notes:
+ Configuration is read from config.extension_config["adk"]:
+ - artifact_table: Artifact versions table name (default: "adk_artifact_versions")
+ """
+
+ __slots__ = ("_artifact_table", "_config")
+
+ def __init__(self, config: ConfigT) -> None:
+ """Initialize the async ADK artifact store.
+
+ Args:
+ config: SQLSpec database configuration.
+ """
+ self._config = config
+ adk_config = self._get_adk_config()
+ self._artifact_table: str = str(adk_config.get("artifact_table", "adk_artifact_versions"))
+ _validate_table_name(self._artifact_table)
+
+ def _get_adk_config(self) -> "dict[str, Any]":
+ """Extract ADK configuration from extension_config.
+
+ Returns:
+ Dict with ADK configuration values.
+ """
+ extension_config = self._config.extension_config
+ return dict(cast("ADKConfig", extension_config.get("adk", {})))
+
+ @property
+ def config(self) -> ConfigT:
+ """Return the database configuration."""
+ return self._config
+
+ @property
+ def artifact_table(self) -> str:
+ """Return the artifact versions table name."""
+ return self._artifact_table
+
+ @abstractmethod
+ async def insert_artifact(self, record: "ArtifactRecord") -> None:
+ """Insert an artifact version metadata row.
+
+ Args:
+ record: Artifact metadata record to insert.
+ """
+
+ @abstractmethod
+ async def get_artifact(
+ self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None, version: "int | None" = None
+ ) -> "ArtifactRecord | None":
+ """Get a specific artifact version's metadata.
+
+ When ``version`` is None, returns the latest version.
+
+ Args:
+ app_name: Application name.
+ user_id: User identifier.
+ filename: Artifact filename.
+ session_id: Session identifier (None for user-scoped).
+ version: Specific version number, or None for latest.
+
+ Returns:
+ Artifact record if found, None otherwise.
+ """
+
+ @abstractmethod
+ async def list_artifact_keys(self, app_name: str, user_id: str, session_id: "str | None" = None) -> "list[str]":
+ """List distinct artifact filenames.
+
+ When ``session_id`` is provided, returns filenames from both
+ session-scoped and user-scoped artifacts. When None, returns
+ only user-scoped artifact filenames.
+
+ Args:
+ app_name: Application name.
+ user_id: User identifier.
+ session_id: Session identifier (None for user-scoped only).
+
+ Returns:
+ List of distinct artifact filenames.
+ """
+
+ @abstractmethod
+ async def list_artifact_versions(
+ self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None
+ ) -> "list[ArtifactRecord]":
+ """List all version records for an artifact, ordered by version ascending.
+
+ Args:
+ app_name: Application name.
+ user_id: User identifier.
+ filename: Artifact filename.
+ session_id: Session identifier (None for user-scoped).
+
+ Returns:
+ List of artifact records ordered by version ascending.
+ """
+
+ @abstractmethod
+ async def delete_artifact(
+ self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None
+ ) -> "list[ArtifactRecord]":
+ """Delete all version records for an artifact and return them.
+
+ The caller uses the returned records to clean up content from
+ object storage. Metadata is deleted first (fail-fast); content
+ cleanup is best-effort.
+
+ Args:
+ app_name: Application name.
+ user_id: User identifier.
+ filename: Artifact filename.
+ session_id: Session identifier (None for user-scoped).
+
+ Returns:
+ List of deleted artifact records (needed for content cleanup).
+ """
+
+ @abstractmethod
+ async def get_next_version(
+ self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None
+ ) -> int:
+ """Get the next version number for an artifact.
+
+ Returns 0 if no versions exist (first version), otherwise
+ ``max(version) + 1``.
+
+ Args:
+ app_name: Application name.
+ user_id: User identifier.
+ filename: Artifact filename.
+ session_id: Session identifier (None for user-scoped).
+
+ Returns:
+ Next version number (0-based).
+ """
+
+ @abstractmethod
+ async def create_table(self) -> None:
+ """Create the artifact versions table if it does not exist."""
+
+ async def ensure_table(self) -> None:
+ """Create the artifact table and emit a standardized log entry."""
+ await self.create_table()
+ log_with_context(
+ logger,
+ logging.DEBUG,
+ "adk.artifact.table.ready",
+ db_system=resolve_db_system(type(self).__name__),
+ artifact_table=self._artifact_table,
+ )
+
+
+class BaseSyncADKArtifactStore(ABC, Generic[ConfigT]):
+ """Base class for sync SQLSpec-backed ADK artifact metadata stores.
+
+ Synchronous counterpart of :class:`BaseAsyncADKArtifactStore`.
+
+ Args:
+ config: SQLSpec database configuration with extension_config["adk"] settings.
+ """
+
+ __slots__ = ("_artifact_table", "_config")
+
+ def __init__(self, config: ConfigT) -> None:
+ """Initialize the sync ADK artifact store.
+
+ Args:
+ config: SQLSpec database configuration.
+ """
+ self._config = config
+ adk_config = self._get_adk_config()
+ self._artifact_table: str = str(adk_config.get("artifact_table", "adk_artifact_versions"))
+ _validate_table_name(self._artifact_table)
+
+ def _get_adk_config(self) -> "dict[str, Any]":
+ """Extract ADK configuration from extension_config.
+
+ Returns:
+ Dict with ADK configuration values.
+ """
+ extension_config = self._config.extension_config
+ return dict(cast("ADKConfig", extension_config.get("adk", {})))
+
+ @property
+ def config(self) -> ConfigT:
+ """Return the database configuration."""
+ return self._config
+
+ @property
+ def artifact_table(self) -> str:
+ """Return the artifact versions table name."""
+ return self._artifact_table
+
+ @abstractmethod
+ def insert_artifact(self, record: "ArtifactRecord") -> None:
+ """Insert an artifact version metadata row.
+
+ Args:
+ record: Artifact metadata record to insert.
+ """
+
+ @abstractmethod
+ def get_artifact(
+ self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None, version: "int | None" = None
+ ) -> "ArtifactRecord | None":
+ """Get a specific artifact version's metadata.
+
+ When ``version`` is None, returns the latest version.
+
+ Args:
+ app_name: Application name.
+ user_id: User identifier.
+ filename: Artifact filename.
+ session_id: Session identifier (None for user-scoped).
+ version: Specific version number, or None for latest.
+
+ Returns:
+ Artifact record if found, None otherwise.
+ """
+
+ @abstractmethod
+ def list_artifact_keys(self, app_name: str, user_id: str, session_id: "str | None" = None) -> "list[str]":
+ """List distinct artifact filenames.
+
+ Args:
+ app_name: Application name.
+ user_id: User identifier.
+ session_id: Session identifier (None for user-scoped only).
+
+ Returns:
+ List of distinct artifact filenames.
+ """
+
+ @abstractmethod
+ def list_artifact_versions(
+ self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None
+ ) -> "list[ArtifactRecord]":
+ """List all version records for an artifact, ordered by version ascending.
+
+ Args:
+ app_name: Application name.
+ user_id: User identifier.
+ filename: Artifact filename.
+ session_id: Session identifier (None for user-scoped).
+
+ Returns:
+ List of artifact records ordered by version ascending.
+ """
+
+ @abstractmethod
+ def delete_artifact(
+ self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None
+ ) -> "list[ArtifactRecord]":
+ """Delete all version records for an artifact and return them.
+
+ Args:
+ app_name: Application name.
+ user_id: User identifier.
+ filename: Artifact filename.
+ session_id: Session identifier (None for user-scoped).
+
+ Returns:
+ List of deleted artifact records (needed for content cleanup).
+ """
+
+ @abstractmethod
+ def get_next_version(self, app_name: str, user_id: str, filename: str, session_id: "str | None" = None) -> int:
+ """Get the next version number for an artifact.
+
+ Args:
+ app_name: Application name.
+ user_id: User identifier.
+ filename: Artifact filename.
+ session_id: Session identifier (None for user-scoped).
+
+ Returns:
+ Next version number (0-based).
+ """
+
+ @abstractmethod
+ def create_table(self) -> None:
+ """Create the artifact versions table if it does not exist."""
+
+ def ensure_table(self) -> None:
+ """Create the artifact table and emit a standardized log entry."""
+ self.create_table()
+ log_with_context(
+ logger,
+ logging.DEBUG,
+ "adk.artifact.table.ready",
+ db_system=resolve_db_system(type(self).__name__),
+ artifact_table=self._artifact_table,
+ )
diff --git a/sqlspec/extensions/adk/converters.py b/sqlspec/extensions/adk/converters.py
index d68cbc141..f1bff6ec1 100644
--- a/sqlspec/extensions/adk/converters.py
+++ b/sqlspec/extensions/adk/converters.py
@@ -1,20 +1,37 @@
-"""Conversion functions between ADK models and database records."""
+"""Conversion functions between ADK models and database records.
+
+Implements full-event JSON storage: the entire Event is serialized via
+``Event.model_dump_json(exclude_none=True)`` into a single ``event_json``
+column, with a small set of indexed scalar columns extracted alongside for
+query performance. Reconstruction uses ``Event.model_validate_json()``.
+
+Also provides scoped-state helpers that normalise ADK state prefixes
+(``app:``, ``user:``, ``temp:``) so the shared service layer can split,
+filter, and merge state before handing it to backend stores.
+"""
-import json
-import pickle
from datetime import datetime, timezone
from typing import Any
from google.adk.events.event import Event
from google.adk.sessions import Session
-from google.genai import types
from sqlspec.extensions.adk._types import EventRecord, SessionRecord
-from sqlspec.utils.logging import get_logger
-logger = get_logger("sqlspec.extensions.adk.converters")
+__all__ = (
+ "event_to_record",
+ "filter_temp_state",
+ "merge_scoped_state",
+ "record_to_event",
+ "record_to_session",
+ "session_to_record",
+ "split_scoped_state",
+)
+
-__all__ = ("event_to_record", "record_to_event", "record_to_session", "session_to_record")
+# ---------------------------------------------------------------------------
+# Session converters
+# ---------------------------------------------------------------------------
def session_to_record(session: "Session") -> SessionRecord:
@@ -30,7 +47,7 @@ def session_to_record(session: "Session") -> SessionRecord:
id=session.id,
app_name=session.app_name,
user_id=session.user_id,
- state=session.state,
+ state=filter_temp_state(session.state),
create_time=datetime.now(timezone.utc),
update_time=datetime.fromtimestamp(session.last_update_time, tz=timezone.utc),
)
@@ -58,115 +75,115 @@ def record_to_session(record: SessionRecord, events: "list[EventRecord]") -> "Se
)
-def event_to_record(event: "Event", session_id: str, app_name: str, user_id: str) -> EventRecord:
- """Convert ADK Event to database record.
+# ---------------------------------------------------------------------------
+# Event converters (full-event JSON storage)
+# ---------------------------------------------------------------------------
+
+
+def event_to_record(event: "Event", session_id: str) -> EventRecord:
+ """Convert ADK Event to database record using full-event JSON storage.
+
+ The entire Event is serialized into ``event_json`` via Pydantic's
+ ``model_dump_json(exclude_none=True)``. A small number of indexed scalar
+ columns are extracted alongside for query performance.
Args:
event: ADK Event object.
session_id: ID of the parent session.
- app_name: Name of the application.
- user_id: ID of the user.
Returns:
EventRecord for database storage.
"""
- actions_bytes = pickle.dumps(event.actions)
-
- long_running_tool_ids_json = None
- if event.long_running_tool_ids:
- long_running_tool_ids_json = json.dumps(list(event.long_running_tool_ids))
-
- content_dict = None
- if event.content:
- content_dict = event.content.model_dump(exclude_none=True, mode="json")
-
- grounding_metadata_dict = None
- if event.grounding_metadata:
- grounding_metadata_dict = event.grounding_metadata.model_dump(exclude_none=True, mode="json")
-
- custom_metadata_dict = event.custom_metadata
-
return EventRecord(
- id=event.id,
- app_name=app_name,
- user_id=user_id,
session_id=session_id,
invocation_id=event.invocation_id,
author=event.author,
- branch=event.branch,
- actions=actions_bytes,
- long_running_tool_ids_json=long_running_tool_ids_json,
timestamp=datetime.fromtimestamp(event.timestamp, tz=timezone.utc),
- content=content_dict,
- grounding_metadata=grounding_metadata_dict,
- custom_metadata=custom_metadata_dict,
- partial=event.partial,
- turn_complete=event.turn_complete,
- interrupted=event.interrupted,
- error_code=event.error_code,
- error_message=event.error_message,
+ event_json=event.model_dump(exclude_none=True, mode="json"),
)
def record_to_event(record: "EventRecord") -> "Event":
"""Convert database record to ADK Event.
+ Reconstruction is lossless: the full Event is restored from
+ ``event_json`` via ``Event.model_validate_json()``.
+
Args:
record: Event database record.
Returns:
ADK Event object.
"""
- actions = pickle.loads(record["actions"]) # noqa: S301
+ return Event.model_validate(record["event_json"])
- long_running_tool_ids = None
- if record["long_running_tool_ids_json"]:
- long_running_tool_ids = set(json.loads(record["long_running_tool_ids_json"]))
- return Event(
- id=record["id"],
- invocation_id=record["invocation_id"],
- author=record["author"],
- branch=record["branch"],
- actions=actions,
- timestamp=record["timestamp"].timestamp(),
- content=_decode_content(record["content"]),
- long_running_tool_ids=long_running_tool_ids,
- partial=record["partial"],
- turn_complete=record["turn_complete"],
- error_code=record["error_code"],
- error_message=record["error_message"],
- interrupted=record["interrupted"],
- grounding_metadata=_decode_grounding_metadata(record["grounding_metadata"]),
- custom_metadata=record["custom_metadata"],
- )
+# ---------------------------------------------------------------------------
+# Scoped-state helpers
+# ---------------------------------------------------------------------------
+
+def filter_temp_state(state: "dict[str, Any]") -> "dict[str, Any]":
+ """Return a copy of *state* with all ``temp:`` keys removed.
-def _decode_content(content_dict: "dict[str, Any] | None") -> Any:
- """Decode content dictionary from database to ADK Content object.
+ ``temp:`` keys are process-local/session-runtime state and must never be
+ written to persistent storage.
Args:
- content_dict: Content dictionary from database.
+ state: ADK state dictionary (may contain ``temp:`` prefixed keys).
Returns:
- ADK Content object or None.
+ A new dict without any ``temp:``-prefixed keys.
"""
- if not content_dict:
- return None
+ return {k: v for k, v in state.items() if not k.startswith("temp:")}
- return types.Content.model_validate(content_dict)
-
-def _decode_grounding_metadata(grounding_dict: "dict[str, Any] | None") -> Any:
- """Decode grounding metadata dictionary from database to ADK object.
+def split_scoped_state(state: "dict[str, Any]") -> "tuple[dict[str, Any], dict[str, Any], dict[str, Any]]":
+ """Split state into app-scoped, user-scoped, and session-scoped buckets.
Args:
- grounding_dict: Grounding metadata dictionary from database.
+ state: Full session state dict (temp: already stripped).
Returns:
- ADK GroundingMetadata object or None.
+ Tuple of (app_state, user_state, session_state).
+ app_state: keys starting with "app:"
+ user_state: keys starting with "user:"
+ session_state: all other keys
"""
- if not grounding_dict:
- return None
+ app_state: dict[str, Any] = {}
+ user_state: dict[str, Any] = {}
+ session_state: dict[str, Any] = {}
+ for k, v in state.items():
+ if k.startswith("app:"):
+ app_state[k] = v
+ elif k.startswith("user:"):
+ user_state[k] = v
+ else:
+ session_state[k] = v
+ return app_state, user_state, session_state
+
+
+def merge_scoped_state(
+ session_state: "dict[str, Any]",
+ app_state: "dict[str, Any] | None" = None,
+ user_state: "dict[str, Any] | None" = None,
+) -> "dict[str, Any]":
+ """Merge scoped state buckets into a single state dict.
+
+ Priority: session_state is base, app_state and user_state overlay.
+ This matches ADK's documented merge semantics on session load.
+
+ Args:
+ session_state: Per-session state.
+ app_state: App-scoped state (shared across sessions for same app).
+ user_state: User-scoped state (shared across sessions for same app+user).
- return types.GroundingMetadata.model_validate(grounding_dict)
+ Returns:
+ Merged state dict.
+ """
+ merged = dict(session_state)
+ if app_state:
+ merged.update(app_state)
+ if user_state:
+ merged.update(user_state)
+ return merged
diff --git a/sqlspec/extensions/adk/memory/converters.py b/sqlspec/extensions/adk/memory/converters.py
index 1dc0dbf83..fafea6d12 100644
--- a/sqlspec/extensions/adk/memory/converters.py
+++ b/sqlspec/extensions/adk/memory/converters.py
@@ -19,15 +19,22 @@
logger = get_logger("sqlspec.extensions.adk.memory.converters")
-__all__ = ("event_to_memory_record", "extract_content_text", "record_to_memory_entry", "session_to_memory_records")
+__all__ = (
+ "event_to_memory_record",
+ "extract_content_text",
+ "memory_entry_to_record",
+ "record_to_memory_entry",
+ "records_to_memory_entries",
+ "session_to_memory_records",
+)
def extract_content_text(content: "types.Content") -> str:
"""Extract plain text from ADK Content for search indexing.
Handles multi-modal Content.parts including text, function calls,
- and other part types. Non-text parts are indexed by their type
- for discoverability.
+ function responses, and other part types. Non-text parts are indexed
+ by their type for discoverability.
Args:
content: ADK Content object with parts list.
@@ -91,6 +98,66 @@ def event_to_memory_record(event: "Event", session_id: str, app_name: str, user_
)
+def memory_entry_to_record(
+ entry: "MemoryEntry", app_name: str, user_id: str, extra_metadata: "dict[str, Any] | None" = None
+) -> "MemoryRecord | None":
+ """Convert an ADK MemoryEntry to a database record.
+
+ Serializes the entry's ``content`` to ``content_json``, extracts text
+ from ``content.parts`` for ``content_text``, and merges entry-level
+ ``custom_metadata`` with the optional ``extra_metadata`` parameter.
+
+ Args:
+ entry: ADK MemoryEntry object.
+ app_name: Name of the application.
+ user_id: ID of the user.
+ extra_metadata: Optional call-level metadata to merge with the
+ entry's own ``custom_metadata``.
+
+ Returns:
+ MemoryRecord for database storage, or None if entry has no
+ indexable content.
+ """
+ content_text = extract_content_text(entry.content)
+ if not content_text.strip():
+ return None
+
+ content_dict = entry.content.model_dump(exclude_none=True, mode="json")
+
+ # Merge entry-level and call-level metadata
+ merged_metadata: dict[str, Any] | None = None
+ if entry.custom_metadata or extra_metadata:
+ merged_metadata = {}
+ if extra_metadata:
+ merged_metadata.update(extra_metadata)
+ if entry.custom_metadata:
+ merged_metadata.update(entry.custom_metadata)
+
+ now = datetime.now(timezone.utc)
+
+ # Parse timestamp from entry if available
+ timestamp = now
+ if entry.timestamp:
+ try:
+ timestamp = datetime.fromisoformat(entry.timestamp)
+ except (ValueError, TypeError):
+ timestamp = now
+
+ return MemoryRecord(
+ id=entry.id or str(uuid.uuid4()),
+ session_id="",
+ app_name=app_name,
+ user_id=user_id,
+ event_id="",
+ author=entry.author or "",
+ timestamp=timestamp,
+ content_json=content_dict,
+ content_text=content_text,
+ metadata_json=merged_metadata,
+ inserted_at=now,
+ )
+
+
def session_to_memory_records(session: "Session") -> list["MemoryRecord"]:
"""Convert a completed ADK Session to a list of memory records.
@@ -121,11 +188,14 @@ def session_to_memory_records(session: "Session") -> list["MemoryRecord"]:
def record_to_memory_entry(record: "MemoryRecord") -> "MemoryEntry":
"""Convert a database record to an ADK MemoryEntry.
+ Preserves ``id`` and ``custom_metadata`` fields that were previously
+ dropped on readback.
+
Args:
record: Memory database record.
Returns:
- ADK MemoryEntry object.
+ ADK MemoryEntry object with all available fields populated.
"""
from google.adk.memory.memory_entry import MemoryEntry
from google.genai import types
@@ -134,7 +204,13 @@ def record_to_memory_entry(record: "MemoryRecord") -> "MemoryEntry":
timestamp_str = record["timestamp"].isoformat() if record["timestamp"] else None
- return MemoryEntry(content=content, author=record["author"], timestamp=timestamp_str)
+ return MemoryEntry(
+ id=record["id"],
+ content=content,
+ author=record["author"],
+ timestamp=timestamp_str,
+ custom_metadata=record["metadata_json"] or {},
+ )
def records_to_memory_entries(records: list["MemoryRecord"]) -> list["Any"]:
diff --git a/sqlspec/extensions/adk/memory/service.py b/sqlspec/extensions/adk/memory/service.py
index abc6360d5..a94c7e9cd 100644
--- a/sqlspec/extensions/adk/memory/service.py
+++ b/sqlspec/extensions/adk/memory/service.py
@@ -4,10 +4,17 @@
from google.adk.memory.base_memory_service import BaseMemoryService, SearchMemoryResponse
-from sqlspec.extensions.adk.memory.converters import records_to_memory_entries, session_to_memory_records
+from sqlspec.extensions.adk.memory.converters import (
+ memory_entry_to_record,
+ records_to_memory_entries,
+ session_to_memory_records,
+)
from sqlspec.utils.logging import get_logger
if TYPE_CHECKING:
+ from collections.abc import Mapping, Sequence
+
+ from google.adk.events.event import Event
from google.adk.memory.memory_entry import MemoryEntry
from google.adk.sessions import Session
@@ -102,6 +109,98 @@ async def add_session_to_memory(self, session: "Session") -> None:
"Stored %d memory entries for session %s (total events: %d)", inserted_count, session.id, len(records)
)
+ async def add_events_to_memory(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ events: "Sequence[Event]",
+ session_id: "str | None" = None,
+ custom_metadata: "Mapping[str, object] | None" = None,
+ ) -> None:
+ """Add an explicit list of events to the memory service.
+
+ Same Event-to-MemoryRecord extraction logic as
+ ``add_session_to_memory``, but operates on a sequence of Events
+ directly (no Session wrapper needed).
+
+ Args:
+ app_name: The application name for memory scope.
+ user_id: The user ID for memory scope.
+ events: The events to add to memory.
+ session_id: Optional session ID for memory scope/partitioning.
+ If None, memory entries are user-scoped only.
+ custom_metadata: Optional portable metadata stored in
+ ``MemoryRecord.metadata_json``.
+ """
+ from sqlspec.extensions.adk.memory.converters import event_to_memory_record
+
+ metadata_dict = dict(custom_metadata) if custom_metadata else None
+ records = []
+ for event in events:
+ record = event_to_memory_record(
+ event=event, session_id=session_id or "", app_name=app_name, user_id=user_id
+ )
+ if record is not None:
+ if metadata_dict:
+ record["metadata_json"] = metadata_dict
+ records.append(record)
+
+ if not records:
+ logger.debug(
+ "No content to store for events (app=%s, user=%s, count=%d)", app_name, user_id, len(list(events))
+ )
+ return
+
+ inserted_count = await self._store.insert_memory_entries(records)
+ logger.debug(
+ "Stored %d memory entries from %d events (app=%s, user=%s)", inserted_count, len(records), app_name, user_id
+ )
+
+ async def add_memory(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ memories: "Sequence[MemoryEntry]",
+ custom_metadata: "Mapping[str, object] | None" = None,
+ ) -> None:
+ """Add explicit memory items directly to the memory service.
+
+ Each entry's ``content`` is serialized to ``content_json``, text is
+ extracted from ``content.parts`` for ``content_text``, and
+ ``custom_metadata`` merges the entry-level ``entry.custom_metadata``
+ with the call-level ``custom_metadata`` parameter.
+
+ Args:
+ app_name: The application name for memory scope.
+ user_id: The user ID for memory scope.
+ memories: Explicit memory items to add.
+ custom_metadata: Optional portable metadata for memory writes.
+ Merged with each entry's ``custom_metadata``.
+ """
+ call_metadata = dict(custom_metadata) if custom_metadata else {}
+ records = []
+ for entry in memories:
+ record = memory_entry_to_record(
+ entry=entry, app_name=app_name, user_id=user_id, extra_metadata=call_metadata
+ )
+ if record is not None:
+ records.append(record)
+
+ if not records:
+ logger.debug("No content to store for memories (app=%s, user=%s)", app_name, user_id)
+ return
+
+ inserted_count = await self._store.insert_memory_entries(records)
+ logger.debug(
+ "Stored %d memory entries from %d memories (app=%s, user=%s)",
+ inserted_count,
+ len(records),
+ app_name,
+ user_id,
+ )
+
async def search_memory(self, *, app_name: str, user_id: str, query: str) -> "SearchMemoryResponse":
"""Search memory entries by text query.
diff --git a/sqlspec/extensions/adk/service.py b/sqlspec/extensions/adk/service.py
index 8656f4beb..132d9ad69 100644
--- a/sqlspec/extensions/adk/service.py
+++ b/sqlspec/extensions/adk/service.py
@@ -7,7 +7,7 @@
from google.adk.sessions.base_session_service import BaseSessionService, GetSessionConfig, ListSessionsResponse
-from sqlspec.extensions.adk.converters import event_to_record, record_to_session
+from sqlspec.extensions.adk.converters import event_to_record, filter_temp_state, record_to_session
from sqlspec.utils.logging import get_logger, log_with_context
if TYPE_CHECKING:
@@ -80,8 +80,10 @@ async def create_session(
if state is None:
state = {}
+ persisted_state = filter_temp_state(state)
+
record = await self._store.create_session(
- session_id=session_id, app_name=app_name, user_id=user_id, state=state
+ session_id=session_id, app_name=app_name, user_id=user_id, state=persisted_state
)
log_with_context(
logger, logging.DEBUG, "adk.session.create", app_name=app_name, session_id=session_id, has_state=bool(state)
@@ -192,6 +194,11 @@ async def delete_session(self, *, app_name: str, user_id: str, session_id: str)
async def append_event(self, session: "Session", event: "Event") -> "Event":
"""Append an event to a session.
+ Persists the event record and the post-append durable state
+ atomically via ``store.append_event_and_update_state()``. ``temp:``
+ keys are stripped from the persisted state snapshot so they never
+ survive a reload.
+
Args:
session: Session to append to.
event: Event to append.
@@ -204,11 +211,14 @@ async def append_event(self, session: "Session", event: "Event") -> "Event":
if event.partial:
return event
- event_record = event_to_record(
- event=event, session_id=session.id, app_name=session.app_name, user_id=session.user_id
- )
+ event_record = event_to_record(event=event, session_id=session.id)
- await self._store.append_event(event_record)
+ # Strip temp: keys before persisting state
+ durable_state = filter_temp_state(session.state)
+
+ await self._store.append_event_and_update_state(
+ event_record=event_record, session_id=session.id, state=durable_state
+ )
log_with_context(
logger,
logging.DEBUG,
diff --git a/sqlspec/extensions/adk/store.py b/sqlspec/extensions/adk/store.py
index 9903ee7b8..9f2ea8d1f 100644
--- a/sqlspec/extensions/adk/store.py
+++ b/sqlspec/extensions/adk/store.py
@@ -250,6 +250,24 @@ async def append_event(self, event_record: "EventRecord") -> None:
"""
raise NotImplementedError
+ @abstractmethod
+ async def append_event_and_update_state(
+ self, event_record: "EventRecord", session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ """Atomically append an event and update the session's durable state.
+
+ This is the authoritative durable write boundary for post-creation
+ session mutations. The event insert and state update must succeed
+ together or fail together.
+
+ Args:
+ event_record: Event record to store.
+ session_id: Session identifier whose state should be updated.
+ state: Post-append durable state snapshot (``temp:`` keys already
+ stripped by the service layer).
+ """
+ raise NotImplementedError
+
@abstractmethod
async def get_events(
self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
@@ -505,6 +523,24 @@ def create_event(
"""
raise NotImplementedError
+ @abstractmethod
+ def create_event_and_update_state(
+ self, event_record: "EventRecord", session_id: str, state: "dict[str, Any]"
+ ) -> None:
+ """Atomically create an event and update the session's durable state.
+
+ This is the authoritative durable write boundary for post-creation
+ session mutations. The event insert and state update must succeed
+ together or fail together.
+
+ Args:
+ event_record: Event record to store.
+ session_id: Session identifier whose state should be updated.
+ state: Post-append durable state snapshot (``temp:`` keys already
+ stripped by the service layer).
+ """
+ raise NotImplementedError
+
@abstractmethod
def list_events(self, session_id: str) -> "list[EventRecord]":
"""List events for a session ordered by timestamp.
diff --git a/tests/conftest.py b/tests/conftest.py
index 0a98005ed..b3fabd251 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,3 +1,4 @@
+import logging
import os
import warnings
@@ -63,6 +64,26 @@ def disable_spanner_builtin_metrics() -> "Generator[None, None, None]":
yield
+@pytest.fixture(scope="session", autouse=True)
+def suppress_noisy_test_loggers() -> "Generator[None, None, None]":
+ """Lower especially noisy library loggers during test runs."""
+ overrides = {
+ "httpx": logging.WARNING,
+ "httpcore": logging.WARNING,
+ "mysql.connector": logging.WARNING,
+ "asyncmy": logging.ERROR,
+ "sqlspec.migrations.tracker": logging.WARNING,
+ }
+ original_levels = {name: logging.getLogger(name).level for name in overrides}
+ for name, level in overrides.items():
+ logging.getLogger(name).setLevel(level)
+ try:
+ yield
+ finally:
+ for name, level in original_levels.items():
+ logging.getLogger(name).setLevel(level)
+
+
@pytest.fixture(scope="session")
def minio_client(minio_service: "MinioService", minio_default_bucket_name: str) -> Generator[Minio, None, None]:
"""Override pytest-databases minio_client to use new minio API with keyword arguments."""
diff --git a/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py b/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py
index 7c0451ba8..e20536d83 100644
--- a/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py
+++ b/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py
@@ -10,6 +10,7 @@
if the driver is not installed.
"""
+import json
from pathlib import Path
from typing import Any
@@ -22,16 +23,16 @@
@pytest.fixture()
-def sqlite_store(tmp_path: Path) -> Any:
+async def sqlite_store(tmp_path: Path) -> Any:
"""SQLite ADBC store fixture."""
db_path = tmp_path / "sqlite_test.db"
config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"})
store = AdbcADKStore(config)
- store.create_tables()
+ await store.create_tables()
return store
-def test_sqlite_dialect_creates_text_columns(sqlite_store: Any) -> None:
+async def test_sqlite_dialect_creates_text_columns(sqlite_store: Any) -> None:
"""Test SQLite dialect creates TEXT columns for JSON."""
with sqlite_store.config.provide_connection() as conn:
cursor = conn.cursor()
@@ -45,54 +46,61 @@ def test_sqlite_dialect_creates_text_columns(sqlite_store: Any) -> None:
cursor.close() # type: ignore[no-untyped-call]
-def test_sqlite_dialect_session_operations(sqlite_store: Any) -> None:
+async def test_sqlite_dialect_session_operations(sqlite_store: Any) -> None:
"""Test SQLite dialect with full session CRUD."""
session_id = "sqlite-session-1"
app_name = "test-app"
user_id = "user-123"
state = {"nested": {"key": "value"}, "count": 42}
- created = sqlite_store.create_session(session_id, app_name, user_id, state)
+ created = await sqlite_store.create_session(session_id, app_name, user_id, state)
assert created["id"] == session_id
assert created["state"] == state
- retrieved = sqlite_store.get_session(session_id)
+ retrieved = await sqlite_store.get_session(session_id)
assert retrieved["state"] == state
new_state = {"updated": True}
- sqlite_store.update_session_state(session_id, new_state)
+ await sqlite_store.update_session_state(session_id, new_state)
- updated = sqlite_store.get_session(session_id)
+ updated = await sqlite_store.get_session(session_id)
assert updated["state"] == new_state
-def test_sqlite_dialect_event_operations(sqlite_store: Any) -> None:
+async def test_sqlite_dialect_event_operations(sqlite_store: Any) -> None:
"""Test SQLite dialect with event operations."""
session_id = "sqlite-session-events"
app_name = "test-app"
user_id = "user-123"
- sqlite_store.create_session(session_id, app_name, user_id, {})
+ await sqlite_store.create_session(session_id, app_name, user_id, {})
- event_id = "event-1"
- actions = b"pickled_actions_data"
content = {"message": "Hello"}
- event = sqlite_store.create_event(
- event_id=event_id, session_id=session_id, app_name=app_name, user_id=user_id, actions=actions, content=content
- )
+ from datetime import datetime, timezone
+
+ from sqlspec.extensions.adk import EventRecord
- assert event["id"] == event_id
- assert event["content"] == content
+ event_record: EventRecord = {
+ "session_id": session_id,
+ "invocation_id": "",
+ "author": "",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {"id": "event-1", "content": content, "app_name": app_name, "user_id": user_id},
+ }
+ await sqlite_store.append_event(event_record)
- events = sqlite_store.list_events(session_id)
+ events = await sqlite_store.get_events(session_id)
assert len(events) == 1
- assert events[0]["content"] == content
+ retrieved_data = (
+ json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"]
+ )
+ assert retrieved_data["content"] == content
@pytest.mark.postgres
@pytest.mark.skipif(True, reason="Requires adbc-driver-postgresql and PostgreSQL server")
-def test_postgresql_dialect_creates_jsonb_columns() -> None:
+async def test_postgresql_dialect_creates_jsonb_columns() -> None:
"""Test PostgreSQL dialect creates JSONB columns.
This test is skipped by default. To run:
@@ -105,7 +113,7 @@ def test_postgresql_dialect_creates_jsonb_columns() -> None:
connection_config={"driver_name": "postgresql", "uri": "postgresql://user:pass@localhost/testdb"}
)
store = AdbcADKStore(config)
- store.create_tables()
+ await store.create_tables()
with store.config.provide_connection() as conn:
cursor = conn.cursor()
@@ -127,7 +135,7 @@ def test_postgresql_dialect_creates_jsonb_columns() -> None:
@pytest.mark.duckdb
@pytest.mark.skipif(True, reason="Requires adbc-driver-duckdb")
-def test_duckdb_dialect_creates_json_columns(tmp_path: Path) -> None:
+async def test_duckdb_dialect_creates_json_columns(tmp_path: Path) -> None:
"""Test DuckDB dialect creates JSON columns.
This test is skipped by default. To run:
@@ -137,18 +145,18 @@ def test_duckdb_dialect_creates_json_columns(tmp_path: Path) -> None:
db_path = tmp_path / "duckdb_test.db"
config = AdbcConfig(connection_config={"driver_name": "duckdb", "uri": f"file:{db_path}"})
store = AdbcADKStore(config)
- store.create_tables()
+ await store.create_tables()
session_id = "duckdb-session-1"
state = {"analytics": {"count": 1000, "revenue": 50000.00}}
- created = store.create_session(session_id, "app", "user", state)
+ created = await store.create_session(session_id, "app", "user", state)
assert created["state"] == state
@pytest.mark.snowflake
@pytest.mark.skipif(True, reason="Requires adbc-driver-snowflake and Snowflake account")
-def test_snowflake_dialect_creates_variant_columns() -> None:
+async def test_snowflake_dialect_creates_variant_columns() -> None:
"""Test Snowflake dialect creates VARIANT columns.
This test is skipped by default. To run:
@@ -165,7 +173,7 @@ def test_snowflake_dialect_creates_variant_columns() -> None:
}
)
store = AdbcADKStore(config)
- store.create_tables()
+ await store.create_tables()
with store.config.provide_connection() as conn:
cursor = conn.cursor()
@@ -185,7 +193,7 @@ def test_snowflake_dialect_creates_variant_columns() -> None:
cursor.close() # type: ignore[no-untyped-call]
-def test_sqlite_with_owner_id_column(tmp_path: Path) -> None:
+async def test_sqlite_with_owner_id_column(tmp_path: Path) -> None:
"""Test SQLite with owner ID column creates proper constraints."""
db_path = tmp_path / "sqlite_fk_test.db"
base_config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"})
@@ -205,16 +213,16 @@ def test_sqlite_with_owner_id_column(tmp_path: Path) -> None:
extension_config={"adk": {"owner_id_column": "tenant_id INTEGER NOT NULL REFERENCES tenants(id)"}},
)
store = AdbcADKStore(config)
- store.create_tables()
+ await store.create_tables()
- session = store.create_session("s1", "app", "user", {"data": "test"}, owner_id=1)
+ session = await store.create_session("s1", "app", "user", {"data": "test"}, owner_id=1)
assert session["id"] == "s1"
- retrieved = store.get_session("s1")
+ retrieved = await store.get_session("s1")
assert retrieved is not None
-def test_generic_dialect_fallback(tmp_path: Path) -> None:
+async def test_generic_dialect_fallback(tmp_path: Path) -> None:
"""Test generic dialect is used for unknown drivers."""
db_path = tmp_path / "generic_test.db"
@@ -223,7 +231,7 @@ def test_generic_dialect_fallback(tmp_path: Path) -> None:
store = AdbcADKStore(config)
assert store.dialect in ["sqlite", "generic"]
- store.create_tables()
+ await store.create_tables()
- session = store.create_session("generic-1", "app", "user", {"test": True})
+ session = await store.create_session("generic-1", "app", "user", {"test": True})
assert session["state"]["test"] is True
diff --git a/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py b/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py
index c87302f23..703d40437 100644
--- a/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py
+++ b/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py
@@ -89,54 +89,59 @@ def test_generic_sessions_ddl_contains_text() -> None:
assert "TIMESTAMP" in ddl
-def test_postgresql_events_ddl_contains_jsonb() -> None:
- """Test PostgreSQL events DDL uses JSONB for content fields."""
+def test_postgresql_events_ddl_uses_jsonb() -> None:
+ """Test PostgreSQL events DDL uses JSONB for event_json."""
config = AdbcConfig(connection_config={"driver_name": "postgresql", "uri": ":memory:"})
store = AdbcADKStore(config)
ddl = store._get_events_ddl_postgresql() # pyright: ignore[reportPrivateUsage]
assert "JSONB" in ddl
- assert "BYTEA" in ddl
- assert "BOOLEAN" in ddl
+ assert "event_json" in ddl
+ assert "session_id" in ddl
+ assert "invocation_id" in ddl
+ assert "author" in ddl
+ assert "timestamp" in ddl.lower()
-def test_sqlite_events_ddl_contains_text_and_integer() -> None:
- """Test SQLite events DDL uses TEXT for JSON and INTEGER for booleans."""
+def test_sqlite_events_ddl_uses_text() -> None:
+ """Test SQLite events DDL uses TEXT for event_json."""
config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": ":memory:"})
store = AdbcADKStore(config)
ddl = store._get_events_ddl_sqlite() # pyright: ignore[reportPrivateUsage]
assert "TEXT" in ddl
- assert "BLOB" in ddl
- assert "INTEGER" in ddl
+ assert "event_json" in ddl
+ assert "session_id" in ddl
+ assert "REAL" in ddl # SQLite uses REAL for timestamps
-def test_duckdb_events_ddl_contains_json_and_boolean() -> None:
- """Test DuckDB events DDL uses JSON and BOOLEAN types."""
+def test_duckdb_events_ddl_uses_json() -> None:
+ """Test DuckDB events DDL uses JSON type for event_json."""
config = AdbcConfig(connection_config={"driver_name": "duckdb", "uri": ":memory:"})
store = AdbcADKStore(config)
ddl = store._get_events_ddl_duckdb() # pyright: ignore[reportPrivateUsage]
assert "JSON" in ddl
- assert "BOOLEAN" in ddl
+ assert "event_json" in ddl
-def test_snowflake_events_ddl_contains_variant() -> None:
- """Test Snowflake events DDL uses VARIANT for content."""
+def test_snowflake_events_ddl_uses_variant() -> None:
+ """Test Snowflake events DDL uses VARIANT for event_json."""
config = AdbcConfig(connection_config={"driver_name": "snowflake", "uri": "snowflake://test"})
store = AdbcADKStore(config)
ddl = store._get_events_ddl_snowflake() # pyright: ignore[reportPrivateUsage]
assert "VARIANT" in ddl
- assert "BINARY" in ddl
+ assert "event_json" in ddl
-def test_ddl_dispatch_uses_correct_dialect() -> None:
+async def test_ddl_dispatch_uses_correct_dialect() -> None:
"""Test that DDL dispatch selects correct dialect method."""
config = AdbcConfig(connection_config={"driver_name": "postgresql", "uri": ":memory:"})
store = AdbcADKStore(config)
- sessions_ddl = store._get_create_sessions_table_sql() # pyright: ignore[reportPrivateUsage]
+ sessions_ddl = await store._get_create_sessions_table_sql() # pyright: ignore[reportPrivateUsage]
assert "JSONB" in sessions_ddl
- events_ddl = store._get_create_events_table_sql() # pyright: ignore[reportPrivateUsage]
+ events_ddl = await store._get_create_events_table_sql() # pyright: ignore[reportPrivateUsage]
assert "JSONB" in events_ddl
+ assert "event_json" in events_ddl
def test_owner_id_column_included_in_sessions_ddl() -> None:
diff --git a/tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py b/tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py
index fc39cebb2..0e11dd0bb 100644
--- a/tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py
+++ b/tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py
@@ -1,5 +1,6 @@
"""Tests for ADBC ADK store edge cases and error handling."""
+import json
from pathlib import Path
from typing import Any
@@ -12,19 +13,19 @@
@pytest.fixture()
-def adbc_store(tmp_path: Path) -> AdbcADKStore:
+async def adbc_store(tmp_path: Path) -> AdbcADKStore:
"""Create ADBC ADK store with SQLite backend."""
db_path = tmp_path / "test_adk.db"
config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"})
store = AdbcADKStore(config)
- store.create_tables()
+ await store.create_tables()
return store
-def test_create_tables_idempotent(adbc_store: Any) -> None:
+async def test_create_tables_idempotent(adbc_store: Any) -> None:
"""Test that create_tables can be called multiple times safely."""
- adbc_store.create_tables()
- adbc_store.create_tables()
+ await adbc_store.create_tables()
+ await adbc_store.create_tables()
def test_table_names_validation(tmp_path: Path) -> None:
@@ -61,23 +62,23 @@ def test_table_names_validation(tmp_path: Path) -> None:
AdbcADKStore(config)
-def test_operations_before_create_tables(tmp_path: Path) -> None:
+async def test_operations_before_create_tables(tmp_path: Path) -> None:
"""Test operations gracefully handle missing tables."""
db_path = tmp_path / "test_no_tables.db"
config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"})
store = AdbcADKStore(config)
- session = store.get_session("nonexistent")
+ session = await store.get_session("nonexistent")
assert session is None
- sessions = store.list_sessions("app", "user")
+ sessions = await store.list_sessions("app", "user")
assert sessions == []
- events = store.list_events("session")
+ events = await store.get_events("session")
assert events == []
-def test_custom_table_names(tmp_path: Path) -> None:
+async def test_custom_table_names(tmp_path: Path) -> None:
"""Test using custom table names."""
db_path = tmp_path / "test_custom.db"
config = AdbcConfig(
@@ -85,43 +86,56 @@ def test_custom_table_names(tmp_path: Path) -> None:
extension_config={"adk": {"session_table": "custom_sessions", "events_table": "custom_events"}},
)
store = AdbcADKStore(config)
- store.create_tables()
+ await store.create_tables()
session_id = "test"
- session = store.create_session(session_id, "app", "user", {"data": "test"})
+ session = await store.create_session(session_id, "app", "user", {"data": "test"})
assert session["id"] == session_id
- retrieved = store.get_session(session_id)
+ retrieved = await store.get_session(session_id)
assert retrieved is not None
-def test_unicode_in_fields(adbc_store: Any) -> None:
+async def test_unicode_in_fields(adbc_store: Any) -> None:
"""Test Unicode characters in various fields."""
session_id = "unicode-session"
- app_name = "测试应用"
- user_id = "ユーザー123"
- state = {"message": "Hello 世界", "emoji": "🎉"}
+ app_name = "\u6d4b\u8bd5\u5e94\u7528"
+ user_id = "\u30e6\u30fc\u30b6\u30fc123"
+ state = {"message": "Hello \u4e16\u754c"}
- created_session = adbc_store.create_session(session_id, app_name, user_id, state)
+ created_session = await adbc_store.create_session(session_id, app_name, user_id, state)
assert created_session["app_name"] == app_name
assert created_session["user_id"] == user_id
- assert created_session["state"]["message"] == "Hello 世界"
- assert created_session["state"]["emoji"] == "🎉"
-
- event = adbc_store.create_event(
- event_id="unicode-event",
- session_id=session_id,
- app_name=app_name,
- user_id=user_id,
- author="アシスタント",
- content={"text": "こんにちは 🌍"},
- )
+ assert created_session["state"]["message"] == "Hello \u4e16\u754c"
+
+ from datetime import datetime, timezone
+
+ from sqlspec.extensions.adk import EventRecord
+
+ event_record: EventRecord = {
+ "session_id": session_id,
+ "invocation_id": "",
+ "author": "\u30a2\u30b7\u30b9\u30bf\u30f3\u30c8",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {
+ "id": "unicode-event",
+ "content": {"text": "\u3053\u3093\u306b\u3061\u306f"},
+ "app_name": app_name,
+ "user_id": user_id,
+ },
+ }
+ await adbc_store.append_event(event_record)
- assert event["author"] == "アシスタント"
- assert event["content"]["text"] == "こんにちは 🌍"
+ events = await adbc_store.get_events(session_id)
+ assert len(events) == 1
+ assert events[0]["author"] == "\u30a2\u30b7\u30b9\u30bf\u30f3\u30c8"
+ event_data = (
+ json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"]
+ )
+ assert event_data["content"]["text"] == "\u3053\u3093\u306b\u3061\u306f"
-def test_special_characters_in_json(adbc_store: Any) -> None:
+async def test_special_characters_in_json(adbc_store: Any) -> None:
"""Test special characters in JSON fields."""
session_id = "special-chars"
state = {
@@ -131,115 +145,105 @@ def test_special_characters_in_json(adbc_store: Any) -> None:
"tab": "Col1\tCol2",
}
- adbc_store.create_session(session_id, "app", "user", state)
- retrieved = adbc_store.get_session(session_id)
+ await adbc_store.create_session(session_id, "app", "user", state)
+ retrieved = await adbc_store.get_session(session_id)
assert retrieved is not None
assert retrieved["state"] == state
-def test_very_long_strings(adbc_store: Any) -> None:
+async def test_very_long_strings(adbc_store: Any) -> None:
"""Test handling very long strings in VARCHAR fields."""
long_id = "x" * 127
long_app = "a" * 127
long_user = "u" * 127
- session = adbc_store.create_session(long_id, long_app, long_user, {})
+ session = await adbc_store.create_session(long_id, long_app, long_user, {})
assert session["id"] == long_id
assert session["app_name"] == long_app
assert session["user_id"] == long_user
-def test_session_state_with_deeply_nested_data(adbc_store: Any) -> None:
+async def test_session_state_with_deeply_nested_data(adbc_store: Any) -> None:
"""Test deeply nested JSON structures."""
session_id = "deep-nest"
deeply_nested = {"level1": {"level2": {"level3": {"level4": {"level5": {"value": "deep"}}}}}}
- adbc_store.create_session(session_id, "app", "user", deeply_nested)
- retrieved = adbc_store.get_session(session_id)
+ await adbc_store.create_session(session_id, "app", "user", deeply_nested)
+ retrieved = await adbc_store.get_session(session_id)
assert retrieved is not None
assert retrieved["state"]["level1"]["level2"]["level3"]["level4"]["level5"]["value"] == "deep"
-def test_concurrent_session_updates(adbc_store: Any) -> None:
+async def test_concurrent_session_updates(adbc_store: Any) -> None:
"""Test multiple updates to the same session."""
session_id = "concurrent-test"
- adbc_store.create_session(session_id, "app", "user", {"version": 1})
+ await adbc_store.create_session(session_id, "app", "user", {"version": 1})
for i in range(10):
- adbc_store.update_session_state(session_id, {"version": i + 2})
+ await adbc_store.update_session_state(session_id, {"version": i + 2})
- final_session = adbc_store.get_session(session_id)
+ final_session = await adbc_store.get_session(session_id)
assert final_session is not None
assert final_session["state"]["version"] == 11
-def test_event_with_none_values(adbc_store: Any) -> None:
- """Test creating event with explicit None values."""
+async def test_event_with_none_values(adbc_store: Any) -> None:
+ """Test creating event with explicit None values for optional fields."""
session_id = "none-test"
- adbc_store.create_session(session_id, "app", "user", {})
-
- event = adbc_store.create_event(
- event_id="none-event",
- session_id=session_id,
- app_name="app",
- user_id="user",
- invocation_id=None,
- author=None,
- actions=None,
- content=None,
- grounding_metadata=None,
- custom_metadata=None,
- partial=None,
- turn_complete=None,
- interrupted=None,
- error_code=None,
- error_message=None,
- )
+ await adbc_store.create_session(session_id, "app", "user", {})
+
+ from datetime import datetime, timezone
+
+ from sqlspec.extensions.adk import EventRecord
+
+ event_record: EventRecord = {
+ "session_id": session_id,
+ "invocation_id": "",
+ "author": "",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {"id": "none-event", "app_name": "app", "user_id": "user"},
+ }
+ await adbc_store.append_event(event_record)
- assert event["invocation_id"] is None
- assert event["author"] is None
- assert event["actions"] == b""
- assert event["content"] is None
- assert event["grounding_metadata"] is None
- assert event["custom_metadata"] is None
- assert event["partial"] is None
- assert event["turn_complete"] is None
- assert event["interrupted"] is None
+ events = await adbc_store.get_events(session_id)
+ assert len(events) == 1
+ assert events[0]["session_id"] == session_id
+ assert "event_json" in events[0]
-def test_list_sessions_with_same_user_different_apps(adbc_store: Any) -> None:
+async def test_list_sessions_with_same_user_different_apps(adbc_store: Any) -> None:
"""Test listing sessions doesn't mix data across apps."""
user_id = "user-123"
app1 = "app1"
app2 = "app2"
- adbc_store.create_session("s1", app1, user_id, {})
- adbc_store.create_session("s2", app1, user_id, {})
- adbc_store.create_session("s3", app2, user_id, {})
+ await adbc_store.create_session("s1", app1, user_id, {})
+ await adbc_store.create_session("s2", app1, user_id, {})
+ await adbc_store.create_session("s3", app2, user_id, {})
- app1_sessions = adbc_store.list_sessions(app1, user_id)
- app2_sessions = adbc_store.list_sessions(app2, user_id)
+ app1_sessions = await adbc_store.list_sessions(app1, user_id)
+ app2_sessions = await adbc_store.list_sessions(app2, user_id)
assert len(app1_sessions) == 2
assert len(app2_sessions) == 1
-def test_delete_nonexistent_session(adbc_store: Any) -> None:
+async def test_delete_nonexistent_session(adbc_store: Any) -> None:
"""Test deleting a session that doesn't exist."""
- adbc_store.delete_session("nonexistent-session")
+ await adbc_store.delete_session("nonexistent-session")
-def test_update_nonexistent_session(adbc_store: Any) -> None:
+async def test_update_nonexistent_session(adbc_store: Any) -> None:
"""Test updating a session that doesn't exist."""
- adbc_store.update_session_state("nonexistent-session", {"data": "test"})
+ await adbc_store.update_session_state("nonexistent-session", {"data": "test"})
-def test_drop_and_recreate_tables(adbc_store: Any) -> None:
+async def test_drop_and_recreate_tables(adbc_store: Any) -> None:
"""Test dropping and recreating tables."""
session_id = "test-session"
- adbc_store.create_session(session_id, "app", "user", {"data": "test"})
+ await adbc_store.create_session(session_id, "app", "user", {"data": "test"})
drop_sqls = adbc_store._get_drop_tables_sql()
with adbc_store._config.provide_connection() as conn:
@@ -251,19 +255,19 @@ def test_drop_and_recreate_tables(adbc_store: Any) -> None:
finally:
cursor.close()
- adbc_store.create_tables()
+ await adbc_store.create_tables()
- session = adbc_store.get_session(session_id)
+ session = await adbc_store.get_session(session_id)
assert session is None
-def test_json_with_escaped_characters(adbc_store: Any) -> None:
+async def test_json_with_escaped_characters(adbc_store: Any) -> None:
"""Test JSON serialization of escaped characters."""
session_id = "escaped-json"
state = {"escaped": r"test\nvalue\t", "quotes": r'"quoted"'}
- adbc_store.create_session(session_id, "app", "user", state)
- retrieved = adbc_store.get_session(session_id)
+ await adbc_store.create_session(session_id, "app", "user", state)
+ retrieved = await adbc_store.get_session(session_id)
assert retrieved is not None
assert retrieved["state"] == state
diff --git a/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py b/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py
index e18cd1496..8e54bf766 100644
--- a/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py
+++ b/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py
@@ -1,6 +1,8 @@
"""Tests for ADBC ADK store event operations."""
-from datetime import datetime, timezone
+import asyncio
+import json
+from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any
@@ -8,313 +10,387 @@
from sqlspec.adapters.adbc import AdbcConfig
from sqlspec.adapters.adbc.adk import AdbcADKStore
+from sqlspec.extensions.adk import EventRecord
pytestmark = [pytest.mark.xdist_group("sqlite"), pytest.mark.adbc, pytest.mark.integration]
@pytest.fixture()
-def adbc_store(tmp_path: Path) -> AdbcADKStore:
+async def adbc_store(tmp_path: Path) -> AdbcADKStore:
"""Create ADBC ADK store with SQLite backend."""
db_path = tmp_path / "test_adk.db"
config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"})
store = AdbcADKStore(config)
- store.create_tables()
+ await store.create_tables()
return store
@pytest.fixture()
-def session_fixture(adbc_store: Any) -> dict[str, str]:
+async def session_fixture(adbc_store: Any) -> dict[str, str]:
"""Create a test session."""
session_id = "test-session"
app_name = "test-app"
user_id = "user-123"
state = {"test": True}
- adbc_store.create_session(session_id, app_name, user_id, state)
+ await adbc_store.create_session(session_id, app_name, user_id, state)
return {"session_id": session_id, "app_name": app_name, "user_id": user_id}
-def test_create_event(adbc_store: Any, session_fixture: Any) -> None:
- """Test creating a new event."""
- event_id = "event-1"
- event = adbc_store.create_event(
- event_id=event_id,
- session_id=session_fixture["session_id"],
- app_name=session_fixture["app_name"],
- user_id=session_fixture["user_id"],
- author="user",
- actions=b"serialized_actions",
- content={"message": "Hello"},
+async def test_create_event(adbc_store: Any, session_fixture: Any) -> None:
+ """Test creating a new event returns 5-key EventRecord."""
+ event_record: EventRecord = {
+ "session_id": session_fixture["session_id"],
+ "invocation_id": "",
+ "author": "user",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {
+ "id": "event-1",
+ "content": {"message": "Hello"},
+ "app_name": session_fixture["app_name"],
+ "user_id": session_fixture["user_id"],
+ },
+ }
+ await adbc_store.append_event(event_record)
+
+ events = await adbc_store.get_events(session_fixture["session_id"])
+ assert len(events) == 1
+ assert events[0]["session_id"] == session_fixture["session_id"]
+ assert events[0]["author"] == "user"
+ assert events[0]["timestamp"] is not None
+ assert "event_json" in events[0]
+
+ # Content is stored inside event_json
+ event_data = (
+ json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"]
)
+ assert event_data["content"] == {"message": "Hello"}
- assert event["id"] == event_id
- assert event["session_id"] == session_fixture["session_id"]
- assert event["author"] == "user"
- assert event["actions"] == b"serialized_actions"
- assert event["content"] == {"message": "Hello"}
- assert event["timestamp"] is not None
-
-def test_list_events(adbc_store: Any, session_fixture: Any) -> None:
+async def test_list_events(adbc_store: Any, session_fixture: Any) -> None:
"""Test listing events for a session."""
- adbc_store.create_event(
- event_id="event-1",
- session_id=session_fixture["session_id"],
- app_name=session_fixture["app_name"],
- user_id=session_fixture["user_id"],
- author="user",
- content={"seq": 1},
- )
- adbc_store.create_event(
- event_id="event-2",
- session_id=session_fixture["session_id"],
- app_name=session_fixture["app_name"],
- user_id=session_fixture["user_id"],
- author="assistant",
- content={"seq": 2},
- )
-
- events = adbc_store.list_events(session_fixture["session_id"])
+ event1: EventRecord = {
+ "session_id": session_fixture["session_id"],
+ "invocation_id": "",
+ "author": "user",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {
+ "id": "event-1",
+ "content": {"seq": 1},
+ "app_name": session_fixture["app_name"],
+ "user_id": session_fixture["user_id"],
+ },
+ }
+ event2: EventRecord = {
+ "session_id": session_fixture["session_id"],
+ "invocation_id": "",
+ "author": "assistant",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {
+ "id": "event-2",
+ "content": {"seq": 2},
+ "app_name": session_fixture["app_name"],
+ "user_id": session_fixture["user_id"],
+ },
+ }
+ await adbc_store.append_event(event1)
+ await adbc_store.append_event(event2)
+
+ events = await adbc_store.get_events(session_fixture["session_id"])
assert len(events) == 2
- assert events[0]["id"] == "event-1"
- assert events[1]["id"] == "event-2"
+ assert events[0]["author"] == "user"
+ assert events[1]["author"] == "assistant"
-def test_list_events_empty(adbc_store: Any, session_fixture: Any) -> None:
+async def test_list_events_empty(adbc_store: Any, session_fixture: Any) -> None:
"""Test listing events when none exist."""
- events = adbc_store.list_events(session_fixture["session_id"])
+ events = await adbc_store.get_events(session_fixture["session_id"])
assert events == []
-def test_event_with_all_fields(adbc_store: Any, session_fixture: Any) -> None:
- """Test creating event with all optional fields."""
- timestamp = datetime.now(timezone.utc)
- event = adbc_store.create_event(
- event_id="full-event",
- session_id=session_fixture["session_id"],
- app_name=session_fixture["app_name"],
- user_id=session_fixture["user_id"],
- invocation_id="invocation-123",
- author="assistant",
- actions=b"complex_action_data",
- long_running_tool_ids_json='["tool1", "tool2"]',
- branch="main",
- timestamp=timestamp,
- content={"text": "Response"},
- grounding_metadata={"sources": ["doc1", "doc2"]},
- custom_metadata={"custom": "data"},
- partial=True,
- turn_complete=False,
- interrupted=False,
- error_code="NONE",
- error_message="No errors",
+async def test_event_with_all_fields(adbc_store: Any, session_fixture: Any) -> None:
+ """Test creating event with all optional fields stored in event_json."""
+ event_record: EventRecord = {
+ "session_id": session_fixture["session_id"],
+ "invocation_id": "invocation-123",
+ "author": "assistant",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {
+ "id": "full-event",
+ "content": {"text": "Response"},
+ "app_name": session_fixture["app_name"],
+ "user_id": session_fixture["user_id"],
+ "branch": "main",
+ "grounding_metadata": {"sources": ["doc1", "doc2"]},
+ "custom_metadata": {"custom": "data"},
+ "partial": True,
+ "turn_complete": False,
+ "interrupted": False,
+ "error_code": "NONE",
+ "error_message": "No errors",
+ },
+ }
+ await adbc_store.append_event(event_record)
+
+ events = await adbc_store.get_events(session_fixture["session_id"])
+ assert len(events) == 1
+
+ # Top-level indexed columns
+ assert events[0]["invocation_id"] == "invocation-123"
+ assert events[0]["author"] == "assistant"
+
+ # Everything else is in event_json
+ event_data = (
+ json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"]
)
-
- assert event["invocation_id"] == "invocation-123"
- assert event["author"] == "assistant"
- assert event["actions"] == b"complex_action_data"
- assert event["long_running_tool_ids_json"] == '["tool1", "tool2"]'
- assert event["branch"] == "main"
- assert event["content"] == {"text": "Response"}
- assert event["grounding_metadata"] == {"sources": ["doc1", "doc2"]}
- assert event["custom_metadata"] == {"custom": "data"}
- assert event["partial"] is True
- assert event["turn_complete"] is False
- assert event["interrupted"] is False
- assert event["error_code"] == "NONE"
- assert event["error_message"] == "No errors"
-
-
-def test_event_with_minimal_fields(adbc_store: Any, session_fixture: Any) -> None:
+ assert event_data["content"] == {"text": "Response"}
+ assert event_data["branch"] == "main"
+ assert event_data["grounding_metadata"] == {"sources": ["doc1", "doc2"]}
+ assert event_data["custom_metadata"] == {"custom": "data"}
+ assert event_data["partial"] is True
+ assert event_data["turn_complete"] is False
+ assert event_data["interrupted"] is False
+ assert event_data["error_code"] == "NONE"
+ assert event_data["error_message"] == "No errors"
+
+
+async def test_event_with_minimal_fields(adbc_store: Any, session_fixture: Any) -> None:
"""Test creating event with only required fields."""
- event = adbc_store.create_event(
- event_id="minimal-event",
- session_id=session_fixture["session_id"],
- app_name=session_fixture["app_name"],
- user_id=session_fixture["user_id"],
- )
-
- assert event["id"] == "minimal-event"
- assert event["session_id"] == session_fixture["session_id"]
- assert event["app_name"] == session_fixture["app_name"]
- assert event["user_id"] == session_fixture["user_id"]
- assert event["author"] is None
- assert event["actions"] == b""
- assert event["content"] is None
-
-
-def test_event_boolean_fields(adbc_store: Any, session_fixture: Any) -> None:
- """Test event boolean field conversion."""
- event_true = adbc_store.create_event(
- event_id="event-true",
- session_id=session_fixture["session_id"],
- app_name=session_fixture["app_name"],
- user_id=session_fixture["user_id"],
- partial=True,
- turn_complete=True,
- interrupted=True,
- )
-
- assert event_true["partial"] is True
- assert event_true["turn_complete"] is True
- assert event_true["interrupted"] is True
-
- event_false = adbc_store.create_event(
- event_id="event-false",
- session_id=session_fixture["session_id"],
- app_name=session_fixture["app_name"],
- user_id=session_fixture["user_id"],
- partial=False,
- turn_complete=False,
- interrupted=False,
- )
-
- assert event_false["partial"] is False
- assert event_false["turn_complete"] is False
- assert event_false["interrupted"] is False
-
- event_none = adbc_store.create_event(
- event_id="event-none",
- session_id=session_fixture["session_id"],
- app_name=session_fixture["app_name"],
- user_id=session_fixture["user_id"],
- )
-
- assert event_none["partial"] is None
- assert event_none["turn_complete"] is None
- assert event_none["interrupted"] is None
-
-
-def test_event_json_fields(adbc_store: Any, session_fixture: Any) -> None:
- """Test event JSON field serialization and deserialization."""
+ event_record: EventRecord = {
+ "session_id": session_fixture["session_id"],
+ "invocation_id": "",
+ "author": "",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {
+ "id": "minimal-event",
+ "app_name": session_fixture["app_name"],
+ "user_id": session_fixture["user_id"],
+ },
+ }
+ await adbc_store.append_event(event_record)
+
+ events = await adbc_store.get_events(session_fixture["session_id"])
+ assert len(events) == 1
+ assert events[0]["session_id"] == session_fixture["session_id"]
+ assert "event_json" in events[0]
+
+
+async def test_event_json_fields(adbc_store: Any, session_fixture: Any) -> None:
+ """Test event JSON field serialization and deserialization via event_json."""
complex_content = {"nested": {"data": "value"}, "list": [1, 2, 3], "null": None}
complex_grounding = {"sources": [{"title": "Doc", "url": "http://example.com"}]}
complex_custom = {"metadata": {"version": 1, "tags": ["tag1", "tag2"]}}
- event = adbc_store.create_event(
- event_id="json-event",
- session_id=session_fixture["session_id"],
- app_name=session_fixture["app_name"],
- user_id=session_fixture["user_id"],
- content=complex_content,
- grounding_metadata=complex_grounding,
- custom_metadata=complex_custom,
+ event_record: EventRecord = {
+ "session_id": session_fixture["session_id"],
+ "invocation_id": "",
+ "author": "",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {
+ "id": "json-event",
+ "content": complex_content,
+ "grounding_metadata": complex_grounding,
+ "custom_metadata": complex_custom,
+ "app_name": session_fixture["app_name"],
+ "user_id": session_fixture["user_id"],
+ },
+ }
+ await adbc_store.append_event(event_record)
+
+ events = await adbc_store.get_events(session_fixture["session_id"])
+ assert len(events) == 1
+ event_data = (
+ json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"]
)
+ assert event_data["content"] == complex_content
+ assert event_data["grounding_metadata"] == complex_grounding
+ assert event_data["custom_metadata"] == complex_custom
- assert event["content"] == complex_content
- assert event["grounding_metadata"] == complex_grounding
- assert event["custom_metadata"] == complex_custom
-
- events = adbc_store.list_events(session_fixture["session_id"])
- retrieved = events[0]
-
- assert retrieved["content"] == complex_content
- assert retrieved["grounding_metadata"] == complex_grounding
- assert retrieved["custom_metadata"] == complex_custom
-
-def test_event_ordering(adbc_store: Any, session_fixture: Any) -> None:
+async def test_event_ordering(adbc_store: Any, session_fixture: Any) -> None:
"""Test that events are ordered by timestamp ASC."""
- import time
-
- adbc_store.create_event(
- event_id="event-1",
- session_id=session_fixture["session_id"],
- app_name=session_fixture["app_name"],
- user_id=session_fixture["user_id"],
- )
-
- time.sleep(0.01)
-
- adbc_store.create_event(
- event_id="event-2",
- session_id=session_fixture["session_id"],
- app_name=session_fixture["app_name"],
- user_id=session_fixture["user_id"],
- )
-
- time.sleep(0.01)
-
- adbc_store.create_event(
- event_id="event-3",
- session_id=session_fixture["session_id"],
- app_name=session_fixture["app_name"],
- user_id=session_fixture["user_id"],
- )
-
- events = adbc_store.list_events(session_fixture["session_id"])
+ ev1: EventRecord = {
+ "session_id": session_fixture["session_id"],
+ "invocation_id": "",
+ "author": "",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {"id": "event-1", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]},
+ }
+ await adbc_store.append_event(ev1)
+
+ await asyncio.sleep(0.01)
+
+ ev2: EventRecord = {
+ "session_id": session_fixture["session_id"],
+ "invocation_id": "",
+ "author": "",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {"id": "event-2", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]},
+ }
+ await adbc_store.append_event(ev2)
+
+ await asyncio.sleep(0.01)
+
+ ev3: EventRecord = {
+ "session_id": session_fixture["session_id"],
+ "invocation_id": "",
+ "author": "",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {"id": "event-3", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]},
+ }
+ await adbc_store.append_event(ev3)
+
+ events = await adbc_store.get_events(session_fixture["session_id"])
assert len(events) == 3
- assert events[0]["id"] == "event-1"
- assert events[1]["id"] == "event-2"
- assert events[2]["id"] == "event-3"
assert events[0]["timestamp"] < events[1]["timestamp"]
assert events[1]["timestamp"] < events[2]["timestamp"]
-def test_delete_session_cascades_events(adbc_store: Any, session_fixture: Any, tmp_path: Path) -> None:
+async def test_delete_session_cascades_events(adbc_store: Any, session_fixture: Any, tmp_path: Path) -> None:
"""Test that deleting a session cascades to delete events.
Note: SQLite with ADBC requires foreign key enforcement to be explicitly
enabled for cascade deletes to work. This test manually enables it.
"""
- adbc_store.create_event(
- event_id="event-1",
- session_id=session_fixture["session_id"],
- app_name=session_fixture["app_name"],
- user_id=session_fixture["user_id"],
- )
- adbc_store.create_event(
- event_id="event-2",
- session_id=session_fixture["session_id"],
- app_name=session_fixture["app_name"],
- user_id=session_fixture["user_id"],
- )
-
- events_before = adbc_store.list_events(session_fixture["session_id"])
+ ev1: EventRecord = {
+ "session_id": session_fixture["session_id"],
+ "invocation_id": "",
+ "author": "",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {"id": "event-1", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]},
+ }
+ ev2: EventRecord = {
+ "session_id": session_fixture["session_id"],
+ "invocation_id": "",
+ "author": "",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {"id": "event-2", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]},
+ }
+ await adbc_store.append_event(ev1)
+ await adbc_store.append_event(ev2)
+
+ events_before = await adbc_store.get_events(session_fixture["session_id"])
assert len(events_before) == 2
- # For SQLite with separate connections per operation, we need to manually delete events
- # or note that cascade deletes require persistent connections
- # For this test, just verify the session deletion works
- adbc_store.delete_session(session_fixture["session_id"])
+ await adbc_store.delete_session(session_fixture["session_id"])
- # Session should be gone
- session_after = adbc_store.get_session(session_fixture["session_id"])
+ session_after = await adbc_store.get_session(session_fixture["session_id"])
assert session_after is None
- # Events may still exist with ADBC SQLite due to FK enforcement across connections
- # This is a known limitation when using ADBC with SQLite in-memory or file-based
- # with separate connections per operation
-
-def test_event_with_empty_actions(adbc_store: Any, session_fixture: Any) -> None:
+async def test_event_with_empty_actions(adbc_store: Any, session_fixture: Any) -> None:
"""Test creating event with empty actions bytes."""
- event = adbc_store.create_event(
- event_id="empty-actions",
- session_id=session_fixture["session_id"],
- app_name=session_fixture["app_name"],
- user_id=session_fixture["user_id"],
- actions=b"",
+ event_record: EventRecord = {
+ "session_id": session_fixture["session_id"],
+ "invocation_id": "",
+ "author": "",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {
+ "id": "empty-actions",
+ "app_name": session_fixture["app_name"],
+ "user_id": session_fixture["user_id"],
+ },
+ }
+ await adbc_store.append_event(event_record)
+
+ events = await adbc_store.get_events(session_fixture["session_id"])
+ assert len(events) == 1
+ assert "event_json" in events[0]
+
+
+async def test_event_with_large_content(adbc_store: Any, session_fixture: Any) -> None:
+ """Test creating event with large content in event_json."""
+ large_content = {"data": "x" * 10000}
+
+ event_record: EventRecord = {
+ "session_id": session_fixture["session_id"],
+ "invocation_id": "",
+ "author": "",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {
+ "id": "large-content",
+ "content": large_content,
+ "app_name": session_fixture["app_name"],
+ "user_id": session_fixture["user_id"],
+ },
+ }
+ await adbc_store.append_event(event_record)
+
+ events = await adbc_store.get_events(session_fixture["session_id"])
+ assert len(events) == 1
+ event_data = (
+ json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"]
)
-
- assert event["actions"] == b""
-
- events = adbc_store.list_events(session_fixture["session_id"])
- assert events[0]["actions"] == b""
-
-
-def test_event_with_large_actions(adbc_store: Any, session_fixture: Any) -> None:
- """Test creating event with large actions BLOB."""
- large_actions = b"x" * 10000
-
- event = adbc_store.create_event(
- event_id="large-actions",
- session_id=session_fixture["session_id"],
- app_name=session_fixture["app_name"],
- user_id=session_fixture["user_id"],
- actions=large_actions,
- )
-
- assert event["actions"] == large_actions
- assert len(event["actions"]) == 10000
+ assert event_data["content"] == large_content
+
+
+async def test_append_event_preserves_existing_session_state(adbc_store: Any, session_fixture: Any) -> None:
+ """append_event must not overwrite the durable session state."""
+ event_record: EventRecord = {
+ "session_id": session_fixture["session_id"],
+ "invocation_id": "append-only",
+ "author": "user",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {
+ "id": "append-only-event",
+ "app_name": session_fixture["app_name"],
+ "user_id": session_fixture["user_id"],
+ },
+ }
+
+ await adbc_store.append_event(event_record)
+
+ session = await adbc_store.get_session(session_fixture["session_id"])
+ assert session is not None
+ assert session["state"] == {"test": True}
+
+
+async def test_get_events_applies_after_timestamp_and_limit(adbc_store: Any, session_fixture: Any) -> None:
+ """get_events must respect both after_timestamp and limit."""
+ base_time = datetime(2026, 1, 1, tzinfo=timezone.utc)
+ event_records = [
+ {
+ "session_id": session_fixture["session_id"],
+ "invocation_id": "",
+ "author": "user",
+ "timestamp": base_time,
+ "event_json": {
+ "id": "event-1",
+ "app_name": session_fixture["app_name"],
+ "user_id": session_fixture["user_id"],
+ },
+ },
+ {
+ "session_id": session_fixture["session_id"],
+ "invocation_id": "",
+ "author": "assistant",
+ "timestamp": base_time + timedelta(seconds=1),
+ "event_json": {
+ "id": "event-2",
+ "app_name": session_fixture["app_name"],
+ "user_id": session_fixture["user_id"],
+ },
+ },
+ {
+ "session_id": session_fixture["session_id"],
+ "invocation_id": "",
+ "author": "assistant",
+ "timestamp": base_time + timedelta(seconds=2),
+ "event_json": {
+ "id": "event-3",
+ "app_name": session_fixture["app_name"],
+ "user_id": session_fixture["user_id"],
+ },
+ },
+ ]
+
+ for event_record in event_records:
+ await adbc_store.append_event(event_record)
+
+ filtered_events = await adbc_store.get_events(session_fixture["session_id"], after_timestamp=base_time, limit=1)
+
+ assert len(filtered_events) == 1
+ filtered_event = filtered_events[0]["event_json"]
+ filtered_data = json.loads(filtered_event) if isinstance(filtered_event, str) else filtered_event
+ assert filtered_data["id"] == "event-2"
diff --git a/tests/integration/adapters/adbc/extensions/adk/test_memory_store.py b/tests/integration/adapters/adbc/extensions/adk/test_memory_store.py
index a417cd84f..0c3afa838 100644
--- a/tests/integration/adapters/adbc/extensions/adk/test_memory_store.py
+++ b/tests/integration/adapters/adbc/extensions/adk/test_memory_store.py
@@ -30,63 +30,63 @@ def _build_record(*, session_id: str, event_id: str, content_text: str, inserted
)
-def _build_store(tmp_path: Path) -> AdbcADKMemoryStore:
+async def _build_store(tmp_path: Path) -> AdbcADKMemoryStore:
db_path = tmp_path / "test_adk_memory.db"
config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"})
store = AdbcADKMemoryStore(config)
- store.create_tables()
+ await store.create_tables()
return store
-def test_adbc_memory_store_insert_search_dedup(tmp_path: Path) -> None:
+async def test_adbc_memory_store_insert_search_dedup(tmp_path: Path) -> None:
"""Insert memory entries, search by text, and skip duplicates."""
- store = _build_store(tmp_path)
+ store = await _build_store(tmp_path)
now = datetime.now(timezone.utc)
record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now)
record2 = _build_record(session_id="s1", event_id="evt-2", content_text="latte", inserted_at=now)
- inserted = store.insert_memory_entries([record1, record2])
+ inserted = await store.insert_memory_entries([record1, record2])
assert inserted == 2
- results = store.search_entries(query="espresso", app_name="app", user_id="user")
+ results = await store.search_entries(query="espresso", app_name="app", user_id="user")
assert len(results) == 1
assert results[0]["event_id"] == "evt-1"
- deduped = store.insert_memory_entries([record1])
+ deduped = await store.insert_memory_entries([record1])
assert deduped == 0
-def test_adbc_memory_store_delete_by_session(tmp_path: Path) -> None:
+async def test_adbc_memory_store_delete_by_session(tmp_path: Path) -> None:
"""Delete memory entries by session id."""
- store = _build_store(tmp_path)
+ store = await _build_store(tmp_path)
now = datetime.now(timezone.utc)
record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now)
record2 = _build_record(session_id="s2", event_id="evt-2", content_text="latte", inserted_at=now)
- store.insert_memory_entries([record1, record2])
+ await store.insert_memory_entries([record1, record2])
- deleted = store.delete_entries_by_session("s1")
+ deleted = await store.delete_entries_by_session("s1")
assert deleted == 1
- remaining = store.search_entries(query="latte", app_name="app", user_id="user")
+ remaining = await store.search_entries(query="latte", app_name="app", user_id="user")
assert len(remaining) == 1
assert remaining[0]["session_id"] == "s2"
-def test_adbc_memory_store_delete_older_than(tmp_path: Path) -> None:
+async def test_adbc_memory_store_delete_older_than(tmp_path: Path) -> None:
"""Delete memory entries older than a cutoff."""
- store = _build_store(tmp_path)
+ store = await _build_store(tmp_path)
now = datetime.now(timezone.utc)
old = now - timedelta(days=40)
record1 = _build_record(session_id="s1", event_id="evt-1", content_text="old", inserted_at=old)
record2 = _build_record(session_id="s1", event_id="evt-2", content_text="new", inserted_at=now)
- store.insert_memory_entries([record1, record2])
+ await store.insert_memory_entries([record1, record2])
- deleted = store.delete_entries_older_than(30)
+ deleted = await store.delete_entries_older_than(30)
assert deleted == 1
- remaining = store.search_entries(query="new", app_name="app", user_id="user")
+ remaining = await store.search_entries(query="new", app_name="app", user_id="user")
assert len(remaining) == 1
assert remaining[0]["event_id"] == "evt-2"
diff --git a/tests/integration/adapters/adbc/extensions/adk/test_owner_id_column.py b/tests/integration/adapters/adbc/extensions/adk/test_owner_id_column.py
index ce2a1bbfa..1b1752ae4 100644
--- a/tests/integration/adapters/adbc/extensions/adk/test_owner_id_column.py
+++ b/tests/integration/adapters/adbc/extensions/adk/test_owner_id_column.py
@@ -9,7 +9,7 @@
@pytest.fixture()
-def adbc_store_with_fk(tmp_path): # type: ignore[no-untyped-def]
+async def adbc_store_with_fk(tmp_path): # type: ignore[no-untyped-def]
"""Create ADBC ADK store with owner ID column (SQLite)."""
db_path = tmp_path / "test_fk.db"
config = AdbcConfig(
@@ -29,21 +29,21 @@ def adbc_store_with_fk(tmp_path): # type: ignore[no-untyped-def]
finally:
cursor.close() # type: ignore[no-untyped-call]
- store.create_tables()
+ await store.create_tables()
return store
@pytest.fixture()
-def adbc_store_no_fk(tmp_path): # type: ignore[no-untyped-def]
+async def adbc_store_no_fk(tmp_path): # type: ignore[no-untyped-def]
"""Create ADBC ADK store without owner ID column (SQLite)."""
db_path = tmp_path / "test_no_fk.db"
config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"})
store = AdbcADKStore(config)
- store.create_tables()
+ await store.create_tables()
return store
-def test_create_session_with_owner_id(adbc_store_with_fk): # type: ignore[no-untyped-def]
+async def test_create_session_with_owner_id(adbc_store_with_fk): # type: ignore[no-untyped-def]
"""Test creating session with owner ID value."""
session_id = "test-session-1"
app_name = "test-app"
@@ -51,32 +51,32 @@ def test_create_session_with_owner_id(adbc_store_with_fk): # type: ignore[no-un
state = {"key": "value"}
tenant_id = 1
- session = adbc_store_with_fk.create_session(session_id, app_name, user_id, state, owner_id=tenant_id)
+ session = await adbc_store_with_fk.create_session(session_id, app_name, user_id, state, owner_id=tenant_id)
assert session["id"] == session_id
assert session["state"] == state
-def test_create_session_without_owner_id_value(adbc_store_with_fk): # type: ignore[no-untyped-def]
+async def test_create_session_without_owner_id_value(adbc_store_with_fk): # type: ignore[no-untyped-def]
"""Test creating session without providing owner ID value still works."""
session_id = "test-session-2"
app_name = "test-app"
user_id = "user-123"
state = {"key": "value"}
- session = adbc_store_with_fk.create_session(session_id, app_name, user_id, state)
+ session = await adbc_store_with_fk.create_session(session_id, app_name, user_id, state)
assert session["id"] == session_id
-def test_create_session_no_fk_column_configured(adbc_store_no_fk): # type: ignore[no-untyped-def]
+async def test_create_session_no_fk_column_configured(adbc_store_no_fk): # type: ignore[no-untyped-def]
"""Test creating session when no FK column configured."""
session_id = "test-session-3"
app_name = "test-app"
user_id = "user-123"
state = {"key": "value"}
- session = adbc_store_no_fk.create_session(session_id, app_name, user_id, state)
+ session = await adbc_store_no_fk.create_session(session_id, app_name, user_id, state)
assert session["id"] == session_id
assert session["state"] == state
@@ -109,16 +109,16 @@ def test_owner_id_column_complex_ddl() -> None:
assert store._owner_id_column_ddl == complex_ddl # pyright: ignore[reportPrivateUsage]
-def test_multiple_tenants_isolation(adbc_store_with_fk): # type: ignore[no-untyped-def]
+async def test_multiple_tenants_isolation(adbc_store_with_fk): # type: ignore[no-untyped-def]
"""Test sessions are properly isolated by tenant."""
app_name = "test-app"
user_id = "user-123"
- adbc_store_with_fk.create_session("session-tenant1", app_name, user_id, {"data": "tenant1"}, owner_id=1)
- adbc_store_with_fk.create_session("session-tenant2", app_name, user_id, {"data": "tenant2"}, owner_id=2)
+ await adbc_store_with_fk.create_session("session-tenant1", app_name, user_id, {"data": "tenant1"}, owner_id=1)
+ await adbc_store_with_fk.create_session("session-tenant2", app_name, user_id, {"data": "tenant2"}, owner_id=2)
- retrieved1 = adbc_store_with_fk.get_session("session-tenant1")
- retrieved2 = adbc_store_with_fk.get_session("session-tenant2")
+ retrieved1 = await adbc_store_with_fk.get_session("session-tenant1")
+ retrieved2 = await adbc_store_with_fk.get_session("session-tenant2")
assert retrieved1["state"]["data"] == "tenant1"
assert retrieved2["state"]["data"] == "tenant2"
diff --git a/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py b/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py
index 819002edc..b749461d7 100644
--- a/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py
+++ b/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py
@@ -12,23 +12,23 @@
@pytest.fixture()
-def adbc_store(tmp_path: Path) -> AdbcADKStore:
+async def adbc_store(tmp_path: Path) -> AdbcADKStore:
"""Create ADBC ADK store with SQLite backend."""
db_path = tmp_path / "test_adk.db"
config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"})
store = AdbcADKStore(config)
- store.create_tables()
+ await store.create_tables()
return store
-def test_create_session(adbc_store: Any) -> None:
+async def test_create_session(adbc_store: Any) -> None:
"""Test creating a new session."""
session_id = "test-session-1"
app_name = "test-app"
user_id = "user-123"
state = {"key": "value", "count": 42}
- session = adbc_store.create_session(session_id, app_name, user_id, state)
+ session = await adbc_store.create_session(session_id, app_name, user_id, state)
assert session["id"] == session_id
assert session["app_name"] == app_name
@@ -38,82 +38,82 @@ def test_create_session(adbc_store: Any) -> None:
assert session["update_time"] is not None
-def test_get_session(adbc_store: Any) -> None:
+async def test_get_session(adbc_store: Any) -> None:
"""Test retrieving a session by ID."""
session_id = "test-session-2"
app_name = "test-app"
user_id = "user-123"
state = {"data": "test"}
- adbc_store.create_session(session_id, app_name, user_id, state)
- retrieved = adbc_store.get_session(session_id)
+ await adbc_store.create_session(session_id, app_name, user_id, state)
+ retrieved = await adbc_store.get_session(session_id)
assert retrieved is not None
assert retrieved["id"] == session_id
assert retrieved["state"] == state
-def test_get_nonexistent_session(adbc_store: Any) -> None:
+async def test_get_nonexistent_session(adbc_store: Any) -> None:
"""Test retrieving a session that doesn't exist."""
- result = adbc_store.get_session("nonexistent-id")
+ result = await adbc_store.get_session("nonexistent-id")
assert result is None
-def test_update_session_state(adbc_store: Any) -> None:
+async def test_update_session_state(adbc_store: Any) -> None:
"""Test updating session state."""
session_id = "test-session-3"
app_name = "test-app"
user_id = "user-123"
initial_state = {"version": 1}
- adbc_store.create_session(session_id, app_name, user_id, initial_state)
+ await adbc_store.create_session(session_id, app_name, user_id, initial_state)
new_state = {"version": 2, "updated": True}
- adbc_store.update_session_state(session_id, new_state)
+ await adbc_store.update_session_state(session_id, new_state)
- updated = adbc_store.get_session(session_id)
+ updated = await adbc_store.get_session(session_id)
assert updated is not None
assert updated["state"] == new_state
assert updated["state"] != initial_state
-def test_delete_session(adbc_store: Any) -> None:
+async def test_delete_session(adbc_store: Any) -> None:
"""Test deleting a session."""
session_id = "test-session-4"
app_name = "test-app"
user_id = "user-123"
state = {"data": "test"}
- adbc_store.create_session(session_id, app_name, user_id, state)
- assert adbc_store.get_session(session_id) is not None
+ await adbc_store.create_session(session_id, app_name, user_id, state)
+ assert await adbc_store.get_session(session_id) is not None
- adbc_store.delete_session(session_id)
- assert adbc_store.get_session(session_id) is None
+ await adbc_store.delete_session(session_id)
+ assert await adbc_store.get_session(session_id) is None
-def test_list_sessions(adbc_store: Any) -> None:
+async def test_list_sessions(adbc_store: Any) -> None:
"""Test listing sessions for an app and user."""
app_name = "test-app"
user_id = "user-123"
- adbc_store.create_session("session-1", app_name, user_id, {"num": 1})
- adbc_store.create_session("session-2", app_name, user_id, {"num": 2})
- adbc_store.create_session("session-3", "other-app", user_id, {"num": 3})
+ await adbc_store.create_session("session-1", app_name, user_id, {"num": 1})
+ await adbc_store.create_session("session-2", app_name, user_id, {"num": 2})
+ await adbc_store.create_session("session-3", "other-app", user_id, {"num": 3})
- sessions = adbc_store.list_sessions(app_name, user_id)
+ sessions = await adbc_store.list_sessions(app_name, user_id)
assert len(sessions) == 2
session_ids = {s["id"] for s in sessions}
assert session_ids == {"session-1", "session-2"}
-def test_list_sessions_empty(adbc_store: Any) -> None:
+async def test_list_sessions_empty(adbc_store: Any) -> None:
"""Test listing sessions when none exist."""
- sessions = adbc_store.list_sessions("nonexistent-app", "nonexistent-user")
+ sessions = await adbc_store.list_sessions("nonexistent-app", "nonexistent-user")
assert sessions == []
-def test_session_state_with_complex_data(adbc_store: Any) -> None:
+async def test_session_state_with_complex_data(adbc_store: Any) -> None:
"""Test session state with nested complex data structures."""
session_id = "complex-session"
app_name = "test-app"
@@ -125,41 +125,41 @@ def test_session_state_with_complex_data(adbc_store: Any) -> None:
"null_value": None,
}
- session = adbc_store.create_session(session_id, app_name, user_id, complex_state)
+ session = await adbc_store.create_session(session_id, app_name, user_id, complex_state)
assert session["state"] == complex_state
- retrieved = adbc_store.get_session(session_id)
+ retrieved = await adbc_store.get_session(session_id)
assert retrieved is not None
assert retrieved["state"] == complex_state
-def test_session_state_empty_dict(adbc_store: Any) -> None:
+async def test_session_state_empty_dict(adbc_store: Any) -> None:
"""Test creating session with empty state dictionary."""
session_id = "empty-state-session"
app_name = "test-app"
user_id = "user-123"
empty_state: dict[str, Any] = {}
- session = adbc_store.create_session(session_id, app_name, user_id, empty_state)
+ session = await adbc_store.create_session(session_id, app_name, user_id, empty_state)
assert session["state"] == empty_state
- retrieved = adbc_store.get_session(session_id)
+ retrieved = await adbc_store.get_session(session_id)
assert retrieved is not None
assert retrieved["state"] == empty_state
-def test_multiple_users_same_app(adbc_store: Any) -> None:
+async def test_multiple_users_same_app(adbc_store: Any) -> None:
"""Test sessions for multiple users in the same app."""
app_name = "test-app"
user1 = "user-1"
user2 = "user-2"
- adbc_store.create_session("session-user1-1", app_name, user1, {"user": 1})
- adbc_store.create_session("session-user1-2", app_name, user1, {"user": 1})
- adbc_store.create_session("session-user2-1", app_name, user2, {"user": 2})
+ await adbc_store.create_session("session-user1-1", app_name, user1, {"user": 1})
+ await adbc_store.create_session("session-user1-2", app_name, user1, {"user": 1})
+ await adbc_store.create_session("session-user2-1", app_name, user2, {"user": 2})
- user1_sessions = adbc_store.list_sessions(app_name, user1)
- user2_sessions = adbc_store.list_sessions(app_name, user2)
+ user1_sessions = await adbc_store.list_sessions(app_name, user1)
+ user2_sessions = await adbc_store.list_sessions(app_name, user2)
assert len(user1_sessions) == 2
assert len(user2_sessions) == 1
@@ -167,18 +167,18 @@ def test_multiple_users_same_app(adbc_store: Any) -> None:
assert all(s["user_id"] == user2 for s in user2_sessions)
-def test_session_ordering(adbc_store: Any) -> None:
+async def test_session_ordering(adbc_store: Any) -> None:
"""Test that sessions are ordered by update_time DESC."""
app_name = "test-app"
user_id = "user-123"
- adbc_store.create_session("session-1", app_name, user_id, {"order": 1})
- adbc_store.create_session("session-2", app_name, user_id, {"order": 2})
- adbc_store.create_session("session-3", app_name, user_id, {"order": 3})
+ await adbc_store.create_session("session-1", app_name, user_id, {"order": 1})
+ await adbc_store.create_session("session-2", app_name, user_id, {"order": 2})
+ await adbc_store.create_session("session-3", app_name, user_id, {"order": 3})
- adbc_store.update_session_state("session-1", {"order": 1, "updated": True})
+ await adbc_store.update_session_state("session-1", {"order": 1, "updated": True})
- sessions = adbc_store.list_sessions(app_name, user_id)
+ sessions = await adbc_store.list_sessions(app_name, user_id)
assert len(sessions) == 3
assert sessions[0]["id"] == "session-1"
diff --git a/tests/integration/adapters/asyncmy/extensions/adk/test_store.py b/tests/integration/adapters/asyncmy/extensions/adk/test_store.py
index d53323565..bd52cade2 100644
--- a/tests/integration/adapters/asyncmy/extensions/adk/test_store.py
+++ b/tests/integration/adapters/asyncmy/extensions/adk/test_store.py
@@ -1,11 +1,12 @@
"""Integration tests for AsyncMY ADK session store."""
-import pickle
+import json
from datetime import datetime, timezone
import pytest
from sqlspec.adapters.asyncmy.adk.store import AsyncmyADKStore
+from sqlspec.extensions.adk import EventRecord
pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.asyncmy, pytest.mark.integration]
@@ -51,13 +52,14 @@ async def test_storage_types_verification(asyncmy_adk_store: AsyncmyADKStore) ->
ORDER BY ORDINAL_POSITION
""")
event_columns = await cursor.fetchall()
+ event_col_names = [col[0] for col in event_columns]
- actions_col = next(col for col in event_columns if col[0] == "actions")
- assert actions_col[1] == "blob", "actions column must use BLOB type for pickled data"
-
- content_col = next((col for col in event_columns if col[0] == "content"), None)
- if content_col:
- assert content_col[1] == "json", "content column must use native JSON type"
+ # New 5-column schema: session_id, invocation_id, author, timestamp, event_json
+ assert "session_id" in event_col_names
+ assert "invocation_id" in event_col_names
+ assert "author" in event_col_names
+ assert "timestamp" in event_col_names
+ assert "event_json" in event_col_names
timestamp_col = next(col for col in event_columns if col[0] == "timestamp")
assert "timestamp(6)" in timestamp_col[2].lower(), "timestamp must be TIMESTAMP(6) for microseconds"
@@ -141,18 +143,14 @@ async def test_delete_session_cascade(asyncmy_adk_store: AsyncmyADKStore) -> Non
await asyncmy_adk_store.create_session(session_id, app_name, user_id, {"status": "active"})
- event_record = {
- "id": "event-001",
+ event_record: EventRecord = {
"session_id": session_id,
- "app_name": app_name,
- "user_id": user_id,
"invocation_id": "inv-001",
"author": "user",
- "actions": pickle.dumps([{"type": "test_action"}]),
"timestamp": datetime.now(timezone.utc),
- "content": {"text": "Hello"},
+ "event_json": {"content": {"text": "Hello"}, "app_name": app_name, "user_id": user_id},
}
- await asyncmy_adk_store.append_event(event_record) # type: ignore[arg-type]
+ await asyncmy_adk_store.append_event(event_record)
events_before = await asyncmy_adk_store.get_events(session_id)
assert len(events_before) == 1
@@ -174,48 +172,39 @@ async def test_append_and_get_events(asyncmy_adk_store: AsyncmyADKStore) -> None
await asyncmy_adk_store.create_session(session_id, app_name, user_id, {"status": "active"})
- event1 = {
- "id": "event-001",
+ event1: EventRecord = {
"session_id": session_id,
- "app_name": app_name,
- "user_id": user_id,
"invocation_id": "inv-001",
"author": "user",
- "actions": pickle.dumps([{"type": "message", "content": "Hello"}]),
"timestamp": datetime.now(timezone.utc),
- "content": {"text": "Hello", "role": "user"},
- "partial": False,
- "turn_complete": True,
+ "event_json": {"content": {"text": "Hello", "role": "user"}, "app_name": app_name},
}
- event2 = {
- "id": "event-002",
+ event2: EventRecord = {
"session_id": session_id,
- "app_name": app_name,
- "user_id": user_id,
"invocation_id": "inv-002",
"author": "assistant",
- "actions": pickle.dumps([{"type": "response", "content": "Hi there"}]),
"timestamp": datetime.now(timezone.utc),
- "content": {"text": "Hi there", "role": "assistant"},
- "partial": False,
- "turn_complete": True,
+ "event_json": {"content": {"text": "Hi there", "role": "assistant"}, "app_name": app_name},
}
- await asyncmy_adk_store.append_event(event1) # type: ignore[arg-type]
- await asyncmy_adk_store.append_event(event2) # type: ignore[arg-type]
+ await asyncmy_adk_store.append_event(event1)
+ await asyncmy_adk_store.append_event(event2)
events = await asyncmy_adk_store.get_events(session_id)
assert len(events) == 2
- assert events[0]["id"] == "event-001"
- assert events[1]["id"] == "event-002"
- assert events[0]["content"] is not None
- assert events[1]["content"] is not None
- assert events[0]["content"]["text"] == "Hello"
- assert events[1]["content"]["text"] == "Hi there"
- assert isinstance(events[0]["actions"], bytes)
- assert pickle.loads(events[0]["actions"])[0]["type"] == "message"
+ assert events[0]["author"] == "user"
+ assert events[1]["author"] == "assistant"
+ # Content is inside event_json
+ event0_data = (
+ json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"]
+ )
+ event1_data = (
+ json.loads(events[1]["event_json"]) if isinstance(events[1]["event_json"], str) else events[1]["event_json"]
+ )
+ assert event0_data["content"]["text"] == "Hello"
+ assert event1_data["content"]["text"] == "Hi there"
async def test_timestamp_precision(asyncmy_adk_store: AsyncmyADKStore) -> None:
@@ -230,17 +219,14 @@ async def test_timestamp_precision(asyncmy_adk_store: AsyncmyADKStore) -> None:
assert hasattr(created["create_time"], "microsecond")
event_time = datetime.now(timezone.utc)
- event = {
- "id": "event-micro",
+ event: EventRecord = {
"session_id": session_id,
- "app_name": app_name,
- "user_id": user_id,
"invocation_id": "inv-micro",
"author": "system",
- "actions": b"",
"timestamp": event_time,
+ "event_json": {"app_name": app_name},
}
- await asyncmy_adk_store.append_event(event) # type: ignore[arg-type]
+ await asyncmy_adk_store.append_event(event)
events = await asyncmy_adk_store.get_events(session_id)
assert len(events) == 1
diff --git a/tests/integration/adapters/duckdb/extensions/adk/test_memory_store.py b/tests/integration/adapters/duckdb/extensions/adk/test_memory_store.py
index e3092176b..b86fef3ae 100644
--- a/tests/integration/adapters/duckdb/extensions/adk/test_memory_store.py
+++ b/tests/integration/adapters/duckdb/extensions/adk/test_memory_store.py
@@ -30,63 +30,63 @@ def _build_record(*, session_id: str, event_id: str, content_text: str, inserted
)
-def _build_store(tmp_path: Path, worker_id: str) -> DuckdbADKMemoryStore:
+async def _build_store(tmp_path: Path, worker_id: str) -> DuckdbADKMemoryStore:
db_path = tmp_path / f"test_adk_memory_{worker_id}.duckdb"
config = DuckDBConfig(connection_config={"database": str(db_path)})
store = DuckdbADKMemoryStore(config)
- store.create_tables()
+ await store.create_tables()
return store
-def test_duckdb_memory_store_insert_search_dedup(tmp_path: Path, worker_id: str) -> None:
+async def test_duckdb_memory_store_insert_search_dedup(tmp_path: Path, worker_id: str) -> None:
"""Insert memory entries, search by text, and skip duplicates."""
- store = _build_store(tmp_path, worker_id)
+ store = await _build_store(tmp_path, worker_id)
now = datetime.now(timezone.utc)
record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now)
record2 = _build_record(session_id="s1", event_id="evt-2", content_text="latte", inserted_at=now)
- inserted = store.insert_memory_entries([record1, record2])
+ inserted = await store.insert_memory_entries([record1, record2])
assert inserted == 2
- results = store.search_entries(query="espresso", app_name="app", user_id="user")
+ results = await store.search_entries(query="espresso", app_name="app", user_id="user")
assert len(results) == 1
assert results[0]["event_id"] == "evt-1"
- deduped = store.insert_memory_entries([record1])
+ deduped = await store.insert_memory_entries([record1])
assert deduped == 0
-def test_duckdb_memory_store_delete_by_session(tmp_path: Path, worker_id: str) -> None:
+async def test_duckdb_memory_store_delete_by_session(tmp_path: Path, worker_id: str) -> None:
"""Delete memory entries by session id."""
- store = _build_store(tmp_path, worker_id)
+ store = await _build_store(tmp_path, worker_id)
now = datetime.now(timezone.utc)
record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now)
record2 = _build_record(session_id="s2", event_id="evt-2", content_text="latte", inserted_at=now)
- store.insert_memory_entries([record1, record2])
+ await store.insert_memory_entries([record1, record2])
- deleted = store.delete_entries_by_session("s1")
+ deleted = await store.delete_entries_by_session("s1")
assert deleted == 1
- remaining = store.search_entries(query="latte", app_name="app", user_id="user")
+ remaining = await store.search_entries(query="latte", app_name="app", user_id="user")
assert len(remaining) == 1
assert remaining[0]["session_id"] == "s2"
-def test_duckdb_memory_store_delete_older_than(tmp_path: Path, worker_id: str) -> None:
+async def test_duckdb_memory_store_delete_older_than(tmp_path: Path, worker_id: str) -> None:
"""Delete memory entries older than a cutoff."""
- store = _build_store(tmp_path, worker_id)
+ store = await _build_store(tmp_path, worker_id)
now = datetime.now(timezone.utc)
old = now - timedelta(days=40)
record1 = _build_record(session_id="s1", event_id="evt-1", content_text="old", inserted_at=old)
record2 = _build_record(session_id="s1", event_id="evt-2", content_text="new", inserted_at=now)
- store.insert_memory_entries([record1, record2])
+ await store.insert_memory_entries([record1, record2])
- deleted = store.delete_entries_older_than(30)
+ deleted = await store.delete_entries_older_than(30)
assert deleted == 1
- remaining = store.search_entries(query="new", app_name="app", user_id="user")
+ remaining = await store.search_entries(query="new", app_name="app", user_id="user")
assert len(remaining) == 1
assert remaining[0]["event_id"] == "evt-2"
diff --git a/tests/integration/adapters/duckdb/extensions/adk/test_store.py b/tests/integration/adapters/duckdb/extensions/adk/test_store.py
index a685e08b8..bc67ad007 100644
--- a/tests/integration/adapters/duckdb/extensions/adk/test_store.py
+++ b/tests/integration/adapters/duckdb/extensions/adk/test_store.py
@@ -1,6 +1,7 @@
"""Integration tests for DuckDB ADK session store."""
-from collections.abc import Generator
+import json
+from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from pathlib import Path
@@ -8,12 +9,13 @@
from sqlspec.adapters.duckdb.adk.store import DuckdbADKStore
from sqlspec.adapters.duckdb.config import DuckDBConfig
+from sqlspec.extensions.adk import EventRecord
pytestmark = [pytest.mark.duckdb, pytest.mark.integration]
@pytest.fixture
-def duckdb_adk_store(tmp_path: Path, worker_id: str) -> "Generator[DuckdbADKStore, None, None]":
+async def duckdb_adk_store(tmp_path: Path, worker_id: str) -> "AsyncGenerator[DuckdbADKStore, None]":
"""Create DuckDB ADK store with temporary file-based database.
Args:
@@ -34,27 +36,27 @@ def duckdb_adk_store(tmp_path: Path, worker_id: str) -> "Generator[DuckdbADKStor
extension_config={"adk": {"session_table": "test_sessions", "events_table": "test_events"}},
)
store = DuckdbADKStore(config)
- store.create_tables()
+ await store.create_tables()
yield store
finally:
if db_path.exists():
db_path.unlink()
-def test_create_tables(duckdb_adk_store: DuckdbADKStore) -> None:
+async def test_create_tables(duckdb_adk_store: DuckdbADKStore) -> None:
"""Test table creation succeeds without errors."""
assert duckdb_adk_store.session_table == "test_sessions"
assert duckdb_adk_store.events_table == "test_events"
-def test_create_and_get_session(duckdb_adk_store: DuckdbADKStore) -> None:
+async def test_create_and_get_session(duckdb_adk_store: DuckdbADKStore) -> None:
"""Test creating and retrieving a session."""
session_id = "session-001"
app_name = "test-app"
user_id = "user-001"
state = {"key": "value", "count": 42}
- created_session = duckdb_adk_store.create_session(
+ created_session = await duckdb_adk_store.create_session(
session_id=session_id, app_name=app_name, user_id=user_id, state=state
)
@@ -65,49 +67,51 @@ def test_create_and_get_session(duckdb_adk_store: DuckdbADKStore) -> None:
assert isinstance(created_session["create_time"], datetime)
assert isinstance(created_session["update_time"], datetime)
- retrieved_session = duckdb_adk_store.get_session(session_id)
+ retrieved_session = await duckdb_adk_store.get_session(session_id)
assert retrieved_session is not None
assert retrieved_session["id"] == session_id
assert retrieved_session["state"] == state
-def test_get_nonexistent_session(duckdb_adk_store: DuckdbADKStore) -> None:
+async def test_get_nonexistent_session(duckdb_adk_store: DuckdbADKStore) -> None:
"""Test getting a non-existent session returns None."""
- result = duckdb_adk_store.get_session("nonexistent-session")
+ result = await duckdb_adk_store.get_session("nonexistent-session")
assert result is None
-def test_update_session_state(duckdb_adk_store: DuckdbADKStore) -> None:
+async def test_update_session_state(duckdb_adk_store: DuckdbADKStore) -> None:
"""Test updating session state."""
session_id = "session-002"
initial_state = {"status": "active"}
updated_state = {"status": "completed", "result": "success"}
- duckdb_adk_store.create_session(session_id=session_id, app_name="test-app", user_id="user-002", state=initial_state)
+ await duckdb_adk_store.create_session(
+ session_id=session_id, app_name="test-app", user_id="user-002", state=initial_state
+ )
- session_before = duckdb_adk_store.get_session(session_id)
+ session_before = await duckdb_adk_store.get_session(session_id)
assert session_before is not None
assert session_before["state"] == initial_state
- duckdb_adk_store.update_session_state(session_id, updated_state)
+ await duckdb_adk_store.update_session_state(session_id, updated_state)
- session_after = duckdb_adk_store.get_session(session_id)
+ session_after = await duckdb_adk_store.get_session(session_id)
assert session_after is not None
assert session_after["state"] == updated_state
assert session_after["update_time"] >= session_before["update_time"]
-def test_list_sessions(duckdb_adk_store: DuckdbADKStore) -> None:
+async def test_list_sessions(duckdb_adk_store: DuckdbADKStore) -> None:
"""Test listing sessions for an app and user."""
app_name = "test-app"
user_id = "user-003"
- duckdb_adk_store.create_session("session-1", app_name, user_id, {"num": 1})
- duckdb_adk_store.create_session("session-2", app_name, user_id, {"num": 2})
- duckdb_adk_store.create_session("session-3", app_name, user_id, {"num": 3})
- duckdb_adk_store.create_session("session-other", "other-app", user_id, {"num": 999})
+ await duckdb_adk_store.create_session("session-1", app_name, user_id, {"num": 1})
+ await duckdb_adk_store.create_session("session-2", app_name, user_id, {"num": 2})
+ await duckdb_adk_store.create_session("session-3", app_name, user_id, {"num": 3})
+ await duckdb_adk_store.create_session("session-other", "other-app", user_id, {"num": 999})
- sessions = duckdb_adk_store.list_sessions(app_name, user_id)
+ sessions = await duckdb_adk_store.list_sessions(app_name, user_id)
assert len(sessions) == 3
session_ids = {s["id"] for s in sessions}
@@ -116,192 +120,214 @@ def test_list_sessions(duckdb_adk_store: DuckdbADKStore) -> None:
assert all(s["user_id"] == user_id for s in sessions)
-def test_list_sessions_empty(duckdb_adk_store: DuckdbADKStore) -> None:
+async def test_list_sessions_empty(duckdb_adk_store: DuckdbADKStore) -> None:
"""Test listing sessions when none exist."""
- sessions = duckdb_adk_store.list_sessions("nonexistent-app", "nonexistent-user")
+ sessions = await duckdb_adk_store.list_sessions("nonexistent-app", "nonexistent-user")
assert sessions == []
-def test_delete_session(duckdb_adk_store: DuckdbADKStore) -> None:
+async def test_delete_session(duckdb_adk_store: DuckdbADKStore) -> None:
"""Test deleting a session."""
session_id = "session-to-delete"
- duckdb_adk_store.create_session(session_id, "test-app", "user-004", {"data": "test"})
+ await duckdb_adk_store.create_session(session_id, "test-app", "user-004", {"data": "test"})
- assert duckdb_adk_store.get_session(session_id) is not None
+ assert await duckdb_adk_store.get_session(session_id) is not None
- duckdb_adk_store.delete_session(session_id)
+ await duckdb_adk_store.delete_session(session_id)
- assert duckdb_adk_store.get_session(session_id) is None
+ assert await duckdb_adk_store.get_session(session_id) is None
-def test_delete_session_cascade_events(duckdb_adk_store: DuckdbADKStore) -> None:
+async def test_delete_session_cascade_events(duckdb_adk_store: DuckdbADKStore) -> None:
"""Test deleting a session also deletes associated events."""
session_id = "session-with-events"
- duckdb_adk_store.create_session(session_id, "test-app", "user-005", {"data": "test"})
-
- event = duckdb_adk_store.create_event(
- event_id="event-001",
- session_id=session_id,
- app_name="test-app",
- user_id="user-005",
- author="user",
- actions=b"test-actions",
- content={"message": "Hello"},
- )
+ await duckdb_adk_store.create_session(session_id, "test-app", "user-005", {"data": "test"})
+
+ event_record: EventRecord = {
+ "session_id": session_id,
+ "invocation_id": "",
+ "author": "user",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {
+ "id": "event-001",
+ "content": {"message": "Hello"},
+ "app_name": "test-app",
+ "user_id": "user-005",
+ },
+ }
+ await duckdb_adk_store.append_event(event_record)
- assert event["id"] == "event-001"
- events = duckdb_adk_store.list_events(session_id)
+ events = await duckdb_adk_store.get_events(session_id)
assert len(events) == 1
- duckdb_adk_store.delete_session(session_id)
+ await duckdb_adk_store.delete_session(session_id)
- assert duckdb_adk_store.get_session(session_id) is None
- events_after = duckdb_adk_store.list_events(session_id)
+ assert await duckdb_adk_store.get_session(session_id) is None
+ events_after = await duckdb_adk_store.get_events(session_id)
assert len(events_after) == 0
-def test_create_and_get_event(duckdb_adk_store: DuckdbADKStore) -> None:
- """Test creating and retrieving an event."""
+async def test_create_event(duckdb_adk_store: DuckdbADKStore) -> None:
+ """Test creating an event and verifying the returned 5-key EventRecord."""
session_id = "session-006"
- duckdb_adk_store.create_session(session_id, "test-app", "user-006", {})
+ await duckdb_adk_store.create_session(session_id, "test-app", "user-006", {})
- event_id = "event-002"
timestamp = datetime.now(timezone.utc)
content = {"text": "Test message", "role": "user"}
- custom_metadata = {"source": "test"}
-
- created_event = duckdb_adk_store.create_event(
- event_id=event_id,
- session_id=session_id,
- app_name="test-app",
- user_id="user-006",
- author="user",
- actions=b"pickled-actions",
- content=content,
- timestamp=timestamp,
- custom_metadata=custom_metadata,
- )
-
- assert created_event["id"] == event_id
- assert created_event["session_id"] == session_id
- assert created_event["author"] == "user"
- assert created_event["content"] == content
- assert created_event["custom_metadata"] == custom_metadata
- retrieved_event = duckdb_adk_store.get_event(event_id)
- assert retrieved_event is not None
- assert retrieved_event["id"] == event_id
- assert retrieved_event["content"] == content
+ event_record: EventRecord = {
+ "session_id": session_id,
+ "invocation_id": "",
+ "author": "user",
+ "timestamp": timestamp,
+ "event_json": {"id": "event-002", "content": content, "app_name": "test-app", "user_id": "user-006"},
+ }
+ await duckdb_adk_store.append_event(event_record)
+ events = await duckdb_adk_store.get_events(session_id)
+ assert len(events) == 1
+ assert events[0]["session_id"] == session_id
+ assert events[0]["author"] == "user"
-def test_get_nonexistent_event(duckdb_adk_store: DuckdbADKStore) -> None:
- """Test getting a non-existent event returns None."""
- result = duckdb_adk_store.get_event("nonexistent-event")
- assert result is None
+ # Content is stored inside event_json
+ event_data = (
+ json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"]
+ )
+ assert event_data["content"] == content
-def test_list_events(duckdb_adk_store: DuckdbADKStore) -> None:
+async def test_list_events(duckdb_adk_store: DuckdbADKStore) -> None:
"""Test listing events for a session."""
session_id = "session-007"
- duckdb_adk_store.create_session(session_id, "test-app", "user-007", {})
-
- duckdb_adk_store.create_event(
- event_id="event-1",
- session_id=session_id,
- app_name="test-app",
- user_id="user-007",
- author="user",
- content={"message": "First"},
- )
- duckdb_adk_store.create_event(
- event_id="event-2",
- session_id=session_id,
- app_name="test-app",
- user_id="user-007",
- author="assistant",
- content={"message": "Second"},
- )
+ await duckdb_adk_store.create_session(session_id, "test-app", "user-007", {})
+
+ event1: EventRecord = {
+ "session_id": session_id,
+ "invocation_id": "",
+ "author": "user",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {"id": "event-1", "content": {"message": "First"}, "app_name": "test-app", "user_id": "user-007"},
+ }
+ event2: EventRecord = {
+ "session_id": session_id,
+ "invocation_id": "",
+ "author": "assistant",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {
+ "id": "event-2",
+ "content": {"message": "Second"},
+ "app_name": "test-app",
+ "user_id": "user-007",
+ },
+ }
+ await duckdb_adk_store.append_event(event1)
+ await duckdb_adk_store.append_event(event2)
- events = duckdb_adk_store.list_events(session_id)
+ events = await duckdb_adk_store.get_events(session_id)
assert len(events) == 2
- assert events[0]["id"] == "event-1"
- assert events[1]["id"] == "event-2"
+ assert events[0]["author"] == "user"
+ assert events[1]["author"] == "assistant"
assert events[0]["timestamp"] <= events[1]["timestamp"]
-def test_list_events_empty(duckdb_adk_store: DuckdbADKStore) -> None:
+async def test_list_events_empty(duckdb_adk_store: DuckdbADKStore) -> None:
"""Test listing events when none exist."""
session_id = "session-no-events"
- duckdb_adk_store.create_session(session_id, "test-app", "user-008", {})
+ await duckdb_adk_store.create_session(session_id, "test-app", "user-008", {})
- events = duckdb_adk_store.list_events(session_id)
+ events = await duckdb_adk_store.get_events(session_id)
assert events == []
-def test_event_with_optional_fields(duckdb_adk_store: DuckdbADKStore) -> None:
- """Test creating events with all optional fields."""
+async def test_event_with_optional_fields(duckdb_adk_store: DuckdbADKStore) -> None:
+ """Test creating events with optional fields stored in event_json."""
session_id = "session-008"
- duckdb_adk_store.create_session(session_id, "test-app", "user-008", {})
-
- event = duckdb_adk_store.create_event(
- event_id="event-full",
- session_id=session_id,
- app_name="test-app",
- user_id="user-008",
- author="assistant",
- actions=b"actions-data",
- content={"text": "Response"},
- invocation_id="inv-123",
- branch="main",
- grounding_metadata={"sources": ["doc1", "doc2"]},
- custom_metadata={"priority": "high"},
- partial=True,
- turn_complete=False,
- interrupted=False,
- error_code=None,
- error_message=None,
- )
+ await duckdb_adk_store.create_session(session_id, "test-app", "user-008", {})
+
+ event_record: EventRecord = {
+ "session_id": session_id,
+ "invocation_id": "inv-123",
+ "author": "assistant",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {
+ "id": "event-full",
+ "content": {"text": "Response"},
+ "app_name": "test-app",
+ "user_id": "user-008",
+ "branch": "main",
+ "grounding_metadata": {"sources": ["doc1", "doc2"]},
+ "custom_metadata": {"priority": "high"},
+ "partial": True,
+ "turn_complete": False,
+ "interrupted": False,
+ },
+ }
+ await duckdb_adk_store.append_event(event_record)
+
+ events = await duckdb_adk_store.get_events(session_id)
+ assert len(events) == 1
- assert event["invocation_id"] == "inv-123"
- assert event["branch"] == "main"
- assert event["grounding_metadata"] == {"sources": ["doc1", "doc2"]}
- assert event["partial"] is True
- assert event["turn_complete"] is False
+ # The 5-key record has invocation_id as a top-level indexed column
+ assert events[0]["invocation_id"] == "inv-123"
- retrieved = duckdb_adk_store.get_event("event-full")
- assert retrieved is not None
- assert retrieved["grounding_metadata"] == {"sources": ["doc1", "doc2"]}
+ # Other fields are inside event_json
+ event_data = (
+ json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"]
+ )
+ assert event_data["branch"] == "main"
+ assert event_data["grounding_metadata"] == {"sources": ["doc1", "doc2"]}
+ assert event_data["partial"] is True
+ assert event_data["turn_complete"] is False
-def test_event_ordering_by_timestamp(duckdb_adk_store: DuckdbADKStore) -> None:
+async def test_event_ordering_by_timestamp(duckdb_adk_store: DuckdbADKStore) -> None:
"""Test events are ordered by timestamp ascending."""
session_id = "session-009"
- duckdb_adk_store.create_session(session_id, "test-app", "user-009", {})
+ await duckdb_adk_store.create_session(session_id, "test-app", "user-009", {})
t1 = datetime.now(timezone.utc)
t2 = datetime.now(timezone.utc)
t3 = datetime.now(timezone.utc)
- duckdb_adk_store.create_event(
- event_id="event-middle", session_id=session_id, app_name="test-app", user_id="user-009", timestamp=t2
- )
- duckdb_adk_store.create_event(
- event_id="event-last", session_id=session_id, app_name="test-app", user_id="user-009", timestamp=t3
- )
- duckdb_adk_store.create_event(
- event_id="event-first", session_id=session_id, app_name="test-app", user_id="user-009", timestamp=t1
- )
+ ev_middle: EventRecord = {
+ "session_id": session_id,
+ "invocation_id": "",
+ "author": "",
+ "timestamp": t2,
+ "event_json": {"id": "event-middle", "app_name": "test-app", "user_id": "user-009"},
+ }
+ ev_last: EventRecord = {
+ "session_id": session_id,
+ "invocation_id": "",
+ "author": "",
+ "timestamp": t3,
+ "event_json": {"id": "event-last", "app_name": "test-app", "user_id": "user-009"},
+ }
+ ev_first: EventRecord = {
+ "session_id": session_id,
+ "invocation_id": "",
+ "author": "",
+ "timestamp": t1,
+ "event_json": {"id": "event-first", "app_name": "test-app", "user_id": "user-009"},
+ }
+
+ await duckdb_adk_store.append_event(ev_middle)
+ await duckdb_adk_store.append_event(ev_last)
+ await duckdb_adk_store.append_event(ev_first)
- events = duckdb_adk_store.list_events(session_id)
+ events = await duckdb_adk_store.get_events(session_id)
assert len(events) == 3
- assert events[0]["id"] == "event-first"
- assert events[1]["id"] == "event-middle"
- assert events[2]["id"] == "event-last"
+ # Events should be ordered by timestamp ASC
+ event_ids = []
+ for e in events:
+ data = json.loads(e["event_json"]) if isinstance(e["event_json"], str) else e["event_json"]
+ event_ids.append(data["id"])
+ assert event_ids == ["event-first", "event-middle", "event-last"]
-def test_session_state_with_complex_data(duckdb_adk_store: DuckdbADKStore) -> None:
+async def test_session_state_with_complex_data(duckdb_adk_store: DuckdbADKStore) -> None:
"""Test session state with nested JSON structures."""
session_id = "session-complex"
complex_state = {
@@ -314,86 +340,84 @@ def test_session_state_with_complex_data(duckdb_adk_store: DuckdbADKStore) -> No
"flags": [True, False, True],
}
- duckdb_adk_store.create_session(session_id, "test-app", "user-010", complex_state)
+ await duckdb_adk_store.create_session(session_id, "test-app", "user-010", complex_state)
- session = duckdb_adk_store.get_session(session_id)
+ session = await duckdb_adk_store.get_session(session_id)
assert session is not None
assert session["state"] == complex_state
assert session["state"]["user"]["preferences"]["theme"] == "dark"
assert session["state"]["conversation"]["turn_count"] == 5
-def test_empty_state(duckdb_adk_store: DuckdbADKStore) -> None:
+async def test_empty_state(duckdb_adk_store: DuckdbADKStore) -> None:
"""Test creating session with empty state."""
session_id = "session-empty-state"
- duckdb_adk_store.create_session(session_id, "test-app", "user-011", {})
+ await duckdb_adk_store.create_session(session_id, "test-app", "user-011", {})
- session = duckdb_adk_store.get_session(session_id)
+ session = await duckdb_adk_store.get_session(session_id)
assert session is not None
assert session["state"] == {}
-def test_table_not_found_handling(tmp_path: Path, worker_id: str) -> None:
+async def test_table_not_found_handling(tmp_path: Path, worker_id: str) -> None:
"""Test graceful handling when tables don't exist."""
db_path = tmp_path / f"test_no_tables_{worker_id}.duckdb"
try:
config = DuckDBConfig(connection_config={"database": str(db_path)})
store = DuckdbADKStore(config)
- result = store.get_session("nonexistent")
+ result = await store.get_session("nonexistent")
assert result is None
- sessions = store.list_sessions("app", "user")
+ sessions = await store.list_sessions("app", "user")
assert sessions == []
- events = store.list_events("session")
+ events = await store.get_events("session")
assert events == []
finally:
if db_path.exists():
db_path.unlink()
-def test_binary_actions_data(duckdb_adk_store: DuckdbADKStore) -> None:
- """Test storing and retrieving binary actions data."""
- session_id = "session-binary"
- duckdb_adk_store.create_session(session_id, "test-app", "user-012", {})
+async def test_event_json_round_trip(duckdb_adk_store: DuckdbADKStore) -> None:
+ """Test storing and retrieving event data via event_json."""
+ session_id = "session-json-rt"
+ await duckdb_adk_store.create_session(session_id, "test-app", "user-012", {})
- binary_data = bytes(range(256))
+ event_record: EventRecord = {
+ "session_id": session_id,
+ "invocation_id": "",
+ "author": "system",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {"id": "event-json", "content": {"data": "value"}, "app_name": "test-app", "user_id": "user-012"},
+ }
+ await duckdb_adk_store.append_event(event_record)
- event = duckdb_adk_store.create_event(
- event_id="event-binary",
- session_id=session_id,
- app_name="test-app",
- user_id="user-012",
- author="system",
- actions=binary_data,
+ events = await duckdb_adk_store.get_events(session_id)
+ assert len(events) == 1
+ event_data = (
+ json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"]
)
-
- assert event["actions"] == binary_data
-
- retrieved = duckdb_adk_store.get_event("event-binary")
- assert retrieved is not None
- assert retrieved["actions"] == binary_data
- assert len(retrieved["actions"]) == 256
+ assert event_data["content"] == {"data": "value"}
-def test_concurrent_session_updates(duckdb_adk_store: DuckdbADKStore) -> None:
+async def test_concurrent_session_updates(duckdb_adk_store: DuckdbADKStore) -> None:
"""Test multiple updates to same session."""
session_id = "session-concurrent"
- duckdb_adk_store.create_session(session_id, "test-app", "user-013", {"counter": 0})
+ await duckdb_adk_store.create_session(session_id, "test-app", "user-013", {"counter": 0})
for i in range(10):
- session = duckdb_adk_store.get_session(session_id)
+ session = await duckdb_adk_store.get_session(session_id)
assert session is not None
current_counter = session["state"]["counter"]
- duckdb_adk_store.update_session_state(session_id, {"counter": current_counter + 1})
+ await duckdb_adk_store.update_session_state(session_id, {"counter": current_counter + 1})
- final_session = duckdb_adk_store.get_session(session_id)
+ final_session = await duckdb_adk_store.get_session(session_id)
assert final_session is not None
assert final_session["state"]["counter"] == 10
-def test_owner_id_column_with_integer(tmp_path: Path, worker_id: str) -> None:
+async def test_owner_id_column_with_integer(tmp_path: Path, worker_id: str) -> None:
"""Test owner ID column with INTEGER type."""
db_path = tmp_path / f"test_owner_id_int_{worker_id}.duckdb"
try:
@@ -415,12 +439,12 @@ def test_owner_id_column_with_integer(tmp_path: Path, worker_id: str) -> None:
},
)
store = DuckdbADKStore(config_with_extension)
- store.create_tables()
+ await store.create_tables()
assert store.owner_id_column_name == "tenant_id"
assert store.owner_id_column_ddl == "tenant_id INTEGER NOT NULL REFERENCES tenants(id)"
- session = store.create_session(
+ session = await store.create_session(
session_id="session-tenant-1", app_name="test-app", user_id="user-001", state={"data": "test"}, owner_id=1
)
@@ -436,7 +460,7 @@ def test_owner_id_column_with_integer(tmp_path: Path, worker_id: str) -> None:
db_path.unlink()
-def test_owner_id_column_with_ubigint(tmp_path: Path, worker_id: str) -> None:
+async def test_owner_id_column_with_ubigint(tmp_path: Path, worker_id: str) -> None:
"""Test owner ID column with DuckDB UBIGINT type."""
db_path = tmp_path / f"test_owner_id_ubigint_{worker_id}.duckdb"
try:
@@ -458,11 +482,11 @@ def test_owner_id_column_with_ubigint(tmp_path: Path, worker_id: str) -> None:
},
)
store = DuckdbADKStore(config_with_extension)
- store.create_tables()
+ await store.create_tables()
assert store.owner_id_column_name == "owner_id"
- session = store.create_session(
+ session = await store.create_session(
session_id="session-user-1",
app_name="test-app",
user_id="user-001",
@@ -482,7 +506,7 @@ def test_owner_id_column_with_ubigint(tmp_path: Path, worker_id: str) -> None:
db_path.unlink()
-def test_owner_id_column_foreign_key_constraint(tmp_path: Path, worker_id: str) -> None:
+async def test_owner_id_column_foreign_key_constraint(tmp_path: Path, worker_id: str) -> None:
"""Test that FK constraint is enforced."""
db_path = tmp_path / f"test_owner_id_constraint_{worker_id}.duckdb"
try:
@@ -504,14 +528,14 @@ def test_owner_id_column_foreign_key_constraint(tmp_path: Path, worker_id: str)
},
)
store = DuckdbADKStore(config_with_extension)
- store.create_tables()
+ await store.create_tables()
- store.create_session(
+ await store.create_session(
session_id="session-org-1", app_name="test-app", user_id="user-001", state={"data": "test"}, owner_id=100
)
with pytest.raises(Exception) as exc_info:
- store.create_session(
+ await store.create_session(
session_id="session-org-invalid",
app_name="test-app",
user_id="user-002",
@@ -525,7 +549,7 @@ def test_owner_id_column_foreign_key_constraint(tmp_path: Path, worker_id: str)
db_path.unlink()
-def test_owner_id_column_without_value(tmp_path: Path, worker_id: str) -> None:
+async def test_owner_id_column_without_value(tmp_path: Path, worker_id: str) -> None:
"""Test creating session without owner_id when column is configured but nullable."""
db_path = tmp_path / f"test_owner_id_nullable_{worker_id}.duckdb"
try:
@@ -546,22 +570,22 @@ def test_owner_id_column_without_value(tmp_path: Path, worker_id: str) -> None:
},
)
store = DuckdbADKStore(config_with_extension)
- store.create_tables()
+ await store.create_tables()
- session = store.create_session(
+ session = await store.create_session(
session_id="session-no-fk", app_name="test-app", user_id="user-001", state={"data": "test"}, owner_id=None
)
assert session["id"] == "session-no-fk"
- retrieved = store.get_session("session-no-fk")
+ retrieved = await store.get_session("session-no-fk")
assert retrieved is not None
finally:
if db_path.exists():
db_path.unlink()
-def test_owner_id_column_with_varchar(tmp_path: Path, worker_id: str) -> None:
+async def test_owner_id_column_with_varchar(tmp_path: Path, worker_id: str) -> None:
"""Test owner ID column with VARCHAR type."""
db_path = tmp_path / f"test_owner_id_varchar_{worker_id}.duckdb"
try:
@@ -583,9 +607,9 @@ def test_owner_id_column_with_varchar(tmp_path: Path, worker_id: str) -> None:
},
)
store = DuckdbADKStore(config_with_extension)
- store.create_tables()
+ await store.create_tables()
- session = store.create_session(
+ session = await store.create_session(
session_id="session-company-1",
app_name="test-app",
user_id="user-001",
@@ -605,7 +629,7 @@ def test_owner_id_column_with_varchar(tmp_path: Path, worker_id: str) -> None:
db_path.unlink()
-def test_owner_id_column_multiple_sessions(tmp_path: Path, worker_id: str) -> None:
+async def test_owner_id_column_multiple_sessions(tmp_path: Path, worker_id: str) -> None:
"""Test multiple sessions with same FK value."""
db_path = tmp_path / f"test_owner_id_multiple_{worker_id}.duckdb"
try:
@@ -627,10 +651,10 @@ def test_owner_id_column_multiple_sessions(tmp_path: Path, worker_id: str) -> No
},
)
store = DuckdbADKStore(config_with_extension)
- store.create_tables()
+ await store.create_tables()
for i in range(5):
- store.create_session(
+ await store.create_session(
session_id=f"session-dept-{i}",
app_name="test-app",
user_id=f"user-{i}",
@@ -648,7 +672,7 @@ def test_owner_id_column_multiple_sessions(tmp_path: Path, worker_id: str) -> No
db_path.unlink()
-def test_owner_id_column_query_by_fk(tmp_path: Path, worker_id: str) -> None:
+async def test_owner_id_column_query_by_fk(tmp_path: Path, worker_id: str) -> None:
"""Test querying sessions by FK column value."""
db_path = tmp_path / f"test_owner_id_query_{worker_id}.duckdb"
try:
@@ -670,11 +694,11 @@ def test_owner_id_column_query_by_fk(tmp_path: Path, worker_id: str) -> None:
},
)
store = DuckdbADKStore(config_with_extension)
- store.create_tables()
+ await store.create_tables()
- store.create_session("s1", "app", "u1", {"val": 1}, owner_id=1)
- store.create_session("s2", "app", "u2", {"val": 2}, owner_id=1)
- store.create_session("s3", "app", "u3", {"val": 3}, owner_id=2)
+ await store.create_session("s1", "app", "u1", {"val": 1}, owner_id=1)
+ await store.create_session("s2", "app", "u2", {"val": 2}, owner_id=1)
+ await store.create_session("s3", "app", "u3", {"val": 3}, owner_id=2)
with config.provide_connection() as conn:
cursor = conn.execute("SELECT id FROM sessions_with_project WHERE project_id = ? ORDER BY id", (1,))
diff --git a/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py b/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py
index 5e6253a47..eb12c69f8 100644
--- a/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py
+++ b/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py
@@ -1,12 +1,13 @@
"""Integration tests for MysqlConnector ADK session store."""
-import pickle
+import json
from datetime import datetime, timezone
from typing import cast
import pytest
from sqlspec.adapters.mysqlconnector.adk.store import MysqlConnectorAsyncADKStore
+from sqlspec.extensions.adk import EventRecord
pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.mysql_connector, pytest.mark.integration]
@@ -50,13 +51,14 @@ async def test_storage_types_verification(mysqlconnector_adk_store: MysqlConnect
ORDER BY ORDINAL_POSITION
""")
event_columns = await cursor.fetchall()
+ event_col_names = [col[0] for col in event_columns]
- actions_col = next(col for col in event_columns if col[0] == "actions")
- assert actions_col[1] == "blob"
-
- content_col = next((col for col in event_columns if col[0] == "content"), None)
- if content_col:
- assert content_col[1] == "json"
+ # New 5-column schema: session_id, invocation_id, author, timestamp, event_json
+ assert "session_id" in event_col_names
+ assert "invocation_id" in event_col_names
+ assert "author" in event_col_names
+ assert "timestamp" in event_col_names
+ assert "event_json" in event_col_names
timestamp_col = next(col for col in event_columns if col[0] == "timestamp")
assert "timestamp(6)" in cast("str", timestamp_col[2]).lower()
@@ -142,18 +144,14 @@ async def test_delete_session_cascade(mysqlconnector_adk_store: MysqlConnectorAs
await mysqlconnector_adk_store.create_session(session_id, app_name, user_id, {"status": "active"})
- event_record = {
- "id": "event-001",
+ event_record: EventRecord = {
"session_id": session_id,
- "app_name": app_name,
- "user_id": user_id,
"invocation_id": "inv-001",
"author": "user",
- "actions": pickle.dumps([{"type": "test_action"}]),
"timestamp": datetime.now(timezone.utc),
- "content": {"text": "Hello"},
+ "event_json": {"content": {"text": "Hello"}, "app_name": app_name, "user_id": user_id},
}
- await mysqlconnector_adk_store.append_event(event_record) # type: ignore[arg-type]
+ await mysqlconnector_adk_store.append_event(event_record)
events_before = await mysqlconnector_adk_store.get_events(session_id)
assert len(events_before) == 1
@@ -175,48 +173,39 @@ async def test_append_and_get_events(mysqlconnector_adk_store: MysqlConnectorAsy
await mysqlconnector_adk_store.create_session(session_id, app_name, user_id, {"status": "active"})
- event1 = {
- "id": "event-001",
+ event1: EventRecord = {
"session_id": session_id,
- "app_name": app_name,
- "user_id": user_id,
"invocation_id": "inv-001",
"author": "user",
- "actions": pickle.dumps([{"type": "message", "content": "Hello"}]),
"timestamp": datetime.now(timezone.utc),
- "content": {"text": "Hello", "role": "user"},
- "partial": False,
- "turn_complete": True,
+ "event_json": {"content": {"text": "Hello", "role": "user"}, "app_name": app_name},
}
- event2 = {
- "id": "event-002",
+ event2: EventRecord = {
"session_id": session_id,
- "app_name": app_name,
- "user_id": user_id,
"invocation_id": "inv-002",
"author": "assistant",
- "actions": pickle.dumps([{"type": "response", "content": "Hi there"}]),
"timestamp": datetime.now(timezone.utc),
- "content": {"text": "Hi there", "role": "assistant"},
- "partial": False,
- "turn_complete": True,
+ "event_json": {"content": {"text": "Hi there", "role": "assistant"}, "app_name": app_name},
}
- await mysqlconnector_adk_store.append_event(event1) # type: ignore[arg-type]
- await mysqlconnector_adk_store.append_event(event2) # type: ignore[arg-type]
+ await mysqlconnector_adk_store.append_event(event1)
+ await mysqlconnector_adk_store.append_event(event2)
events = await mysqlconnector_adk_store.get_events(session_id)
assert len(events) == 2
- assert events[0]["id"] == "event-001"
- assert events[1]["id"] == "event-002"
- content0 = events[0]["content"]
- content1 = events[1]["content"]
- assert content0 is not None and content0["text"] == "Hello"
- assert content1 is not None and content1["text"] == "Hi there"
- assert isinstance(events[0]["actions"], bytes)
- assert pickle.loads(events[0]["actions"])[0]["type"] == "message"
+ assert events[0]["author"] == "user"
+ assert events[1]["author"] == "assistant"
+ # Content is inside event_json
+ event0_data = (
+ json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"]
+ )
+ event1_data = (
+ json.loads(events[1]["event_json"]) if isinstance(events[1]["event_json"], str) else events[1]["event_json"]
+ )
+ assert event0_data["content"]["text"] == "Hello"
+ assert event1_data["content"]["text"] == "Hi there"
async def test_timestamp_precision(mysqlconnector_adk_store: MysqlConnectorAsyncADKStore) -> None:
@@ -230,17 +219,14 @@ async def test_timestamp_precision(mysqlconnector_adk_store: MysqlConnectorAsync
assert hasattr(created["create_time"], "microsecond")
event_time = datetime.now(timezone.utc)
- event = {
- "id": "event-micro",
+ event: EventRecord = {
"session_id": session_id,
- "app_name": app_name,
- "user_id": user_id,
"invocation_id": "inv-micro",
"author": "system",
- "actions": b"",
"timestamp": event_time,
+ "event_json": {"app_name": app_name},
}
- await mysqlconnector_adk_store.append_event(event) # type: ignore[arg-type]
+ await mysqlconnector_adk_store.append_event(event)
events = await mysqlconnector_adk_store.get_events(session_id)
assert len(events) == 1
diff --git a/tests/integration/adapters/oracledb/extensions/adk/test_inmemory.py b/tests/integration/adapters/oracledb/extensions/adk/test_inmemory.py
index 26d86165f..5a32a911b 100644
--- a/tests/integration/adapters/oracledb/extensions/adk/test_inmemory.py
+++ b/tests/integration/adapters/oracledb/extensions/adk/test_inmemory.py
@@ -302,14 +302,14 @@ async def test_inmemory_tables_functional_async(oracle_async_config: OracleAsync
@pytest.mark.oracledb
-def test_inmemory_enabled_sync(oracle_sync_config: OracleSyncConfig) -> None:
+async def test_inmemory_enabled_sync(oracle_sync_config: OracleSyncConfig) -> None:
"""Test that in_memory=True works with sync store."""
config = OracleSyncConfig(
connection_config=oracle_sync_config.connection_config, extension_config={"adk": {"in_memory": True}}
)
store = OracleSyncADKStore(config)
- store.create_tables()
+ await store.create_tables()
try:
with config.provide_connection() as conn:
@@ -343,14 +343,14 @@ def test_inmemory_enabled_sync(oracle_sync_config: OracleSyncConfig) -> None:
@pytest.mark.oracledb
-def test_inmemory_disabled_sync(oracle_sync_config: OracleSyncConfig) -> None:
+async def test_inmemory_disabled_sync(oracle_sync_config: OracleSyncConfig) -> None:
"""Test that in_memory=False works with sync store."""
config = OracleSyncConfig(
connection_config=oracle_sync_config.connection_config, extension_config={"adk": {"in_memory": False}}
)
store = OracleSyncADKStore(config)
- store.create_tables()
+ await store.create_tables()
try:
with config.provide_connection() as conn:
@@ -382,14 +382,14 @@ def test_inmemory_disabled_sync(oracle_sync_config: OracleSyncConfig) -> None:
@pytest.mark.oracledb
-def test_inmemory_tables_functional_sync(oracle_sync_config: OracleSyncConfig) -> None:
+async def test_inmemory_tables_functional_sync(oracle_sync_config: OracleSyncConfig) -> None:
"""Test that INMEMORY tables work correctly in sync mode."""
config = OracleSyncConfig(
connection_config=oracle_sync_config.connection_config, extension_config={"adk": {"in_memory": True}}
)
store = OracleSyncADKStore(config)
- store.create_tables()
+ await store.create_tables()
try:
session_id = "inmemory-sync-session"
@@ -397,11 +397,11 @@ def test_inmemory_tables_functional_sync(oracle_sync_config: OracleSyncConfig) -
user_id = "user-456"
state = {"sync": True, "value": 99}
- session = store.create_session(session_id, app_name, user_id, state)
+ session = await store.create_session(session_id, app_name, user_id, state)
assert session["id"] == session_id
assert session["state"] == state
- retrieved = store.get_session(session_id)
+ retrieved = await store.get_session(session_id)
assert retrieved is not None
assert retrieved["state"] == state
diff --git a/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py b/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py
index 52cd230f4..1e6cb1f94 100644
--- a/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py
+++ b/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py
@@ -1,7 +1,7 @@
"""Oracle-specific ADK store tests for LOB handling, JSON types, and FK columns."""
-import pickle
-from collections.abc import AsyncGenerator, Generator
+import json
+from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from typing import Any, cast
from uuid import uuid4
@@ -62,11 +62,11 @@ async def oracle_async_store(oracle_async_config: "OracleAsyncConfig") -> "Async
await _cleanup_async_store(store, oracle_async_config)
-@pytest.fixture(scope="module")
-def oracle_sync_store(oracle_sync_config: "OracleSyncConfig") -> "Generator[OracleSyncADKStore, None, None]":
- """Create a sync Oracle ADK store with tables created once per module."""
+@pytest.fixture
+async def oracle_sync_store(oracle_sync_config: "OracleSyncConfig") -> "AsyncGenerator[OracleSyncADKStore, None]":
+ """Create a sync Oracle ADK store with tables created per test."""
store = OracleSyncADKStore(oracle_sync_config)
- store.create_tables()
+ await store.create_tables()
try:
yield store
finally:
@@ -140,7 +140,9 @@ async def oracle_store_with_fk(
@pytest.fixture
-def oracle_config_with_users_table(oracle_sync_config: "OracleSyncConfig") -> "Generator[OracleSyncConfig, None, None]":
+async def oracle_config_with_users_table(
+ oracle_sync_config: "OracleSyncConfig",
+) -> "AsyncGenerator[OracleSyncConfig, None]":
"""Create a users table for FK testing."""
with oracle_sync_config.provide_connection() as conn:
cursor = conn.cursor()
@@ -187,9 +189,9 @@ def oracle_config_with_users_table(oracle_sync_config: "OracleSyncConfig") -> "G
@pytest.fixture
-def oracle_store_sync_with_fk(
+async def oracle_store_sync_with_fk(
oracle_config_with_users_table: "OracleSyncConfig",
-) -> "Generator[OracleSyncADKStore, None, None]":
+) -> "AsyncGenerator[OracleSyncADKStore, None]":
"""Create a sync Oracle ADK store with owner_id FK column."""
config_with_extension = OracleSyncConfig(
connection_config=oracle_config_with_users_table.connection_config,
@@ -197,7 +199,7 @@ def oracle_store_sync_with_fk(
)
store = OracleSyncADKStore(config_with_extension)
_cleanup_sync_store(store, config_with_extension)
- store.create_tables()
+ await store.create_tables()
try:
yield store
finally:
@@ -220,8 +222,8 @@ async def test_state_lob_deserialization(oracle_async_store: "OracleAsyncADKStor
assert retrieved["state"]["large_field"] == "x" * 10000
-async def test_event_content_lob_deserialization(oracle_async_store: "OracleAsyncADKStore") -> None:
- """Test event content CLOB is correctly deserialized."""
+async def test_event_json_lob_deserialization(oracle_async_store: "OracleAsyncADKStore") -> None:
+ """Test event_json LOB data is correctly deserialized."""
session_id = _unique_session_id("event-lob")
app_name = "test-app"
user_id = "user-123"
@@ -229,169 +231,130 @@ async def test_event_content_lob_deserialization(oracle_async_store: "OracleAsyn
await oracle_async_store.create_session(session_id, app_name, user_id, {})
content = {"message": "x" * 5000, "data": {"nested": True}}
- grounding_metadata = {"sources": ["a" * 1000, "b" * 1000]}
- custom_metadata = {"tags": ["tag1", "tag2"], "priority": "high"}
+ event_data = {
+ "content": content,
+ "app_name": app_name,
+ "user_id": user_id,
+ "grounding_metadata": {"sources": ["a" * 1000, "b" * 1000]},
+ "custom_metadata": {"tags": ["tag1", "tag2"], "priority": "high"},
+ }
event_record: EventRecord = {
- "id": "event-1",
"session_id": session_id,
- "app_name": app_name,
- "user_id": user_id,
+ "invocation_id": "",
"author": "assistant",
- "actions": pickle.dumps([{"name": "test", "args": {}}]),
- "content": content,
- "grounding_metadata": grounding_metadata,
- "custom_metadata": custom_metadata,
"timestamp": datetime.now(timezone.utc),
- "partial": False,
- "turn_complete": True,
- "interrupted": False,
- "error_code": None,
- "error_message": None,
- "invocation_id": "",
- "branch": None,
- "long_running_tool_ids_json": None,
+ "event_json": event_data,
}
await oracle_async_store.append_event(event_record)
events = await oracle_async_store.get_events(session_id)
assert len(events) == 1
- assert events[0]["content"] == content
- assert events[0]["grounding_metadata"] == grounding_metadata
- assert events[0]["custom_metadata"] == custom_metadata
+ # event_json contains all the data
+ retrieved_data = (
+ json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"]
+ )
+ assert retrieved_data["content"] == content
+ assert retrieved_data["grounding_metadata"] == {"sources": ["a" * 1000, "b" * 1000]}
+ assert retrieved_data["custom_metadata"] == {"tags": ["tag1", "tag2"], "priority": "high"}
-async def test_actions_blob_handling(oracle_async_store: "OracleAsyncADKStore") -> None:
- """Test actions BLOB is correctly read and unpickled."""
- session_id = _unique_session_id("actions-blob")
+async def test_event_json_storage(oracle_async_store: "OracleAsyncADKStore") -> None:
+ """Test event_json blob is correctly stored and retrieved."""
+ session_id = _unique_session_id("event-json")
app_name = "test-app"
user_id = "user-123"
await oracle_async_store.create_session(session_id, app_name, user_id, {})
- test_actions = [{"function": "test_func", "args": {"param": "value"}, "result": 42}]
- actions_bytes = pickle.dumps(test_actions)
+ event_data = {"function": "test_func", "args": {"param": "value"}, "result": 42}
event_record: EventRecord = {
- "id": "event-actions",
"session_id": session_id,
- "app_name": app_name,
- "user_id": user_id,
+ "invocation_id": "",
"author": "user",
- "actions": actions_bytes,
- "content": None,
- "grounding_metadata": None,
- "custom_metadata": None,
"timestamp": datetime.now(timezone.utc),
- "partial": None,
- "turn_complete": None,
- "interrupted": None,
- "error_code": None,
- "error_message": None,
- "invocation_id": "",
- "branch": None,
- "long_running_tool_ids_json": None,
+ "event_json": event_data,
}
await oracle_async_store.append_event(event_record)
events = await oracle_async_store.get_events(session_id)
assert len(events) == 1
- assert events[0]["actions"] == actions_bytes
- unpickled = pickle.loads(events[0]["actions"])
- assert unpickled == test_actions
+ retrieved_data = (
+ json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"]
+ )
+ assert retrieved_data == event_data
-def test_state_lob_deserialization_sync(oracle_sync_store: "OracleSyncADKStore") -> None:
+async def test_state_lob_deserialization_sync(oracle_sync_store: "OracleSyncADKStore") -> None:
"""Test state CLOB/BLOB is correctly deserialized in sync mode."""
session_id = _unique_session_id("lob-session-sync")
app_name = "test-app"
user_id = "user-123"
state = {"large_field": "y" * 10000, "nested": {"data": [4, 5, 6]}}
- session = oracle_sync_store.create_session(session_id, app_name, user_id, state)
+ session = await oracle_sync_store.create_session(session_id, app_name, user_id, state)
assert session["state"] == state
- retrieved = oracle_sync_store.get_session(session_id)
+ retrieved = await oracle_sync_store.get_session(session_id)
assert retrieved is not None
assert retrieved["state"] == state
-async def test_boolean_fields_conversion(oracle_async_store: "OracleAsyncADKStore") -> None:
- """Test partial, turn_complete, interrupted converted to NUMBER(1)."""
- session_id = _unique_session_id("bool-session")
+async def test_event_record_5_column_contract(oracle_async_store: "OracleAsyncADKStore") -> None:
+ """Test the new 5-column EventRecord contract with append_event."""
+ session_id = _unique_session_id("5col-session")
app_name = "test-app"
user_id = "user-123"
await oracle_async_store.create_session(session_id, app_name, user_id, {})
event_record: EventRecord = {
- "id": "bool-event-1",
"session_id": session_id,
- "app_name": app_name,
- "user_id": user_id,
+ "invocation_id": "inv-001",
"author": "assistant",
- "actions": b"",
- "content": None,
- "grounding_metadata": None,
- "custom_metadata": None,
"timestamp": datetime.now(timezone.utc),
- "partial": True,
- "turn_complete": False,
- "interrupted": True,
- "error_code": None,
- "error_message": None,
- "invocation_id": "",
- "branch": None,
- "long_running_tool_ids_json": None,
+ "event_json": {"content": {"text": "Hello"}, "partial": True, "turn_complete": False, "interrupted": True},
}
await oracle_async_store.append_event(event_record)
events = await oracle_async_store.get_events(session_id)
assert len(events) == 1
- assert events[0]["partial"] is True
- assert events[0]["turn_complete"] is False
- assert events[0]["interrupted"] is True
+ assert events[0]["session_id"] == session_id
+ assert events[0]["invocation_id"] == "inv-001"
+ assert events[0]["author"] == "assistant"
+
+ retrieved_data = (
+ json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"]
+ )
+ assert retrieved_data["partial"] is True
+ assert retrieved_data["turn_complete"] is False
+ assert retrieved_data["interrupted"] is True
-async def test_boolean_fields_none_values(oracle_async_store: "OracleAsyncADKStore") -> None:
- """Test None values for boolean fields."""
- session_id = _unique_session_id("bool-none-session")
+async def test_event_with_none_values(oracle_async_store: "OracleAsyncADKStore") -> None:
+ """Test event with minimal event_json content."""
+ session_id = _unique_session_id("none-session")
app_name = "test-app"
user_id = "user-123"
await oracle_async_store.create_session(session_id, app_name, user_id, {})
event_record: EventRecord = {
- "id": "bool-event-none",
"session_id": session_id,
- "app_name": app_name,
- "user_id": user_id,
+ "invocation_id": "",
"author": "user",
- "actions": b"",
- "content": None,
- "grounding_metadata": None,
- "custom_metadata": None,
"timestamp": datetime.now(timezone.utc),
- "partial": None,
- "turn_complete": None,
- "interrupted": None,
- "error_code": None,
- "error_message": None,
- "invocation_id": "",
- "branch": None,
- "long_running_tool_ids_json": None,
+ "event_json": {"app_name": app_name},
}
await oracle_async_store.append_event(event_record)
events = await oracle_async_store.get_events(session_id)
assert len(events) == 1
- assert events[0]["partial"] is None
- assert events[0]["turn_complete"] is None
- assert events[0]["interrupted"] is None
async def test_create_session_with_owner_id(oracle_store_with_fk: "OracleAsyncADKStore") -> None:
@@ -453,7 +416,7 @@ async def test_json_storage_type_detection(oracle_async_store: "OracleAsyncADKSt
detector = cast("Any", oracle_async_store)
storage_type = await detector._detect_json_storage_type()
- assert storage_type in ["json", "blob_json", "clob_json", "blob_plain"]
+ assert storage_type in ["json", "blob_json", "blob_plain"]
async def test_json_fields_stored_and_retrieved(oracle_async_store: "OracleAsyncADKStore") -> None:
@@ -465,7 +428,7 @@ async def test_json_fields_stored_and_retrieved(oracle_async_store: "OracleAsync
"complex": {
"nested": {"deep": {"structure": "value"}},
"array": [1, 2, 3, {"key": "value"}],
- "unicode": "こんにちは世界",
+ "unicode": "\u65e5\u672c\u8a9e\u30c6\u30b9\u30c8",
"special_chars": "test@example.com | value > 100",
}
}
@@ -476,10 +439,10 @@ async def test_json_fields_stored_and_retrieved(oracle_async_store: "OracleAsync
retrieved = await oracle_async_store.get_session(session_id)
assert retrieved is not None
assert retrieved["state"] == state
- assert retrieved["state"]["complex"]["unicode"] == "こんにちは世界"
+ assert retrieved["state"]["complex"]["unicode"] == "\u65e5\u672c\u8a9e\u30c6\u30b9\u30c8"
-def test_create_session_with_owner_id_sync(oracle_store_sync_with_fk: "OracleSyncADKStore") -> None:
+async def test_create_session_with_owner_id_sync(oracle_store_sync_with_fk: "OracleSyncADKStore") -> None:
"""Test creating session with owner_id in sync mode."""
session_id = _unique_session_id("sync-fk")
app_name = "test-app"
@@ -487,10 +450,10 @@ def test_create_session_with_owner_id_sync(oracle_store_sync_with_fk: "OracleSyn
state = {"data": "sync test"}
owner_id = 100
- session = oracle_store_sync_with_fk.create_session(session_id, app_name, user_id, state, owner_id=owner_id)
+ session = await oracle_store_sync_with_fk.create_session(session_id, app_name, user_id, state, owner_id=owner_id)
assert session["id"] == session_id
assert session["state"] == state
- retrieved = oracle_store_sync_with_fk.get_session(session_id)
+ retrieved = await oracle_store_sync_with_fk.get_session(session_id)
assert retrieved is not None
assert retrieved["id"] == session_id
diff --git a/tests/integration/adapters/psycopg/extensions/adk/test_owner_id_column.py b/tests/integration/adapters/psycopg/extensions/adk/test_owner_id_column.py
index 789eff133..5e7fb0123 100644
--- a/tests/integration/adapters/psycopg/extensions/adk/test_owner_id_column.py
+++ b/tests/integration/adapters/psycopg/extensions/adk/test_owner_id_column.py
@@ -1,6 +1,6 @@
"""Integration tests for Psycopg ADK store owner_id_column feature."""
-from collections.abc import AsyncGenerator, Generator
+from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Any
import pytest
@@ -42,7 +42,7 @@ async def psycopg_async_store_with_fk(postgres_service: "PostgresService") -> "A
@pytest.fixture
-def psycopg_sync_store_with_fk(postgres_service: "PostgresService") -> "Generator[Any, None, None]":
+async def psycopg_sync_store_with_fk(postgres_service: "PostgresService") -> "AsyncGenerator[Any, None]":
"""Create Psycopg sync ADK store with owner_id_column configured."""
config = PsycopgSyncConfig(
connection_config={
@@ -57,7 +57,7 @@ def psycopg_sync_store_with_fk(postgres_service: "PostgresService") -> "Generato
},
)
store = PsycopgSyncADKStore(config)
- store.create_tables()
+ await store.create_tables()
yield store
with config.provide_connection() as conn, conn.cursor() as cur:
@@ -74,7 +74,7 @@ async def test_async_store_owner_id_column_initialization(psycopg_async_store_wi
assert psycopg_async_store_with_fk.owner_id_column_name == "tenant_id"
-def test_sync_store_owner_id_column_initialization(psycopg_sync_store_with_fk: PsycopgSyncADKStore) -> None:
+async def test_sync_store_owner_id_column_initialization(psycopg_sync_store_with_fk: PsycopgSyncADKStore) -> None:
"""Test that owner_id_column is properly initialized in sync store."""
assert psycopg_sync_store_with_fk.owner_id_column_ddl == "account_id VARCHAR(64) NOT NULL"
assert psycopg_sync_store_with_fk.owner_id_column_name == "account_id"
@@ -105,7 +105,7 @@ async def test_async_store_inherits_owner_id_column(postgres_service: "PostgresS
await config.close_pool()
-def test_sync_store_inherits_owner_id_column(postgres_service: "PostgresService") -> None:
+async def test_sync_store_inherits_owner_id_column(postgres_service: "PostgresService") -> None:
"""Test that sync store correctly inherits owner_id_column from base class."""
config = PsycopgSyncConfig(
connection_config={
@@ -147,7 +147,7 @@ async def test_async_store_without_owner_id_column(postgres_service: "PostgresSe
await config.close_pool()
-def test_sync_store_without_owner_id_column(postgres_service: "PostgresService") -> None:
+async def test_sync_store_without_owner_id_column(postgres_service: "PostgresService") -> None:
"""Test that sync store works without owner_id_column (default behavior)."""
config = PsycopgSyncConfig(
connection_config={
@@ -172,9 +172,9 @@ async def test_async_ddl_includes_owner_id_column(psycopg_async_store_with_fk: P
assert "test_sessions_fk" in ddl
-def test_sync_ddl_includes_owner_id_column(psycopg_sync_store_with_fk: PsycopgSyncADKStore) -> None:
+async def test_sync_ddl_includes_owner_id_column(psycopg_sync_store_with_fk: PsycopgSyncADKStore) -> None:
"""Test that the DDL generation includes the owner_id_column."""
- ddl = psycopg_sync_store_with_fk._get_create_sessions_table_sql() # pyright: ignore[reportPrivateUsage]
+ ddl = await psycopg_sync_store_with_fk._get_create_sessions_table_sql() # pyright: ignore[reportPrivateUsage]
assert "account_id VARCHAR(64) NOT NULL" in ddl
assert "test_sessions_sync_fk" in ddl
diff --git a/tests/integration/adapters/spanner/extensions/adk/conftest.py b/tests/integration/adapters/spanner/extensions/adk/conftest.py
index 4ae8782d2..57ad9bace 100644
--- a/tests/integration/adapters/spanner/extensions/adk/conftest.py
+++ b/tests/integration/adapters/spanner/extensions/adk/conftest.py
@@ -1,4 +1,4 @@
-from collections.abc import Generator
+from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING
import pytest
@@ -30,7 +30,7 @@ def spanner_adk_config(spanner_service: SpannerService, spanner_database: "Datab
@pytest.fixture
-def spanner_adk_store(spanner_adk_config: SpannerSyncConfig) -> Generator[SpannerSyncADKStore, None, None]:
+async def spanner_adk_store(spanner_adk_config: SpannerSyncConfig) -> AsyncGenerator[SpannerSyncADKStore, None]:
store = SpannerSyncADKStore(spanner_adk_config)
- store.create_tables()
+ await store.create_tables()
yield store
diff --git a/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py b/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py
index 98a313d15..b7cca39f2 100644
--- a/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py
+++ b/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py
@@ -1,75 +1,95 @@
-"""Integration tests for Spanner ADK store (sync)."""
+"""Integration tests for Spanner ADK store."""
+import json
+from datetime import datetime, timezone
from typing import Any
import pytest
+from sqlspec.extensions.adk import EventRecord
+
pytestmark = [pytest.mark.spanner, pytest.mark.integration]
-def test_create_and_get_session(spanner_adk_store: Any) -> None:
+async def test_create_and_get_session(spanner_adk_store: Any) -> None:
session_id = "session-create"
- spanner_adk_store.delete_session(session_id)
- created = spanner_adk_store.create_session(session_id, "app", "user", {"a": 1})
+ await spanner_adk_store.delete_session(session_id)
+ created = await spanner_adk_store.create_session(session_id, "app", "user", {"a": 1})
assert created["id"] == session_id
- fetched = spanner_adk_store.get_session(session_id)
+ fetched = await spanner_adk_store.get_session(session_id)
assert fetched is not None
assert fetched["state"] == {"a": 1}
-def test_update_session_state(spanner_adk_store: Any) -> None:
+async def test_update_session_state(spanner_adk_store: Any) -> None:
session_id = "session-update"
- spanner_adk_store.delete_session(session_id)
- spanner_adk_store.create_session(session_id, "app", "user", {"a": 1})
+ await spanner_adk_store.delete_session(session_id)
+ await spanner_adk_store.create_session(session_id, "app", "user", {"a": 1})
- spanner_adk_store.update_session_state(session_id, {"a": 2, "b": True})
+ await spanner_adk_store.update_session_state(session_id, {"a": 2, "b": True})
- fetched = spanner_adk_store.get_session(session_id)
+ fetched = await spanner_adk_store.get_session(session_id)
assert fetched is not None
assert fetched["state"] == {"a": 2, "b": True}
-def test_list_sessions(spanner_adk_store: Any) -> None:
- spanner_adk_store.delete_session("session-list-1")
- spanner_adk_store.delete_session("session-list-2")
- spanner_adk_store.delete_session("session-list-3")
- spanner_adk_store.create_session("session-list-1", "app-list", "user1", {"v": 1})
- spanner_adk_store.create_session("session-list-2", "app-list", "user1", {"v": 2})
- spanner_adk_store.create_session("session-list-3", "app-list", "user2", {"v": 3})
+async def test_list_sessions(spanner_adk_store: Any) -> None:
+ await spanner_adk_store.delete_session("session-list-1")
+ await spanner_adk_store.delete_session("session-list-2")
+ await spanner_adk_store.delete_session("session-list-3")
+ await spanner_adk_store.create_session("session-list-1", "app-list", "user1", {"v": 1})
+ await spanner_adk_store.create_session("session-list-2", "app-list", "user1", {"v": 2})
+ await spanner_adk_store.create_session("session-list-3", "app-list", "user2", {"v": 3})
- sessions = spanner_adk_store.list_sessions("app-list", "user1")
+ sessions = await spanner_adk_store.list_sessions("app-list", "user1")
session_ids = {s["id"] for s in sessions}
assert session_ids == {"session-list-1", "session-list-2"}
-def test_delete_session(spanner_adk_store: Any) -> None:
+async def test_delete_session(spanner_adk_store: Any) -> None:
session_id = "session-delete"
- spanner_adk_store.delete_session(session_id)
- spanner_adk_store.create_session(session_id, "app", "user", {"k": "v"})
- spanner_adk_store.delete_session(session_id)
+ await spanner_adk_store.delete_session(session_id)
+ await spanner_adk_store.create_session(session_id, "app", "user", {"k": "v"})
+ await spanner_adk_store.delete_session(session_id)
- assert spanner_adk_store.get_session(session_id) is None
+ assert await spanner_adk_store.get_session(session_id) is None
-def test_create_and_list_events(spanner_adk_store: Any) -> None:
+async def test_create_and_list_events(spanner_adk_store: Any) -> None:
session_id = "session-events"
- spanner_adk_store.delete_session(session_id)
- spanner_adk_store.create_session(session_id, "app", "user", {"x": 1})
-
- spanner_adk_store.create_event("event-1", session_id, "app", "user", author="user", content={"msg": "hi"})
- spanner_adk_store.create_event(
- "event-2",
- session_id,
- "app",
- "user",
- author="assistant",
- content={"msg": "ok"},
- partial=False,
- turn_complete=True,
+ await spanner_adk_store.delete_session(session_id)
+ await spanner_adk_store.create_session(session_id, "app", "user", {"x": 1})
+
+ event_one: EventRecord = {
+ "session_id": session_id,
+ "invocation_id": "event-1",
+ "author": "user",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {"id": "event-1", "content": {"msg": "hi"}, "app_name": "app", "user_id": "user"},
+ }
+ event_two: EventRecord = {
+ "session_id": session_id,
+ "invocation_id": "event-2",
+ "author": "assistant",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {"id": "event-2", "content": {"msg": "ok"}, "app_name": "app", "user_id": "user"},
+ }
+
+ await spanner_adk_store.append_event(event_one)
+ await spanner_adk_store.append_event(event_two)
+
+ events = await spanner_adk_store.get_events(session_id)
+ assert len(events) == 2
+ assert events[0]["author"] == "user"
+ assert events[1]["author"] == "assistant"
+
+ # Content is inside event_json in the new 5-column schema
+ event0_data = (
+ json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"]
)
-
- events = spanner_adk_store.list_events(session_id)
- ids = [e["id"] for e in events]
- assert ids == ["event-1", "event-2"]
- assert events[0]["content"] == {"msg": "hi"}
+ event1_data = (
+ json.loads(events[1]["event_json"]) if isinstance(events[1]["event_json"], str) else events[1]["event_json"]
+ )
+ assert event0_data["content"] == {"msg": "hi"}
+ assert event1_data["content"] == {"msg": "ok"}
diff --git a/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py b/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py
index e3397dd8e..71645b388 100644
--- a/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py
+++ b/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py
@@ -30,64 +30,64 @@ def _build_record(*, session_id: str, event_id: str, content_text: str, inserted
)
-def test_sqlite_memory_store_insert_search_dedup() -> None:
+async def test_sqlite_memory_store_insert_search_dedup() -> None:
"""Insert memory entries, search by text, and skip duplicates."""
with tempfile.NamedTemporaryFile(suffix=".db") as tmp:
config = SqliteConfig(connection_config={"database": tmp.name})
store = SqliteADKMemoryStore(config)
- store.create_tables()
+ await store.create_tables()
now = datetime.now(timezone.utc)
record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now)
record2 = _build_record(session_id="s1", event_id="evt-2", content_text="latte", inserted_at=now)
- inserted = store.insert_memory_entries([record1, record2])
+ inserted = await store.insert_memory_entries([record1, record2])
assert inserted == 2
- results = store.search_entries(query="espresso", app_name="app", user_id="user")
+ results = await store.search_entries(query="espresso", app_name="app", user_id="user")
assert len(results) == 1
assert results[0]["event_id"] == "evt-1"
- deduped = store.insert_memory_entries([record1])
+ deduped = await store.insert_memory_entries([record1])
assert deduped == 0
-def test_sqlite_memory_store_delete_by_session() -> None:
+async def test_sqlite_memory_store_delete_by_session() -> None:
"""Delete memory entries by session id."""
with tempfile.NamedTemporaryFile(suffix=".db") as tmp:
config = SqliteConfig(connection_config={"database": tmp.name})
store = SqliteADKMemoryStore(config)
- store.create_tables()
+ await store.create_tables()
now = datetime.now(timezone.utc)
record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now)
record2 = _build_record(session_id="s2", event_id="evt-2", content_text="latte", inserted_at=now)
- store.insert_memory_entries([record1, record2])
+ await store.insert_memory_entries([record1, record2])
- deleted = store.delete_entries_by_session("s1")
+ deleted = await store.delete_entries_by_session("s1")
assert deleted == 1
- remaining = store.search_entries(query="latte", app_name="app", user_id="user")
+ remaining = await store.search_entries(query="latte", app_name="app", user_id="user")
assert len(remaining) == 1
assert remaining[0]["session_id"] == "s2"
-def test_sqlite_memory_store_delete_older_than() -> None:
+async def test_sqlite_memory_store_delete_older_than() -> None:
"""Delete memory entries older than a cutoff."""
with tempfile.NamedTemporaryFile(suffix=".db") as tmp:
config = SqliteConfig(connection_config={"database": tmp.name})
store = SqliteADKMemoryStore(config)
- store.create_tables()
+ await store.create_tables()
now = datetime.now(timezone.utc)
old = now - timedelta(days=40)
record1 = _build_record(session_id="s1", event_id="evt-1", content_text="old", inserted_at=old)
record2 = _build_record(session_id="s1", event_id="evt-2", content_text="new", inserted_at=now)
- store.insert_memory_entries([record1, record2])
+ await store.insert_memory_entries([record1, record2])
- deleted = store.delete_entries_older_than(30)
+ deleted = await store.delete_entries_older_than(30)
assert deleted == 1
- remaining = store.search_entries(query="new", app_name="app", user_id="user")
+ remaining = await store.search_entries(query="new", app_name="app", user_id="user")
assert len(remaining) == 1
assert remaining[0]["event_id"] == "evt-2"
diff --git a/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py b/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py
index ec618e94e..b275006c5 100644
--- a/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py
+++ b/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py
@@ -2,7 +2,12 @@
from decimal import Decimal
-from sqlspec.adapters.oracledb.adk.store import OracleAsyncADKStore, OracleSyncADKStore
+from sqlspec.adapters.oracledb.adk.store import (
+ JSONStorageType,
+ OracleAsyncADKStore,
+ OracleSyncADKStore,
+ _event_json_column_ddl,
+)
async def test_oracle_async_adk_store_deserialize_dict_coerces_decimal() -> None:
@@ -43,3 +48,29 @@ def test_oracle_sync_adk_store_deserialize_state_dict_coerces_decimal() -> None:
result = store._deserialize_state(payload) # type: ignore[attr-defined]
assert result == {"state": 5.0}
+
+
+def test_oracle_event_json_column_ddl_prefers_blob_over_clob() -> None:
+ assert _event_json_column_ddl(JSONStorageType.JSON_NATIVE) == "event_json JSON NOT NULL"
+ assert _event_json_column_ddl(JSONStorageType.BLOB_JSON) == "event_json BLOB CHECK (event_json IS JSON) NOT NULL"
+ assert _event_json_column_ddl(JSONStorageType.BLOB_PLAIN) == "event_json BLOB NOT NULL"
+
+
+async def test_oracle_async_adk_store_serialize_event_json_uses_blob_for_non_native() -> None:
+ store = OracleAsyncADKStore.__new__(OracleAsyncADKStore) # type: ignore[call-arg]
+ store._json_storage_type = JSONStorageType.BLOB_JSON # type: ignore[attr-defined]
+
+ result = await store._serialize_event_json({"value": 1}) # type: ignore[attr-defined]
+
+ assert isinstance(result, bytes)
+ assert b'"value":1' in result
+
+
+def test_oracle_sync_adk_store_serialize_event_json_uses_blob_for_non_native() -> None:
+ store = OracleSyncADKStore.__new__(OracleSyncADKStore) # type: ignore[call-arg]
+ store._json_storage_type = JSONStorageType.BLOB_JSON # type: ignore[attr-defined]
+
+ result = store._serialize_event_json({"value": 1}) # type: ignore[attr-defined]
+
+ assert isinstance(result, bytes)
+ assert b'"value":1' in result
diff --git a/tests/unit/adapters/test_psycopg/test_adk_store.py b/tests/unit/adapters/test_psycopg/test_adk_store.py
new file mode 100644
index 000000000..a0d3eb6c1
--- /dev/null
+++ b/tests/unit/adapters/test_psycopg/test_adk_store.py
@@ -0,0 +1,109 @@
+"""Unit tests for psycopg ADK store sync wrappers."""
+
+from datetime import datetime, timezone
+from typing import Any
+
+from psycopg.types.json import Jsonb
+from typing_extensions import Self
+
+from sqlspec.adapters.psycopg.adk.store import PsycopgSyncADKStore
+
+
+class _DummyCursor:
+ def __init__(self, rows: "list[dict[str, Any]] | None" = None) -> None:
+ self.execute_calls: list[tuple[Any, Any]] = []
+ self._rows = rows or []
+
+ def __enter__(self) -> Self:
+ return self
+
+ def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
+ return None
+
+ def execute(self, query: Any, params: Any) -> None:
+ self.execute_calls.append((query, params))
+
+ def fetchall(self) -> "list[dict[str, Any]]":
+ return self._rows
+
+
+class _DummyConnection:
+ def __init__(self, cursor: _DummyCursor) -> None:
+ self._cursor = cursor
+ self.commit_called = False
+
+ def __enter__(self) -> Self:
+ return self
+
+ def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
+ return None
+
+ def cursor(self) -> _DummyCursor:
+ return self._cursor
+
+ def commit(self) -> None:
+ self.commit_called = True
+
+
+class _DummyConfig:
+ def __init__(self, connection: _DummyConnection) -> None:
+ self._connection = connection
+
+ def provide_connection(self) -> _DummyConnection:
+ return self._connection
+
+
+def _build_store(
+ rows: "list[dict[str, Any]] | None" = None,
+) -> "tuple[PsycopgSyncADKStore, _DummyCursor, _DummyConnection]":
+ cursor = _DummyCursor(rows)
+ connection = _DummyConnection(cursor)
+ store = PsycopgSyncADKStore.__new__(PsycopgSyncADKStore) # type: ignore[call-arg]
+ store._config = _DummyConfig(connection) # type: ignore[attr-defined]
+ store._events_table = "test_events" # type: ignore[attr-defined]
+ store._session_table = "test_sessions" # type: ignore[attr-defined]
+ store._owner_id_column_ddl = None # type: ignore[attr-defined]
+ store._owner_id_column_name = None # type: ignore[attr-defined]
+ return store, cursor, connection
+
+
+def test_sync_append_event_inserts_without_session_update() -> None:
+ """append_event must insert a single event without writing session state."""
+ store, cursor, connection = _build_store()
+ event_record = {
+ "session_id": "session-1",
+ "invocation_id": "",
+ "author": "assistant",
+ "timestamp": datetime.now(timezone.utc),
+ "event_json": {"id": "event-1"},
+ }
+
+ store._append_event(event_record) # type: ignore[arg-type]
+
+ assert len(cursor.execute_calls) == 1
+ _, params = cursor.execute_calls[0]
+ assert params[0] == "session-1"
+ assert isinstance(params[4], Jsonb)
+ assert connection.commit_called
+
+
+def test_sync_get_events_passes_after_timestamp_and_limit() -> None:
+ """get_events must forward after_timestamp and limit to the sync query."""
+ base_time = datetime(2026, 1, 1, tzinfo=timezone.utc)
+ rows = [
+ {
+ "session_id": "session-1",
+ "invocation_id": "",
+ "author": "assistant",
+ "timestamp": base_time,
+ "event_json": {"id": "event-2"},
+ }
+ ]
+ store, cursor, _ = _build_store(rows)
+
+ result = store._get_events("session-1", after_timestamp=base_time, limit=1)
+
+ assert len(cursor.execute_calls) == 1
+ _, params = cursor.execute_calls[0]
+ assert params == ("session-1", base_time, 1)
+ assert result[0]["event_json"]["id"] == "event-2"
diff --git a/tests/unit/extensions/test_adk/test_converters.py b/tests/unit/extensions/test_adk/test_converters.py
new file mode 100644
index 000000000..903358ae9
--- /dev/null
+++ b/tests/unit/extensions/test_adk/test_converters.py
@@ -0,0 +1,478 @@
+"""Unit tests for ADK session/event converters and scoped state helpers.
+
+Tests the NEW contract specified in Chapter 1 of the ADK Clean-Break Overhaul:
+- EventRecord has exactly 5 keys (session_id, invocation_id, author, timestamp, event_json)
+- event_to_record takes only (event, session_id), not (event, session_id, app_name, user_id)
+- record_to_event uses Event.model_validate for full round-trip fidelity
+- filter_temp_state, split_scoped_state, merge_scoped_state for scoped state handling
+- session_to_record strips temp: keys from state
+"""
+
+import importlib.util
+from datetime import datetime, timezone
+
+import pytest
+
+if importlib.util.find_spec("google.genai") is None or importlib.util.find_spec("google.adk") is None:
+ pytest.skip("google-adk not installed", allow_module_level=True)
+
+from google.adk.events.event import Event
+from google.adk.events.event_actions import EventActions
+from google.adk.sessions.session import Session
+from google.genai import types
+
+from sqlspec.extensions.adk.converters import (
+ event_to_record,
+ filter_temp_state,
+ merge_scoped_state,
+ record_to_event,
+ record_to_session,
+ session_to_record,
+ split_scoped_state,
+)
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+
+def _make_event(
+ *,
+ event_id: str = "evt-1",
+ invocation_id: str = "inv-1",
+ author: str = "user",
+ text: "str | None" = None,
+ state_delta: "dict | None" = None,
+ branch: "str | None" = None,
+ partial: "bool | None" = None,
+ turn_complete: "bool | None" = None,
+ custom_metadata: "dict | None" = None,
+) -> Event:
+ content = types.Content(parts=[types.Part(text=text)]) if text is not None else None
+ actions = EventActions(state_delta=state_delta or {})
+ return Event(
+ id=event_id,
+ invocation_id=invocation_id,
+ author=author,
+ content=content,
+ actions=actions,
+ timestamp=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc).timestamp(),
+ branch=branch,
+ partial=partial,
+ turn_complete=turn_complete,
+ custom_metadata=custom_metadata,
+ )
+
+
+def _make_session(
+ *, session_id: str = "session-1", app_name: str = "test-app", user_id: str = "user-1", state: "dict | None" = None
+) -> Session:
+ return Session(
+ id=session_id,
+ app_name=app_name,
+ user_id=user_id,
+ state=state or {},
+ last_update_time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc).timestamp(),
+ )
+
+
+# ---------------------------------------------------------------------------
+# filter_temp_state
+# ---------------------------------------------------------------------------
+
+
+def test_filter_temp_state_removes_temp_keys() -> None:
+ """temp:-prefixed keys are removed; all other keys are kept."""
+ state = {"x": 1, "temp:y": 2, "app:z": 3, "user:w": 4}
+ result = filter_temp_state(state)
+ assert result == {"x": 1, "app:z": 3, "user:w": 4}
+
+
+def test_filter_temp_state_empty_dict() -> None:
+ """Empty dict returns empty dict."""
+ assert filter_temp_state({}) == {}
+
+
+def test_filter_temp_state_all_temp_keys() -> None:
+ """Dict with only temp: keys returns empty dict."""
+ state = {"temp:a": 1, "temp:b": 2, "temp:": 3}
+ assert filter_temp_state(state) == {}
+
+
+def test_filter_temp_state_no_temp_keys() -> None:
+ """Dict with no temp: keys is returned unchanged."""
+ state = {"x": 1, "app:y": 2, "user:z": 3}
+ result = filter_temp_state(state)
+ assert result == state
+
+
+def test_filter_temp_state_does_not_mutate_input() -> None:
+ """Input dict is not mutated."""
+ state = {"key": "v", "temp:remove": "gone"}
+ original = dict(state)
+ filter_temp_state(state)
+ assert state == original
+
+
+# ---------------------------------------------------------------------------
+# split_scoped_state
+# ---------------------------------------------------------------------------
+
+
+def test_split_scoped_state_separates_buckets() -> None:
+ """app:, user:, and plain keys go into the correct buckets."""
+ state = {"app:shared": "a", "user:profile": "u", "session_key": "s", "another": "v"}
+ app, user, session = split_scoped_state(state)
+ assert app == {"app:shared": "a"}
+ assert user == {"user:profile": "u"}
+ assert session == {"session_key": "s", "another": "v"}
+
+
+def test_split_scoped_state_empty() -> None:
+ """Empty state produces three empty dicts."""
+ app, user, session = split_scoped_state({})
+ assert app == {}
+ assert user == {}
+ assert session == {}
+
+
+def test_split_scoped_state_only_app_keys() -> None:
+ """State with only app: keys puts everything in app bucket."""
+ state = {"app:x": 1, "app:y": 2}
+ app, user, session = split_scoped_state(state)
+ assert app == {"app:x": 1, "app:y": 2}
+ assert user == {}
+ assert session == {}
+
+
+def test_split_scoped_state_only_user_keys() -> None:
+ """State with only user: keys puts everything in user bucket."""
+ state = {"user:a": "one", "user:b": "two"}
+ app, user, session = split_scoped_state(state)
+ assert app == {}
+ assert user == {"user:a": "one", "user:b": "two"}
+ assert session == {}
+
+
+def test_split_scoped_state_only_session_keys() -> None:
+ """State with no prefix puts everything in session bucket."""
+ state = {"key1": 1, "key2": 2}
+ app, user, session = split_scoped_state(state)
+ assert app == {}
+ assert user == {}
+ assert session == {"key1": 1, "key2": 2}
+
+
+def test_split_scoped_state_preserves_full_key_names() -> None:
+ """Keys are not stripped of their prefix in the returned buckets."""
+ state = {"app:my_key": "val", "user:my_key": "val2"}
+ app, user, _ = split_scoped_state(state)
+ assert "app:my_key" in app
+ assert "user:my_key" in user
+
+
+# ---------------------------------------------------------------------------
+# merge_scoped_state
+# ---------------------------------------------------------------------------
+
+
+def test_merge_scoped_state_combines_all_buckets() -> None:
+ """All three buckets appear in the merged result."""
+ merged = merge_scoped_state(session_state={"key": "s"}, app_state={"app:x": "a"}, user_state={"user:y": "u"})
+ assert merged == {"key": "s", "app:x": "a", "user:y": "u"}
+
+
+def test_merge_scoped_state_overlay_priority_app_over_session() -> None:
+ """app_state overlays session_state for the same key."""
+ merged = merge_scoped_state(session_state={"app:x": "old"}, app_state={"app:x": "new"})
+ assert merged["app:x"] == "new"
+
+
+def test_merge_scoped_state_overlay_priority_user_over_session() -> None:
+ """user_state overlays session_state for the same key."""
+ merged = merge_scoped_state(session_state={"user:y": "session_val"}, user_state={"user:y": "user_val"})
+ assert merged["user:y"] == "user_val"
+
+
+def test_merge_scoped_state_no_app_no_user() -> None:
+ """Merging without app_state or user_state returns session_state copy."""
+ session = {"key": "v", "other": 42}
+ merged = merge_scoped_state(session_state=session)
+ assert merged == session
+
+
+def test_merge_scoped_state_empty_session_state() -> None:
+ """Empty session_state with app/user state returns combined app+user keys."""
+ merged = merge_scoped_state(session_state={}, app_state={"app:a": 1}, user_state={"user:b": 2})
+ assert merged == {"app:a": 1, "user:b": 2}
+
+
+def test_merge_scoped_state_does_not_mutate_session_state() -> None:
+ """Input session_state dict is not mutated."""
+ session = {"key": "v"}
+ original = dict(session)
+ merge_scoped_state(session_state=session, app_state={"app:x": 1})
+ assert session == original
+
+
+# ---------------------------------------------------------------------------
+# event_to_record — signature and structure
+# ---------------------------------------------------------------------------
+
+
+def test_event_to_record_only_5_keys() -> None:
+ """EventRecord has exactly session_id, invocation_id, author, timestamp, event_json."""
+ event = _make_event()
+ record = event_to_record(event, "session-1")
+ assert set(record.keys()) == {"session_id", "invocation_id", "author", "timestamp", "event_json"}
+
+
+def test_event_to_record_signature_two_args_only() -> None:
+ """event_to_record raises TypeError if called with extra positional args (old 4-arg signature)."""
+ event = _make_event()
+ with pytest.raises(TypeError):
+ event_to_record(event, "session-1", "app-name", "user-id") # type: ignore[call-arg]
+
+
+def test_event_to_record_session_id_stored_correctly() -> None:
+ """session_id in the record matches the argument passed."""
+ event = _make_event(invocation_id="inv-abc", author="model")
+ record = event_to_record(event, "my-session-id")
+ assert record["session_id"] == "my-session-id"
+
+
+def test_event_to_record_indexed_fields_match_event() -> None:
+ """Indexed scalar columns (invocation_id, author, timestamp) match the source event."""
+ event = _make_event(invocation_id="inv-xyz", author="tool")
+ record = event_to_record(event, "s1")
+ assert record["invocation_id"] == "inv-xyz"
+ assert record["author"] == "tool"
+ assert isinstance(record["timestamp"], datetime)
+
+
+def test_event_to_record_event_json_matches_model_dump() -> None:
+ """event_json in the record equals event.model_dump(exclude_none=True, mode='json')."""
+ event = _make_event(text="hello", state_delta={"key": "val"}, custom_metadata={"foo": "bar"})
+ record = event_to_record(event, "s1")
+ expected_json = event.model_dump(exclude_none=True, mode="json")
+ assert record["event_json"] == expected_json
+
+
+def test_event_to_record_event_json_is_dict() -> None:
+ """event_json field is a plain dict (not bytes, not string)."""
+ event = _make_event()
+ record = event_to_record(event, "s1")
+ assert isinstance(record["event_json"], dict)
+
+
+def test_event_to_record_actions_in_event_json_is_structured() -> None:
+ """Actions are stored as structured JSON dict in event_json, not as raw bytes."""
+ event = _make_event(state_delta={"x": "y"})
+ record = event_to_record(event, "s1")
+ event_json = record["event_json"]
+ # actions should be a dict in the JSON blob
+ if "actions" in event_json:
+ assert isinstance(event_json["actions"], dict)
+
+
+def test_event_to_record_timestamp_is_datetime() -> None:
+ """timestamp column is a datetime object with timezone."""
+ event = _make_event()
+ record = event_to_record(event, "s1")
+ assert isinstance(record["timestamp"], datetime)
+ assert record["timestamp"].tzinfo is not None
+
+
+# ---------------------------------------------------------------------------
+# record_to_event — full round-trip fidelity
+# ---------------------------------------------------------------------------
+
+
+def test_record_to_event_full_roundtrip_basic() -> None:
+ """Event -> record -> Event produces an identical object for basic fields."""
+ original = _make_event(event_id="evt-rt", invocation_id="inv-rt", author="model")
+ record = event_to_record(original, "s1")
+ restored = record_to_event(record)
+
+ assert restored.id == original.id
+ assert restored.invocation_id == original.invocation_id
+ assert restored.author == original.author
+
+
+def test_record_to_event_roundtrip_preserves_content() -> None:
+ """Content (parts) survives the round-trip."""
+ original = _make_event(text="hello world", author="model")
+ record = event_to_record(original, "s1")
+ restored = record_to_event(record)
+
+ assert restored.content is not None
+ assert restored.content.parts is not None
+ assert restored.content.parts[0].text == "hello world"
+
+
+def test_record_to_event_roundtrip_preserves_actions() -> None:
+ """EventActions (state_delta) survives the round-trip."""
+ original = _make_event(state_delta={"key": "v1", "other": 42})
+ record = event_to_record(original, "s1")
+ restored = record_to_event(record)
+
+ assert restored.actions is not None
+ assert restored.actions.state_delta == {"key": "v1", "other": 42}
+
+
+def test_record_to_event_roundtrip_preserves_custom_metadata() -> None:
+ """custom_metadata survives the round-trip."""
+ original = _make_event(custom_metadata={"tag": "v2", "score": 0.9})
+ record = event_to_record(original, "s1")
+ restored = record_to_event(record)
+
+ assert restored.custom_metadata == {"tag": "v2", "score": 0.9}
+
+
+def test_record_to_event_roundtrip_preserves_branch() -> None:
+ """branch field survives the round-trip."""
+ original = _make_event(branch="feature-branch")
+ record = event_to_record(original, "s1")
+ restored = record_to_event(record)
+
+ assert restored.branch == "feature-branch"
+
+
+def test_record_to_event_roundtrip_preserves_partial_flag() -> None:
+ """partial flag survives the round-trip."""
+ original = _make_event(partial=True)
+ record = event_to_record(original, "s1")
+ restored = record_to_event(record)
+
+ assert restored.partial is True
+
+
+def test_record_to_event_roundtrip_preserves_turn_complete() -> None:
+ """turn_complete flag survives the round-trip."""
+ original = _make_event(turn_complete=True)
+ record = event_to_record(original, "s1")
+ restored = record_to_event(record)
+
+ assert restored.turn_complete is True
+
+
+def test_record_to_event_roundtrip_preserves_timestamp() -> None:
+ """timestamp survives the round-trip within float precision."""
+ fixed_ts = datetime(2024, 6, 1, 10, 30, 0, tzinfo=timezone.utc).timestamp()
+ event = Event(id="ts-evt", invocation_id="inv-1", author="user", actions=EventActions(), timestamp=fixed_ts)
+ record = event_to_record(event, "s1")
+ restored = record_to_event(record)
+
+ assert abs(restored.timestamp - fixed_ts) < 1.0 # within 1 second
+
+
+@pytest.mark.xfail(
+ reason="ADK Event model uses extra='forbid' — unknown fields raise ValidationError. "
+ "Future ADK versions that add fields will also update the model, so this is safe.",
+ strict=True,
+)
+def test_record_to_event_with_extra_fields_in_event_json() -> None:
+ """Events with extra/unknown fields in event_json are rejected by Event model."""
+ event = _make_event(event_id="extra-fields-evt", author="tool")
+ record = event_to_record(event, "s1")
+
+ # Inject hypothetical future ADK field into event_json
+ record["event_json"]["hypothetical_v3_field"] = "some_value" # type: ignore[index]
+
+ # This WILL raise because Event has extra='forbid'
+ restored = record_to_event(record)
+ assert restored.id == "extra-fields-evt"
+
+
+# ---------------------------------------------------------------------------
+# session_to_record — strips temp: keys
+# ---------------------------------------------------------------------------
+
+
+def test_session_to_record_strips_temp_keys_from_state() -> None:
+ """session_to_record removes temp:-prefixed keys before persisting."""
+ session = _make_session(state={"key": "v", "temp:x": "t", "app:y": "a"})
+ record = session_to_record(session)
+ assert "temp:x" not in record["state"]
+ assert record["state"]["key"] == "v"
+ assert record["state"]["app:y"] == "a"
+
+
+def test_session_to_record_empty_state_stays_empty() -> None:
+ """Empty state produces empty state in record."""
+ session = _make_session(state={})
+ record = session_to_record(session)
+ assert record["state"] == {}
+
+
+def test_session_to_record_all_temp_state_produces_empty() -> None:
+ """Session state with only temp: keys produces empty state in record."""
+ session = _make_session(state={"temp:a": 1, "temp:b": 2})
+ record = session_to_record(session)
+ assert record["state"] == {}
+
+
+def test_session_to_record_no_temp_state_unchanged() -> None:
+ """Session state with no temp: keys is stored without modification."""
+ state = {"x": 1, "app:y": 2, "user:z": 3}
+ session = _make_session(state=state)
+ record = session_to_record(session)
+ assert record["state"] == state
+
+
+def test_session_to_record_includes_required_fields() -> None:
+ """Session record includes id, app_name, user_id, state, create_time, update_time."""
+ session = _make_session()
+ record = session_to_record(session)
+ assert "id" in record
+ assert "app_name" in record
+ assert "user_id" in record
+ assert "state" in record
+ assert "create_time" in record
+ assert "update_time" in record
+
+
+# ---------------------------------------------------------------------------
+# record_to_session — integrates with record_to_event
+# ---------------------------------------------------------------------------
+
+
+def test_record_to_session_with_events_round_trip() -> None:
+ """Sessions with events reconstruct correctly using record_to_session."""
+ from sqlspec.extensions.adk._types import SessionRecord
+
+ session_record = SessionRecord(
+ id="s1",
+ app_name="app",
+ user_id="u1",
+ state={"key": "val"},
+ create_time=datetime.now(timezone.utc),
+ update_time=datetime.now(timezone.utc),
+ )
+ event = _make_event(text="hello", author="user")
+ event_record = event_to_record(event, "s1")
+
+ session = record_to_session(session_record, [event_record])
+
+ assert session.id == "s1"
+ assert session.app_name == "app"
+ assert session.user_id == "u1"
+ assert session.state == {"key": "val"}
+ assert len(session.events) == 1
+ assert session.events[0].id == event.id
+
+
+def test_record_to_session_empty_events() -> None:
+ """Sessions without events reconstruct with empty events list."""
+ from sqlspec.extensions.adk._types import SessionRecord
+
+ session_record = SessionRecord(
+ id="s2",
+ app_name="app",
+ user_id="u2",
+ state={},
+ create_time=datetime.now(timezone.utc),
+ update_time=datetime.now(timezone.utc),
+ )
+ session = record_to_session(session_record, [])
+ assert session.events == []
diff --git a/tests/unit/extensions/test_adk/test_service.py b/tests/unit/extensions/test_adk/test_service.py
new file mode 100644
index 000000000..a34d78862
--- /dev/null
+++ b/tests/unit/extensions/test_adk/test_service.py
@@ -0,0 +1,321 @@
+"""Unit tests for SQLSpecSessionService — state persistence fix.
+
+Tests the NEW contract specified in Chapter 1 of the ADK Clean-Break Overhaul:
+- append_event calls append_event_and_update_state (not the old append_event)
+- temp: keys are stripped before persisting session state
+- partial events are not persisted
+- create_session strips temp: keys from initial state
+
+The store is mocked — no database required.
+"""
+
+import importlib.util
+from datetime import datetime, timezone
+from typing import Any
+
+import pytest
+
+if importlib.util.find_spec("google.genai") is None or importlib.util.find_spec("google.adk") is None:
+ pytest.skip("google-adk not installed", allow_module_level=True)
+
+from google.adk.events.event import Event
+from google.adk.events.event_actions import EventActions
+from google.adk.sessions.session import Session
+
+from sqlspec.extensions.adk.service import SQLSpecSessionService
+
+# ---------------------------------------------------------------------------
+# Mock store
+# ---------------------------------------------------------------------------
+
+
+class MockStore:
+ """Simple mock that records calls to store methods.
+
+ Attributes are set to AsyncMock so that await works out of the box,
+ and call arguments are captured for assertion.
+ """
+
+ def __init__(self) -> None:
+ # Track calls to the new combined method
+ self.append_event_and_update_state_calls: list[dict[str, Any]] = []
+ self.append_event_and_update_state_called = False
+
+ # Track calls to create_session
+ self.create_session_calls: list[dict[str, Any]] = []
+
+ # Provide a get_session that returns a minimal session record
+ self._session_record = {
+ "id": "s1",
+ "app_name": "app",
+ "user_id": "u1",
+ "state": {},
+ "create_time": datetime.now(timezone.utc),
+ "update_time": datetime.now(timezone.utc),
+ }
+
+ async def append_event_and_update_state(self, event_record: Any, session_id: str, state: "dict[str, Any]") -> None:
+ self.append_event_and_update_state_called = True
+ self.append_event_and_update_state_calls.append({
+ "event_record": event_record,
+ "session_id": session_id,
+ "state": state,
+ })
+
+ async def get_session(self, session_id: str) -> "dict[str, Any] | None":
+ return self._session_record
+
+ async def create_session(
+ self, *, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]"
+ ) -> "dict[str, Any]":
+ self.create_session_calls.append({
+ "session_id": session_id,
+ "app_name": app_name,
+ "user_id": user_id,
+ "state": state,
+ })
+ return {
+ "id": session_id,
+ "app_name": app_name,
+ "user_id": user_id,
+ "state": state,
+ "create_time": datetime.now(timezone.utc),
+ "update_time": datetime.now(timezone.utc),
+ }
+
+ # Old method — should NOT be called by the new service
+ async def append_event(self, event_record: Any) -> None:
+ raise AssertionError("append_event (old method) must not be called — use append_event_and_update_state")
+
+ async def get_events(self, *, session_id: str, after_timestamp: Any = None, limit: Any = None) -> list:
+ return []
+
+ async def list_sessions(self, *, app_name: str, user_id: "str | None" = None) -> list:
+ return []
+
+ async def delete_session(self, session_id: str) -> None:
+ pass
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+
+def _make_session(
+ *, session_id: str = "s1", app_name: str = "app", user_id: str = "u1", state: "dict | None" = None
+) -> Session:
+ return Session(
+ id=session_id,
+ app_name=app_name,
+ user_id=user_id,
+ state=state or {},
+ last_update_time=datetime.now(timezone.utc).timestamp(),
+ )
+
+
+def _make_event(
+ *, invocation_id: str = "inv-1", author: str = "model", state_delta: "dict | None" = None, partial: bool = False
+) -> Event:
+ actions = EventActions(state_delta=state_delta or {})
+ return Event(
+ invocation_id=invocation_id,
+ author=author,
+ actions=actions,
+ timestamp=datetime.now(timezone.utc).timestamp(),
+ partial=partial,
+ )
+
+
+# ---------------------------------------------------------------------------
+# append_event — calls append_event_and_update_state
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.anyio
+async def test_append_event_calls_append_event_and_update_state() -> None:
+ """append_event must call append_event_and_update_state, not the old append_event."""
+ store = MockStore()
+ service = SQLSpecSessionService(store) # type: ignore[arg-type]
+ session = _make_session(state={"key": "v0"})
+ event = _make_event(state_delta={"key": "v1"})
+
+ await service.append_event(session, event)
+
+ assert store.append_event_and_update_state_called, (
+ "append_event_and_update_state was never called — state will not be persisted"
+ )
+
+
+@pytest.mark.anyio
+async def test_append_event_persists_updated_state() -> None:
+ """append_event persists the state AFTER applying event.actions.state_delta."""
+ store = MockStore()
+ service = SQLSpecSessionService(store) # type: ignore[arg-type]
+ session = _make_session(state={"key": "v0"})
+ event = _make_event(state_delta={"key": "v1"})
+
+ await service.append_event(session, event)
+
+ assert store.append_event_and_update_state_called
+ last_call = store.append_event_and_update_state_calls[-1]
+ # The persisted state must reflect the mutation from state_delta
+ assert last_call["state"]["key"] == "v1"
+
+
+@pytest.mark.anyio
+async def test_append_event_strips_temp_from_persisted_state() -> None:
+ """temp: keys are removed before state persistence."""
+ store = MockStore()
+ service = SQLSpecSessionService(store) # type: ignore[arg-type]
+ session = _make_session(state={"key": "v", "temp:transient": "should_not_persist"})
+ event = _make_event()
+
+ await service.append_event(session, event)
+
+ assert store.append_event_and_update_state_called
+ last_call = store.append_event_and_update_state_calls[-1]
+ persisted_state = last_call["state"]
+ assert "temp:transient" not in persisted_state
+ assert persisted_state["key"] == "v"
+
+
+@pytest.mark.anyio
+async def test_append_event_strips_temp_state_delta_from_persisted_state() -> None:
+ """temp: keys added via state_delta are also stripped before persisting."""
+ store = MockStore()
+ service = SQLSpecSessionService(store) # type: ignore[arg-type]
+ # Session state has temp: key added by an agent via state_delta
+ session = _make_session(state={"regular": "v"})
+ event = _make_event(state_delta={"temp:output": "transient", "regular": "updated"})
+
+ await service.append_event(session, event)
+
+ last_call = store.append_event_and_update_state_calls[-1]
+ persisted_state = last_call["state"]
+ assert "temp:output" not in persisted_state
+ assert persisted_state["regular"] == "updated"
+
+
+@pytest.mark.anyio
+async def test_append_event_skips_partial_events() -> None:
+ """Partial events are not persisted to the store."""
+ store = MockStore()
+ service = SQLSpecSessionService(store) # type: ignore[arg-type]
+ session = _make_session()
+ partial_event = _make_event(partial=True)
+
+ result = await service.append_event(session, partial_event)
+
+ assert not store.append_event_and_update_state_called, (
+ "append_event_and_update_state must NOT be called for partial events"
+ )
+ assert result.partial is True
+
+
+@pytest.mark.anyio
+async def test_append_event_passes_correct_session_id_to_store() -> None:
+ """append_event_and_update_state receives the correct session_id."""
+ store = MockStore()
+ service = SQLSpecSessionService(store) # type: ignore[arg-type]
+ session = _make_session(session_id="my-unique-session-id")
+ event = _make_event()
+
+ await service.append_event(session, event)
+
+ last_call = store.append_event_and_update_state_calls[-1]
+ assert last_call["session_id"] == "my-unique-session-id"
+
+
+@pytest.mark.anyio
+async def test_append_event_event_record_has_5_keys() -> None:
+ """The event_record passed to the store has exactly 5 keys (new schema)."""
+ store = MockStore()
+ service = SQLSpecSessionService(store) # type: ignore[arg-type]
+ session = _make_session()
+ event = _make_event()
+
+ await service.append_event(session, event)
+
+ last_call = store.append_event_and_update_state_calls[-1]
+ event_record = last_call["event_record"]
+ assert set(event_record.keys()) == {"session_id", "invocation_id", "author", "timestamp", "event_json"}
+
+
+@pytest.mark.anyio
+async def test_append_event_returns_the_event() -> None:
+ """append_event returns the event after persisting."""
+ store = MockStore()
+ service = SQLSpecSessionService(store) # type: ignore[arg-type]
+ session = _make_session()
+ event = _make_event(author="model")
+
+ result = await service.append_event(session, event)
+
+ assert result is not None
+ assert result.author == "model"
+
+
+# ---------------------------------------------------------------------------
+# create_session — strips temp: keys from initial state
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.anyio
+async def test_create_session_strips_temp_keys_from_initial_state() -> None:
+ """create_session filters temp: keys before passing state to the store."""
+ store = MockStore()
+ service = SQLSpecSessionService(store) # type: ignore[arg-type]
+
+ await service.create_session(app_name="app", user_id="u1", state={"x": 1, "temp:y": 2, "app:z": 3})
+
+ assert len(store.create_session_calls) == 1
+ persisted_state = store.create_session_calls[0]["state"]
+ assert "temp:y" not in persisted_state
+ assert persisted_state["x"] == 1
+ assert persisted_state["app:z"] == 3
+
+
+@pytest.mark.anyio
+async def test_create_session_with_only_temp_state_persists_empty() -> None:
+ """create_session with only temp: state persists empty state dict."""
+ store = MockStore()
+ service = SQLSpecSessionService(store) # type: ignore[arg-type]
+
+ await service.create_session(app_name="app", user_id="u1", state={"temp:only": "gone"})
+
+ assert store.create_session_calls[0]["state"] == {}
+
+
+@pytest.mark.anyio
+async def test_create_session_none_state_persists_empty() -> None:
+ """create_session with state=None persists empty state dict."""
+ store = MockStore()
+ service = SQLSpecSessionService(store) # type: ignore[arg-type]
+
+ await service.create_session(app_name="app", user_id="u1")
+
+ assert store.create_session_calls[0]["state"] == {}
+
+
+@pytest.mark.anyio
+async def test_create_session_generates_uuid_if_no_session_id() -> None:
+ """create_session generates a UUID if no session_id is provided."""
+ store = MockStore()
+ service = SQLSpecSessionService(store) # type: ignore[arg-type]
+
+ session = await service.create_session(app_name="app", user_id="u1")
+
+ assert session.id is not None
+ assert len(session.id) > 0
+
+
+@pytest.mark.anyio
+async def test_create_session_uses_provided_session_id() -> None:
+ """create_session uses the caller-provided session_id."""
+ store = MockStore()
+ service = SQLSpecSessionService(store) # type: ignore[arg-type]
+
+ session = await service.create_session(app_name="app", user_id="u1", session_id="my-id")
+
+ assert session.id == "my-id"
diff --git a/tests/unit/extensions/test_adk/test_store_instantiation.py b/tests/unit/extensions/test_adk/test_store_instantiation.py
new file mode 100644
index 000000000..b98662fee
--- /dev/null
+++ b/tests/unit/extensions/test_adk/test_store_instantiation.py
@@ -0,0 +1,87 @@
+"""Smoke tests verifying all shipped ADK store classes are instantiable (not abstract).
+
+Every shipped store class must be concrete — no unsatisfied abstract methods.
+This catches bugs where stores have method signature mismatches with the base
+class, such as cockroach, mysqlconnector sync, pymysql, and spanner stores
+that are missing abstract method implementations added to the base contract.
+
+NOTE: Some stores WILL fail this test currently — that is expected and
+documents one of the bugs the ADK Clean-Break Overhaul (Ch1) is fixing.
+"""
+
+import importlib
+
+import pytest
+
+# Session stores (async)
+ASYNC_SESSION_STORES = [
+ "sqlspec.adapters.asyncpg.adk.store.AsyncpgADKStore",
+ "sqlspec.adapters.aiosqlite.adk.store.AiosqliteADKStore",
+ "sqlspec.adapters.asyncmy.adk.store.AsyncmyADKStore",
+ "sqlspec.adapters.cockroach_asyncpg.adk.store.CockroachAsyncpgADKStore",
+ "sqlspec.adapters.cockroach_psycopg.adk.store.CockroachPsycopgAsyncADKStore",
+ "sqlspec.adapters.mysqlconnector.adk.store.MysqlConnectorAsyncADKStore",
+ "sqlspec.adapters.oracledb.adk.store.OracleAsyncADKStore",
+ "sqlspec.adapters.psqlpy.adk.store.PsqlpyADKStore",
+ "sqlspec.adapters.psycopg.adk.store.PsycopgAsyncADKStore",
+ # sqlite uses BaseAsyncADKStore despite being backed by a sync driver
+ "sqlspec.adapters.sqlite.adk.store.SqliteADKStore",
+]
+
+# Session stores (sync)
+SYNC_SESSION_STORES = [
+ "sqlspec.adapters.adbc.adk.store.AdbcADKStore",
+ "sqlspec.adapters.cockroach_psycopg.adk.store.CockroachPsycopgSyncADKStore",
+ "sqlspec.adapters.duckdb.adk.store.DuckdbADKStore",
+ "sqlspec.adapters.mysqlconnector.adk.store.MysqlConnectorSyncADKStore",
+ "sqlspec.adapters.oracledb.adk.store.OracleSyncADKStore",
+ "sqlspec.adapters.psycopg.adk.store.PsycopgSyncADKStore",
+ "sqlspec.adapters.pymysql.adk.store.PyMysqlADKStore",
+ "sqlspec.adapters.spanner.adk.store.SpannerSyncADKStore",
+]
+
+# Memory stores (async)
+ASYNC_MEMORY_STORES = [
+ "sqlspec.adapters.asyncpg.adk.store.AsyncpgADKMemoryStore",
+ "sqlspec.adapters.aiosqlite.adk.store.AiosqliteADKMemoryStore",
+ "sqlspec.adapters.asyncmy.adk.store.AsyncmyADKMemoryStore",
+ "sqlspec.adapters.cockroach_asyncpg.adk.store.CockroachAsyncpgADKMemoryStore",
+ "sqlspec.adapters.cockroach_psycopg.adk.store.CockroachPsycopgAsyncADKMemoryStore",
+ "sqlspec.adapters.mysqlconnector.adk.store.MysqlConnectorAsyncADKMemoryStore",
+ "sqlspec.adapters.oracledb.adk.store.OracleAsyncADKMemoryStore",
+ "sqlspec.adapters.psqlpy.adk.store.PsqlpyADKMemoryStore",
+ "sqlspec.adapters.psycopg.adk.store.PsycopgAsyncADKMemoryStore",
+]
+
+# Memory stores (sync)
+SYNC_MEMORY_STORES = [
+ "sqlspec.adapters.adbc.adk.store.AdbcADKMemoryStore",
+ "sqlspec.adapters.cockroach_psycopg.adk.store.CockroachPsycopgSyncADKMemoryStore",
+ "sqlspec.adapters.duckdb.adk.store.DuckdbADKMemoryStore",
+ "sqlspec.adapters.mysqlconnector.adk.store.MysqlConnectorSyncADKMemoryStore",
+ "sqlspec.adapters.oracledb.adk.store.OracleSyncADKMemoryStore",
+ "sqlspec.adapters.psycopg.adk.store.PsycopgSyncADKMemoryStore",
+ "sqlspec.adapters.pymysql.adk.store.PyMysqlADKMemoryStore",
+ "sqlspec.adapters.spanner.adk.store.SpannerSyncADKMemoryStore",
+ "sqlspec.adapters.sqlite.adk.store.SqliteADKMemoryStore",
+]
+
+ALL_STORE_CLASSES = ASYNC_SESSION_STORES + SYNC_SESSION_STORES + ASYNC_MEMORY_STORES + SYNC_MEMORY_STORES
+
+
+@pytest.mark.parametrize("class_path", ALL_STORE_CLASSES)
+def test_store_has_no_abstract_methods(class_path: str) -> None:
+ """Every shipped store class must be concrete (no unsatisfied abstract methods).
+
+ A class with entries in ``__abstractmethods__`` cannot be instantiated and
+ signals that the concrete store is missing one or more method implementations
+ required by its base class contract.
+ """
+ module_path, class_name = class_path.rsplit(".", 1)
+ try:
+ module = importlib.import_module(module_path)
+ except ImportError:
+ pytest.skip(f"Module {module_path} not importable (missing optional dependency)")
+ cls = getattr(module, class_name)
+ abstract: set[str] = getattr(cls, "__abstractmethods__", set())
+ assert not abstract, f"{class_path} has unsatisfied abstract methods: {abstract}"
diff --git a/uv.lock b/uv.lock
index 4300db512..b499bb1d7 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1509,7 +1509,7 @@ name = "exceptiongroup"
version = "1.3.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "typing-extensions", marker = "python_full_version < '3.13'" },
+ { name = "typing-extensions", marker = "python_full_version < '3.11'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" }
wheels = [
@@ -1963,7 +1963,7 @@ wheels = [
[[package]]
name = "google-cloud-aiplatform"
-version = "1.141.0"
+version = "1.142.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "docstring-parser" },
@@ -1979,13 +1979,14 @@ dependencies = [
{ name = "pydantic" },
{ name = "typing-extensions" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/ac/dc/1209c7aab43bd7233cf631165a3b1b4284d22fc7fe7387c66228d07868ab/google_cloud_aiplatform-1.141.0.tar.gz", hash = "sha256:e3b1cdb28865dd862aac9c685dfc5ac076488705aba0a5354016efadcddd59c6", size = 10152688, upload-time = "2026-03-10T22:20:08.692Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/41/0d/3063a0512d60cf18854a279e00ccb796429545464345ef821cf77cb93d05/google_cloud_aiplatform-1.142.0.tar.gz", hash = "sha256:87b49e002703dc14885093e9b264587db84222bef5f70f5a442d03f41beecdd1", size = 10207993, upload-time = "2026-03-20T22:49:13.797Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/6a/fc/428af69a69ff2e477e7f5e12d227b31fe5790f1a8234aacd54297f49c836/google_cloud_aiplatform-1.141.0-py2.py3-none-any.whl", hash = "sha256:6bd25b4d514c40b8181ca703e1b313ad6d0454ab8006fc9907fb3e9f672f31d1", size = 8358409, upload-time = "2026-03-10T22:20:04.871Z" },
+ { url = "https://files.pythonhosted.org/packages/59/8b/f29646d3fa940f0e38cfcc12137f4851856b50d7486a3c05103ebc78d82d/google_cloud_aiplatform-1.142.0-py2.py3-none-any.whl", hash = "sha256:17c91db9b613cbbafb2c36335b123686aeb2b4b8448be5134b565ae07165a39a", size = 8388991, upload-time = "2026-03-20T22:49:10.334Z" },
]
[package.optional-dependencies]
agent-engines = [
+ { name = "aiohttp" },
{ name = "cloudpickle" },
{ name = "google-cloud-iam" },
{ name = "google-cloud-logging" },
@@ -7071,7 +7072,7 @@ wheels = [
[[package]]
name = "sqlspec"
-version = "0.41.1"
+version = "0.42.0"
source = { editable = "." }
dependencies = [
{ name = "mypy-extensions" },
@@ -7121,6 +7122,7 @@ cockroachdb = [
]
duckdb = [
{ name = "duckdb" },
+ { name = "pytz" },
]
fastapi = [
{ name = "fastapi" },
@@ -7397,6 +7399,7 @@ requires-dist = [
{ name = "pydantic-extra-types", marker = "extra == 'pydantic'" },
{ name = "pymssql", marker = "extra == 'pymssql'" },
{ name = "pymysql", marker = "extra == 'pymysql'" },
+ { name = "pytz", marker = "extra == 'duckdb'" },
{ name = "rich-click", specifier = ">=1.9.0" },
{ name = "sqlglot", specifier = ">=30.0.0" },
{ name = "sqlglot", extras = ["c"], marker = "extra == 'mypyc'", specifier = ">=30.0.0" },