Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/bedrock_agentcore/memory/integrations/strands/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
"""Strands integration for Bedrock AgentCore Memory."""

from .converters import BedrockConverseConverter, MemoryConverter, OpenAIConverseConverter

__all__ = ["BedrockConverseConverter", "MemoryConverter", "OpenAIConverseConverter"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Memory converters for AgentCore Memory STM events."""

from .bedrock import BedrockConverseConverter
from .openai import OpenAIConverseConverter
from .protocol import CONVERSATIONAL_MAX_SIZE, MemoryConverter, exceeds_conversational_limit

__all__ = [
"BedrockConverseConverter",
"CONVERSATIONAL_MAX_SIZE",
"MemoryConverter",
"OpenAIConverseConverter",
"exceeds_conversational_limit",
]
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
"""Bedrock AgentCore Memory conversion utilities."""
"""Bedrock Converse format converter for AgentCore Memory."""

import json
import logging
from typing import Any, Tuple

from strands.types.session import SessionMessage

logger = logging.getLogger(__name__)
from .protocol import exceeds_conversational_limit

CONVERSATIONAL_MAX_SIZE = 9000
logger = logging.getLogger(__name__)


class AgentCoreMemoryConverter:
"""Handles conversion between Strands and Bedrock AgentCore Memory formats."""
class BedrockConverseConverter:
"""Handles conversion between Strands SessionMessages and Bedrock Converse event payloads."""

@staticmethod
def _filter_empty_text(message: dict) -> dict:
Expand All @@ -35,7 +35,7 @@ def message_to_payload(session_message: SessionMessage) -> list[Tuple[str, str]]
# First convert to dict (which encodes bytes to base64),
# then filter empty text on the encoded version
session_dict = session_message.to_dict()
filtered_message = AgentCoreMemoryConverter._filter_empty_text(session_dict["message"])
filtered_message = BedrockConverseConverter._filter_empty_text(session_dict["message"])
if not filtered_message.get("content"):
logger.debug("Skipping message with no content after filtering empty text")
return []
Expand Down Expand Up @@ -77,7 +77,7 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]:
if "conversational" in payload_item:
conv = payload_item["conversational"]
session_msg = SessionMessage.from_dict(json.loads(conv["content"]["text"]))
session_msg.message = AgentCoreMemoryConverter._filter_empty_text(session_msg.message)
session_msg.message = BedrockConverseConverter._filter_empty_text(session_msg.message)
if session_msg.message.get("content"):
messages.append(session_msg)
elif "blob" in payload_item:
Expand All @@ -86,7 +86,7 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]:
if isinstance(blob_data, (tuple, list)) and len(blob_data) == 2:
try:
session_msg = SessionMessage.from_dict(json.loads(blob_data[0]))
session_msg.message = AgentCoreMemoryConverter._filter_empty_text(session_msg.message)
session_msg.message = BedrockConverseConverter._filter_empty_text(session_msg.message)
if session_msg.message.get("content"):
messages.append(session_msg)
except (json.JSONDecodeError, ValueError):
Expand All @@ -103,4 +103,4 @@ def total_length(message: tuple[str, str]) -> int:
@staticmethod
def exceeds_conversational_limit(message: tuple[str, str]) -> bool:
"""Check if message exceeds conversational size limit."""
return AgentCoreMemoryConverter.total_length(message) >= CONVERSATIONAL_MAX_SIZE
return exceeds_conversational_limit(message)
214 changes: 214 additions & 0 deletions src/bedrock_agentcore/memory/integrations/strands/converters/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
"""OpenAI-format converter for AgentCore Memory.

Converts between Strands SessionMessages (Bedrock Converse format) and OpenAI message format
stored in AgentCore Memory STM events.
"""

import json
import logging
from typing import Any, Tuple

from strands.types.session import SessionMessage

from .protocol import exceeds_conversational_limit

logger = logging.getLogger(__name__)


def _bedrock_to_openai(message: dict) -> dict:
"""Convert a Bedrock Converse message dict to OpenAI message format.

Args:
message: Bedrock Converse format message with role and content list.

Returns:
OpenAI format message dict.
"""
role = message.get("role", "user")
content = message.get("content", [])

# Check for toolResult — maps to OpenAI "tool" role
if content and "toolResult" in content[0]:
tool_result = content[0]["toolResult"]
text_parts = [c.get("text", "") for c in tool_result.get("content", []) if "text" in c]
result = {
"role": "tool",
"tool_call_id": tool_result["toolUseId"],
"content": "\n".join(text_parts),
}
if "status" in tool_result:
result["status"] = tool_result["status"]
return result

# Separate text and toolUse items
text_parts = []
tool_calls = []
for item in content:
if "text" in item:
text = item["text"].strip()
if text:
text_parts.append(text)
elif "toolUse" in item:
tu = item["toolUse"]
tool_calls.append({
"id": tu["toolUseId"],
"type": "function",
"function": {
"name": tu["name"],
"arguments": json.dumps(tu.get("input", {})),
},
})

result: dict[str, Any] = {"role": role}

if tool_calls:
# OpenAI convention: content is null when only tool_calls, string when both
result["content"] = "\n".join(text_parts) if text_parts else None
result["tool_calls"] = tool_calls
else:
result["content"] = "\n".join(text_parts) if text_parts else ""

return result


def _openai_to_bedrock(openai_msg: dict) -> dict:
"""Convert an OpenAI message dict to Bedrock Converse format.

Args:
openai_msg: OpenAI format message dict.

Returns:
Bedrock Converse format message dict with role and content list.
"""
role = openai_msg.get("role", "user")
content_items: list[dict[str, Any]] = []

# Handle tool response
if role == "tool":
tool_result: dict[str, Any] = {
"toolUseId": openai_msg["tool_call_id"],
"content": [{"text": openai_msg.get("content", "")}],
}
if "status" in openai_msg:
tool_result["status"] = openai_msg["status"]
return {
"role": "user",
"content": [{"toolResult": tool_result}],
}

# Handle system messages — Bedrock Converse has no system role
if role == "system":
return {
"role": "user",
"content": [{"text": openai_msg.get("content", "")}],
}

# Handle text content
text_content = openai_msg.get("content")
if text_content and isinstance(text_content, str):
content_items.append({"text": text_content})

# Handle tool_calls
for tc in openai_msg.get("tool_calls", []):
fn = tc.get("function", {})
args_str = fn.get("arguments", "{}")
try:
args = json.loads(args_str)
except (json.JSONDecodeError, ValueError):
args = {}
content_items.append({
"toolUse": {
"toolUseId": tc["id"],
"name": fn["name"],
"input": args,
}
})

# Map role
bedrock_role = "assistant" if role == "assistant" else "user"

return {"role": bedrock_role, "content": content_items}


class OpenAIConverseConverter:
"""Converts between Strands SessionMessages (Bedrock Converse) and OpenAI message format in STM."""

@staticmethod
def message_to_payload(session_message: SessionMessage) -> list[Tuple[str, str]]:
"""Convert a SessionMessage (Bedrock Converse) to OpenAI-format STM payload.

Args:
session_message: The Strands session message to convert.

Returns:
List of (json_str, role) tuples for STM storage. Empty list if no content.
"""
message = session_message.message
content = message.get("content", [])
if not content:
return []

# Filter empty/whitespace-only text items
has_non_empty = any(
("text" in item and item["text"].strip())
or "toolUse" in item
or "toolResult" in item
for item in content
)
if not has_non_empty:
return []

openai_msg = _bedrock_to_openai(message)
role = openai_msg.get("role", "user")

return [(json.dumps(openai_msg), role)]

@staticmethod
def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]:
"""Convert STM events containing OpenAI-format messages to SessionMessages (Bedrock Converse).

Args:
events: List of STM events. Each may contain conversational or blob payloads
with OpenAI-format JSON.

Returns:
List of SessionMessage objects in Bedrock Converse format.
"""
messages: list[SessionMessage] = []

for event in events:
for payload_item in event.get("payload", []):
openai_msg = None

if "conversational" in payload_item:
conv = payload_item["conversational"]
try:
openai_msg = json.loads(conv["content"]["text"])
except (json.JSONDecodeError, KeyError, ValueError):
logger.error("Failed to parse conversational payload as OpenAI message")
continue

elif "blob" in payload_item:
try:
blob_data = json.loads(payload_item["blob"])
if isinstance(blob_data, (tuple, list)) and len(blob_data) == 2:
openai_msg = json.loads(blob_data[0])
except (json.JSONDecodeError, ValueError):
logger.error("Failed to parse blob payload: %s", payload_item)
continue

if openai_msg and isinstance(openai_msg, dict):
bedrock_msg = _openai_to_bedrock(openai_msg)
if bedrock_msg.get("content"):
session_msg = SessionMessage(
message=bedrock_msg,
message_id=0,
)
messages.append(session_msg)

return list(reversed(messages))

@staticmethod
def exceeds_conversational_limit(message: tuple[str, str]) -> bool:
"""Check if message exceeds conversational payload size limit."""
return exceeds_conversational_limit(message)
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Shared protocol and utilities for memory converters."""

from typing import Any, Protocol, Tuple

from strands.types.session import SessionMessage

CONVERSATIONAL_MAX_SIZE = 9000


class MemoryConverter(Protocol):
"""Protocol for converting between Strands messages and STM event payloads."""

@staticmethod
def message_to_payload(session_message: SessionMessage) -> list[Tuple[str, str]]:
"""Convert SessionMessage to STM event payload format."""
...

@staticmethod
def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]:
"""Convert STM events to SessionMessages."""
...

@staticmethod
def exceeds_conversational_limit(message: tuple[str, str]) -> bool:
"""Check if message exceeds conversational payload size limit."""
...


def exceeds_conversational_limit(message: tuple[str, str]) -> bool:
"""Check if message exceeds the conversational payload size limit.

Args:
message: A (text, role) tuple.

Returns:
True if the total length meets or exceeds CONVERSATIONAL_MAX_SIZE.
"""
return sum(len(text) for text in message) >= CONVERSATIONAL_MAX_SIZE
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from bedrock_agentcore.memory.client import MemoryClient

from .bedrock_converter import AgentCoreMemoryConverter
from .converters import BedrockConverseConverter, MemoryConverter
from .config import AgentCoreMemoryConfig, RetrievalConfig

if TYPE_CHECKING:
Expand Down Expand Up @@ -84,6 +84,7 @@ def _get_monotonic_timestamp(cls, desired_timestamp: Optional[datetime] = None)
def __init__(
self,
agentcore_memory_config: AgentCoreMemoryConfig,
converter: Optional[type[MemoryConverter]] = None,
region_name: Optional[str] = None,
boto_session: Optional[boto3.Session] = None,
boto_client_config: Optional[BotocoreConfig] = None,
Expand All @@ -93,12 +94,15 @@ def __init__(

Args:
agentcore_memory_config (AgentCoreMemoryConfig): Configuration for AgentCore Memory integration.
converter (Optional[type[MemoryConverter]], optional): Custom converter for message format conversion.
Defaults to BedrockConverseConverter.
region_name (Optional[str], optional): AWS region for Bedrock AgentCore Memory. Defaults to None.
boto_session (Optional[boto3.Session], optional): Optional boto3 session. Defaults to None.
boto_client_config (Optional[BotocoreConfig], optional): Optional boto3 client configuration.
Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
self.converter = converter or BedrockConverseConverter
self.config = agentcore_memory_config
self.memory_client = MemoryClient(region_name=region_name)
session = boto_session or boto3.Session(region_name=region_name)
Expand Down Expand Up @@ -351,15 +355,15 @@ def create_message(
raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session_id}")

try:
messages = AgentCoreMemoryConverter.message_to_payload(session_message)
messages = self.converter.message_to_payload(session_message)
if not messages:
return

# Parse the original timestamp and use it as desired timestamp
original_timestamp = datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00"))
monotonic_timestamp = self._get_monotonic_timestamp(original_timestamp)

if not AgentCoreMemoryConverter.exceeds_conversational_limit(messages[0]):
if not self.converter.exceeds_conversational_limit(messages[0]):
event = self.memory_client.create_event(
memory_id=self.config.memory_id,
actor_id=self.config.actor_id,
Expand Down Expand Up @@ -462,7 +466,7 @@ def list_messages(
session_id=session_id,
max_results=max_results,
)
messages = AgentCoreMemoryConverter.events_to_messages(events)
messages = self.converter.events_to_messages(events)
if limit is not None:
return messages[offset : offset + limit]
else:
Expand Down
Loading
Loading