diff --git a/src/bedrock_agentcore/memory/client.py b/src/bedrock_agentcore/memory/client.py index eb0d6d4..e03fcb0 100644 --- a/src/bedrock_agentcore/memory/client.py +++ b/src/bedrock_agentcore/memory/client.py @@ -860,6 +860,7 @@ def list_events( all_events.extend(events) next_token = response.get("nextToken") + # Break if: no more pages or reached max if not next_token or len(all_events) >= max_results: break diff --git a/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py b/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py index 1f0905c..b9c06c4 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py +++ b/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py @@ -72,7 +72,7 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]: list[SessionMessage]: list of SessionMessage objects. """ messages = [] - for event in events: + for event in reversed(events): for payload_item in event.get("payload", []): if "conversational" in payload_item: conv = payload_item["conversational"] @@ -93,7 +93,7 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]: logger.error("This is not a SessionMessage but just a blob message. Ignoring") except (json.JSONDecodeError, ValueError): logger.error("Failed to parse blob content: %s", payload_item) - return list(reversed(messages)) + return messages @staticmethod def total_length(message: tuple[str, str]) -> int: diff --git a/src/bedrock_agentcore/memory/integrations/strands/config.py b/src/bedrock_agentcore/memory/integrations/strands/config.py index d2d5cef..7017568 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/config.py +++ b/src/bedrock_agentcore/memory/integrations/strands/config.py @@ -29,9 +29,12 @@ class AgentCoreMemoryConfig(BaseModel): session_id: Required unique ID for the session actor_id: Required unique ID for the agent instance/user retrieval_config: Optional dictionary mapping namespaces to retrieval configurations + batch_size: Number of messages to batch before sending to AgentCore Memory. + Default of 1 means immediate sending (no batching). Max 100. """ memory_id: str = Field(min_length=1) session_id: str = Field(min_length=1) actor_id: str = Field(min_length=1) retrieval_config: Optional[Dict[str, RetrievalConfig]] = None + batch_size: int = Field(default=1, ge=1, le=100) diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index c3912a7..ce8f020 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -5,6 +5,7 @@ import threading from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timedelta, timezone +from enum import Enum from typing import TYPE_CHECKING, Any, Optional import boto3 @@ -19,6 +20,7 @@ from typing_extensions import override from bedrock_agentcore.memory.client import MemoryClient +from bedrock_agentcore.memory.models.filters import EventMetadataFilter, LeftExpression, OperatorType, RightExpression from .bedrock_converter import AgentCoreMemoryConverter from .config import AgentCoreMemoryConfig, RetrievalConfig @@ -28,11 +30,23 @@ logger = logging.getLogger(__name__) -SESSION_PREFIX = "session_" -AGENT_PREFIX = "agent_" -MESSAGE_PREFIX = "message_" MAX_FETCH_ALL_RESULTS = 10000 +# Legacy prefixes for backwards compatibility with old events +LEGACY_SESSION_PREFIX = "session_" +LEGACY_AGENT_PREFIX = "agent_" + +# Metadata keys for event identification +STATE_TYPE_KEY = "stateType" +AGENT_ID_KEY = "agentId" + + +class StateType(Enum): + """State type for distinguishing session and agent metadata in events.""" + + SESSION = "SESSION" + AGENT = "AGENT" + class AgentCoreMemorySessionManager(RepositorySessionManager, SessionRepository): """AgentCore Memory-based session manager for Bedrock AgentCore Memory integration. @@ -104,7 +118,10 @@ def __init__( session = boto_session or boto3.Session(region_name=region_name) self.has_existing_agent = False - # Override the clients if custom boto session or config is provided + # Batching support - stores pre-processed messages: (session_id, messages, is_blob, timestamp) + self._message_buffer: list[tuple[str, list[tuple[str, str]], bool, datetime]] = [] + self._buffer_lock = threading.Lock() + # Add strands-agents to the request user agent if boto_client_config: existing_user_agent = getattr(boto_client_config, "user_agent_extra", None) @@ -125,38 +142,6 @@ def __init__( ) super().__init__(session_id=self.config.session_id, session_repository=self) - def _get_full_session_id(self, session_id: str) -> str: - """Get the full session ID with the configured prefix. - - Args: - session_id (str): The session ID. - - Returns: - str: The full session ID with the prefix. - """ - full_session_id = f"{SESSION_PREFIX}{session_id}" - if full_session_id == self.config.actor_id: - raise SessionException( - f"Cannot have session [ {full_session_id} ] with the same ID as the actor ID: {self.config.actor_id}" - ) - return full_session_id - - def _get_full_agent_id(self, agent_id: str) -> str: - """Get the full agent ID with the configured prefix. - - Args: - agent_id (str): The agent ID. - - Returns: - str: The full agent ID with the prefix. - """ - full_agent_id = f"{AGENT_PREFIX}{agent_id}" - if full_agent_id == self.config.actor_id: - raise SessionException( - f"Cannot create agent [ {full_agent_id} ] with the same ID as the actor ID: {self.config.actor_id}" - ) - return full_agent_id - # region SessionRepository interface implementation def create_session(self, session: Session, **kwargs: Any) -> Session: """Create a new session in AgentCore Memory. @@ -179,12 +164,13 @@ def create_session(self, session: Session, **kwargs: Any) -> Session: event = self.memory_client.gmdp_client.create_event( memoryId=self.config.memory_id, - actorId=self._get_full_session_id(session.session_id), + actorId=self.config.actor_id, sessionId=self.session_id, payload=[ {"blob": json.dumps(session.to_dict())}, ], eventTimestamp=self._get_monotonic_timestamp(), + metadata={STATE_TYPE_KEY: {"stringValue": StateType.SESSION.value}}, ) logger.info("Created session: %s with event: %s", session.session_id, event.get("event", {}).get("eventId")) return session @@ -206,17 +192,50 @@ def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: if session_id != self.config.session_id: return None + # 1. Try new approach (metadata filter) + event_metadata = [ + EventMetadataFilter.build_expression( + left_operand=LeftExpression.build(STATE_TYPE_KEY), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build(StateType.SESSION.value), + ) + ] + events = self.memory_client.list_events( memory_id=self.config.memory_id, - actor_id=self._get_full_session_id(session_id), + actor_id=self.config.actor_id, session_id=session_id, + event_metadata=event_metadata, max_results=1, ) - if not events: - return None + if events: + session_data = json.loads(events[0].get("payload", {})[0].get("blob")) + return Session.from_dict(session_data) + + # 2. Fallback: check for legacy event and migrate + legacy_actor_id = f"{LEGACY_SESSION_PREFIX}{session_id}" + events = self.memory_client.list_events( + memory_id=self.config.memory_id, + actor_id=legacy_actor_id, + session_id=session_id, + max_results=1, + ) + if events: + old_event = events[0] + session_data = json.loads(old_event.get("payload", {})[0].get("blob")) + session = Session.from_dict(session_data) + # Migrate: create new event with metadata, delete old + self.create_session(session) + self.memory_client.gmdp_client.delete_event( + memoryId=self.config.memory_id, + actorId=legacy_actor_id, + sessionId=session_id, + eventId=old_event.get("eventId"), + ) + logger.info("Migrated legacy session event for session: %s", session_id) + return session - session_data = json.loads(events[0].get("payload", {})[0].get("blob")) - return Session.from_dict(session_data) + return None def delete_session(self, session_id: str, **kwargs: Any) -> None: """Delete session and all associated data. @@ -250,12 +269,16 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A event = self.memory_client.gmdp_client.create_event( memoryId=self.config.memory_id, - actorId=self._get_full_agent_id(session_agent.agent_id), + actorId=self.config.actor_id, sessionId=self.session_id, payload=[ {"blob": json.dumps(session_agent.to_dict())}, ], eventTimestamp=self._get_monotonic_timestamp(), + metadata={ + STATE_TYPE_KEY: {"stringValue": StateType.AGENT.value}, + AGENT_ID_KEY: {"stringValue": session_agent.agent_id}, + }, ) logger.info( "Created agent: %s in session: %s with event %s", @@ -280,18 +303,56 @@ def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[ if session_id != self.config.session_id: return None try: + # 1. Try new approach (metadata filter) + event_metadata = [ + EventMetadataFilter.build_expression( + left_operand=LeftExpression.build(STATE_TYPE_KEY), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build(StateType.AGENT.value), + ), + EventMetadataFilter.build_expression( + left_operand=LeftExpression.build(AGENT_ID_KEY), + operator=OperatorType.EQUALS_TO, + right_operand=RightExpression.build(agent_id), + ), + ] + events = self.memory_client.list_events( memory_id=self.config.memory_id, - actor_id=self._get_full_agent_id(agent_id), + actor_id=self.config.actor_id, session_id=session_id, + event_metadata=event_metadata, max_results=1, ) - if not events: - return None + if events: + agent_data = json.loads(events[0].get("payload", {})[0].get("blob")) + return SessionAgent.from_dict(agent_data) + + # 2. Fallback: check for legacy event and migrate + legacy_actor_id = f"{LEGACY_AGENT_PREFIX}{agent_id}" + events = self.memory_client.list_events( + memory_id=self.config.memory_id, + actor_id=legacy_actor_id, + session_id=session_id, + max_results=1, + ) + if events: + old_event = events[0] + agent_data = json.loads(old_event.get("payload", {})[0].get("blob")) + agent = SessionAgent.from_dict(agent_data) + # Migrate: create new event with metadata, delete old + self.create_agent(session_id, agent) + self.memory_client.gmdp_client.delete_event( + memoryId=self.config.memory_id, + actorId=legacy_actor_id, + sessionId=session_id, + eventId=old_event.get("eventId"), + ) + logger.info("Migrated legacy agent event for agent: %s", agent_id) + return agent - agent_data = json.loads(events[0].get("payload", {})[0].get("blob")) - return SessionAgent.from_dict(agent_data) + return None except Exception as e: logger.error("Failed to read agent %s", e) return None @@ -311,8 +372,9 @@ def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A previous_agent = self.read_agent(session_id=session_id, agent_id=agent_id) if previous_agent is None: raise SessionException(f"Agent {agent_id} in session {session_id} does not exist") + else: + session_agent.created_at = previous_agent.created_at - session_agent.created_at = previous_agent.created_at # Create a new agent as AgentCore Memory is immutable. We always get the latest one in `read_agent` self.create_agent(session_id, session_agent) @@ -321,6 +383,9 @@ def create_message( ) -> Optional[dict[str, Any]]: """Create a new message in AgentCore Memory. + If batch_size > 1, the message is buffered and sent when the buffer reaches batch_size. + Use _flush_messages() or close() to send any remaining buffered messages. + Args: session_id (str): The session ID to create the message in. agent_id (str): The agent ID associated with the message (only here for the interface. @@ -330,6 +395,7 @@ def create_message( Returns: Optional[dict[str, Any]]: The created event data from AgentCore Memory. + Returns empty dict if message is buffered (batch_size > 1). Raises: SessionException: If session ID doesn't match configuration or message creation fails. @@ -350,16 +416,33 @@ def create_message( if session_id != self.config.session_id: raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session_id}") - try: - messages = AgentCoreMemoryConverter.message_to_payload(session_message) - if not messages: - return + # Convert and check size ONCE (not again at flush) + messages = AgentCoreMemoryConverter.message_to_payload(session_message) + if not messages: + return None - # Parse the original timestamp and use it as desired timestamp - original_timestamp = datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00")) - monotonic_timestamp = self._get_monotonic_timestamp(original_timestamp) + is_blob = AgentCoreMemoryConverter.exceeds_conversational_limit(messages[0]) - if not AgentCoreMemoryConverter.exceeds_conversational_limit(messages[0]): + # Parse the original timestamp and use it as desired timestamp + original_timestamp = datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00")) + monotonic_timestamp = self._get_monotonic_timestamp(original_timestamp) + + if self.config.batch_size > 1: + # Buffer the pre-processed message + should_flush = False + with self._buffer_lock: + self._message_buffer.append((session_id, messages, is_blob, monotonic_timestamp)) + should_flush = len(self._message_buffer) >= self.config.batch_size + + # Flush outside the lock to prevent deadlock + if should_flush: + self._flush_messages() + + return {} # No eventId yet + + # Immediate send (batch_size == 1) + try: + if not is_blob: event = self.memory_client.create_event( memory_id=self.config.memory_id, actor_id=self.config.actor_id, @@ -586,3 +669,131 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: RepositorySessionManager.initialize(self, agent, **kwargs) # endregion RepositorySessionManager overrides + + # region Batching support + + def _flush_messages(self) -> list[dict[str, Any]]: + """Flush all buffered messages to AgentCore Memory. + + Call this method to send any remaining buffered messages when batch_size > 1. + This is automatically called when the buffer reaches batch_size, but should + also be called explicitly when the session is complete (via close() or context manager). + + Messages are batched by session_id - all conversational messages for the same + session are combined into a single create_event() call to reduce API calls. + Blob messages (>9KB) are sent individually as they require a different API path. + + Returns: + list[dict[str, Any]]: List of created event responses from AgentCore Memory. + + Raises: + SessionException: If any message creation fails. On failure, all messages + remain in the buffer to prevent data loss. + """ + with self._buffer_lock: + messages_to_send = list(self._message_buffer) + + if not messages_to_send: + return [] + + # Group conversational messages by session_id, preserve order + # Structure: {session_id: {"messages": [...], "timestamp": latest_timestamp}} + session_groups: dict[str, dict[str, Any]] = {} + blob_messages: list[tuple[str, list[tuple[str, str]], datetime]] = [] + + for session_id, messages, is_blob, monotonic_timestamp in messages_to_send: + if is_blob: + # Blobs cannot be combined - collect them separately + blob_messages.append((session_id, messages, monotonic_timestamp)) + else: + # Group conversational messages by session_id + if session_id not in session_groups: + session_groups[session_id] = {"messages": [], "timestamp": monotonic_timestamp} + # Extend messages list to preserve order (earlier messages first) + session_groups[session_id]["messages"].extend(messages) + # Use the latest timestamp for the combined event + if monotonic_timestamp > session_groups[session_id]["timestamp"]: + session_groups[session_id]["timestamp"] = monotonic_timestamp + + results = [] + try: + # Send one create_event per session_id with combined messages + for session_id, group in session_groups.items(): + event = self.memory_client.create_event( + memory_id=self.config.memory_id, + actor_id=self.config.actor_id, + session_id=session_id, + messages=group["messages"], + event_timestamp=group["timestamp"], + ) + results.append(event) + logger.debug("Flushed batched event for session %s: %s", session_id, event.get("eventId")) + + # Send blob messages individually (they use a different API path) + for session_id, messages, monotonic_timestamp in blob_messages: + event = self.memory_client.gmdp_client.create_event( + memoryId=self.config.memory_id, + actorId=self.config.actor_id, + sessionId=session_id, + payload=[ + {"blob": json.dumps(messages[0])}, + ], + eventTimestamp=monotonic_timestamp, + ) + results.append(event) + logger.debug("Flushed blob event for session %s: %s", session_id, event.get("eventId")) + + # Clear buffer only after ALL messages succeed + with self._buffer_lock: + self._message_buffer.clear() + + except Exception as e: + logger.error("Failed to flush messages to AgentCore Memory for session: %s", e) + raise SessionException(f"Failed to flush messages: {e}") from e + + logger.info("Flushed %d events to AgentCore Memory", len(results)) + return results + + def pending_message_count(self) -> int: + """Return the number of messages pending in the buffer. + + Returns: + int: Number of buffered messages waiting to be sent. + """ + with self._buffer_lock: + return len(self._message_buffer) + + def close(self) -> None: + """Explicitly flush pending messages and close the session manager. + + Call this method when the session is complete to ensure all buffered + messages are sent to AgentCore Memory. Alternatively, use the context + manager protocol (with statement) for automatic cleanup. + """ + self._flush_messages() + + def __enter__(self) -> "AgentCoreMemorySessionManager": + """Enter the context manager. + + Returns: + AgentCoreMemorySessionManager: This session manager instance. + """ + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Exit the context manager and flush any pending messages. + + Args: + exc_type: Exception type if an exception occurred. + exc_val: Exception value if an exception occurred. + exc_tb: Exception traceback if an exception occurred. + """ + try: + self._flush_messages() + except Exception as e: + if exc_type is not None: + logger.error("Failed to flush messages during exception handling: %s", e) + else: + raise + + # endregion Batching support diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py index a01973c..a75ed3c 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py @@ -1,5 +1,6 @@ """Tests for AgentCoreMemorySessionManager.""" +import logging from unittest.mock import Mock, patch import pytest @@ -48,25 +49,48 @@ def mock_memory_client(): return client +def _create_session_manager(config, mock_memory_client): + """Helper to create a session manager with mocked dependencies.""" + with ( + patch( + "bedrock_agentcore.memory.integrations.strands.session_manager.MemoryClient", + return_value=mock_memory_client, + ), + patch("boto3.Session") as mock_boto_session, + patch("strands.session.repository_session_manager.RepositorySessionManager.__init__", return_value=None), + ): + mock_session = Mock() + mock_session.region_name = "us-west-2" + mock_session.client.return_value = Mock() + mock_boto_session.return_value = mock_session + + manager = AgentCoreMemorySessionManager(config) + manager.session_id = config.session_id + manager.session = Session(session_id=config.session_id, session_type=SessionType.AGENT) + return manager + + @pytest.fixture def session_manager(agentcore_config, mock_memory_client): """Create an AgentCoreMemorySessionManager with mocked dependencies.""" - with patch( - "bedrock_agentcore.memory.integrations.strands.session_manager.MemoryClient", return_value=mock_memory_client - ): - with patch("boto3.Session") as mock_boto_session: - mock_session = Mock() - mock_session.region_name = "us-west-2" - mock_session.client.return_value = Mock() - mock_boto_session.return_value = mock_session + return _create_session_manager(agentcore_config, mock_memory_client) + + +@pytest.fixture +def batching_config(): + """Create a config with batch_size > 1.""" + return AgentCoreMemoryConfig( + memory_id="test-memory-123", + session_id="test-session-456", + actor_id="test-actor-789", + batch_size=10, + ) - with patch( - "strands.session.repository_session_manager.RepositorySessionManager.__init__", return_value=None - ): - manager = AgentCoreMemorySessionManager(agentcore_config) - manager.session_id = agentcore_config.session_id - manager.session = Session(session_id=agentcore_config.session_id, session_type=SessionType.AGENT) - return manager + +@pytest.fixture +def batching_session_manager(batching_config, mock_memory_client): + """Create a session manager with batching enabled.""" + return _create_session_manager(batching_config, mock_memory_client) @pytest.fixture @@ -160,6 +184,37 @@ def test_read_session_invalid(self, session_manager): assert result is None + def test_read_session_legacy_migration(self, session_manager, mock_memory_client): + """Test reading a legacy session event triggers migration.""" + legacy_session_data = '{"session_id": "test-session-456", "session_type": "AGENT"}' + + # First call (new approach with metadata) returns empty + # Second call (legacy actor_id) returns the legacy event + mock_memory_client.list_events.side_effect = [ + [], # New approach returns nothing + [{"eventId": "legacy-event-1", "payload": [{"blob": legacy_session_data}]}], # Legacy approach + ] + mock_memory_client.gmdp_client.create_event.return_value = {"event": {"eventId": "new-event-1"}} + + result = session_manager.read_session("test-session-456") + + # Verify session was returned + assert result is not None + assert result.session_id == "test-session-456" + assert result.session_type == SessionType.AGENT + + # Verify migration: new event created with metadata + mock_memory_client.gmdp_client.create_event.assert_called_once() + create_call_kwargs = mock_memory_client.gmdp_client.create_event.call_args.kwargs + assert "metadata" in create_call_kwargs + assert create_call_kwargs["metadata"]["stateType"]["stringValue"] == "SESSION" + + # Verify migration: old event deleted + mock_memory_client.gmdp_client.delete_event.assert_called_once() + delete_call_kwargs = mock_memory_client.gmdp_client.delete_event.call_args.kwargs + assert delete_call_kwargs["actorId"] == "session_test-session-456" + assert delete_call_kwargs["eventId"] == "legacy-event-1" + def test_create_agent(self, session_manager): """Test creating an agent.""" session_agent = SessionAgent(agent_id="test-agent-123", state={}, conversation_manager_state={}) @@ -198,6 +253,36 @@ def test_read_agent_no_events(self, session_manager, mock_memory_client): assert result is None + def test_read_agent_legacy_migration(self, session_manager, mock_memory_client): + """Test reading a legacy agent event triggers migration.""" + legacy_agent_data = '{"agent_id": "test-agent-123", "state": {}, "conversation_manager_state": {}}' + + # New approach with metadata returns empty, then legacy approach returns the event + mock_memory_client.list_events.side_effect = [ + [], # New approach with metadata - returns empty + [{"eventId": "legacy-agent-event-1", "payload": [{"blob": legacy_agent_data}]}], # Legacy approach + ] + mock_memory_client.gmdp_client.create_event.return_value = {"event": {"eventId": "new-agent-event-1"}} + + result = session_manager.read_agent("test-session-456", "test-agent-123") + + # Verify agent was returned + assert result is not None + assert result.agent_id == "test-agent-123" + + # Verify migration: new event created with metadata + mock_memory_client.gmdp_client.create_event.assert_called_once() + create_call_kwargs = mock_memory_client.gmdp_client.create_event.call_args.kwargs + assert "metadata" in create_call_kwargs + assert create_call_kwargs["metadata"]["stateType"]["stringValue"] == "AGENT" + assert create_call_kwargs["metadata"]["agentId"]["stringValue"] == "test-agent-123" + + # Verify migration: old event deleted + mock_memory_client.gmdp_client.delete_event.assert_called_once() + delete_call_kwargs = mock_memory_client.gmdp_client.delete_event.call_args.kwargs + assert delete_call_kwargs["actorId"] == "agent_test-agent-123" + assert delete_call_kwargs["eventId"] == "legacy-agent-event-1" + def test_create_message(self, session_manager, mock_memory_client): """Test creating a message.""" mock_memory_client.create_event.return_value = {"eventId": "event-123"} @@ -515,8 +600,8 @@ def test_load_long_term_memories_with_validation_failure(self, mock_memory_clien # Should not call retrieve_memories due to validation failure assert mock_memory_client.retrieve_memories.call_count == 0 - # No memories should be stored - assert "ltm_memories" not in test_agent.state._state + # No memories should be stored (agent.state is unmodified since we mocked the method) + assert test_agent.state.get("ltm_memories") is None def test_retry_with_backoff_success(self, session_manager): """Test retry mechanism with eventual success.""" @@ -925,22 +1010,6 @@ def test_init_with_boto_config(self, agentcore_config, mock_memory_client): manager = AgentCoreMemorySessionManager(agentcore_config, boto_client_config=boto_config) assert manager.memory_client is not None - def test_get_full_session_id_conflict(self, session_manager): - """Test session ID conflict with actor ID.""" - # Set up a scenario where session ID would conflict with actor ID - session_manager.config.actor_id = "session_test-session" - - with pytest.raises(SessionException, match="Cannot have session"): - session_manager._get_full_session_id("test-session") - - def test_get_full_agent_id_conflict(self, session_manager): - """Test agent ID conflict with actor ID.""" - # Set up a scenario where agent ID would conflict with actor ID - session_manager.config.actor_id = "agent_test-agent" - - with pytest.raises(SessionException, match="Cannot create agent"): - session_manager._get_full_agent_id("test-agent") - def test_retrieve_customer_context_no_messages(self, agentcore_config_with_retrieval, mock_memory_client): """Test retrieve_customer_context with no messages.""" with patch( @@ -1116,3 +1185,705 @@ def test_list_messages_with_limit_calculates_max_results(self, session_manager, mock_memory_client.list_events.assert_called_once() call_kwargs = mock_memory_client.list_events.call_args[1] assert call_kwargs["max_results"] == 550 # limit + offset + + +class TestBatchingConfig: + """Test batch_size configuration validation.""" + + def test_batch_size_default_value(self): + """Test batch_size defaults to 1 (immediate send).""" + config = AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + ) + assert config.batch_size == 1 + + def test_batch_size_custom_value(self): + """Test batch_size can be set to a custom value.""" + config = AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + batch_size=10, + ) + assert config.batch_size == 10 + + def test_batch_size_maximum_value(self): + """Test batch_size accepts maximum value of 100.""" + config = AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + batch_size=100, + ) + assert config.batch_size == 100 + + def test_batch_size_exceeds_maximum_raises_error(self): + """Test batch_size above 100 raises validation error.""" + with pytest.raises(ValueError): + AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + batch_size=101, + ) + + def test_batch_size_zero_raises_error(self): + """Test batch_size of 0 raises validation error.""" + with pytest.raises(ValueError): + AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + batch_size=0, + ) + + def test_batch_size_negative_raises_error(self): + """Test negative batch_size raises validation error.""" + with pytest.raises(ValueError): + AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + batch_size=-1, + ) + + +class TestBatchingBufferManagement: + """Test batching buffer management and pending_message_count.""" + + @pytest.fixture + def batching_config(self): + """Override with batch_size=5 for buffer management tests.""" + return AgentCoreMemoryConfig( + memory_id="test-memory-123", + session_id="test-session-456", + actor_id="test-actor-789", + batch_size=5, + ) + + @pytest.fixture + def batching_session_manager(self, batching_config, mock_memory_client): + """Create a session manager with batch_size=5.""" + return _create_session_manager(batching_config, mock_memory_client) + + def test_pending_message_count_empty_buffer(self, batching_session_manager): + """Test pending_message_count returns 0 for empty buffer.""" + assert batching_session_manager.pending_message_count() == 0 + + def test_pending_message_count_with_buffered_messages(self, batching_session_manager, mock_memory_client): + """Test pending_message_count returns correct count.""" + # Add messages to buffer (batch_size=5, so won't auto-flush) + for i in range(3): + message = SessionMessage( + message={"role": "user", "content": [{"text": f"Message {i}"}]}, + message_id=i, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + assert batching_session_manager.pending_message_count() == 3 + # Verify no events were sent (still buffered) + mock_memory_client.create_event.assert_not_called() + + def test_buffer_auto_flushes_at_batch_size(self, batching_session_manager, mock_memory_client): + """Test buffer automatically flushes when reaching batch_size.""" + mock_memory_client.create_event.return_value = {"eventId": "event_123"} + + # Add exactly batch_size messages (5) + for i in range(5): + message = SessionMessage( + message={"role": "user", "content": [{"text": f"Message {i}"}]}, + message_id=i, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + # Buffer should have been flushed + assert batching_session_manager.pending_message_count() == 0 + # One batched API call for all messages in the same session + assert mock_memory_client.create_event.call_count == 1 + + def test_create_message_returns_empty_dict_when_buffered(self, batching_session_manager): + """Test create_message returns empty dict when message is buffered.""" + message = SessionMessage( + message={"role": "user", "content": [{"text": "Hello"}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + + result = batching_session_manager.create_message("test-session-456", "test-agent", message) + + assert result == {} + + +class TestBatchingFlush: + """Test _flush_messages behavior.""" + + def test__flush_messages_empty_buffer(self, batching_session_manager): + """Test _flush_messages with empty buffer returns empty list.""" + results = batching_session_manager._flush_messages() + assert results == [] + + def test__flush_messages_sends_all_buffered(self, batching_session_manager, mock_memory_client): + """Test _flush_messages sends all buffered messages in a single batched call.""" + mock_memory_client.create_event.return_value = {"eventId": "event_123"} + + # Add 3 messages (below batch_size of 10) + for i in range(3): + message = SessionMessage( + message={"role": "user", "content": [{"text": f"Message {i}"}]}, + message_id=i, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + assert batching_session_manager.pending_message_count() == 3 + + # Flush manually + results = batching_session_manager._flush_messages() + + # One batched API call for all messages in the same session + assert len(results) == 1 + assert batching_session_manager.pending_message_count() == 0 + assert mock_memory_client.create_event.call_count == 1 + + def test__flush_messages_maintains_order(self, batching_session_manager, mock_memory_client): + """Test _flush_messages maintains message order within batched payload.""" + sent_payloads = [] + + def track_create_event(**kwargs): + sent_payloads.append(kwargs.get("messages")) + return {"eventId": f"event_{len(sent_payloads)}"} + + mock_memory_client.create_event.side_effect = track_create_event + + # Add messages with distinct content + for i in range(3): + message = SessionMessage( + message={"role": "user", "content": [{"text": f"Message_{i}"}]}, + message_id=i, + created_at=f"2024-01-01T12:0{i}:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + batching_session_manager._flush_messages() + + # Should be one batched call with messages in order + assert len(sent_payloads) == 1 + combined_messages = sent_payloads[0] + assert len(combined_messages) == 3 + for i, msg in enumerate(combined_messages): + assert f"Message_{i}" in msg[0] + + def test__flush_messages_clears_buffer(self, batching_session_manager, mock_memory_client): + """Test _flush_messages clears the buffer after sending.""" + mock_memory_client.create_event.return_value = {"eventId": "event_123"} + + message = SessionMessage( + message={"role": "user", "content": [{"text": "Hello"}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + # First flush + batching_session_manager._flush_messages() + assert batching_session_manager.pending_message_count() == 0 + + # Second flush should be no-op + results = batching_session_manager._flush_messages() + assert results == [] + + def test__flush_messages_exception_handling(self, batching_session_manager, mock_memory_client): + """Test _flush_messages raises SessionException on failure.""" + mock_memory_client.create_event.side_effect = Exception("API Error") + + message = SessionMessage( + message={"role": "user", "content": [{"text": "Hello"}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + with pytest.raises(SessionException, match="Failed to flush messages"): + batching_session_manager._flush_messages() + + def test_partial_flush_failure_preserves_all_messages(self, batching_session_manager, mock_memory_client): + """Test that on flush failure, all messages remain in buffer to prevent data loss.""" + mock_memory_client.create_event.side_effect = Exception("API Error") + + # Add multiple messages + for i in range(3): + message = SessionMessage( + message={"role": "user", "content": [{"text": f"Message {i}"}]}, + message_id=i, + created_at=f"2024-01-01T12:0{i}:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + assert batching_session_manager.pending_message_count() == 3 + + # Flush should fail + with pytest.raises(SessionException): + batching_session_manager._flush_messages() + + # All messages should still be in buffer (not cleared on failure) + assert batching_session_manager.pending_message_count() == 3 + + # Fix the mock and retry - should succeed now + mock_memory_client.create_event.side_effect = None + mock_memory_client.create_event.return_value = {"eventId": "event_123"} + + results = batching_session_manager._flush_messages() + assert len(results) == 1 # One batched call for all messages + assert batching_session_manager.pending_message_count() == 0 + + def test_batching_combines_messages_for_same_session(self, batching_session_manager, mock_memory_client): + """Test that multiple messages for the same session are combined into one API call.""" + sent_payloads = [] + + def track_create_event(**kwargs): + sent_payloads.append(kwargs.get("messages")) + return {"eventId": f"event_{len(sent_payloads)}"} + + mock_memory_client.create_event.side_effect = track_create_event + + # Add 5 messages to the same session + for i in range(5): + message = SessionMessage( + message={"role": "user", "content": [{"text": f"Message_{i}"}]}, + message_id=i, + created_at=f"2024-01-01T12:0{i}:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + batching_session_manager._flush_messages() + + # Should be ONE API call with all 5 messages combined + assert mock_memory_client.create_event.call_count == 1 + assert len(sent_payloads) == 1 + # The combined payload should have all 5 messages + assert len(sent_payloads[0]) == 5 + # Messages should be in order + for i in range(5): + assert f"Message_{i}" in sent_payloads[0][i][0] + + def test_multiple_sessions_grouped_into_separate_api_calls(self, batching_session_manager, mock_memory_client): + """Test that messages to different sessions are grouped into separate API calls. + + Note: In normal usage, create_message enforces session_id == config.session_id, + so all messages go to one session. This test verifies the internal grouping logic + by directly manipulating the buffer. + """ + from datetime import datetime, timezone + + calls_by_session = {} + + def track_create_event(**kwargs): + session_id = kwargs.get("session_id") + messages = kwargs.get("messages") + calls_by_session[session_id] = messages + return {"eventId": f"event_{session_id}"} + + mock_memory_client.create_event.side_effect = track_create_event + + # Directly populate buffer with messages for multiple sessions + # Buffer format: (session_id, messages, is_blob, monotonic_timestamp) + base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + batching_session_manager._message_buffer = [ + ("session-A", [("SessionA_Message_0", "user")], False, base_time), + ("session-A", [("SessionA_Message_1", "user")], False, base_time), + ("session-B", [("SessionB_Message_0", "user")], False, base_time), + ("session-B", [("SessionB_Message_1", "user")], False, base_time), + ("session-B", [("SessionB_Message_2", "user")], False, base_time), + ("session-A", [("SessionA_Message_2", "user")], False, base_time), # Non-consecutive + ] + + batching_session_manager._flush_messages() + + # Should be TWO API calls - one per session + assert mock_memory_client.create_event.call_count == 2 + assert len(calls_by_session) == 2 + + # Session A should have 3 messages combined + assert "session-A" in calls_by_session + assert len(calls_by_session["session-A"]) == 3 + assert calls_by_session["session-A"][0] == ("SessionA_Message_0", "user") + assert calls_by_session["session-A"][1] == ("SessionA_Message_1", "user") + assert calls_by_session["session-A"][2] == ("SessionA_Message_2", "user") + + # Session B should have 3 messages combined + assert "session-B" in calls_by_session + assert len(calls_by_session["session-B"]) == 3 + for i in range(3): + assert calls_by_session["session-B"][i] == (f"SessionB_Message_{i}", "user") + + def test_latest_timestamp_used_for_combined_events(self, batching_session_manager, mock_memory_client): + """Test that the latest timestamp from grouped messages is used for the combined event.""" + captured_timestamps = [] + + def track_create_event(**kwargs): + captured_timestamps.append(kwargs.get("event_timestamp")) + return {"eventId": "event_123"} + + mock_memory_client.create_event.side_effect = track_create_event + + # Add messages with different timestamps (out of order) + timestamps = ["2024-01-01T12:05:00Z", "2024-01-01T12:01:00Z", "2024-01-01T12:10:00Z"] + for i, ts in enumerate(timestamps): + message = SessionMessage( + message={"role": "user", "content": [{"text": f"Message_{i}"}]}, + message_id=i, + created_at=ts, + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + batching_session_manager._flush_messages() + + # The combined event should use the latest timestamp (12:10:00) + assert len(captured_timestamps) == 1 + # The timestamp should be the latest one (12:10:00) + from datetime import datetime, timezone + + expected_latest = datetime(2024, 1, 1, 12, 10, 0, tzinfo=timezone.utc) + # Account for monotonic timestamp adjustment (may add microseconds) + assert captured_timestamps[0] >= expected_latest + + def test_partial_failure_multiple_sessions_preserves_buffer(self, batching_session_manager, mock_memory_client): + """Test that when one session fails, ALL messages remain in buffer. + + Note: Tests internal grouping logic by directly manipulating buffer. + """ + from datetime import datetime, timezone + + def fail_on_second_session(**kwargs): + session_id = kwargs.get("session_id") + if session_id == "session-B": + raise Exception("API Error for session B") + return {"eventId": f"event_{session_id}"} + + mock_memory_client.create_event.side_effect = fail_on_second_session + + # Directly populate buffer with messages for multiple sessions + base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + batching_session_manager._message_buffer = [ + ("session-A", [("SessionA_Message_0", "user")], False, base_time), + ("session-A", [("SessionA_Message_1", "user")], False, base_time), + ("session-B", [("SessionB_Message_0", "user")], False, base_time), + ("session-B", [("SessionB_Message_1", "user")], False, base_time), + ] + + assert batching_session_manager.pending_message_count() == 4 + + # Flush should fail + with pytest.raises(SessionException, match="Failed to flush messages"): + batching_session_manager._flush_messages() + + # ALL messages should still be in buffer (even session A's which "succeeded") + # This is because buffer is only cleared after ALL succeed + assert batching_session_manager.pending_message_count() == 4 + + def test_blob_messages_sent_individually_not_batched(self, batching_session_manager, mock_memory_client): + """Test that multiple blob messages are sent as individual API calls, not batched.""" + blob_calls = [] + + def track_blob_event(**kwargs): + blob_calls.append(kwargs) + return {"event": {"eventId": f"blob_event_{len(blob_calls)}"}} + + mock_memory_client.gmdp_client.create_event.side_effect = track_blob_event + mock_memory_client.create_event.return_value = {"eventId": "conv_event"} + + # Add multiple blob messages (>9KB each) + for i in range(3): + large_text = f"blob_{i}_" + "x" * 10000 + message = SessionMessage( + message={"role": "user", "content": [{"text": large_text}]}, + message_id=i, + created_at=f"2024-01-01T12:0{i}:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + batching_session_manager._flush_messages() + + # Each blob should be sent individually (3 separate API calls) + assert mock_memory_client.gmdp_client.create_event.call_count == 3 + assert len(blob_calls) == 3 + + # Verify each blob was sent separately with correct content + for i, call in enumerate(blob_calls): + assert "payload" in call + assert "blob" in call["payload"][0] + assert f"blob_{i}_" in call["payload"][0]["blob"] + + def test_mixed_sessions_with_blobs_and_conversational(self, batching_session_manager, mock_memory_client): + """Test complex scenario: multiple sessions with both blob and conversational messages. + + Note: Tests internal grouping logic by directly manipulating buffer. + """ + from datetime import datetime, timezone + + conv_calls = {} + blob_calls = [] + + def track_conv_event(**kwargs): + session_id = kwargs.get("session_id") + conv_calls[session_id] = kwargs.get("messages") + return {"eventId": f"conv_event_{session_id}"} + + def track_blob_event(**kwargs): + blob_calls.append(kwargs) + return {"event": {"eventId": f"blob_event_{len(blob_calls)}"}} + + mock_memory_client.create_event.side_effect = track_conv_event + mock_memory_client.gmdp_client.create_event.side_effect = track_blob_event + + # Directly populate buffer with mixed messages + base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + blob_content = {"role": "user", "content": [{"text": "blob_A_" + "x" * 10000}]} + batching_session_manager._message_buffer = [ + # Session A: 2 conversational messages + ("session-A", [("SessionA_conv_0", "user")], False, base_time), + ("session-A", [("SessionA_conv_1", "user")], False, base_time), + # Session A: 1 blob message + ("session-A", [blob_content], True, base_time), + # Session B: 1 conversational message + ("session-B", [("SessionB_conv_0", "user")], False, base_time), + ] + + batching_session_manager._flush_messages() + + # Should have: + # - 2 conversational API calls (one per session) + # - 1 blob API call + assert mock_memory_client.create_event.call_count == 2 + assert mock_memory_client.gmdp_client.create_event.call_count == 1 + + # Session A conversational messages should be batched together + assert "session-A" in conv_calls + assert len(conv_calls["session-A"]) == 2 + + # Session B conversational message + assert "session-B" in conv_calls + assert len(conv_calls["session-B"]) == 1 + + # Blob sent separately + assert len(blob_calls) == 1 + assert "blob_A_" in blob_calls[0]["payload"][0]["blob"] + + +class TestBatchingBackwardsCompatibility: + """Test batch_size=1 behaves identically to previous implementation.""" + + def test_batch_size_one_sends_immediately(self, session_manager, mock_memory_client): + """Test batch_size=1 (default) sends message immediately.""" + mock_memory_client.create_event.return_value = {"eventId": "event_123"} + + message = SessionMessage( + message={"role": "user", "content": [{"text": "Hello"}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + + result = session_manager.create_message("test-session-456", "test-agent-123", message) + + # Should return event immediately + assert result.get("eventId") == "event_123" + # Should have sent immediately + mock_memory_client.create_event.assert_called_once() + # Buffer should be empty + assert session_manager.pending_message_count() == 0 + + def test_batch_size_one_returns_event_id(self, session_manager, mock_memory_client): + """Test batch_size=1 returns the event with eventId.""" + mock_memory_client.create_event.return_value = {"eventId": "unique_event_id"} + + message = SessionMessage( + message={"role": "user", "content": [{"text": "Hello"}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + + result = session_manager.create_message("test-session-456", "test-agent-123", message) + + assert "eventId" in result + assert result["eventId"] == "unique_event_id" + + +class TestBatchingContextManager: + """Test context manager (__enter__/__exit__) functionality.""" + + def test_context_manager_returns_self(self, batching_session_manager): + """Test __enter__ returns the session manager instance.""" + with batching_session_manager as ctx: + assert ctx is batching_session_manager + + def test_context_manager_flushes_on_exit(self, batching_session_manager, mock_memory_client): + """Test __exit__ flushes pending messages.""" + mock_memory_client.create_event.return_value = {"eventId": "event_123"} + + with batching_session_manager: + message = SessionMessage( + message={"role": "user", "content": [{"text": "Hello"}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + # Should still be buffered + assert batching_session_manager.pending_message_count() == 1 + + # After exiting context, should have flushed + assert batching_session_manager.pending_message_count() == 0 + mock_memory_client.create_event.assert_called_once() + + def test_context_manager_flushes_on_exception(self, batching_session_manager, mock_memory_client): + """Test __exit__ flushes even when exception occurs.""" + mock_memory_client.create_event.return_value = {"eventId": "event_123"} + + try: + with batching_session_manager: + message = SessionMessage( + message={"role": "user", "content": [{"text": "Hello"}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + raise ValueError("Test exception") + except ValueError: + pass + + # Should have flushed despite exception + assert batching_session_manager.pending_message_count() == 0 + mock_memory_client.create_event.assert_called_once() + + def test_exit_preserves_original_exception_when_flush_fails( + self, batching_session_manager, mock_memory_client, caplog + ): + """Test __exit__ logs flush failure and preserves the original exception.""" + mock_memory_client.create_event.side_effect = RuntimeError("flush failed") + + with caplog.at_level(logging.ERROR): + with pytest.raises(ValueError, match="original error"): + with batching_session_manager: + message = SessionMessage( + message={"role": "user", "content": [{"text": "Hello"}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + raise ValueError("original error") + + assert any( + "Failed to flush messages during exception handling" in record.message and record.levelno == logging.ERROR + for record in caplog.records + ) + + def test_exit_raises_flush_exception_when_no_original_exception( + self, batching_session_manager, mock_memory_client, caplog + ): + """Test __exit__ still raises flush exceptions when no original exception.""" + mock_memory_client.create_event.side_effect = RuntimeError("flush failed") + + with caplog.at_level(logging.ERROR): + with pytest.raises(SessionException, match="flush failed"): + with batching_session_manager: + message = SessionMessage( + message={"role": "user", "content": [{"text": "Hello"}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + assert not any( + "Failed to flush messages during exception handling" in record.message for record in caplog.records + ) + + +class TestBatchingClose: + """Test close() method functionality.""" + + def test_close_flushes_pending_messages(self, batching_session_manager, mock_memory_client): + """Test close() flushes all pending messages in a batched call.""" + mock_memory_client.create_event.return_value = {"eventId": "event_123"} + + # Add messages + for i in range(3): + message = SessionMessage( + message={"role": "user", "content": [{"text": f"Message {i}"}]}, + message_id=i, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + assert batching_session_manager.pending_message_count() == 3 + + # Close should flush + batching_session_manager.close() + + assert batching_session_manager.pending_message_count() == 0 + # One batched API call for all messages in the same session + assert mock_memory_client.create_event.call_count == 1 + + def test_close_with_empty_buffer(self, batching_session_manager, mock_memory_client): + """Test close() with empty buffer is a no-op.""" + batching_session_manager.close() + + mock_memory_client.create_event.assert_not_called() + assert batching_session_manager.pending_message_count() == 0 + + +class TestBatchingBlobMessages: + """Test batching handles blob messages (exceeding conversational limit) correctly.""" + + def test_blob_message_sent_via_gmdp_client(self, batching_session_manager, mock_memory_client): + """Test large messages (blobs) are sent via gmdp_client.""" + mock_memory_client.gmdp_client.create_event.return_value = {"event": {"eventId": "blob_event_123"}} + + # Create a message that exceeds CONVERSATIONAL_MAX_SIZE (9000) + large_text = "x" * 10000 + message = SessionMessage( + message={"role": "user", "content": [{"text": large_text}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + # Flush and verify blob path was used + batching_session_manager._flush_messages() + + mock_memory_client.gmdp_client.create_event.assert_called_once() + call_kwargs = mock_memory_client.gmdp_client.create_event.call_args.kwargs + assert "payload" in call_kwargs + assert "blob" in call_kwargs["payload"][0] + + def test_mixed_conversational_and_blob_messages(self, batching_session_manager, mock_memory_client): + """Test batching correctly handles mix of conversational and blob messages.""" + mock_memory_client.create_event.return_value = {"eventId": "conv_event"} + mock_memory_client.gmdp_client.create_event.return_value = {"event": {"eventId": "blob_event"}} + + # Add small (conversational) message + small_message = SessionMessage( + message={"role": "user", "content": [{"text": "Small message"}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", small_message) + + # Add large (blob) message + large_text = "x" * 10000 + large_message = SessionMessage( + message={"role": "user", "content": [{"text": large_text}]}, + message_id=2, + created_at="2024-01-01T12:01:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", large_message) + + # Flush + batching_session_manager._flush_messages() + + # Verify both paths were used + assert mock_memory_client.create_event.call_count == 1 # Conversational + assert mock_memory_client.gmdp_client.create_event.call_count == 1 # Blob diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py b/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py index e15a457..c44175f 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py @@ -8,6 +8,21 @@ from bedrock_agentcore.memory.integrations.strands.bedrock_converter import AgentCoreMemoryConverter +def _make_conversational_event(session_messages): + """Build one event with multiple conversational payloads.""" + payloads = [] + for sm in session_messages: + payloads.append( + { + "conversational": { + "content": {"text": json.dumps(sm.to_dict())}, + "role": sm.message["role"].upper(), + } + } + ) + return {"payload": payloads} + + class TestAgentCoreMemoryConverter: """Test cases for AgentCoreMemoryConverter.""" @@ -221,3 +236,142 @@ def test_message_to_payload_with_bytes_encodes_before_filtering(self): assert isinstance(encoded_bytes, dict) assert encoded_bytes.get("__bytes_encoded__") is True assert "data" in encoded_bytes + + # --- Ordering tests for events_to_messages --- + + def test_events_to_messages_empty_events(self): + """Test that empty input returns empty output.""" + result = AgentCoreMemoryConverter.events_to_messages([]) + assert result == [] + + def test_events_to_messages_multiple_events_chronological_order(self): + """Test two single-payload events in reverse chronological order produce chronological result.""" + msg_first = SessionMessage( + message_id=1, message={"role": "user", "content": [{"text": "First"}]}, created_at="2023-01-01T00:00:00Z" + ) + msg_second = SessionMessage( + message_id=2, + message={"role": "assistant", "content": [{"text": "Second"}]}, + created_at="2023-01-01T00:00:01Z", + ) + + # API returns newest first + event_newer = _make_conversational_event([msg_second]) + event_older = _make_conversational_event([msg_first]) + events = [event_newer, event_older] + + result = AgentCoreMemoryConverter.events_to_messages(events) + + assert len(result) == 2 + assert result[0].message["content"][0]["text"] == "First" + assert result[1].message["content"][0]["text"] == "Second" + + def test_events_to_messages_single_event_multiple_payloads_preserves_order(self): + """Test one event with 3 conversational payloads preserves payload order.""" + msgs = [ + SessionMessage( + message_id=i, + message={"role": "user", "content": [{"text": f"msg{i}"}]}, + created_at="2023-01-01T00:00:00Z", + ) + for i in range(1, 4) + ] + + event = _make_conversational_event(msgs) + result = AgentCoreMemoryConverter.events_to_messages([event]) + + assert len(result) == 3 + assert result[0].message["content"][0]["text"] == "msg1" + assert result[1].message["content"][0]["text"] == "msg2" + assert result[2].message["content"][0]["text"] == "msg3" + + def test_events_to_messages_multiple_batched_events_ordering(self): + """Test two multi-payload events: event order reversed, intra-event payload order preserved. + + This is the exact scenario that the original reverse-after-flatten bug broke. + """ + msg1 = SessionMessage( + message_id=1, message={"role": "user", "content": [{"text": "msg1"}]}, created_at="2023-01-01T00:00:00Z" + ) + msg2 = SessionMessage( + message_id=2, + message={"role": "assistant", "content": [{"text": "msg2"}]}, + created_at="2023-01-01T00:00:01Z", + ) + msg3 = SessionMessage( + message_id=3, message={"role": "user", "content": [{"text": "msg3"}]}, created_at="2023-01-01T00:00:02Z" + ) + msg4 = SessionMessage( + message_id=4, + message={"role": "assistant", "content": [{"text": "msg4"}]}, + created_at="2023-01-01T00:00:03Z", + ) + + # API returns newest event first + event_newer = _make_conversational_event([msg3, msg4]) + event_older = _make_conversational_event([msg1, msg2]) + events = [event_newer, event_older] + + result = AgentCoreMemoryConverter.events_to_messages(events) + + assert len(result) == 4 + assert result[0].message["content"][0]["text"] == "msg1" + assert result[1].message["content"][0]["text"] == "msg2" + assert result[2].message["content"][0]["text"] == "msg3" + assert result[3].message["content"][0]["text"] == "msg4" + + def test_events_to_messages_mixed_blob_and_conversational_ordering(self): + """Test blob and conversational events in reverse chronological order produce chronological result.""" + msg_first = SessionMessage( + message_id=1, message={"role": "user", "content": [{"text": "First"}]}, created_at="2023-01-01T00:00:00Z" + ) + msg_second = SessionMessage( + message_id=2, + message={"role": "assistant", "content": [{"text": "Second"}]}, + created_at="2023-01-01T00:00:01Z", + ) + + # Newer event uses blob format, older event uses conversational format + blob_data = [json.dumps(msg_second.to_dict()), "assistant"] + event_newer = {"payload": [{"blob": json.dumps(blob_data)}]} + event_older = _make_conversational_event([msg_first]) + events = [event_newer, event_older] + + result = AgentCoreMemoryConverter.events_to_messages(events) + + assert len(result) == 2 + assert result[0].message["content"][0]["text"] == "First" + assert result[1].message["content"][0]["text"] == "Second" + + @patch("bedrock_agentcore.memory.integrations.strands.bedrock_converter.logger") + def test_events_to_messages_malformed_payload_does_not_break_batch(self, mock_logger): + """Test a malformed blob payload between two valid conversational payloads in a single event.""" + msg1 = SessionMessage( + message_id=1, message={"role": "user", "content": [{"text": "msg1"}]}, created_at="2023-01-01T00:00:00Z" + ) + msg3 = SessionMessage( + message_id=3, message={"role": "user", "content": [{"text": "msg3"}]}, created_at="2023-01-01T00:00:02Z" + ) + + conv1 = { + "conversational": { + "content": {"text": json.dumps(msg1.to_dict())}, + "role": "USER", + } + } + bad_blob = {"blob": "invalid json"} + conv3 = { + "conversational": { + "content": {"text": json.dumps(msg3.to_dict())}, + "role": "USER", + } + } + + events = [{"payload": [conv1, bad_blob, conv3]}] + + result = AgentCoreMemoryConverter.events_to_messages(events) + + assert len(result) == 2 + assert result[0].message["content"][0]["text"] == "msg1" + assert result[1].message["content"][0]["text"] == "msg3" + mock_logger.error.assert_called() diff --git a/tests_integ/memory/integrations/test_session_manager.py b/tests_integ/memory/integrations/test_session_manager.py index 218fbe4..f6709d0 100644 --- a/tests_integ/memory/integrations/test_session_manager.py +++ b/tests_integ/memory/integrations/test_session_manager.py @@ -4,17 +4,22 @@ Run with: python -m pytest tests_integ/memory/integrations/test_session_manager.py -v """ +import json import logging import os import time import uuid +from datetime import datetime, timezone import pytest from strands import Agent +from strands.types.session import Session, SessionAgent, SessionType from bedrock_agentcore.memory import MemoryClient +from bedrock_agentcore.memory.integrations.strands.bedrock_converter import AgentCoreMemoryConverter from bedrock_agentcore.memory.integrations.strands.config import AgentCoreMemoryConfig, RetrievalConfig from bedrock_agentcore.memory.integrations.strands.session_manager import AgentCoreMemorySessionManager +from bedrock_agentcore.memory.models.filters import EventMetadataFilter, LeftExpression, OperatorType logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) @@ -41,7 +46,9 @@ def memory_client(self): def test_memory_stm(self, memory_client): """Create a test memory for integration tests.""" memory_name = f"testmemorySTM{uuid.uuid4().hex[:8]}" - memory = memory_client.create_memory(name=memory_name, description="Test STM memory for integration tests") + memory = memory_client.create_memory_and_wait( + name=memory_name, description="Test STM memory for integration tests", strategies=[] + ) yield memory # Cleanup try: @@ -190,3 +197,169 @@ def test_session_manager_error_handling(self): # This should fail when trying to use the session manager agent = Agent(system_prompt="Test", session_manager=session_manager) agent("Test message") + + def test_legacy_event_migration(self, test_memory_stm, memory_client): + """Test that legacy events with prefixed actorIds are migrated to metadata format. + + The constructor calls read_session which creates a metadata-path session if none exists. + To test legacy migration, we create the legacy event BEFORE constructing the session manager, + so the constructor's read_session finds it via the fallback and migrates it on first access. + """ + session_id = f"test-legacy-{uuid.uuid4().hex[:8]}" + actor_id = f"test-actor-{uuid.uuid4().hex[:8]}" + + # --- Session migration --- + # Create a legacy session event BEFORE constructing the session manager. + # Legacy events use blob payloads with the session data, so we use gmdp_client directly. + legacy_session_actor_id = f"session_{session_id}" + session_data = Session(session_id=session_id, session_type=SessionType.AGENT) + memory_client.gmdp_client.create_event( + memoryId=test_memory_stm["id"], + actorId=legacy_session_actor_id, + sessionId=session_id, + payload=[{"blob": json.dumps(session_data.to_dict())}], + eventTimestamp=datetime.now(timezone.utc), + ) + + # Verify legacy event exists before migration + legacy_events_before = memory_client.list_events( + memory_id=test_memory_stm["id"], + actor_id=legacy_session_actor_id, + session_id=session_id, + ) + assert len(legacy_events_before) >= 1 + + # Constructing the session manager triggers read_session in __init__, + # which should find the legacy event, migrate it, and delete the old one + config = AgentCoreMemoryConfig( + memory_id=test_memory_stm["id"], + session_id=session_id, + actor_id=actor_id, + ) + session_manager = AgentCoreMemorySessionManager(agentcore_memory_config=config, region_name=REGION) + + # Verify migration: legacy event should be deleted + legacy_events_after = memory_client.list_events( + memory_id=test_memory_stm["id"], + actor_id=legacy_session_actor_id, + session_id=session_id, + ) + assert len(legacy_events_after) == 0 + + # Verify migration: read_session finds it via the new metadata path + read_session_result = session_manager.read_session(session_id) + assert read_session_result is not None + assert read_session_result.session_id == session_id + + # --- Agent migration --- + agent_id = f"test-agent-{uuid.uuid4().hex[:8]}" + legacy_agent_actor_id = f"agent_{agent_id}" + agent_data = SessionAgent( + agent_id=agent_id, + state={"key": "value"}, + conversation_manager_state={}, + ) + memory_client.gmdp_client.create_event( + memoryId=test_memory_stm["id"], + actorId=legacy_agent_actor_id, + sessionId=session_id, + payload=[{"blob": json.dumps(agent_data.to_dict())}], + eventTimestamp=datetime.now(timezone.utc), + ) + + # read_agent should find via fallback and migrate + read_agent_result = session_manager.read_agent(session_id, agent_id) + assert read_agent_result is not None + assert read_agent_result.agent_id == agent_id + + # Verify migration: legacy event should be deleted + legacy_agent_events = memory_client.list_events( + memory_id=test_memory_stm["id"], + actor_id=legacy_agent_actor_id, + session_id=session_id, + ) + assert len(legacy_agent_events) == 0 + + # endregion Event metadata integration tests + + # region End-to-end agent with batching tests + + def test_agent_conversation_with_context_manager(self, test_memory_stm): + """Test that Agent messages are flushed when the context manager exits, and session resume loads them.""" + session_id = f"test-agent-ctx-{uuid.uuid4().hex[:8]}" + actor_id = f"test-actor-{uuid.uuid4().hex[:8]}" + + config = AgentCoreMemoryConfig( + memory_id=test_memory_stm["id"], + session_id=session_id, + actor_id=actor_id, + batch_size=10, + ) + + # Use context manager — __exit__ calls _flush_messages() which is blocking + with AgentCoreMemorySessionManager(agentcore_memory_config=config, region_name=REGION) as sm: + agent = Agent(system_prompt="You are a helpful assistant.", session_manager=sm) + response1 = agent("Hello, my name is Bob") + assert response1 is not None + + # After __exit__, buffered messages have been flushed (blocking). + # Resume session with a new session manager to verify persistence. + config2 = AgentCoreMemoryConfig( + memory_id=test_memory_stm["id"], + session_id=session_id, + actor_id=actor_id, + batch_size=10, + ) + sm2 = AgentCoreMemorySessionManager(agentcore_memory_config=config2, region_name=REGION) + agent2 = Agent(system_prompt="You are a helpful assistant.", session_manager=sm2) + + response2 = agent2("What is my name?") + assert response2 is not None + assert "Bob" in response2.message["content"][0]["text"] + + sm2.close() + + def test_agent_multi_turn_with_batching(self, test_memory_stm): + """Test that a multi-turn conversation within a single Agent works with batching.""" + session_id = f"test-agent-multi-{uuid.uuid4().hex[:8]}" + actor_id = f"test-actor-{uuid.uuid4().hex[:8]}" + + config = AgentCoreMemoryConfig( + memory_id=test_memory_stm["id"], + session_id=session_id, + actor_id=actor_id, + batch_size=10, + ) + session_manager = AgentCoreMemorySessionManager(agentcore_memory_config=config, region_name=REGION) + + agent = Agent(system_prompt="You are a helpful assistant.", session_manager=session_manager) + + agent("Hello, my name is Charlie") + agent("I live in Seattle") + response3 = agent("What is my name and where do I live?") + assert response3 is not None + response_text = response3.message["content"][0]["text"] + assert "Charlie" in response_text + assert "Seattle" in response_text + + # Flush remaining buffered messages (blocking) + session_manager.close() + + # Verify batched messages are persisted — filter out state events + message_filter = EventMetadataFilter.build_expression( + left_operand=LeftExpression.build("stateType"), + operator=OperatorType.NOT_EXISTS, + ) + events = session_manager.memory_client.list_events( + memory_id=test_memory_stm["id"], + actor_id=actor_id, + session_id=session_id, + event_metadata=[message_filter], + ) + + # Convert events back to messages and verify all turns are present + messages = AgentCoreMemoryConverter.events_to_messages(events) + # At least 3 user + 3 assistant messages + assert len(messages) >= 6 + + # endregion End-to-end agent with batching tests