diff --git a/src/bedrock_agentcore/memory/integrations/strands/__init__.py b/src/bedrock_agentcore/memory/integrations/strands/__init__.py index 9f162933..5f4c0bfb 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/__init__.py +++ b/src/bedrock_agentcore/memory/integrations/strands/__init__.py @@ -1 +1,5 @@ """Strands integration for Bedrock AgentCore Memory.""" + +from .converters import BedrockConverseConverter, MemoryConverter, OpenAIConverseConverter + +__all__ = ["BedrockConverseConverter", "MemoryConverter", "OpenAIConverseConverter"] diff --git a/src/bedrock_agentcore/memory/integrations/strands/converters/__init__.py b/src/bedrock_agentcore/memory/integrations/strands/converters/__init__.py new file mode 100644 index 00000000..5a08a5d7 --- /dev/null +++ b/src/bedrock_agentcore/memory/integrations/strands/converters/__init__.py @@ -0,0 +1,13 @@ +"""Memory converters for AgentCore Memory STM events.""" + +from .bedrock import BedrockConverseConverter +from .openai import OpenAIConverseConverter +from .protocol import CONVERSATIONAL_MAX_SIZE, MemoryConverter, exceeds_conversational_limit + +__all__ = [ + "BedrockConverseConverter", + "CONVERSATIONAL_MAX_SIZE", + "MemoryConverter", + "OpenAIConverseConverter", + "exceeds_conversational_limit", +] diff --git a/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py b/src/bedrock_agentcore/memory/integrations/strands/converters/bedrock.py similarity index 88% rename from src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py rename to src/bedrock_agentcore/memory/integrations/strands/converters/bedrock.py index 1f0905c7..3a066f51 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py +++ b/src/bedrock_agentcore/memory/integrations/strands/converters/bedrock.py @@ -1,4 +1,4 @@ -"""Bedrock AgentCore Memory conversion utilities.""" +"""Bedrock Converse format converter for AgentCore Memory.""" import json import logging @@ -6,13 +6,13 @@ from strands.types.session import SessionMessage -logger = logging.getLogger(__name__) +from .protocol import exceeds_conversational_limit -CONVERSATIONAL_MAX_SIZE = 9000 +logger = logging.getLogger(__name__) -class AgentCoreMemoryConverter: - """Handles conversion between Strands and Bedrock AgentCore Memory formats.""" +class BedrockConverseConverter: + """Handles conversion between Strands SessionMessages and Bedrock Converse event payloads.""" @staticmethod def _filter_empty_text(message: dict) -> dict: @@ -35,7 +35,7 @@ def message_to_payload(session_message: SessionMessage) -> list[Tuple[str, str]] # First convert to dict (which encodes bytes to base64), # then filter empty text on the encoded version session_dict = session_message.to_dict() - filtered_message = AgentCoreMemoryConverter._filter_empty_text(session_dict["message"]) + filtered_message = BedrockConverseConverter._filter_empty_text(session_dict["message"]) if not filtered_message.get("content"): logger.debug("Skipping message with no content after filtering empty text") return [] @@ -77,7 +77,7 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]: if "conversational" in payload_item: conv = payload_item["conversational"] session_msg = SessionMessage.from_dict(json.loads(conv["content"]["text"])) - session_msg.message = AgentCoreMemoryConverter._filter_empty_text(session_msg.message) + session_msg.message = BedrockConverseConverter._filter_empty_text(session_msg.message) if session_msg.message.get("content"): messages.append(session_msg) elif "blob" in payload_item: @@ -86,7 +86,7 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]: if isinstance(blob_data, (tuple, list)) and len(blob_data) == 2: try: session_msg = SessionMessage.from_dict(json.loads(blob_data[0])) - session_msg.message = AgentCoreMemoryConverter._filter_empty_text(session_msg.message) + session_msg.message = BedrockConverseConverter._filter_empty_text(session_msg.message) if session_msg.message.get("content"): messages.append(session_msg) except (json.JSONDecodeError, ValueError): @@ -103,4 +103,4 @@ def total_length(message: tuple[str, str]) -> int: @staticmethod def exceeds_conversational_limit(message: tuple[str, str]) -> bool: """Check if message exceeds conversational size limit.""" - return AgentCoreMemoryConverter.total_length(message) >= CONVERSATIONAL_MAX_SIZE + return exceeds_conversational_limit(message) diff --git a/src/bedrock_agentcore/memory/integrations/strands/converters/openai.py b/src/bedrock_agentcore/memory/integrations/strands/converters/openai.py new file mode 100644 index 00000000..3f3c75b7 --- /dev/null +++ b/src/bedrock_agentcore/memory/integrations/strands/converters/openai.py @@ -0,0 +1,214 @@ +"""OpenAI-format converter for AgentCore Memory. + +Converts between Strands SessionMessages (Bedrock Converse format) and OpenAI message format +stored in AgentCore Memory STM events. +""" + +import json +import logging +from typing import Any, Tuple + +from strands.types.session import SessionMessage + +from .protocol import exceeds_conversational_limit + +logger = logging.getLogger(__name__) + + +def _bedrock_to_openai(message: dict) -> dict: + """Convert a Bedrock Converse message dict to OpenAI message format. + + Args: + message: Bedrock Converse format message with role and content list. + + Returns: + OpenAI format message dict. + """ + role = message.get("role", "user") + content = message.get("content", []) + + # Check for toolResult — maps to OpenAI "tool" role + if content and "toolResult" in content[0]: + tool_result = content[0]["toolResult"] + text_parts = [c.get("text", "") for c in tool_result.get("content", []) if "text" in c] + result = { + "role": "tool", + "tool_call_id": tool_result["toolUseId"], + "content": "\n".join(text_parts), + } + if "status" in tool_result: + result["status"] = tool_result["status"] + return result + + # Separate text and toolUse items + text_parts = [] + tool_calls = [] + for item in content: + if "text" in item: + text = item["text"].strip() + if text: + text_parts.append(text) + elif "toolUse" in item: + tu = item["toolUse"] + tool_calls.append({ + "id": tu["toolUseId"], + "type": "function", + "function": { + "name": tu["name"], + "arguments": json.dumps(tu.get("input", {})), + }, + }) + + result: dict[str, Any] = {"role": role} + + if tool_calls: + # OpenAI convention: content is null when only tool_calls, string when both + result["content"] = "\n".join(text_parts) if text_parts else None + result["tool_calls"] = tool_calls + else: + result["content"] = "\n".join(text_parts) if text_parts else "" + + return result + + +def _openai_to_bedrock(openai_msg: dict) -> dict: + """Convert an OpenAI message dict to Bedrock Converse format. + + Args: + openai_msg: OpenAI format message dict. + + Returns: + Bedrock Converse format message dict with role and content list. + """ + role = openai_msg.get("role", "user") + content_items: list[dict[str, Any]] = [] + + # Handle tool response + if role == "tool": + tool_result: dict[str, Any] = { + "toolUseId": openai_msg["tool_call_id"], + "content": [{"text": openai_msg.get("content", "")}], + } + if "status" in openai_msg: + tool_result["status"] = openai_msg["status"] + return { + "role": "user", + "content": [{"toolResult": tool_result}], + } + + # Handle system messages — Bedrock Converse has no system role + if role == "system": + return { + "role": "user", + "content": [{"text": openai_msg.get("content", "")}], + } + + # Handle text content + text_content = openai_msg.get("content") + if text_content and isinstance(text_content, str): + content_items.append({"text": text_content}) + + # Handle tool_calls + for tc in openai_msg.get("tool_calls", []): + fn = tc.get("function", {}) + args_str = fn.get("arguments", "{}") + try: + args = json.loads(args_str) + except (json.JSONDecodeError, ValueError): + args = {} + content_items.append({ + "toolUse": { + "toolUseId": tc["id"], + "name": fn["name"], + "input": args, + } + }) + + # Map role + bedrock_role = "assistant" if role == "assistant" else "user" + + return {"role": bedrock_role, "content": content_items} + + +class OpenAIConverseConverter: + """Converts between Strands SessionMessages (Bedrock Converse) and OpenAI message format in STM.""" + + @staticmethod + def message_to_payload(session_message: SessionMessage) -> list[Tuple[str, str]]: + """Convert a SessionMessage (Bedrock Converse) to OpenAI-format STM payload. + + Args: + session_message: The Strands session message to convert. + + Returns: + List of (json_str, role) tuples for STM storage. Empty list if no content. + """ + message = session_message.message + content = message.get("content", []) + if not content: + return [] + + # Filter empty/whitespace-only text items + has_non_empty = any( + ("text" in item and item["text"].strip()) + or "toolUse" in item + or "toolResult" in item + for item in content + ) + if not has_non_empty: + return [] + + openai_msg = _bedrock_to_openai(message) + role = openai_msg.get("role", "user") + + return [(json.dumps(openai_msg), role)] + + @staticmethod + def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]: + """Convert STM events containing OpenAI-format messages to SessionMessages (Bedrock Converse). + + Args: + events: List of STM events. Each may contain conversational or blob payloads + with OpenAI-format JSON. + + Returns: + List of SessionMessage objects in Bedrock Converse format. + """ + messages: list[SessionMessage] = [] + + for event in events: + for payload_item in event.get("payload", []): + openai_msg = None + + if "conversational" in payload_item: + conv = payload_item["conversational"] + try: + openai_msg = json.loads(conv["content"]["text"]) + except (json.JSONDecodeError, KeyError, ValueError): + logger.error("Failed to parse conversational payload as OpenAI message") + continue + + elif "blob" in payload_item: + try: + blob_data = json.loads(payload_item["blob"]) + if isinstance(blob_data, (tuple, list)) and len(blob_data) == 2: + openai_msg = json.loads(blob_data[0]) + except (json.JSONDecodeError, ValueError): + logger.error("Failed to parse blob payload: %s", payload_item) + continue + + if openai_msg and isinstance(openai_msg, dict): + bedrock_msg = _openai_to_bedrock(openai_msg) + if bedrock_msg.get("content"): + session_msg = SessionMessage( + message=bedrock_msg, + message_id=0, + ) + messages.append(session_msg) + + return list(reversed(messages)) + + @staticmethod + def exceeds_conversational_limit(message: tuple[str, str]) -> bool: + """Check if message exceeds conversational payload size limit.""" + return exceeds_conversational_limit(message) diff --git a/src/bedrock_agentcore/memory/integrations/strands/converters/protocol.py b/src/bedrock_agentcore/memory/integrations/strands/converters/protocol.py new file mode 100644 index 00000000..dccafbdc --- /dev/null +++ b/src/bedrock_agentcore/memory/integrations/strands/converters/protocol.py @@ -0,0 +1,38 @@ +"""Shared protocol and utilities for memory converters.""" + +from typing import Any, Protocol, Tuple + +from strands.types.session import SessionMessage + +CONVERSATIONAL_MAX_SIZE = 9000 + + +class MemoryConverter(Protocol): + """Protocol for converting between Strands messages and STM event payloads.""" + + @staticmethod + def message_to_payload(session_message: SessionMessage) -> list[Tuple[str, str]]: + """Convert SessionMessage to STM event payload format.""" + ... + + @staticmethod + def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]: + """Convert STM events to SessionMessages.""" + ... + + @staticmethod + def exceeds_conversational_limit(message: tuple[str, str]) -> bool: + """Check if message exceeds conversational payload size limit.""" + ... + + +def exceeds_conversational_limit(message: tuple[str, str]) -> bool: + """Check if message exceeds the conversational payload size limit. + + Args: + message: A (text, role) tuple. + + Returns: + True if the total length meets or exceeds CONVERSATIONAL_MAX_SIZE. + """ + return sum(len(text) for text in message) >= CONVERSATIONAL_MAX_SIZE diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index c3912a73..6bcb17fc 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -20,7 +20,7 @@ from bedrock_agentcore.memory.client import MemoryClient -from .bedrock_converter import AgentCoreMemoryConverter +from .converters import BedrockConverseConverter, MemoryConverter from .config import AgentCoreMemoryConfig, RetrievalConfig if TYPE_CHECKING: @@ -84,6 +84,7 @@ def _get_monotonic_timestamp(cls, desired_timestamp: Optional[datetime] = None) def __init__( self, agentcore_memory_config: AgentCoreMemoryConfig, + converter: Optional[type[MemoryConverter]] = None, region_name: Optional[str] = None, boto_session: Optional[boto3.Session] = None, boto_client_config: Optional[BotocoreConfig] = None, @@ -93,12 +94,15 @@ def __init__( Args: agentcore_memory_config (AgentCoreMemoryConfig): Configuration for AgentCore Memory integration. + converter (Optional[type[MemoryConverter]], optional): Custom converter for message format conversion. + Defaults to BedrockConverseConverter. region_name (Optional[str], optional): AWS region for Bedrock AgentCore Memory. Defaults to None. boto_session (Optional[boto3.Session], optional): Optional boto3 session. Defaults to None. boto_client_config (Optional[BotocoreConfig], optional): Optional boto3 client configuration. Defaults to None. **kwargs (Any): Additional keyword arguments. """ + self.converter = converter or BedrockConverseConverter self.config = agentcore_memory_config self.memory_client = MemoryClient(region_name=region_name) session = boto_session or boto3.Session(region_name=region_name) @@ -351,7 +355,7 @@ def create_message( raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session_id}") try: - messages = AgentCoreMemoryConverter.message_to_payload(session_message) + messages = self.converter.message_to_payload(session_message) if not messages: return @@ -359,7 +363,7 @@ def create_message( original_timestamp = datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00")) monotonic_timestamp = self._get_monotonic_timestamp(original_timestamp) - if not AgentCoreMemoryConverter.exceeds_conversational_limit(messages[0]): + if not self.converter.exceeds_conversational_limit(messages[0]): event = self.memory_client.create_event( memory_id=self.config.memory_id, actor_id=self.config.actor_id, @@ -462,7 +466,7 @@ def list_messages( session_id=session_id, max_results=max_results, ) - messages = AgentCoreMemoryConverter.events_to_messages(events) + messages = self.converter.events_to_messages(events) if limit is not None: return messages[offset : offset + limit] else: diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py index c05f79d3..96866045 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py @@ -10,7 +10,7 @@ from strands.types.exceptions import SessionException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType -from bedrock_agentcore.memory.integrations.strands.bedrock_converter import AgentCoreMemoryConverter +from bedrock_agentcore.memory.integrations.strands.converters import BedrockConverseConverter from bedrock_agentcore.memory.integrations.strands.config import AgentCoreMemoryConfig, RetrievalConfig from bedrock_agentcore.memory.integrations.strands.session_manager import AgentCoreMemorySessionManager @@ -118,7 +118,7 @@ def test_events_to_messages(self, session_manager): } ] - messages = AgentCoreMemoryConverter.events_to_messages(events) + messages = BedrockConverseConverter.events_to_messages(events) assert messages[0].message["role"] == "user" assert messages[0].message["content"][0]["text"] == "Hello" @@ -298,7 +298,7 @@ def test_events_to_messages_empty_payload(self, session_manager): } ] - messages = AgentCoreMemoryConverter.events_to_messages(events) + messages = BedrockConverseConverter.events_to_messages(events) assert len(messages) == 0 @@ -1116,3 +1116,143 @@ def test_list_messages_with_limit_calculates_max_results(self, session_manager, mock_memory_client.list_events.assert_called_once() call_kwargs = mock_memory_client.list_events.call_args[1] assert call_kwargs["max_results"] == 550 # limit + offset + + +class TestInjectableConverter: + """Tests for injectable message converter support.""" + + def _make_manager(self, config, mock_memory_client, converter=None): + """Helper to create a session manager with optional custom converter.""" + with patch( + "bedrock_agentcore.memory.integrations.strands.session_manager.MemoryClient", + return_value=mock_memory_client, + ): + with patch("boto3.Session") as mock_boto_session: + mock_session = Mock() + mock_session.region_name = "us-west-2" + mock_session.client.return_value = Mock() + mock_boto_session.return_value = mock_session + + with patch( + "strands.session.repository_session_manager.RepositorySessionManager.__init__", return_value=None + ): + kwargs = {} + if converter is not None: + kwargs["converter"] = converter + manager = AgentCoreMemorySessionManager(config, **kwargs) + manager.session_id = config.session_id + manager.session = Session(session_id=config.session_id, session_type=SessionType.AGENT) + return manager + + def test_default_converter_is_agentcore(self, agentcore_config, mock_memory_client): + """When no converter is provided, BedrockConverseConverter is used.""" + manager = self._make_manager(agentcore_config, mock_memory_client) + assert manager.converter is BedrockConverseConverter + + def test_custom_converter_is_stored(self, agentcore_config, mock_memory_client): + """When a custom converter is provided, it is stored on the manager.""" + + class CustomConverter: + @staticmethod + def message_to_payload(session_message): + return [] + + @staticmethod + def events_to_messages(events): + return [] + + @staticmethod + def exceeds_conversational_limit(message): + return False + + manager = self._make_manager(agentcore_config, mock_memory_client, converter=CustomConverter) + assert manager.converter is CustomConverter + + def test_create_message_uses_custom_converter(self, agentcore_config, mock_memory_client): + """create_message() should call the custom converter's message_to_payload and exceeds_conversational_limit.""" + mock_memory_client.create_event.return_value = {"eventId": "event-123"} + + custom_converter = Mock() + custom_converter.message_to_payload.return_value = [('{"message": {"role": "user"}}', "user")] + custom_converter.exceeds_conversational_limit.return_value = False + + manager = self._make_manager(agentcore_config, mock_memory_client, converter=custom_converter) + + message = SessionMessage( + message={"role": "user", "content": [{"text": "Hello"}]}, message_id=1, created_at="2024-01-01T12:00:00Z" + ) + + manager.create_message("test-session-456", "test-agent-123", message) + + custom_converter.message_to_payload.assert_called_once_with(message) + custom_converter.exceeds_conversational_limit.assert_called_once() + + def test_list_messages_uses_custom_converter(self, agentcore_config, mock_memory_client): + """list_messages() should call the custom converter's events_to_messages.""" + events = [ + { + "eventId": "event-1", + "payload": [ + { + "conversational": { + "content": {"text": '{"message": {"role": "user", "content": [{"text": "Hi"}]}}'}, + "role": "USER", + } + } + ], + } + ] + mock_memory_client.list_events.return_value = events + + expected_messages = [ + SessionMessage(message={"role": "user", "content": [{"text": "Hi"}]}, message_id=1) + ] + custom_converter = Mock() + custom_converter.events_to_messages.return_value = expected_messages + + manager = self._make_manager(agentcore_config, mock_memory_client, converter=custom_converter) + + result = manager.list_messages("test-session-456", "test-agent-123") + + custom_converter.events_to_messages.assert_called_once_with(events) + assert result == expected_messages + + def test_create_message_empty_payload_from_custom_converter(self, agentcore_config, mock_memory_client): + """When custom converter returns empty payload, create_message should return None.""" + custom_converter = Mock() + custom_converter.message_to_payload.return_value = [] + + manager = self._make_manager(agentcore_config, mock_memory_client, converter=custom_converter) + + message = SessionMessage( + message={"role": "user", "content": [{"text": "Hello"}]}, message_id=1, created_at="2024-01-01T12:00:00Z" + ) + + result = manager.create_message("test-session-456", "test-agent-123", message) + + assert result is None + mock_memory_client.create_event.assert_not_called() + + def test_create_message_oversized_uses_blob_fallback(self, agentcore_config, mock_memory_client): + """When custom converter says message exceeds limit, blob fallback is used.""" + mock_memory_client.gmdp_client.create_event.return_value = { + "event": {"eventId": "event-blob-123"} + } + + custom_converter = Mock() + custom_converter.message_to_payload.return_value = [('{"large": "payload"}', "user")] + custom_converter.exceeds_conversational_limit.return_value = True + + manager = self._make_manager(agentcore_config, mock_memory_client, converter=custom_converter) + + message = SessionMessage( + message={"role": "user", "content": [{"text": "x" * 10000}]}, + message_id=1, + created_at="2024-01-01T12:00:00Z", + ) + + manager.create_message("test-session-456", "test-agent-123", message) + + # Should use gmdp_client.create_event (blob path), not memory_client.create_event + mock_memory_client.create_event.assert_not_called() + mock_memory_client.gmdp_client.create_event.assert_called_once() diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py b/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py index e15a457f..7b7465dd 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py @@ -1,15 +1,15 @@ -"""Tests for AgentCoreMemoryConverter.""" +"""Tests for BedrockConverseConverter.""" import json from unittest.mock import patch from strands.types.session import SessionMessage -from bedrock_agentcore.memory.integrations.strands.bedrock_converter import AgentCoreMemoryConverter +from bedrock_agentcore.memory.integrations.strands.converters import BedrockConverseConverter -class TestAgentCoreMemoryConverter: - """Test cases for AgentCoreMemoryConverter.""" +class TestBedrockConverseConverter: + """Test cases for BedrockConverseConverter.""" def test_message_to_payload(self): """Test converting SessionMessage to payload format.""" @@ -17,7 +17,7 @@ def test_message_to_payload(self): message_id=1, message={"role": "user", "content": [{"text": "Hello"}]}, created_at="2023-01-01T00:00:00Z" ) - result = AgentCoreMemoryConverter.message_to_payload(message) + result = BedrockConverseConverter.message_to_payload(message) assert len(result) == 1 assert result[0][1] == "user" @@ -38,7 +38,7 @@ def test_events_to_messages_conversational(self): } ] - result = AgentCoreMemoryConverter.events_to_messages(events) + result = BedrockConverseConverter.events_to_messages(events) assert len(result) == 1 assert result[0].message["role"] == "user" @@ -52,28 +52,28 @@ def test_events_to_messages_blob_valid(self): blob_data = [json.dumps(session_message.to_dict()), "user"] events = [{"payload": [{"blob": json.dumps(blob_data)}]}] - result = AgentCoreMemoryConverter.events_to_messages(events) + result = BedrockConverseConverter.events_to_messages(events) assert len(result) == 1 assert result[0].message["role"] == "user" - @patch("bedrock_agentcore.memory.integrations.strands.bedrock_converter.logger") + @patch("bedrock_agentcore.memory.integrations.strands.converters.bedrock.logger") def test_events_to_messages_blob_invalid_json(self, mock_logger): """Test handling invalid JSON in blob events.""" events = [{"payload": [{"blob": "invalid json"}]}] - result = AgentCoreMemoryConverter.events_to_messages(events) + result = BedrockConverseConverter.events_to_messages(events) assert len(result) == 0 mock_logger.error.assert_called() - @patch("bedrock_agentcore.memory.integrations.strands.bedrock_converter.logger") + @patch("bedrock_agentcore.memory.integrations.strands.converters.bedrock.logger") def test_events_to_messages_blob_invalid_session_message(self, mock_logger): """Test handling invalid SessionMessage in blob events.""" blob_data = ["invalid", "user"] events = [{"payload": [{"blob": json.dumps(blob_data)}]}] - result = AgentCoreMemoryConverter.events_to_messages(events) + result = BedrockConverseConverter.events_to_messages(events) assert len(result) == 0 mock_logger.error.assert_called() @@ -81,47 +81,47 @@ def test_events_to_messages_blob_invalid_session_message(self, mock_logger): def test_total_length(self): """Test calculating total length of message tuple.""" message = ("hello", "world") - result = AgentCoreMemoryConverter.total_length(message) + result = BedrockConverseConverter.total_length(message) assert result == 10 def test_exceeds_conversational_limit_false(self): """Test message under conversational limit.""" message = ("short", "message") - result = AgentCoreMemoryConverter.exceeds_conversational_limit(message) + result = BedrockConverseConverter.exceeds_conversational_limit(message) assert result is False def test_exceeds_conversational_limit_true(self): """Test message over conversational limit.""" long_text = "x" * 5000 message = (long_text, long_text) - result = AgentCoreMemoryConverter.exceeds_conversational_limit(message) + result = BedrockConverseConverter.exceeds_conversational_limit(message) assert result is True def test_filter_empty_text_removes_empty_string(self): """Test filtering removes empty text items.""" message = {"role": "user", "content": [{"text": ""}, {"text": "hello"}]} - result = AgentCoreMemoryConverter._filter_empty_text(message) + result = BedrockConverseConverter._filter_empty_text(message) assert len(result["content"]) == 1 assert result["content"][0]["text"] == "hello" def test_filter_empty_text_removes_whitespace_only(self): """Test filtering removes whitespace-only text items.""" message = {"role": "user", "content": [{"text": " "}, {"text": "hello"}]} - result = AgentCoreMemoryConverter._filter_empty_text(message) + result = BedrockConverseConverter._filter_empty_text(message) assert len(result["content"]) == 1 assert result["content"][0]["text"] == "hello" def test_filter_empty_text_keeps_non_text_items(self): """Test filtering keeps non-text items like toolUse.""" message = {"role": "user", "content": [{"text": ""}, {"toolUse": {"name": "test"}}]} - result = AgentCoreMemoryConverter._filter_empty_text(message) + result = BedrockConverseConverter._filter_empty_text(message) assert len(result["content"]) == 1 assert "toolUse" in result["content"][0] def test_filter_empty_text_all_empty_returns_empty_content(self): """Test filtering all empty text returns empty content array.""" message = {"role": "user", "content": [{"text": ""}]} - result = AgentCoreMemoryConverter._filter_empty_text(message) + result = BedrockConverseConverter._filter_empty_text(message) assert result["content"] == [] def test_message_to_payload_skips_all_empty_text(self): @@ -129,7 +129,7 @@ def test_message_to_payload_skips_all_empty_text(self): message = SessionMessage( message_id=1, message={"role": "user", "content": [{"text": ""}]}, created_at="2023-01-01T00:00:00Z" ) - result = AgentCoreMemoryConverter.message_to_payload(message) + result = BedrockConverseConverter.message_to_payload(message) assert result == [] def test_message_to_payload_filters_empty_text_items(self): @@ -139,7 +139,7 @@ def test_message_to_payload_filters_empty_text_items(self): message={"role": "user", "content": [{"text": ""}, {"text": "hello"}]}, created_at="2023-01-01T00:00:00Z", ) - result = AgentCoreMemoryConverter.message_to_payload(message) + result = BedrockConverseConverter.message_to_payload(message) assert len(result) == 1 parsed = json.loads(result[0][0]) assert len(parsed["message"]["content"]) == 1 @@ -159,7 +159,7 @@ def test_events_to_messages_filters_empty_text_conversational(self): ] } ] - result = AgentCoreMemoryConverter.events_to_messages(events) + result = BedrockConverseConverter.events_to_messages(events) assert len(result) == 1 assert len(result[0].message["content"]) == 1 assert result[0].message["content"][0]["text"] == "hello" @@ -172,7 +172,7 @@ def test_events_to_messages_drops_all_empty_conversational(self): events = [ {"payload": [{"conversational": {"content": {"text": json.dumps(empty_msg.to_dict())}, "role": "USER"}}]} ] - result = AgentCoreMemoryConverter.events_to_messages(events) + result = BedrockConverseConverter.events_to_messages(events) assert len(result) == 0 def test_events_to_messages_filters_empty_text_blob(self): @@ -183,7 +183,7 @@ def test_events_to_messages_filters_empty_text_blob(self): created_at="2023-01-01T00:00:00Z", ) events = [{"payload": [{"blob": json.dumps([json.dumps(msg_with_empty.to_dict()), "user"])}]}] - result = AgentCoreMemoryConverter.events_to_messages(events) + result = BedrockConverseConverter.events_to_messages(events) assert len(result) == 1 assert len(result[0].message["content"]) == 1 assert result[0].message["content"][0]["text"] == "hello" @@ -209,7 +209,7 @@ def test_message_to_payload_with_bytes_encodes_before_filtering(self): ) # This should not raise "Object of type bytes is not JSON serializable" - result = AgentCoreMemoryConverter.message_to_payload(message) + result = BedrockConverseConverter.message_to_payload(message) assert len(result) == 1 # Verify json.dumps succeeded and bytes were encoded diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_openai_converter.py b/tests/bedrock_agentcore/memory/integrations/strands/test_openai_converter.py new file mode 100644 index 00000000..4addf439 --- /dev/null +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_openai_converter.py @@ -0,0 +1,504 @@ +"""Tests for OpenAIConverseConverter.""" + +import json +from unittest.mock import patch + +import pytest +from strands.types.session import SessionMessage + +from bedrock_agentcore.memory.integrations.strands.converters import OpenAIConverseConverter + + +class TestOpenAIConverseConverterMessageToPayload: + """Test converting Strands SessionMessages (Bedrock Converse) to OpenAI-format STM payloads.""" + + def test_user_text_message(self): + """Convert a simple user text message to OpenAI format payload.""" + msg = SessionMessage( + message={"role": "user", "content": [{"text": "Hello"}]}, + message_id=1, + ) + result = OpenAIConverseConverter.message_to_payload(msg) + + assert len(result) == 1 + payload_json, role = result[0] + assert role == "user" + + payload = json.loads(payload_json) + assert payload["role"] == "user" + assert payload["content"] == "Hello" + + def test_assistant_text_message(self): + """Convert a simple assistant text message to OpenAI format payload.""" + msg = SessionMessage( + message={"role": "assistant", "content": [{"text": "Hi there"}]}, + message_id=2, + ) + result = OpenAIConverseConverter.message_to_payload(msg) + + assert len(result) == 1 + payload_json, role = result[0] + assert role == "assistant" + + payload = json.loads(payload_json) + assert payload["role"] == "assistant" + assert payload["content"] == "Hi there" + + def test_assistant_tool_use_message(self): + """Convert an assistant message with toolUse to OpenAI tool_calls format.""" + msg = SessionMessage( + message={ + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "call_123", + "name": "get_weather", + "input": {"city": "Seattle"}, + } + } + ], + }, + message_id=3, + ) + result = OpenAIConverseConverter.message_to_payload(msg) + + assert len(result) == 1 + payload_json, role = result[0] + assert role == "assistant" + + payload = json.loads(payload_json) + assert payload["role"] == "assistant" + assert payload.get("content") is None + assert len(payload["tool_calls"]) == 1 + + tc = payload["tool_calls"][0] + assert tc["id"] == "call_123" + assert tc["type"] == "function" + assert tc["function"]["name"] == "get_weather" + assert json.loads(tc["function"]["arguments"]) == {"city": "Seattle"} + + def test_assistant_text_and_tool_use(self): + """Convert an assistant message with both text and toolUse.""" + msg = SessionMessage( + message={ + "role": "assistant", + "content": [ + {"text": "Let me check that for you."}, + { + "toolUse": { + "toolUseId": "call_456", + "name": "search", + "input": {"q": "test"}, + } + }, + ], + }, + message_id=4, + ) + result = OpenAIConverseConverter.message_to_payload(msg) + + assert len(result) == 1 + payload_json, role = result[0] + payload = json.loads(payload_json) + + assert payload["role"] == "assistant" + assert payload["content"] == "Let me check that for you." + assert len(payload["tool_calls"]) == 1 + assert payload["tool_calls"][0]["function"]["name"] == "search" + + def test_tool_result_message(self): + """Convert a toolResult message to OpenAI tool response format.""" + msg = SessionMessage( + message={ + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "call_123", + "content": [{"text": "72°F and sunny"}], + "status": "success", + } + } + ], + }, + message_id=5, + ) + result = OpenAIConverseConverter.message_to_payload(msg) + + assert len(result) == 1 + payload_json, role = result[0] + payload = json.loads(payload_json) + + assert payload["role"] == "tool" + assert payload["tool_call_id"] == "call_123" + assert payload["content"] == "72°F and sunny" + + def test_empty_content_returns_empty(self): + """A message with empty content list returns empty payload.""" + msg = SessionMessage( + message={"role": "user", "content": []}, + message_id=6, + ) + result = OpenAIConverseConverter.message_to_payload(msg) + assert result == [] + + def test_empty_text_filtered(self): + """A message with only empty/whitespace text returns empty payload.""" + msg = SessionMessage( + message={"role": "user", "content": [{"text": " "}]}, + message_id=7, + ) + result = OpenAIConverseConverter.message_to_payload(msg) + assert result == [] + + def test_multiple_tool_calls(self): + """Convert an assistant message with multiple tool calls.""" + msg = SessionMessage( + message={ + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "c1", "name": "fn_a", "input": {}}}, + {"toolUse": {"toolUseId": "c2", "name": "fn_b", "input": {"x": 1}}}, + ], + }, + message_id=8, + ) + result = OpenAIConverseConverter.message_to_payload(msg) + + payload = json.loads(result[0][0]) + assert len(payload["tool_calls"]) == 2 + assert payload["tool_calls"][0]["id"] == "c1" + assert payload["tool_calls"][1]["id"] == "c2" + + +class TestOpenAIConverseConverterEventsToMessages: + """Test converting STM events (OpenAI format) to Strands SessionMessages (Bedrock Converse).""" + + def _make_conversational_event(self, openai_msg: dict, role: str = "USER") -> dict: + """Helper to create an STM event with a conversational payload.""" + return { + "eventId": "event-1", + "payload": [ + { + "conversational": { + "content": {"text": json.dumps(openai_msg)}, + "role": role, + } + } + ], + } + + def _make_blob_event(self, openai_msg: dict, role: str = "user") -> dict: + """Helper to create an STM event with a blob payload.""" + return { + "eventId": "event-1", + "payload": [ + {"blob": json.dumps((json.dumps(openai_msg), role))} + ], + } + + def test_user_text_event(self): + """Convert an OpenAI user text event to Bedrock Converse SessionMessage.""" + events = [self._make_conversational_event({"role": "user", "content": "Hello"}, "USER")] + + messages = OpenAIConverseConverter.events_to_messages(events) + + assert len(messages) == 1 + msg = messages[0].message + assert msg["role"] == "user" + assert msg["content"] == [{"text": "Hello"}] + + def test_assistant_text_event(self): + """Convert an OpenAI assistant text event to Bedrock Converse SessionMessage.""" + events = [self._make_conversational_event({"role": "assistant", "content": "Hi"}, "ASSISTANT")] + + messages = OpenAIConverseConverter.events_to_messages(events) + + assert len(messages) == 1 + msg = messages[0].message + assert msg["role"] == "assistant" + assert msg["content"] == [{"text": "Hi"}] + + def test_assistant_tool_calls_event(self): + """Convert OpenAI assistant tool_calls event to Bedrock Converse toolUse.""" + openai_msg = { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_abc", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "NYC"}', + }, + } + ], + } + events = [self._make_conversational_event(openai_msg, "ASSISTANT")] + + messages = OpenAIConverseConverter.events_to_messages(events) + + assert len(messages) == 1 + msg = messages[0].message + assert msg["role"] == "assistant" + content = msg["content"] + assert len(content) == 1 + assert content[0]["toolUse"]["toolUseId"] == "call_abc" + assert content[0]["toolUse"]["name"] == "get_weather" + assert content[0]["toolUse"]["input"] == {"city": "NYC"} + + def test_assistant_text_and_tool_calls_event(self): + """Convert OpenAI assistant with both text and tool_calls.""" + openai_msg = { + "role": "assistant", + "content": "Let me look that up.", + "tool_calls": [ + { + "id": "call_xyz", + "type": "function", + "function": {"name": "search", "arguments": '{"q": "test"}'}, + } + ], + } + events = [self._make_conversational_event(openai_msg, "ASSISTANT")] + + messages = OpenAIConverseConverter.events_to_messages(events) + + msg = messages[0].message + assert msg["role"] == "assistant" + # Should have text + toolUse + assert msg["content"][0] == {"text": "Let me look that up."} + assert msg["content"][1]["toolUse"]["name"] == "search" + + def test_tool_response_event(self): + """Convert OpenAI tool response to Bedrock Converse toolResult.""" + openai_msg = { + "role": "tool", + "tool_call_id": "call_abc", + "content": "72°F and sunny", + } + events = [self._make_conversational_event(openai_msg, "USER")] + + messages = OpenAIConverseConverter.events_to_messages(events) + + assert len(messages) == 1 + msg = messages[0].message + assert msg["role"] == "user" + tool_result = msg["content"][0]["toolResult"] + assert tool_result["toolUseId"] == "call_abc" + assert tool_result["content"] == [{"text": "72°F and sunny"}] + + def test_system_message_event(self): + """Convert OpenAI system message to user message in Bedrock Converse.""" + openai_msg = {"role": "system", "content": "You are helpful."} + events = [self._make_conversational_event(openai_msg, "USER")] + + messages = OpenAIConverseConverter.events_to_messages(events) + + assert len(messages) == 1 + msg = messages[0].message + assert msg["role"] == "user" + assert msg["content"] == [{"text": "You are helpful."}] + + def test_multiple_events_in_order(self): + """Multiple events are returned in correct order (reversed from STM).""" + events = [ + self._make_conversational_event({"role": "user", "content": "First"}, "USER"), + self._make_conversational_event({"role": "assistant", "content": "Second"}, "ASSISTANT"), + ] + + messages = OpenAIConverseConverter.events_to_messages(events) + + assert len(messages) == 2 + # STM returns newest first; converter reverses + assert messages[0].message["content"][0]["text"] == "Second" + assert messages[1].message["content"][0]["text"] == "First" + + def test_empty_events(self): + """Empty event list returns empty message list.""" + assert OpenAIConverseConverter.events_to_messages([]) == [] + + def test_event_with_empty_payload(self): + """Event with no payload is skipped.""" + events = [{"eventId": "event-1"}] + assert OpenAIConverseConverter.events_to_messages(events) == [] + + def test_blob_event(self): + """Convert a blob-format STM event with OpenAI data.""" + openai_msg = {"role": "user", "content": "From blob"} + events = [self._make_blob_event(openai_msg)] + + messages = OpenAIConverseConverter.events_to_messages(events) + + assert len(messages) == 1 + assert messages[0].message["role"] == "user" + assert messages[0].message["content"] == [{"text": "From blob"}] + + def test_multiple_tool_calls_event(self): + """Convert OpenAI message with multiple tool calls.""" + openai_msg = { + "role": "assistant", + "content": None, + "tool_calls": [ + {"id": "c1", "type": "function", "function": {"name": "fn_a", "arguments": "{}"}}, + {"id": "c2", "type": "function", "function": {"name": "fn_b", "arguments": '{"x":1}'}}, + ], + } + events = [self._make_conversational_event(openai_msg, "ASSISTANT")] + + messages = OpenAIConverseConverter.events_to_messages(events) + + content = messages[0].message["content"] + assert len(content) == 2 + assert content[0]["toolUse"]["toolUseId"] == "c1" + assert content[1]["toolUse"]["toolUseId"] == "c2" + + @patch("bedrock_agentcore.memory.integrations.strands.converters.openai.logger") + def test_blob_invalid_json(self, mock_logger): + """Invalid JSON in blob payload is handled gracefully.""" + events = [{"eventId": "e1", "payload": [{"blob": "not valid json"}]}] + + messages = OpenAIConverseConverter.events_to_messages(events) + + assert messages == [] + mock_logger.error.assert_called_once() + + @patch("bedrock_agentcore.memory.integrations.strands.converters.openai.logger") + def test_conversational_invalid_json(self, mock_logger): + """Invalid JSON in conversational payload is handled gracefully.""" + events = [ + { + "eventId": "e1", + "payload": [ + {"conversational": {"content": {"text": "not valid json"}, "role": "USER"}} + ], + } + ] + + messages = OpenAIConverseConverter.events_to_messages(events) + + assert messages == [] + mock_logger.error.assert_called_once() + + def test_tool_calls_with_malformed_arguments(self): + """Tool calls with non-JSON arguments default to empty dict.""" + openai_msg = { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_bad", + "type": "function", + "function": {"name": "broken_fn", "arguments": "not json"}, + } + ], + } + events = [self._make_conversational_event(openai_msg, "ASSISTANT")] + + messages = OpenAIConverseConverter.events_to_messages(events) + + assert len(messages) == 1 + tool_use = messages[0].message["content"][0]["toolUse"] + assert tool_use["name"] == "broken_fn" + assert tool_use["input"] == {} + + def test_assistant_null_content_no_tool_calls(self): + """Assistant message with null content and no tool_calls produces empty content.""" + openai_msg = {"role": "assistant", "content": None} + events = [self._make_conversational_event(openai_msg, "ASSISTANT")] + + messages = OpenAIConverseConverter.events_to_messages(events) + + # No content items → filtered out + assert messages == [] + + +class TestOpenAIConverseConverterExceedsLimit: + """Test conversational size limit check.""" + + def test_small_message_does_not_exceed(self): + result = OpenAIConverseConverter.exceeds_conversational_limit(("short", "user")) + assert result is False + + def test_large_message_exceeds(self): + big = "x" * 9000 + result = OpenAIConverseConverter.exceeds_conversational_limit((big, "user")) + assert result is True + + +class TestOpenAIConverseConverterRoundTrip: + """Test that converting Bedrock→OpenAI→Bedrock preserves message semantics.""" + + def test_roundtrip_user_text(self): + """User text survives Bedrock→OpenAI→Bedrock round trip.""" + original = SessionMessage( + message={"role": "user", "content": [{"text": "Hello world"}]}, + message_id=1, + ) + payload = OpenAIConverseConverter.message_to_payload(original) + openai_json = payload[0][0] + + # Simulate STM storage and retrieval + event = { + "eventId": "e1", + "payload": [{"conversational": {"content": {"text": openai_json}, "role": "USER"}}], + } + restored = OpenAIConverseConverter.events_to_messages([event]) + + assert len(restored) == 1 + assert restored[0].message["role"] == "user" + assert restored[0].message["content"] == [{"text": "Hello world"}] + + def test_roundtrip_tool_use(self): + """Tool use survives Bedrock→OpenAI→Bedrock round trip.""" + original = SessionMessage( + message={ + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "call_1", "name": "calc", "input": {"expr": "2+2"}}}, + ], + }, + message_id=2, + ) + payload = OpenAIConverseConverter.message_to_payload(original) + openai_json = payload[0][0] + + event = { + "eventId": "e2", + "payload": [{"conversational": {"content": {"text": openai_json}, "role": "ASSISTANT"}}], + } + restored = OpenAIConverseConverter.events_to_messages([event]) + + assert len(restored) == 1 + tu = restored[0].message["content"][0]["toolUse"] + assert tu["toolUseId"] == "call_1" + assert tu["name"] == "calc" + assert tu["input"] == {"expr": "2+2"} + + def test_roundtrip_tool_result(self): + """Tool result survives Bedrock→OpenAI→Bedrock round trip.""" + original = SessionMessage( + message={ + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "call_1", "content": [{"text": "4"}], "status": "success"}}, + ], + }, + message_id=3, + ) + payload = OpenAIConverseConverter.message_to_payload(original) + openai_json = payload[0][0] + + event = { + "eventId": "e3", + "payload": [{"conversational": {"content": {"text": openai_json}, "role": "USER"}}], + } + restored = OpenAIConverseConverter.events_to_messages([event]) + + assert len(restored) == 1 + tr = restored[0].message["content"][0]["toolResult"] + assert tr["toolUseId"] == "call_1" + assert tr["content"] == [{"text": "4"}] + assert tr["status"] == "success"