Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions src/google/adk/memory/vertex_ai_memory_bank_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,18 +534,36 @@ async def search_memory(self, *, app_name: str, user_id: str, query: str):
logger.info('Search memory response received.')

memory_events: list[MemoryEntry] = []
async for retrieved_memory in retrieved_memories_iterator:
# TODO: add more complex error handling
logger.debug('Retrieved memory: %s', retrieved_memory)
memory_events.append(
MemoryEntry(
author='user',
content=types.Content(
parts=[types.Part(text=retrieved_memory.memory.fact)],
role='user',
),
timestamp=retrieved_memory.memory.update_time.isoformat(),
try:
async for retrieved_memory in retrieved_memories_iterator:
try:
memory = retrieved_memory.memory
if memory is None:
logger.warning('Skipping memory entry with missing memory object.')
continue
fact = memory.fact
if not fact:
logger.warning('Skipping memory entry with empty or missing fact.')
continue
update_time = memory.update_time
memory_events.append(
MemoryEntry(
author='user',
content=types.Content(
parts=[types.Part(text=fact)],
role='user',
),
timestamp=update_time.isoformat() if update_time else None,
)
)
except AttributeError:
logger.warning(
'Skipping malformed memory entry: %s', retrieved_memory
)
except Exception:
logger.exception(
'Error while iterating memory results. Returning %d partial results.',
len(memory_events),
)
return SearchMemoryResponse(memories=memory_events)

Expand Down
108 changes: 107 additions & 1 deletion tests/unittests/memory/test_vertex_ai_memory_bank_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,6 @@ async def test_search_memory_empty_results(mock_vertexai_client):
assert len(result.memories) == 0


@pytest.mark.asyncio
async def test_search_memory_uses_async_client_path():
sync_client = mock.MagicMock()
sync_client.agent_engines.memories.retrieve.side_effect = AssertionError(
Expand Down Expand Up @@ -1039,3 +1038,110 @@ async def test_search_memory_uses_async_client_path():
similarity_search_params={'search_query': 'query'},
)
sync_client.agent_engines.memories.retrieve.assert_not_called()


@pytest.mark.asyncio
async def test_search_memory_skips_entry_with_none_memory(mock_vertexai_client):
bad_entry = mock.MagicMock()
bad_entry.memory = None

good_entry = mock.MagicMock()
good_entry.memory.fact = 'good fact'
good_entry.memory.update_time = datetime.datetime(2024, 1, 1)

mock_vertexai_client.agent_engines.memories.retrieve.return_value = (
_AsyncListIterator([bad_entry, good_entry])
)
memory_service = mock_vertex_ai_memory_bank_service()

result = await memory_service.search_memory(
app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='query'
)

assert len(result.memories) == 1
assert result.memories[0].content.parts[0].text == 'good fact'


@pytest.mark.asyncio
async def test_search_memory_skips_entry_with_empty_fact(mock_vertexai_client):
for empty_fact in [None, '']:
bad_entry = mock.MagicMock()
bad_entry.memory.fact = empty_fact
bad_entry.memory.update_time = datetime.datetime(2024, 1, 1)

mock_vertexai_client.agent_engines.memories.retrieve.return_value = (
_AsyncListIterator([bad_entry])
)
memory_service = mock_vertex_ai_memory_bank_service()

result = await memory_service.search_memory(
app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='query'
)

assert len(result.memories) == 0


@pytest.mark.asyncio
async def test_search_memory_handles_missing_update_time(mock_vertexai_client):
entry = mock.MagicMock()
entry.memory.fact = 'some fact'
entry.memory.update_time = None

mock_vertexai_client.agent_engines.memories.retrieve.return_value = (
_AsyncListIterator([entry])
)
memory_service = mock_vertex_ai_memory_bank_service()

result = await memory_service.search_memory(
app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='query'
)

assert len(result.memories) == 1
assert result.memories[0].content.parts[0].text == 'some fact'
assert result.memories[0].timestamp is None


@pytest.mark.asyncio
async def test_search_memory_skips_malformed_entry(mock_vertexai_client):
malformed = mock.MagicMock(spec=[]) # no attributes → AttributeError

good_entry = mock.MagicMock()
good_entry.memory.fact = 'good fact'
good_entry.memory.update_time = datetime.datetime(2024, 1, 1)

mock_vertexai_client.agent_engines.memories.retrieve.return_value = (
_AsyncListIterator([malformed, good_entry])
)
memory_service = mock_vertex_ai_memory_bank_service()

result = await memory_service.search_memory(
app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='query'
)

assert len(result.memories) == 1
assert result.memories[0].content.parts[0].text == 'good fact'


@pytest.mark.asyncio
async def test_search_memory_returns_partial_results_on_iterator_error(
mock_vertexai_client,
):
good_entry = mock.MagicMock()
good_entry.memory.fact = 'good fact'
good_entry.memory.update_time = datetime.datetime(2024, 1, 1)

async def failing_async_iterator():
yield good_entry
raise RuntimeError('API stream error')

mock_vertexai_client.agent_engines.memories.retrieve.return_value = (
failing_async_iterator()
)
memory_service = mock_vertex_ai_memory_bank_service()

result = await memory_service.search_memory(
app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='query'
)

assert len(result.memories) == 1
assert result.memories[0].content.parts[0].text == 'good fact'