diff --git a/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py b/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py index 7cc51e4930..3c4fb68fe8 100644 --- a/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py +++ b/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py @@ -1297,7 +1297,6 @@ def test_text_only_messages(self) -> None: assert result[0].content[0].text == "hello" def test_image_uri_content(self) -> None: - from agent_framework import Content img = Content.from_uri(uri="https://example.com/photo.png", media_type="image/png") messages = [Message(role="user", contents=[img])] @@ -1309,7 +1308,6 @@ def test_image_uri_content(self) -> None: assert result[0].content[0].image.url == "https://example.com/photo.png" def test_mixed_text_and_image_content(self) -> None: - from agent_framework import Content text = Content.from_text("describe this image") img = Content.from_uri(uri="https://example.com/img.jpg", media_type="image/jpeg") @@ -1319,7 +1317,6 @@ def test_mixed_text_and_image_content(self) -> None: assert len(result[0].content) == 2 def test_skips_non_text_non_image_content(self) -> None: - from agent_framework import Content error = Content.from_error(message="oops") messages = [Message(role="user", contents=[error])] @@ -1327,7 +1324,6 @@ def test_skips_non_text_non_image_content(self) -> None: assert len(result) == 0 # message had no usable content def test_skips_empty_text(self) -> None: - from agent_framework import Content empty = Content.from_text("") messages = [Message(role="user", contents=[empty])] @@ -1341,7 +1337,6 @@ def test_fallback_to_msg_text_when_no_contents(self) -> None: assert result[0].content[0].text == "fallback text" def test_data_uri_image(self) -> None: - from agent_framework import Content img = Content.from_data(data=b"\x89PNG", media_type="image/png") messages = [Message(role="user", contents=[img])] @@ -1352,7 +1347,6 @@ def test_data_uri_image(self) -> None: assert isinstance(result[0].content[0], KnowledgeBaseMessageImageContent) def test_non_image_uri_skipped(self) -> None: - from agent_framework import Content pdf = Content.from_uri(uri="https://example.com/doc.pdf", media_type="application/pdf") messages = [Message(role="user", contents=[pdf])] @@ -1568,9 +1562,7 @@ def test_references_become_annotations(self) -> None: KnowledgeBaseMessage(role="assistant", content=[KnowledgeBaseMessageTextContent(text="answer")]), ], references=[ - KnowledgeBaseWebReference( - id="ref-1", activity_source=0, url="https://example.com", title="Example" - ), + KnowledgeBaseWebReference(id="ref-1", activity_source=0, url="https://example.com", title="Example"), ], ) result = AzureAISearchContextProvider._parse_messages_from_kb_response(response) diff --git a/python/packages/azure-cosmos/AGENTS.md b/python/packages/azure-cosmos/AGENTS.md new file mode 100644 index 0000000000..7cb0c2c717 --- /dev/null +++ b/python/packages/azure-cosmos/AGENTS.md @@ -0,0 +1,28 @@ +# Azure Cosmos DB Package (agent-framework-azure-cosmos) + +Azure Cosmos DB history provider integration for Agent Framework. + +## Main Classes + +- **`CosmosHistoryProvider`** - Persistent conversation history storage backed by Azure Cosmos DB + +## Usage + +```python +from agent_framework_azure_cosmos import CosmosHistoryProvider + +provider = CosmosHistoryProvider( + endpoint="https://.documents.azure.com:443/", + credential="", + database_name="agent-framework", + container_name="chat-history", +) +``` + +Container name is configured on the provider. `session_id` is used as the partition key. + +## Import Path + +```python +from agent_framework_azure_cosmos import CosmosHistoryProvider +``` diff --git a/python/packages/azure-cosmos/LICENSE b/python/packages/azure-cosmos/LICENSE new file mode 100644 index 0000000000..9e841e7a26 --- /dev/null +++ b/python/packages/azure-cosmos/LICENSE @@ -0,0 +1,21 @@ + MIT License + + Copyright (c) Microsoft Corporation. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE diff --git a/python/packages/azure-cosmos/README.md b/python/packages/azure-cosmos/README.md new file mode 100644 index 0000000000..198376bcbb --- /dev/null +++ b/python/packages/azure-cosmos/README.md @@ -0,0 +1,38 @@ +# Get Started with Microsoft Agent Framework Azure Cosmos DB + +Please install this package via pip: + +```bash +pip install agent-framework-azure-cosmos --pre +``` + +## Azure Cosmos DB History Provider + +The Azure Cosmos DB integration provides `CosmosHistoryProvider` for persistent conversation history storage. + +### Basic Usage Example + +```python +from azure.identity.aio import DefaultAzureCredential +from agent_framework_azure_cosmos import CosmosHistoryProvider + +provider = CosmosHistoryProvider( + endpoint="https://.documents.azure.com:443/", + credential=DefaultAzureCredential(), + database_name="agent-framework", + container_name="chat-history", +) +``` + +Credentials follow the same pattern used by other Azure connectors in the repository: + +- Pass a credential object (for example `DefaultAzureCredential`) +- Or pass a key string directly +- Or set `AZURE_COSMOS_KEY` in the environment + +Container naming behavior: + +- Container name is configured on the provider (`container_name` or `AZURE_COSMOS_CONTAINER_NAME`) +- `session_id` is used as the Cosmos partition key for reads/writes + +See `samples/cosmos_history_provider.py` for a runnable package-local example. diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/__init__.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/__init__.py new file mode 100644 index 0000000000..5bcfb3928b --- /dev/null +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Microsoft. All rights reserved. + +import importlib.metadata + +from ._history_provider import CosmosHistoryProvider + +try: + __version__ = importlib.metadata.version(__name__) +except importlib.metadata.PackageNotFoundError: + __version__ = "0.0.0" # Fallback for development mode + +__all__ = [ + "CosmosHistoryProvider", + "__version__", +] diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py new file mode 100644 index 0000000000..790cdfb391 --- /dev/null +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py @@ -0,0 +1,254 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Azure Cosmos DB history provider.""" + +from __future__ import annotations + +import time +import uuid +from collections.abc import Sequence +from typing import Any, ClassVar, TypedDict + +from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Message +from agent_framework._sessions import BaseHistoryProvider +from agent_framework._settings import SecretString, load_settings +from agent_framework.azure._entra_id_authentication import AzureCredentialTypes +from azure.cosmos import PartitionKey +from azure.cosmos.aio import ContainerProxy, CosmosClient, DatabaseProxy + + +class AzureCosmosHistorySettings(TypedDict, total=False): + """Settings for CosmosHistoryProvider resolved from args and environment.""" + + endpoint: str | None + database_name: str | None + container_name: str | None + key: SecretString | None + + +class CosmosHistoryProvider(BaseHistoryProvider): + """Azure Cosmos DB-backed history provider using BaseHistoryProvider hooks.""" + + DEFAULT_SOURCE_ID: ClassVar[str] = "azure_cosmos_history" + _BATCH_OPERATION_LIMIT: ClassVar[int] = 100 + + def __init__( + self, + source_id: str = DEFAULT_SOURCE_ID, + *, + load_messages: bool = True, + store_outputs: bool = True, + store_inputs: bool = True, + store_context_messages: bool = False, + store_context_from: set[str] | None = None, + endpoint: str | None = None, + database_name: str | None = None, + container_name: str | None = None, + credential: str | AzureCredentialTypes | None = None, + cosmos_client: CosmosClient | None = None, + container_client: ContainerProxy | None = None, + env_file_path: str | None = None, + env_file_encoding: str | None = None, + ) -> None: + """Initialize the Azure Cosmos DB history provider. + + Args: + source_id: Unique identifier for this provider instance. + load_messages: Whether to load messages before invocation. + store_outputs: Whether to store response messages. + store_inputs: Whether to store input messages. + store_context_messages: Whether to store context from other providers. + store_context_from: If set, only store context from these source_ids. + endpoint: Cosmos DB account endpoint. + Can be set via ``AZURE_COSMOS_ENDPOINT``. + database_name: Cosmos DB database name. + Can be set via ``AZURE_COSMOS_DATABASE_NAME``. + container_name: Cosmos DB container name. + Can be set via ``AZURE_COSMOS_CONTAINER_NAME``. + credential: Credential to authenticate with Cosmos DB. + Supports key string and Azure credential objects. + Can be set via ``AZURE_COSMOS_KEY`` when omitted. + cosmos_client: Pre-created Cosmos async client. + container_client: Pre-created Cosmos container client for fixed-container usage. + env_file_path: Path to environment file for loading settings. + env_file_encoding: Encoding of the environment file. + """ + super().__init__( + source_id, + load_messages=load_messages, + store_outputs=store_outputs, + store_inputs=store_inputs, + store_context_messages=store_context_messages, + store_context_from=store_context_from, + ) + + self._cosmos_client: CosmosClient | None = cosmos_client + self._database_client: DatabaseProxy | None = None + self._container: ContainerProxy | None = container_client + self._owns_client = False + + if self._container is not None: + return + + required_fields: list[str] = ["database_name", "container_name"] + if cosmos_client is None: + required_fields.append("endpoint") + if credential is None: + required_fields.append("key") + + settings = load_settings( + AzureCosmosHistorySettings, + env_prefix="AZURE_COSMOS_", + required_fields=required_fields, + endpoint=endpoint, + database_name=database_name, + container_name=container_name, + key=credential if isinstance(credential, str) else None, + env_file_path=env_file_path, + env_file_encoding=env_file_encoding, + ) + + database_name_value = settings.get("database_name") + container_name_value = settings.get("container_name") + if database_name_value is None or container_name_value is None: + raise ValueError("Both database_name and container_name are required to initialize CosmosHistoryProvider.") + self.database_name = database_name_value + self.container_name = container_name_value + + if self._cosmos_client is None: + endpoint_value = settings.get("endpoint") + if endpoint_value is None: + raise ValueError( + "endpoint is required to initialize CosmosHistoryProvider when cosmos_client is not set." + ) + + resolved_credential = self._resolve_credential(credential, settings) + self._cosmos_client = CosmosClient( + url=endpoint_value, + credential=resolved_credential, # type: ignore[arg-type] + user_agent_suffix=AGENT_FRAMEWORK_USER_AGENT, + ) + self._owns_client = True + + self._database_client = self._cosmos_client.get_database_client(self.database_name) + + @staticmethod + def _resolve_credential( + credential: str | AzureCredentialTypes | None, settings: AzureCosmosHistorySettings + ) -> str | AzureCredentialTypes: + if credential is not None: + return credential + + if settings.get("key") is not None: + return settings["key"].get_secret_value() # type: ignore[union-attr] + + raise ValueError( + "Azure Cosmos credential is required. Provide 'credential' or set 'AZURE_COSMOS_KEY' environment variable." + ) + + async def _get_container(self) -> ContainerProxy: + """Get or create the Cosmos DB container for storing messages.""" + if self._container is not None: + return self._container + if self._database_client is None: + raise RuntimeError("Cosmos database client is not initialized.") + + self._container = await self._database_client.create_container_if_not_exists( + id=self.container_name, + partition_key=PartitionKey(path="/session_id"), + ) + return self._container + + @staticmethod + def _session_partition_key(session_id: str | None) -> str: + return session_id or "default" + + async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Message]: + """Retrieve stored messages for this session from Azure Cosmos DB.""" + session_key = self._session_partition_key(session_id) + container = await self._get_container() + + query = "SELECT c.message FROM c WHERE c.session_id = @session_id ORDER BY c.sort_key ASC" + parameters: list[dict[str, object]] = [{"name": "@session_id", "value": session_key}] + items = container.query_items(query=query, parameters=parameters, partition_key=session_key) + + messages: list[Message] = [] + async for item in items: + message_payload = item.get("message") + if isinstance(message_payload, dict): + messages.append(Message.from_dict(message_payload)) + + return messages + + async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs: Any) -> None: + """Persist messages for this session to Azure Cosmos DB.""" + if not messages: + return + + session_key = self._session_partition_key(session_id) + container = await self._get_container() + + base_sort_key = time.time_ns() + operations: list[tuple[str, tuple[dict[str, Any]]]] = [] + for index, message in enumerate(messages): + document = { + "id": str(uuid.uuid4()), + "session_id": session_key, + "sort_key": base_sort_key + index, + "source_id": self.source_id, + "message": message.to_dict(), + } + operations.append(("upsert", (document,))) + + for start in range(0, len(operations), self._BATCH_OPERATION_LIMIT): + batch = operations[start : start + self._BATCH_OPERATION_LIMIT] + await container.execute_item_batch(batch_operations=batch, partition_key=session_key) + + async def clear(self, session_id: str | None) -> None: + """Clear all messages for a session from Azure Cosmos DB.""" + session_key = self._session_partition_key(session_id) + container = await self._get_container() + + query = "SELECT c.id FROM c WHERE c.session_id = @session_id" + parameters: list[dict[str, object]] = [{"name": "@session_id", "value": session_key}] + items = container.query_items(query=query, parameters=parameters, partition_key=session_key) + + delete_operations: list[tuple[str, tuple[str]]] = [] + async for item in items: + item_id = item.get("id") + if isinstance(item_id, str): + delete_operations.append(("delete", (item_id,))) + + for start in range(0, len(delete_operations), self._BATCH_OPERATION_LIMIT): + batch = delete_operations[start : start + self._BATCH_OPERATION_LIMIT] + await container.execute_item_batch(batch_operations=batch, partition_key=session_key) + + async def list_sessions(self) -> list[str]: + """List all session IDs stored in this provider's Cosmos container.""" + container = await self._get_container() + query = "SELECT DISTINCT VALUE c.session_id FROM c" + items = container.query_items(query=query, enable_cross_partition_query=True) + + session_ids: set[str] = set() + async for item in items: + if isinstance(item, str): + session_ids.add(item) + return sorted(session_ids) + + async def close(self) -> None: + """Close the underlying Cosmos client when this provider owns it.""" + if self._owns_client and self._cosmos_client is not None: + await self._cosmos_client.close() + + async def __aenter__(self) -> CosmosHistoryProvider: + """Async context manager entry.""" + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + """Async context manager exit.""" + await self.close() diff --git a/python/packages/azure-cosmos/pyproject.toml b/python/packages/azure-cosmos/pyproject.toml new file mode 100644 index 0000000000..031de97844 --- /dev/null +++ b/python/packages/azure-cosmos/pyproject.toml @@ -0,0 +1,89 @@ +[project] +name = "agent-framework-azure-cosmos" +description = "Azure Cosmos DB history provider integration for Microsoft Agent Framework." +authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] +readme = "README.md" +requires-python = ">=3.10" +version = "1.0.0b260219" +license-files = ["LICENSE"] +urls.homepage = "https://aka.ms/agent-framework" +urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" +urls.release_notes = "https://github.com/microsoft/agent-framework/releases?q=tag%3Apython-1&expanded=true" +urls.issues = "https://github.com/microsoft/agent-framework/issues" +classifiers = [ + "License :: OSI Approved :: MIT License", + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Typing :: Typed", +] +dependencies = [ + "agent-framework-core>=1.0.0rc1", + "azure-cosmos>=4.9.0", +] + +[tool.uv] +prerelease = "if-necessary-or-explicit" +environments = [ + "sys_platform == 'darwin'", + "sys_platform == 'linux'", + "sys_platform == 'win32'" +] + +[tool.uv-dynamic-versioning] +fallback-version = "0.0.0" + +[tool.pytest.ini_options] +testpaths = 'tests' +addopts = "-ra -q -r fEX" +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" +filterwarnings = [ + "ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*" +] +timeout = 120 + +[tool.ruff] +extend = "../../pyproject.toml" + +[tool.coverage.run] +omit = [ + "**/__init__.py" +] + +[tool.pyright] +extends = "../../pyproject.toml" + +[tool.mypy] +plugins = ['pydantic.mypy'] +strict = true +python_version = "3.10" +ignore_missing_imports = true +disallow_untyped_defs = true +no_implicit_optional = true +check_untyped_defs = true +warn_return_any = true +show_error_codes = true +warn_unused_ignores = false +disallow_incomplete_defs = true +disallow_untyped_decorators = true + +[tool.bandit] +targets = ["agent_framework_azure_cosmos"] +exclude_dirs = ["tests"] + +[tool.poe] +executor.type = "uv" +include = "../../shared_tasks.toml" +[tool.poe.tasks] +mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_azure_cosmos" +test = "pytest --cov=agent_framework_azure_cosmos --cov-report=term-missing:skip-covered tests" + +[build-system] +requires = ["flit-core >= 3.11,<4.0"] +build-backend = "flit_core.buildapi" diff --git a/python/packages/azure-cosmos/samples/README.md b/python/packages/azure-cosmos/samples/README.md new file mode 100644 index 0000000000..082a9c2cfe --- /dev/null +++ b/python/packages/azure-cosmos/samples/README.md @@ -0,0 +1,20 @@ +# Azure Cosmos DB Package Samples + +This folder contains samples for `agent-framework-azure-cosmos`. + +| File | Description | +| --- | --- | +| [`cosmos_history_provider.py`](cosmos_history_provider.py) | Demonstrates an Agent using `CosmosHistoryProvider` with `AzureOpenAIResponsesClient` (project endpoint), provider-configured container name, and `session_id` partitioning. | + +## Prerequisites + +- `AZURE_COSMOS_ENDPOINT` +- `AZURE_COSMOS_DATABASE_NAME` +- `AZURE_COSMOS_CONTAINER_NAME` +- `AZURE_COSMOS_KEY` (or equivalent credential flow) + +## Run + +```bash +uv run --directory packages/azure-cosmos python samples/cosmos_history_provider.py +``` diff --git a/python/packages/azure-cosmos/samples/__init__.py b/python/packages/azure-cosmos/samples/__init__.py new file mode 100644 index 0000000000..516b9492f6 --- /dev/null +++ b/python/packages/azure-cosmos/samples/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Samples for the Azure Cosmos history provider package.""" diff --git a/python/packages/azure-cosmos/samples/cosmos_history_provider.py b/python/packages/azure-cosmos/samples/cosmos_history_provider.py new file mode 100644 index 0000000000..f612d4f05e --- /dev/null +++ b/python/packages/azure-cosmos/samples/cosmos_history_provider.py @@ -0,0 +1,100 @@ +# Copyright (c) Microsoft. All rights reserved. +# ruff: noqa: T201 + +import asyncio +import os + +from agent_framework.azure import AzureOpenAIResponsesClient +from azure.identity.aio import AzureCliCredential +from dotenv import load_dotenv + +from agent_framework_azure_cosmos import CosmosHistoryProvider + +""" +This sample demonstrates CosmosHistoryProvider as an agent context provider. + +Key components: +- AzureOpenAIResponsesClient configured with an Azure AI project endpoint +- CosmosHistoryProvider configured for Cosmos DB-backed message history +- Provider-configured container name with session_id as partition key + +Environment variables: + AZURE_AI_PROJECT_ENDPOINT + AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME + AZURE_COSMOS_ENDPOINT + AZURE_COSMOS_DATABASE_NAME + AZURE_COSMOS_CONTAINER_NAME +Optional: + AZURE_COSMOS_KEY +""" + +# Load environment variables from .env file. +load_dotenv() + + +async def main() -> None: + """Run the Cosmos history provider sample with an Agent.""" + project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT") + deployment_name = os.getenv("AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME") + cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") + cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") + cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") + cosmos_key = os.getenv("AZURE_COSMOS_KEY") + + if ( + not project_endpoint + or not deployment_name + or not cosmos_endpoint + or not cosmos_database_name + or not cosmos_container_name + ): + print( + "Please set AZURE_AI_PROJECT_ENDPOINT, AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME, " + "AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, and AZURE_COSMOS_CONTAINER_NAME." + ) + return + + # 1. Create the Azure OpenAI Responses client using project endpoint auth. + credential = AzureCliCredential() + client = AzureOpenAIResponsesClient( + project_endpoint=project_endpoint, + deployment_name=deployment_name, + credential=credential, + ) + + # 3. Create an agent that uses the history provider as a context provider. + async with ( + CosmosHistoryProvider( + endpoint=cosmos_endpoint, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + credential=cosmos_key or credential, + ) as history_provider, + client.as_agent( + name="CosmosHistoryAgent", + instructions="You are a helpful assistant that remembers prior turns.", + context_providers=[history_provider], + default_options={"store": False}, + ) as agent, + ): + # 4. Create a session (session_id is used as the partition key). + session = agent.create_session() + + # 5. Run a multi-turn conversation; history is persisted by CosmosHistoryProvider. + response1 = await agent.run("My name is Ada and I enjoy distributed systems.", session=session) + print(f"Assistant: {response1.text}") + + response2 = await agent.run("What do you remember about me?", session=session) + print(f"Assistant: {response2.text}") + print(f"Container: {history_provider.container_name}") + + +if __name__ == "__main__": + asyncio.run(main()) + +""" +Sample output: +Assistant: Nice to meet you, Ada! Distributed systems are a fascinating area. +Assistant: You told me your name is Ada and that you enjoy distributed systems. +Container: +""" diff --git a/python/packages/azure-cosmos/tests/test_cosmos_history_provider.py b/python/packages/azure-cosmos/tests/test_cosmos_history_provider.py new file mode 100644 index 0000000000..e8c65cc2d7 --- /dev/null +++ b/python/packages/azure-cosmos/tests/test_cosmos_history_provider.py @@ -0,0 +1,265 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest +from agent_framework import AgentResponse, Message +from agent_framework._sessions import AgentSession, SessionContext +from agent_framework.exceptions import SettingNotFoundError + +import agent_framework_azure_cosmos._history_provider as history_provider_module +from agent_framework_azure_cosmos._history_provider import CosmosHistoryProvider + + +def _to_async_iter(items: list[Any]) -> AsyncIterator[Any]: + async def _iterator() -> AsyncIterator[Any]: + for item in items: + yield item + + return _iterator() + + +@pytest.fixture +def mock_container() -> MagicMock: + container = MagicMock() + container.query_items = MagicMock(return_value=_to_async_iter([])) + container.execute_item_batch = AsyncMock(return_value=[]) + return container + + +@pytest.fixture +def mock_cosmos_client(mock_container: MagicMock) -> MagicMock: + database_client = MagicMock() + database_client.create_container_if_not_exists = AsyncMock(return_value=mock_container) + + client = MagicMock() + client.get_database_client.return_value = database_client + client.close = AsyncMock() + return client + + +class TestCosmosHistoryProviderInit: + def test_uses_provided_container_client(self, mock_container: MagicMock) -> None: + provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container) + assert provider.source_id == "mem" + assert provider.load_messages is True + assert provider.store_outputs is True + assert provider.store_inputs is True + + def test_uses_provided_cosmos_client(self, mock_cosmos_client: MagicMock) -> None: + provider = CosmosHistoryProvider( + source_id="mem", + cosmos_client=mock_cosmos_client, + database_name="db1", + container_name="history", + ) + + mock_cosmos_client.get_database_client.assert_called_once_with("db1") + assert provider.database_name == "db1" + assert provider.container_name == "history" + + def test_missing_required_settings_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("AZURE_COSMOS_ENDPOINT", raising=False) + monkeypatch.delenv("AZURE_COSMOS_DATABASE_NAME", raising=False) + monkeypatch.delenv("AZURE_COSMOS_CONTAINER_NAME", raising=False) + monkeypatch.delenv("AZURE_COSMOS_KEY", raising=False) + + with pytest.raises(SettingNotFoundError, match="database_name"): + CosmosHistoryProvider() + + def test_constructs_client_with_string_credential( + self, monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock + ) -> None: + mock_factory = MagicMock(return_value=mock_cosmos_client) + monkeypatch.setattr(history_provider_module, "CosmosClient", mock_factory) + + CosmosHistoryProvider( + endpoint="https://account.documents.azure.com:443/", + credential="key-123", + database_name="db1", + container_name="history", + ) + + mock_factory.assert_called_once() + kwargs = mock_factory.call_args.kwargs + assert kwargs["url"] == "https://account.documents.azure.com:443/" + assert kwargs["credential"] == "key-123" + + +class TestCosmosHistoryProviderContainerConfig: + async def test_provider_container_name_is_used(self, mock_cosmos_client: MagicMock) -> None: + provider = CosmosHistoryProvider( + source_id="mem", + cosmos_client=mock_cosmos_client, + database_name="db1", + container_name="custom-history", + ) + + await provider.get_messages("session-123") + + database_client = mock_cosmos_client.get_database_client.return_value + assert database_client.create_container_if_not_exists.await_count == 1 + kwargs = database_client.create_container_if_not_exists.await_args.kwargs + assert kwargs["id"] == "custom-history" + + +class TestCosmosHistoryProviderGetMessages: + async def test_returns_deserialized_messages(self, mock_container: MagicMock) -> None: + msg1 = Message(role="user", contents=["Hello"]) + msg2 = Message(role="assistant", contents=["Hi"]) + mock_container.query_items.return_value = _to_async_iter([ + {"message": msg1.to_dict()}, + {"message": msg2.to_dict()}, + ]) + + provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container) + messages = await provider.get_messages("s1") + + assert len(messages) == 2 + assert messages[0].role == "user" + assert messages[0].text == "Hello" + assert messages[1].role == "assistant" + assert messages[1].text == "Hi" + + async def test_empty_returns_empty(self, mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) + + provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container) + messages = await provider.get_messages("s1") + + assert messages == [] + + +class TestCosmosHistoryProviderListSessions: + async def test_list_sessions_returns_unique_sorted_ids(self, mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter(["s2", "s1", "s1", "s3"]) + provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container) + + sessions = await provider.list_sessions() + + assert sessions == ["s1", "s2", "s3"] + kwargs = mock_container.query_items.call_args.kwargs + assert kwargs["query"] == "SELECT DISTINCT VALUE c.session_id FROM c" + assert kwargs["enable_cross_partition_query"] is True + + +class TestCosmosHistoryProviderSaveMessages: + async def test_saves_messages(self, mock_container: MagicMock) -> None: + provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container) + messages = [Message(role="user", contents=["Hello"]), Message(role="assistant", contents=["Hi"])] + + await provider.save_messages("s1", messages) + + mock_container.execute_item_batch.assert_awaited_once() + batch_operations = mock_container.execute_item_batch.await_args.kwargs["batch_operations"] + assert len(batch_operations) == 2 + first_operation, first_args = batch_operations[0] + assert first_operation == "upsert" + first_document = first_args[0] + assert first_document["session_id"] == "s1" + assert first_document["message"]["role"] == "user" + assert mock_container.execute_item_batch.await_args.kwargs["partition_key"] == "s1" + + async def test_empty_messages_noop(self, mock_container: MagicMock) -> None: + provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container) + + await provider.save_messages("s1", []) + + mock_container.execute_item_batch.assert_not_awaited() + + +class TestCosmosHistoryProviderClear: + async def test_clear_deletes_all_session_items(self, mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([{"id": "1"}, {"id": "2"}]) + provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container) + + await provider.clear("s1") + + mock_container.execute_item_batch.assert_awaited_once() + batch_operations = mock_container.execute_item_batch.await_args.kwargs["batch_operations"] + assert len(batch_operations) == 2 + assert batch_operations[0] == ("delete", ("1",)) + assert batch_operations[1] == ("delete", ("2",)) + assert mock_container.execute_item_batch.await_args.kwargs["partition_key"] == "s1" + + +class TestCosmosHistoryProviderBeforeAfterRun: + async def test_before_run_loads_history(self, mock_container: MagicMock) -> None: + msg = Message(role="user", contents=["old msg"]) + mock_container.query_items.return_value = _to_async_iter([{"message": msg.to_dict()}]) + + provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container) + session = AgentSession(session_id="test") + context = SessionContext(input_messages=[Message(role="user", contents=["new msg"])], session_id="s1") + + await provider.before_run( + agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + assert "mem" in context.context_messages + assert context.context_messages["mem"][0].text == "old msg" + + async def test_after_run_stores_input_and_response(self, mock_container: MagicMock) -> None: + provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container) + session = AgentSession(session_id="test") + context = SessionContext(input_messages=[Message(role="user", contents=["hi"])], session_id="s1") + context._response = AgentResponse(messages=[Message(role="assistant", contents=["hello"])]) + + await provider.after_run( + agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + mock_container.execute_item_batch.assert_awaited_once() + batch_operations = mock_container.execute_item_batch.await_args.kwargs["batch_operations"] + assert len(batch_operations) == 2 + + +class TestCosmosHistoryProviderClose: + async def test_close_closes_owned_client( + self, monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock + ) -> None: + mock_factory = MagicMock(return_value=mock_cosmos_client) + monkeypatch.setattr(history_provider_module, "CosmosClient", mock_factory) + + provider = CosmosHistoryProvider( + endpoint="https://account.documents.azure.com:443/", + credential="key-123", + database_name="db1", + container_name="history", + ) + + await provider.close() + + mock_cosmos_client.close.assert_awaited_once() + + async def test_close_does_not_close_external_client(self, mock_cosmos_client: MagicMock) -> None: + provider = CosmosHistoryProvider( + source_id="mem", + cosmos_client=mock_cosmos_client, + database_name="db1", + container_name="history", + ) + + await provider.close() + + mock_cosmos_client.close.assert_not_awaited() + + async def test_async_context_manager_closes_owned_client( + self, monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock + ) -> None: + mock_factory = MagicMock(return_value=mock_cosmos_client) + monkeypatch.setattr(history_provider_module, "CosmosClient", mock_factory) + + async with CosmosHistoryProvider( + endpoint="https://account.documents.azure.com:443/", + credential="key-123", + database_name="db1", + container_name="history", + ) as provider: + assert provider is not None + + mock_cosmos_client.close.assert_awaited_once() diff --git a/python/packages/core/agent_framework/_skills.py b/python/packages/core/agent_framework/_skills.py index 0e132a9336..33d001b6f2 100644 --- a/python/packages/core/agent_framework/_skills.py +++ b/python/packages/core/agent_framework/_skills.py @@ -115,6 +115,7 @@ class _FileAgentSkill: source_path: str resource_names: list[str] = field(default_factory=list) + # endregion # region Private module-level functions (skill discovery, parsing, security) @@ -165,9 +166,7 @@ def _has_symlink_in_path(full_path: str, directory_path: str) -> bool: try: relative = Path(full_path).relative_to(dir_path) except ValueError as exc: - raise ValueError( - f"full_path {full_path!r} does not start with directory_path {directory_path!r}" - ) from exc + raise ValueError(f"full_path {full_path!r} does not start with directory_path {directory_path!r}") from exc current = dir_path for part in relative.parts: @@ -436,6 +435,7 @@ def _build_skills_instruction_prompt( return template.format("\n".join(lines)) + # endregion # region Public API @@ -494,7 +494,9 @@ def __init__( """ super().__init__(source_id or self.DEFAULT_SOURCE_ID) - resolved_paths: Sequence[str] = [str(skill_paths)] if isinstance(skill_paths, (str, Path)) else [str(p) for p in skill_paths] + resolved_paths: Sequence[str] = ( + [str(skill_paths)] if isinstance(skill_paths, (str, Path)) else [str(p) for p in skill_paths] + ) self._skills = _discover_and_load_skills(resolved_paths) self._skills_instruction_prompt = _build_skills_instruction_prompt(skills_instruction_prompt, self._skills) @@ -594,4 +596,5 @@ def _read_skill_resource(self, skill_name: str, resource_name: str) -> str: logger.exception("Failed to read resource '%s' from skill '%s'", resource_name, skill_name) return f"Error: Failed to read resource '{resource_name}' from skill '{skill_name}'." + # endregion diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index f278afaeac..319d35f152 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -937,8 +937,7 @@ def ai_func(arg1: str) -> str: orphaned_calls = all_call_ids - all_result_ids assert not orphaned_calls, ( - f"Response contains orphaned FunctionCallContent without matching " - f"FunctionResultContent: {orphaned_calls}." + f"Response contains orphaned FunctionCallContent without matching FunctionResultContent: {orphaned_calls}." ) @@ -1123,8 +1122,7 @@ def browser_snapshot(url: str) -> str: orphaned_calls = all_call_ids - all_result_ids assert not orphaned_calls, ( - f"Response contains orphaned function calls {orphaned_calls}. " - f"This would cause API errors on the next call." + f"Response contains orphaned function calls {orphaned_calls}. This would cause API errors on the next call." ) diff --git a/python/packages/core/tests/core/test_skills.py b/python/packages/core/tests/core/test_skills.py index 571abeaa21..a77f214718 100644 --- a/python/packages/core/tests/core/test_skills.py +++ b/python/packages/core/tests/core/test_skills.py @@ -38,6 +38,7 @@ def _symlinks_supported(tmp: Path) -> bool: test_link.unlink(missing_ok=True) test_target.unlink(missing_ok=True) + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- diff --git a/python/pyproject.toml b/python/pyproject.toml index 259caffaf9..18258f6b5f 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -76,6 +76,7 @@ agent-framework-core = { workspace = true } agent-framework-a2a = { workspace = true } agent-framework-ag-ui = { workspace = true } agent-framework-azure-ai-search = { workspace = true } +agent-framework-azure-cosmos = { workspace = true } agent-framework-anthropic = { workspace = true } agent-framework-azure-ai = { workspace = true } agent-framework-azurefunctions = { workspace = true } @@ -236,6 +237,7 @@ check = ["check-packages", "samples-lint", "samples-syntax", "test", "markdown-c [tool.poe.tasks.all-tests-cov] cmd = """ pytest --import-mode=importlib +-m "not integration" --cov=agent_framework --cov=agent_framework_core --cov=agent_framework_a2a @@ -263,6 +265,7 @@ pytest --import-mode=importlib [tool.poe.tasks.all-tests] cmd = """ pytest --import-mode=importlib +-m "not integration" --ignore-glob=packages/lab/** --ignore-glob=packages/devui/** -rs diff --git a/python/uv.lock b/python/uv.lock index 1cd73c34ae..d779ad2067 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -32,6 +32,7 @@ members = [ "agent-framework-anthropic", "agent-framework-azure-ai", "agent-framework-azure-ai-search", + "agent-framework-azure-cosmos", "agent-framework-azurefunctions", "agent-framework-bedrock", "agent-framework-chatkit", @@ -235,6 +236,21 @@ requires-dist = [ { name = "azure-search-documents", specifier = "==11.7.0b2" }, ] +[[package]] +name = "agent-framework-azure-cosmos" +version = "1.0.0b260219" +source = { editable = "packages/azure-cosmos" } +dependencies = [ + { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "azure-cosmos", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, +] + +[package.metadata] +requires-dist = [ + { name = "agent-framework-core", editable = "packages/core" }, + { name = "azure-cosmos", specifier = ">=4.9.0" }, +] + [[package]] name = "agent-framework-azurefunctions" version = "1.0.0b260219" @@ -1033,6 +1049,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/42/23/6371a551800d3812d6019cd813acd985f9fac0fedc1290129211a73da4ae/azure_core-1.38.2-py3-none-any.whl", hash = "sha256:074806c75cf239ea284a33a66827695ef7aeddac0b4e19dda266a93e4665ead9", size = 217957, upload-time = "2026-02-18T19:33:07.696Z" }, ] +[[package]] +name = "azure-cosmos" +version = "4.15.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "azure-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c8/a3/0474e622bf9676e3206d61269461ed16a05958363c254ea3b15af16219b2/azure_cosmos-4.15.0.tar.gz", hash = "sha256:be1cf49837c197d9da880ec47fe020a24d679075b89e0e1e2aca8d376b3a5a24", size = 2100744, upload-time = "2026-02-23T16:01:52.293Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/31/5f/b6e3d3ae16fa121fdc17e62447800d378b7e716cd6103c3650977a6c4618/azure_cosmos-4.15.0-py3-none-any.whl", hash = "sha256:83c1da7386bcd0df9a15c52116cc35012225d8a72d4f1379938b83ea5eb19fff", size = 424870, upload-time = "2026-02-23T16:01:54.514Z" }, +] + [[package]] name = "azure-functions" version = "1.24.0"