diff --git a/python/CHANGELOG.md b/python/CHANGELOG.md index 99947710c9..0ae0df1454 100644 --- a/python/CHANGELOG.md +++ b/python/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed +- **agent-framework-azure-cosmos**: [BREAKING] `CosmosCheckpointStorage` now uses restricted pickle deserialization by default, matching `FileCheckpointStorage` behavior. If your checkpoints contain application-defined types, pass them via `allowed_checkpoint_types=["my_app.models:MyState"]`. ([#5200](https://github.com/microsoft/agent-framework/issues/5200)) + ## [1.0.1] - 2026-04-09 ### Added diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py index 1b6257f203..496d95d7c3 100644 --- a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py @@ -43,9 +43,34 @@ class CosmosCheckpointStorage: ``FileCheckpointStorage``, allowing full Python object fidelity for complex workflow state while keeping the document structure human-readable. - SECURITY WARNING: Checkpoints use pickle for data serialization. Only load - checkpoints from trusted sources. Loading a malicious checkpoint can execute - arbitrary code. + Security warning: checkpoints use pickle for non-JSON-native values. Loading + checkpoints from untrusted sources is unsafe and can execute arbitrary code + during deserialization. The built-in deserialization restrictions reduce risk, + but they do not make untrusted checkpoints safe to load. Extending + ``allowed_checkpoint_types`` may further increase risk and should only be done + for trusted application types. + + By default, checkpoint deserialization is restricted to a built-in set of safe + Python types (primitives, datetime, uuid, ...) and all ``agent_framework`` + internal types. To allow additional application-specific types, pass them via + the ``allowed_checkpoint_types`` parameter using ``"module:qualname"`` format. + + Example: + + .. code-block:: python + + from azure.identity.aio import DefaultAzureCredential + from agent_framework_azure_cosmos import CosmosCheckpointStorage + + storage = CosmosCheckpointStorage( + endpoint="https://my-account.documents.azure.com:443/", + credential=DefaultAzureCredential(), + database_name="agent-db", + container_name="checkpoints", + allowed_checkpoint_types=[ + "my_app.models:MyState", + ], + ) The database and container are created automatically on first use if they do not already exist. The container uses partition key @@ -97,6 +122,7 @@ def __init__( container_client: ContainerProxy | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, + allowed_checkpoint_types: list[str] | None = None, ) -> None: """Initialize the Azure Cosmos DB checkpoint storage. @@ -129,10 +155,15 @@ def __init__( container_client: Pre-created Cosmos container client. env_file_path: Path to environment file for loading settings. env_file_encoding: Encoding of the environment file. + allowed_checkpoint_types: Additional types (beyond the built-in safe set + and framework types) that are permitted during checkpoint + deserialization. Each entry should be a ``"module:qualname"`` + string (e.g., ``"my_app.models:MyState"``). """ self._cosmos_client: CosmosClient | None = cosmos_client self._container_proxy: ContainerProxy | None = container_client self._owns_client = False + self._allowed_types: frozenset[str] = frozenset(allowed_checkpoint_types or []) if self._container_proxy is not None: self.database_name: str = database_name or "" @@ -401,8 +432,7 @@ async def _ensure_container_proxy(self) -> None: partition_key=PartitionKey(path="/workflow_name"), ) - @staticmethod - def _document_to_checkpoint(document: dict[str, Any]) -> WorkflowCheckpoint: + def _document_to_checkpoint(self, document: dict[str, Any]) -> WorkflowCheckpoint: """Convert a Cosmos DB document back to a WorkflowCheckpoint. Strips Cosmos DB system properties (``_rid``, ``_self``, ``_etag``, @@ -413,7 +443,7 @@ def _document_to_checkpoint(document: dict[str, Any]) -> WorkflowCheckpoint: cosmos_keys = {"id", "_rid", "_self", "_etag", "_attachments", "_ts"} cleaned = {k: v for k, v in document.items() if k not in cosmos_keys} - decoded = decode_checkpoint_value(cleaned) + decoded = decode_checkpoint_value(cleaned, allowed_types=self._allowed_types) return WorkflowCheckpoint.from_dict(decoded) @staticmethod diff --git a/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py b/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py index 52155d0e21..016220e693 100644 --- a/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py +++ b/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py @@ -6,6 +6,7 @@ import uuid from collections.abc import AsyncIterator from contextlib import suppress +from dataclasses import dataclass from typing import Any from unittest.mock import AsyncMock, MagicMock, patch @@ -595,3 +596,142 @@ async def test_cosmos_checkpoint_storage_roundtrip_with_emulator() -> None: finally: with suppress(Exception): await cosmos_client.delete_database(database_name) + + +# --- Tests for allowed_checkpoint_types --- + + +@dataclass +class _AppState: + """Application-defined state type used to test allowed_checkpoint_types.""" + + label: str + count: int + + +_APP_STATE_TYPE_KEY = f"{_AppState.__module__}:{_AppState.__qualname__}" + + +def _make_checkpoint_with_state(state: dict[str, Any]) -> WorkflowCheckpoint: + """Create a checkpoint with custom state for serialization tests.""" + return WorkflowCheckpoint( + workflow_name="test-workflow", + graph_signature_hash="abc123", + timestamp="2025-01-01T00:00:00+00:00", + state=state, + iteration_count=1, + ) + + +async def test_init_accepts_allowed_checkpoint_types(mock_container: MagicMock) -> None: + """CosmosCheckpointStorage.__init__ accepts allowed_checkpoint_types.""" + storage = CosmosCheckpointStorage( + container_client=mock_container, + allowed_checkpoint_types=["some.module:SomeType"], + ) + assert storage is not None + + +async def test_load_allows_builtin_safe_types(mock_container: MagicMock) -> None: + """Built-in safe types load without opt-in via allowed_checkpoint_types.""" + from datetime import datetime, timezone + + checkpoint = _make_checkpoint_with_state({ + "ts": datetime(2025, 1, 1, tzinfo=timezone.utc), + "tags": {1, 2, 3}, + }) + doc = _checkpoint_to_cosmos_document(checkpoint) + mock_container.query_items.return_value = _to_async_iter([doc]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + loaded = await storage.load(checkpoint.checkpoint_id) + + assert loaded.state["ts"] == datetime(2025, 1, 1, tzinfo=timezone.utc) + assert loaded.state["tags"] == {1, 2, 3} + + +async def test_load_blocks_unlisted_app_type(mock_container: MagicMock) -> None: + """Application types are blocked when not listed in allowed_checkpoint_types.""" + checkpoint = _make_checkpoint_with_state({"data": _AppState(label="x", count=1)}) + doc = _checkpoint_to_cosmos_document(checkpoint) + mock_container.query_items.return_value = _to_async_iter([doc]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + + with pytest.raises(WorkflowCheckpointException, match="deserialization blocked"): + await storage.load(checkpoint.checkpoint_id) + + +async def test_load_allows_listed_app_type(mock_container: MagicMock) -> None: + """Application types are allowed when listed in allowed_checkpoint_types.""" + checkpoint = _make_checkpoint_with_state({"data": _AppState(label="ok", count=7)}) + doc = _checkpoint_to_cosmos_document(checkpoint) + mock_container.query_items.return_value = _to_async_iter([doc]) + + storage = CosmosCheckpointStorage( + container_client=mock_container, + allowed_checkpoint_types=[_APP_STATE_TYPE_KEY], + ) + loaded = await storage.load(checkpoint.checkpoint_id) + + assert isinstance(loaded.state["data"], _AppState) + assert loaded.state["data"].label == "ok" + assert loaded.state["data"].count == 7 + + +async def test_list_checkpoints_blocks_unlisted_app_type(mock_container: MagicMock) -> None: + """list_checkpoints skips documents with unlisted application types.""" + checkpoint = _make_checkpoint_with_state({"data": _AppState(label="x", count=1)}) + doc = _checkpoint_to_cosmos_document(checkpoint) + mock_container.query_items.return_value = _to_async_iter([doc]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + results = await storage.list_checkpoints(workflow_name="test-workflow") + + # The document is skipped (logged as warning) because the type is blocked + assert len(results) == 0 + + +async def test_list_checkpoints_allows_listed_app_type(mock_container: MagicMock) -> None: + """list_checkpoints decodes documents with listed application types.""" + checkpoint = _make_checkpoint_with_state({"data": _AppState(label="ok", count=3)}) + doc = _checkpoint_to_cosmos_document(checkpoint) + mock_container.query_items.return_value = _to_async_iter([doc]) + + storage = CosmosCheckpointStorage( + container_client=mock_container, + allowed_checkpoint_types=[_APP_STATE_TYPE_KEY], + ) + results = await storage.list_checkpoints(workflow_name="test-workflow") + + assert len(results) == 1 + assert isinstance(results[0].state["data"], _AppState) + + +async def test_get_latest_blocks_unlisted_app_type(mock_container: MagicMock) -> None: + """get_latest raises when the checkpoint contains an unlisted application type.""" + checkpoint = _make_checkpoint_with_state({"data": _AppState(label="x", count=1)}) + doc = _checkpoint_to_cosmos_document(checkpoint) + mock_container.query_items.return_value = _to_async_iter([doc]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + + with pytest.raises(WorkflowCheckpointException, match="deserialization blocked"): + await storage.get_latest(workflow_name="test-workflow") + + +async def test_get_latest_allows_listed_app_type(mock_container: MagicMock) -> None: + """get_latest decodes checkpoints with listed application types.""" + checkpoint = _make_checkpoint_with_state({"data": _AppState(label="latest", count=9)}) + doc = _checkpoint_to_cosmos_document(checkpoint) + mock_container.query_items.return_value = _to_async_iter([doc]) + + storage = CosmosCheckpointStorage( + container_client=mock_container, + allowed_checkpoint_types=[_APP_STATE_TYPE_KEY], + ) + result = await storage.get_latest(workflow_name="test-workflow") + + assert result is not None + assert isinstance(result.state["data"], _AppState) + assert result.state["data"].label == "latest"